diff --git a/.gitignore b/.gitignore index 5afe375f46f07b3b557ae23f75740b337517d3bd..1ef4c297ee4f369775c13b32a46a55887de719e7 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ __pycache__ *.swp .vscode/ cmake_build/ +tensorflow/contrib/cmake/_build/ .idea/** /build/ [Bb]uild/ @@ -30,6 +31,7 @@ Podfile.lock xcuserdata/** /api_init_files_list.txt /estimator_api_init_files_list.txt +*.whl # Android .gradle diff --git a/CODEOWNERS b/CODEOWNERS index b9f0313cc6d59d3fbdcd014e1a528126d863075a..113eaf798f7a0abf1a9ad3fed6308f234f8efe75 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,53 +1,62 @@ -# NOTE: Disabled temporarily because it's too noisy on pushes. # Where component owners are known, add them here. -# /tensorflow/core/platform/windows/ @mrry -# /tensorflow/java/ @asimshankar -# /tensorflow/tensorboard/ @jart @dandelionmane -# /tensorflow/tools/docs/ @markdaoust +/tenosrflow/core/debug @caisq +/tensorflow/core/platform/windows/ @mrry +/tensorflow/go @asimshankar +/tensorflow/java/ @asimshankar +/tensorflow/python/debug @caisq +/tensorflow/python/tools/api/generator/ @annarev +/tensorflow/tensorboard/ @jart +/tensorflow/tools/docs/ @markdaoust # contrib -# NEED OWNER: /tensorflow/contrib/avro/ -# /tensorflow/contrib/batching/ @alextp @chrisolston -# /tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon -# /tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva -# /tensorflow/contrib/cmake/ @mrry @benoitsteiner -# /tensorflow/contrib/copy_graph/ @tucker @poxvoculi -# /tensorflow/contrib/crf/ @kentonl -# /tensorflow/contrib/data/ @mrry -# /tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi -# /tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo -# /tensorflow/contrib/ffmpeg/ @fredbertsch -# NEED OWNER: /tensorflow/contrib/framework/ -# /tensorflow/contrib/graph_editor/ @purpledog +# NEED OWNER: /tensorflow/contrib/all_reduce +/tensorflow/contrib/batching/ @alextp @chrisolston +/tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon +/tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva +/tensorflow/contrib/checkpoint/ @allenlavoie +/tensorflow/contrib/contrib/cluster_resolver/ @frankchn +/tensorflow/contrib/cmake/ @mrry +/tensorflow/contrib/copy_graph/ @tucker @poxvoculi +/tensorflow/contrib/crf/ @kentonl +/tensorflow/contrib/data/ @mrry +/tensorflow/tensorflow/contrib/distribute @joshl @priyag @sourabhbajaj @frankchn +/tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi +/tensorflow/contrib/eager @alextp @asimshankar +/tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo +/tensorflow/contrib/ffmpeg/ @fredbertsch +/tensorflow/contrib/framework/ @ebrevdo +/tensorflow/contrib/gan/ @joel-shor +/tensorflow/contrib/graph_editor/ @purpledog # NEED OWNER: /tensorflow/contrib/grid_rnn/ -# /tensorflow/contrib/hvx/ @satok16 -# /tensorflow/contrib/integrate/ @shoyer -# /tensorflow/contrib/kernel_methods/ @petrosmol -# /tensorflow/contrib/ios_examples/ @petewarden -# /tensorflow/contrib/labeled_tensor/ @shoyer -# /tensorflow/contrib/layers/ @fchollet @martinwicke -# /tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp -# /tensorflow/contrib/linalg/ @langmore -# /tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis -# /tensorflow/contrib/lookup/ @ysuematsu @andreasst -# /tensorflow/contrib/losses/ @alextp @ispirmustafa -# /tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg -# /tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa -# /tensorflow/contrib/nccl/ @cwhipkey @zheng-xq -# /tensorflow/contrib/opt/ @strategist333 -# /tensorflow/contrib/pi_examples/ @maciekcc -# /tensorflow/contrib/quantization/ @petewarden @cwhipkey @keveman -# /tensorflow/contrib/rnn/ @ebrevdo -# /tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh -# /tensorflow/contrib/seq2seq/ @lukaszkaiser -# /tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh -# /tensorflow/contrib/slim/ @sguada @thenbasilmanran -# /tensorflow/contrib/stateless/ @girving -# /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank -# /tensorflow/contrib/testing/ @dandelionmane -# /tensorflow/contrib/timeseries/ @allenlavoie -# /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu -# /tensorflow/contrib/training/ @joel-shor @ebrevdo -# /tensorflow/contrib/util/ @sherrym +/tensorflow/contrib/hvx/ @satok16 +/tensorflow/contrib/integrate/ @shoyer +/tensorflow/contrib/kernel_methods/ @petrosmol +/tensorflow/contrib/ios_examples/ @petewarden +/tensorflow/contrib/labeled_tensor/ @shoyer +/tensorflow/contrib/layers/ @fchollet @martinwicke +/tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp +/tensorflow/contrib/linalg/ @langmore +/tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis +/tensorflow/contrib/lookup/ @ysuematsu @andreasst +/tensorflow/contrib/losses/ @alextp @ispirmustafa +/tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg +/tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa +/tensorflow/contrib/nccl/ @cwhipkey @zheng-xq +/tensorflow/contrib/opt/ @strategist333 @alextp +/tensorflow/contrib/pi_examples/ @maciekcc +/tensorflow/contrib/quantization/ @petewarden +/tensorflow/contrib/rnn/ @ebrevdo @scottzhu +/tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh @allenl +/tensorflow/contrib/seq2seq/ @ebrevdo @lmthang +/tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh +/tensorflow/contrib/slim/ @sguada @thenbasilmanran +/tensorflow/contrib/stateless/ @girving @alextp +/tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst @yupbank +/tensorflow/contrib/tensorrt/ @laigd +# NEED OWNER: /tensorflow/contrib/testing/ +/tensorflow/contrib/timeseries/ @allenlavoie +/tensorflow/contrib/tpu/ @frankchn @saeta @jhseu @sourabhbajaj +/tensorflow/contrib/training/ @joel-shor @ebrevdo +/tensorflow/contrib/util/ @sherrym \ No newline at end of file diff --git a/README.md b/README.md index 16d354ca7b150814f11fd825d6a22c84cebc2a01..91f49f8e95cc25fc9bd052ccd13a3c1cae232740 100644 --- a/README.md +++ b/README.md @@ -100,16 +100,16 @@ The TensorFlow project strives to abide by generally accepted best practices in | **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | | **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA | | **Linux CPU with Intel® MKL-DNN** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | -| **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild)|[1.9.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp27-cp27mu-linux_x86_64.whl)
[1.9.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp35-cp35m-linux_x86_64.whl)
[1.9.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.9.0-cp36-cp36m-linux_x86_64.whl) | +| **Linux CPU with Intel® MKL-DNN** Python 2.7
**Linux CPU with Intel® MKL-DNN** Python 3.5
**Linux CPU with Intel® MKL-DNN** Python 3.6 | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild)|[1.10.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp27-cp27mu-linux_x86_64.whl)
[1.10.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp35-cp35m-linux_x86_64.whl)
[1.10.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.10.0-cp36-cp36m-linux_x86_64.whl) | ## For more information -* [Tensorflow Blog](https://medium.com/tensorflow) +* [TensorFlow Blog](https://medium.com/tensorflow) * [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) * [TensorFlow Model Zoo](https://github.com/tensorflow/models) * [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) * [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) -* [Tensorflow Twitter](https://twitter.com/tensorflow) +* [TensorFlow Twitter](https://twitter.com/tensorflow) * [TensorFlow Website](https://www.tensorflow.org) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index b8adf6c1279e72d0c2056368253aa0cb470216e5..173bbea596a4276559f5cd67824e5cc75313985c 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -1240,7 +1240,7 @@ void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name, const char* value, size_t length) { tensorflow::NameAttrList func_name; - func_name.set_name(std::string(value, value + length)); + func_name.set_name(string(value, value + length)); desc->node_builder.Attr(attr_name, func_name); } @@ -2065,7 +2065,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, for (int i = 0; i < size; ++i) { TensorId id = results.missing_unused_input_map_keys[i]; - tf_results->missing_unused_key_names_data.push_back(std::string(id.first)); + tf_results->missing_unused_key_names_data.emplace_back(id.first); tf_results->missing_unused_key_names[i] = tf_results->missing_unused_key_names_data.back().c_str(); tf_results->missing_unused_key_indexes[i] = id.second; diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index aa2a537f03be31ae45ff3d6f7815b449d661cf9c..03516c39dc970aa23967107d3a0446da94669465 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -259,8 +259,8 @@ TEST(CAPI, DeprecatedSession) { TF_Run(session, run_options, nullptr, nullptr, 0, nullptr, nullptr, 0, nullptr, 0, run_metadata, s); EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)) << TF_Message(s); - EXPECT_EQ(std::string("Session was not created with a graph before Run()!"), - std::string(TF_Message(s))); + EXPECT_EQ("Session was not created with a graph before Run()!", + string(TF_Message(s))); TF_DeleteBuffer(run_metadata); TF_DeleteBuffer(run_options); @@ -1224,8 +1224,8 @@ class CApiColocationTest : public ::testing::Test { TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_); if (expected.empty()) { ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); - EXPECT_EQ(std::string("Operation 'add' has no attr named '_class'."), - std::string(TF_Message(s_))); + EXPECT_EQ("Operation 'add' has no attr named '_class'.", + string(TF_Message(s_))); return; } EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); @@ -1369,16 +1369,16 @@ TEST(CAPI, SavedModel) { input.flat()(i) = example.SerializeAsString(); } - const tensorflow::string input_op_name = - std::string(tensorflow::ParseTensorName(input_name).first); + const tensorflow::string input_op_name( + tensorflow::ParseTensorName(input_name).first); TF_Operation* input_op = TF_GraphOperationByName(graph, input_op_name.c_str()); ASSERT_TRUE(input_op != nullptr); csession.SetInputs({{input_op, TF_TensorFromTensor(input, s)}}); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - const tensorflow::string output_op_name = - std::string(tensorflow::ParseTensorName(output_name).first); + const tensorflow::string output_op_name( + tensorflow::ParseTensorName(output_name).first); TF_Operation* output_op = TF_GraphOperationByName(graph, output_op_name.c_str()); ASSERT_TRUE(output_op != nullptr); diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index 74bc25a491ac01cb725d1c004197e48727c30230..d3311f0cd06f2b151c3567735eb41b5baf72e102 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -125,7 +125,7 @@ CheckpointReader::BuildV2VarMaps() { const auto& slice_proto = entry.slices(i); CHECK(filtered_keys .insert(EncodeTensorNameSlice( - std::string(v2_reader_->key()) /* full var's name */, + string(v2_reader_->key()) /* full var's name */, TensorSlice(slice_proto))) .second); } @@ -138,11 +138,11 @@ CheckpointReader::BuildV2VarMaps() { new TensorSliceReader::VarToDataTypeMap); v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { - if (filtered_keys.count(std::string(v2_reader_->key())) > 0) continue; + if (filtered_keys.count(string(v2_reader_->key())) > 0) continue; CHECK(entry.ParseFromArray(v2_reader_->value().data(), v2_reader_->value().size())) << entry.InitializationErrorString(); - string key = std::string(v2_reader_->key()); + string key(v2_reader_->key()); (*var_to_shape_map)[key] = TensorShape(entry.shape()); (*var_to_data_type_map)[key] = DataType(entry.dtype()); } diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc old mode 100644 new mode 100755 index dfb1c9a37644c726e1eabab775593596d5b556b9..1ccae3f138920b1908f18387ea87b11388115d37 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -244,8 +244,8 @@ void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, } void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options, - unsigned char async) { - options->async = async; + unsigned char enable) { + options->async = enable; } void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { @@ -253,9 +253,9 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( } TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, - unsigned char async, + unsigned char enable, TF_Status* status) { - status->status = ctx->context.SetAsyncForThread(async); + status->status = ctx->context.SetAsyncForThread(enable); } void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h old mode 100644 new mode 100755 index a0ebc6fa0a22ed61be91c2974352c2988fb4cd92..eec2750d6eb3bceed8da3ed44812ac2e8fd5c877 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -76,7 +76,7 @@ typedef enum TFE_ContextDevicePlacementPolicy { // Sets the default execution mode (sync/async). Note that this can be // overridden per thread using TFE_ContextSetAsyncForThread. TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, - unsigned char async); + unsigned char enable); TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); @@ -114,7 +114,7 @@ TFE_ContextGetDevicePlacementPolicy(TFE_Context*); // Overrides the execution mode (sync/async) for the current thread. TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context*, - unsigned char async, + unsigned char enable, TF_Status* status); // A tensorflow.ServerDef specifies remote workers (in addition to the current diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index c20ea95a15e3f53b9b26716ed7b624fa853017c9..a32d1b1eb50fc715084f5ee663a732770db1883c 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -466,7 +466,7 @@ string AvoidCPPKeywords(StringPiece name) { if (IsCPPKeyword(name)) { return strings::StrCat(name, "_"); } - return std::string(name); + return string(name); } void InferArgAttributes(const OpDef::ArgDef& arg, diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 8c886f31711eb014fb9e9d600c9c78cf22073f71..7f6ac4cae78d8d6e118837fce9ae5270336cdc89 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -225,7 +225,7 @@ std::unordered_set Scope::Impl::GetColocationConstraints( for (const string& entry : node_constraints) { StringPiece s(entry); if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) { - current_constraints.insert(std::string(s)); + current_constraints.emplace(s); } } } else { diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 3830416159158cca8bfb8422c2959b49fa42406d..c6abe2f41b9b5ec2faee6f65b429ff606f8ac08e 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -148,7 +148,7 @@ Status RunMainOp(const RunOptions& run_options, const string& export_dir, AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; const StringPiece main_op_name = main_op_it->second.node_list().value(0); - return RunOnce(run_options, inputs, {}, {main_op_name.ToString()}, + return RunOnce(run_options, inputs, {}, {string(main_op_name)}, nullptr /* outputs */, &run_metadata, session); } return Status::OK(); @@ -182,12 +182,12 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, variables_path_tensor.scalar()() = variables_path; std::vector> inputs = { - {variable_filename_const_op_name.ToString(), variables_path_tensor}}; + {string(variable_filename_const_op_name), variables_path_tensor}}; AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); RunMetadata run_metadata; - return RunOnce(run_options, inputs, {}, {restore_op_name.ToString()}, + return RunOnce(run_options, inputs, {}, {string(restore_op_name)}, nullptr /* outputs */, &run_metadata, session); } diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 2220d0786d3757abc378d1a3d0ddc704bba6a4f3..59b961cdd9dac8a1c305a3f5f520ca1b68148cca 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -32,7 +32,6 @@ cc_library( deps = [ ":embedded_protocol_buffers", "//tensorflow/compiler/tf2xla", - "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/tf2xla:tf2xla_proto", "//tensorflow/compiler/tf2xla:tf2xla_util", @@ -56,6 +55,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -72,6 +72,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", "@llvm//:support", # fixdeps: keep "@llvm//:x86_code_gen", # fixdeps: keep ], @@ -100,6 +101,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -195,6 +197,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:support", "@llvm//:target", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 44291d977f8e97bdcba8131363e65956cad60cb7..e77a8fecf09fa037726b0baf5d2f38aeae0ef155 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -20,9 +20,10 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" #include "tensorflow/compiler/aot/embedded_protocol_buffers.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" -#include "tensorflow/compiler/tf2xla/str_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -142,7 +142,7 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, } rewrites->push_back({"{{I}}", strings::StrCat(i)}); rewrites->push_back({"{{TYPE}}", type}); - rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")}); + rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")}); rewrites->push_back({"{{DIM_SIZES}}", dim_sizes}); rewrites->push_back({"{{INDICES}}", indices}); return Status::OK(); @@ -158,8 +158,9 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, // text-templating mechanism. string RewriteWithName(const string& name, string code, const std::vector>& rewrites) { - str_util::ReplaceAllPairs(&code, rewrites); - return str_util::StringReplace(code, "{{NAME}}", name, /*replace_all=*/true); + absl::StrReplaceAll(rewrites, &code); + absl::StrReplaceAll({{"{{NAME}}", name}}, &code); + return code; } // Generate methods for args (inputs). @@ -571,11 +572,11 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())}, - {"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")}, + {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")}, {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, {"{{DECLS_FROM_OBJ_FILE}}", - str_util::Join(metadata_result.header_variable_decls, "\n")}, + absl::StrJoin(metadata_result.header_variable_decls, "\n")}, {"{{ENTRY}}", compile_result.entry_point}, {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}", metadata_result.hlo_profile_printer_data_access_shim}, @@ -595,8 +596,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())}, {"{{BUFFER_INFOS_AS_STRING}}", - str_util::Join(buffer_infos_as_strings, ",\n")}}; - str_util::ReplaceAllPairs(header, rewrites); + absl::StrJoin(buffer_infos_as_strings, ",\n")}}; + absl::StrReplaceAll(rewrites, header); return Status::OK(); } diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 60d59ae996e8f7ec490c98aeab05182626e61976..e3a53edb7368c209bea16a9e34b1f452a8ff4bf8 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include +#include "absl/strings/match.h" #include "llvm/Support/TargetSelect.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" @@ -34,9 +34,9 @@ namespace { using ::tensorflow::cpu_function_runtime::BufferInfo; -void ExpectErrorContains(const Status& status, StringPiece str) { +void ExpectErrorContains(const Status& status, absl::string_view str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) + EXPECT_TRUE(absl::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index 8fb2fad31c680c5dbbd058a1b9a9265607224429..1401aae7586bfd40ec209b0ae591d6ab69d0a26b 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_replace.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/LLVMContext.h" @@ -27,7 +28,6 @@ limitations under the License. #include "llvm/Support/TargetRegistry.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Target/TargetOptions.h" -#include "tensorflow/compiler/tf2xla/str_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/util.h" @@ -65,14 +65,13 @@ static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, " return proto;\n" " }()"; - str_util::ReplaceAllPairs( - &code, + return absl::StrReplaceAll( + code, { {"{{ARRAY_SYMBOL}}", strings::StrCat(protobuf_array_symbol_name)}, {"{{ARRAY_SIZE}}", strings::StrCat(protobuf_array_size)}, {"{{PROTOBUF_NAME}}", strings::StrCat(qualified_cpp_protobuf_name)}, }); - return code; } static StatusOr CodegenModule(llvm::TargetMachine* target_machine, @@ -97,7 +96,7 @@ static StatusOr> GetTargetMachineFromTriple(StringPiece target_triple) { std::string error; std::string normalized_triple = - llvm::Triple::normalize(AsStringRef(target_triple)); + llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple))); const llvm::Target* target = llvm::TargetRegistry::lookupTarget(normalized_triple, error); if (target == nullptr) { diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 0ecc3feeb6fef1dd691ab2785b3221075a79ba88..723e9bec8afcfbf7ceeeb59c63e4e12442fdb7ab 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -187,6 +187,9 @@ tf_library( cpp_class = "MatMulAndAddCompWithProfiling", enable_xla_hlo_profiling = True, graph = "test_graph_tfmatmulandadd.pb", + tags = [ + "manual", + ], ) tf_library( @@ -226,5 +229,6 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//third_party/eigen3", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 0c0c676ece78565e03578d3e33633c7e23b77669..dd2b151098f2054571ac32b8b506cbc00659588a 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -16,6 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #define EIGEN_USE_CUSTOM_THREAD_POOL +#include "absl/strings/str_split.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -546,7 +546,7 @@ TEST(TFCompileTest, HloProfiling) { VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string; std::vector hlo_profile_lines = - tensorflow::str_util::Split(hlo_profile_as_string, '\n'); + absl::StrSplit(hlo_profile_as_string, '\n'); auto header = HasSubstr("Execution profile for"); auto total_cycles_profile_line = HasSubstr("[total]"); diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 839e1588b7be6c91cf30c87bbaf75402446bd169..f3c44e9dda8ce96a268420a7f4d0f22e50ddfe41 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/flags.h" @@ -34,7 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -55,7 +56,7 @@ const char kUsageHeader[] = "\n"; Status ReadProtoFile(const string& fname, protobuf::Message* proto) { - if (str_util::EndsWith(fname, ".pbtxt")) { + if (absl::EndsWith(fname, ".pbtxt")) { return ReadTextProto(Env::Default(), fname, proto); } else { return ReadBinaryProto(Env::Default(), fname, proto); @@ -75,7 +76,7 @@ Status Main(const MainFlags& flags) { for (const tf2xla::Fetch& fetch : config.fetch()) { nodes.insert(fetch.id().node_name()); } - std::cout << str_util::Join(nodes, ","); + std::cout << absl::StrJoin(nodes, ","); return Status::OK(); } diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 2466c218c82dbd504043dbfff70fb3ba88d38e3b..df81f3c23e38a2ec2cea827cd0adb123855e7714 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -310,6 +310,51 @@ tf_cc_test( ], ) +cc_library( + name = "resource_operation_safety_analysis", + srcs = ["resource_operation_safety_analysis.cc"], + hdrs = ["resource_operation_safety_analysis.h"], + deps = [ + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "resource_operation_safety_analysis_test", + srcs = ["resource_operation_safety_analysis_test.cc"], + deps = [ + ":common", + ":resource_operation_safety_analysis", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "compilation_passes", srcs = [ @@ -335,11 +380,10 @@ cc_library( ":union_find", ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", - "//tensorflow/compiler/jit/kernels:parallel_check_op", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", - "//tensorflow/compiler/jit/ops:parallel_check_op", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", @@ -351,6 +395,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", + "@com_google_absl//absl/strings", ], ) @@ -359,6 +404,7 @@ cc_library( srcs = ["xla_cluster_util.cc"], hdrs = ["xla_cluster_util.h"], deps = [ + ":resource_operation_safety_analysis", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/core:framework", "//tensorflow/core:graph", @@ -437,6 +483,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:sendrecv_ops", "//tensorflow/compiler/jit/kernels:xla_launch_op", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -448,6 +495,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/strings", ], ) @@ -528,6 +576,9 @@ tf_cuda_cc_test( ":common", ":xla_cluster_util", ":xla_fusion_optimizer", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", "//tensorflow/core:graph", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 1b1ce78ed2b79d0948b6fc951f82a2cebe8009e5..56b034a30b7bddb023e54ead22c91a7a18095d2d 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -126,7 +126,8 @@ Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, const DataTypeVector& arg_types = (*fbody)->arg_types; std::vector const_args(arg_types.size()); // If we can't analyze the const args. Bail out. - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *((*fbody)->graph), &const_args, /*compile_time_const_nodes=*/nullptr)); for (int i = 0; i < const_args.size(); ++i) { if (const_args[i]) { @@ -208,8 +209,13 @@ Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, // device memory. // XlaLaunch kernel keeps all outputs (including constants, which it copies), - // in device memory + // in device memory except for resources. MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); + for (int i = 0; i < fbody->ret_types.size(); ++i) { + if (fbody->ret_types[i] == DT_RESOURCE) { + output_memory_types[i] = HOST_MEMORY; + } + } // Create the kernel. NameAttrList function; diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 0ca0f949dcd13992ccd9504d75ca65d2aff72a19..fe28502f69d34e7c075bdf85afd2473024b4081d 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/deadness_analysis.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" @@ -153,7 +154,7 @@ class AndPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", str_util::Join(operands_str, " & "), ")"); + return strings::StrCat("(", absl::StrJoin(operands_str, " & "), ")"); } Kind kind() const override { return Kind::kAnd; } @@ -182,7 +183,7 @@ class OrPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", str_util::Join(operands_str, " | "), ")"); + return strings::StrCat("(", absl::StrJoin(operands_str, " | "), ")"); } Kind kind() const override { return Kind::kOr; } diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index cc9f1023985560be0bce5971931d2ec8e742b377..28a56044d5e3795fc3ecf5d1092491b87cb90f01 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index f150bf1819d407e1c6a279673a89de4307b5426b..2788102620546d8eab657c519f078c5b03e265cc 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/graph.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -2504,7 +2504,8 @@ Status EncapsulateSubgraphsPass::Run( const int num_args = input_permutation->size(); std::vector const_args(num_args); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + **subgraph, &const_args, /*compile_time_const_nodes=*/nullptr)); DataTypeVector arg_types(num_args); TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types)); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index c0543a00792235c5dd090e81930d8c219dc7f1a3..b3600fc48b9daa0e901e2b01cdc121aef0a1e8af 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function_testlib.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/equal_graph_def.h" @@ -124,8 +124,8 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, std::unordered_set control_input_a; std::unordered_set control_input_b; for (int i = 0; i < a.input_size(); ++i) { - if (str_util::StartsWith(a.input(i), "^")) { - if (!str_util::StartsWith(b.input(i), "^")) { + if (absl::StartsWith(a.input(i), "^")) { + if (!absl::StartsWith(b.input(i), "^")) { if (diff) { *diff = strings::StrCat( diff_preamble, " mismatch for node ", a.name(), " input ", i, @@ -768,7 +768,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - str_util::StartsWith(n->name(), "const")) { + absl::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { @@ -813,7 +813,7 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { Graph* graph = graph_ptr->get(); for (const Node* n : graph->nodes()) { if (n->type_string() == "_Arg" && - str_util::StartsWith(n->name(), "const")) { + absl::StartsWith(n->name(), "const")) { ++guaranteed_consts; EXPECT_TRUE(HasGuaranteeConstAttr(*n)); } else { diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 8f78c110cb15f3cbc0344d102764241996b0d7de..253a5d254792a19d98b75310ea6848f42597c0c7 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -29,16 +29,3 @@ cc_library( ], alwayslink = 1, ) - -cc_library( - name = "parallel_check_op", - srcs = ["parallel_check_op.cc"], - visibility = ["//tensorflow/compiler/jit:friends"], - deps = [ - "//tensorflow/compiler/jit/legacy_flags:parallel_check_op_flags", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - ], - alwayslink = 1, -) diff --git a/tensorflow/compiler/jit/kernels/parallel_check_op.cc b/tensorflow/compiler/jit/kernels/parallel_check_op.cc deleted file mode 100644 index bd4eefbc0bb960f8ddc1d238057e73a29a098f26..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/kernels/parallel_check_op.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h" -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { -namespace { - -// Inputs 2*N tensors, outputs the first N inputs. -// Logs errors if input tensor i and i + N are not (near) identical -// in any position. -class ParallelCheckOp : public OpKernel { - public: - explicit ParallelCheckOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - template - int CompareTensors(DataType dtype, const char* v0, const char* v1, - int64 num_elts, int input_idx) { - int failed = 0; - const T* p0 = reinterpret_cast(v0); - const T* p1 = reinterpret_cast(v1); - double rtol; - legacy_flags::ParallelCheckOpFlags* flags = - legacy_flags::GetParallelCheckOpFlags(); - if (!tensorflow::strings::safe_strtod(flags->parallel_check_rtol.c_str(), - &rtol)) { - LOG(ERROR) << "can't convert parallel_check_rtol " - << flags->parallel_check_rtol << " to double"; - } - double atol; - if (!tensorflow::strings::safe_strtod(flags->parallel_check_atol.c_str(), - &atol)) { - LOG(ERROR) << "can't convert parallel_check_atol " - << flags->parallel_check_atol << " to double"; - } - for (int i = 0; i < num_elts; ++i) { - bool ok = (p0[i] == p1[i]); - VLOG(2) << "output " << input_idx << " element " << i << ": " << p0[i]; - if (!ok) { - if (std::is_same::value || std::is_same::value) { - float tolerance = - std::max(atol, std::max(fabs(rtol * p0[i]), fabs(rtol * p1[i]))); - T diff = p0[i] - p1[i]; - if (diff < 0) diff = 0 - diff; - ok = (diff <= tolerance); - } - if (ok) continue; - LOG(ERROR) << "Op " << name() << " fails equality at output " - << input_idx << " type " << DataTypeString(dtype) - << " element " << i << ": std_val=" << p0[i] - << " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]); - if (++failed > 10) break; - } - } - return failed; - } - - void Compute(OpKernelContext* ctx) override { - VLOG(1) << "Compute " << name(); - const int num_pairs = ctx->num_inputs() / 2; - for (int i = 0; i < num_pairs; ++i) { - CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs)); - Tensor t0 = ctx->input(i); - Tensor t1 = ctx->input(i + num_pairs); - int64 num_elts = t0.NumElements(); - CHECK_EQ(num_elts, t1.NumElements()); - - // Compare inputs elementwise for near-exact equality. - const char* v0 = t0.tensor_data().data(); - const char* v1 = t1.tensor_data().data(); - int failed = 0; - switch (ctx->input_dtype(i)) { - case DT_INT32: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_INT64: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_FLOAT: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_DOUBLE: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - case DT_BOOL: - failed = - CompareTensors(ctx->input_dtype(i), v0, v1, num_elts, i); - break; - default: - LOG(FATAL) << "unimpl: " << ctx->input_dtype(i); - } - if (failed > 0) { - LOG(ERROR) << "check failed for " << name() << " output " << i - << " num_elts: " << num_elts; - legacy_flags::ParallelCheckOpFlags* flags = - legacy_flags::GetParallelCheckOpFlags(); - if (flags->parallel_check_failfast) { - LOG(QFATAL) << "failfast on first parallel-check failure"; - } - } else { - VLOG(1) << "check passed for " << name() << " output " << i - << " num_elts: " << num_elts; - } - - // Propagate the std value. - if (IsRefType(ctx->input_dtype(i))) { - ctx->forward_ref_input_to_ref_output(i, i); - } else { - ctx->set_output(i, ctx->input(i)); - } - } - } - - TF_DISALLOW_COPY_AND_ASSIGN(ParallelCheckOp); -}; - -REGISTER_KERNEL_BUILDER(Name("ParallelCheck").Device(DEVICE_CPU), - ParallelCheckOp); - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index ddb27a38ae3b6749a82f86ba8be88ec68e733006..fde4135bf7f5f7bdede170d47fb2a76d1d6b3ae9 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -187,7 +187,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK( ctx, cache->Compile(options, function_, constant_args, variables, ctx, - &kernel, &executable, &compile_options)); + &kernel, &executable, compile_options)); VLOG(1) << "Executing XLA Computation..."; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 11bd5eec238c0e542814f22bc7a33a90abd0ec28..4e4abade3278089a1c7f8fdee46a34b8ce503651 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -27,7 +27,9 @@ limitations under the License. #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" @@ -40,6 +42,7 @@ limitations under the License. #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" @@ -74,18 +77,40 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); } +bool HasResourceOutput(const Node& node) { + return std::find(node.output_types().begin(), node.output_types().end(), + DT_RESOURCE) != node.output_types().end(); +} + +bool HasResourceInput(const Node& node) { + return std::find(node.input_types().begin(), node.input_types().end(), + DT_RESOURCE) != node.input_types().end(); +} + +// Returns true if `node` is a resource operation recognized by tf2xla that +// operates on something other than resource variables. +bool IsNonResourceVarResourceOp(const Node& node) { + // TODO(b/112837194): We can't cluster these because we only support + // snapshotting resource variables (and we can't e.g. snapshot stacks). This + // limitation may be fixable with some work. + const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(node.type_string()); + return op_info && op_info->resource_kind() != XlaResourceKind::kVariable; +} + // Make sure we don't recurse infinitely on recursive functions. const int kMaxRecursionDepth = 10; bool IsCompilableCall(const NodeDef& call_def, - const DeviceType& jit_device_type, int depth, + const DeviceType& jit_device_type, + bool allow_resource_ops, int depth, FunctionLibraryRuntime* lib_runtime); // Tests whether 'while_node' is a completely compilable loop. // Every operator in the condition and body functions must be compilable for a // while loop to be compilable. bool IsCompilableWhile(const Node& while_node, - const DeviceType& jit_device_type, int depth, + const DeviceType& jit_device_type, + bool allow_resource_ops, int depth, FunctionLibraryRuntime* lib_runtime) { const NameAttrList* name_attr; NodeDef call; @@ -100,7 +125,8 @@ bool IsCompilableWhile(const Node& while_node, call.set_name("while_cond"); call.set_op(cond_func); *call.mutable_attr() = name_attr->attr(); - if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { + if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1, + lib_runtime)) { VLOG(2) << "Rejecting While " << while_node.name() << ": can't compile loop condition: " << cond_func; return false; @@ -115,7 +141,8 @@ bool IsCompilableWhile(const Node& while_node, call.set_name("while_body"); call.set_op(body_func); *call.mutable_attr() = name_attr->attr(); - if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { + if (!IsCompilableCall(call, jit_device_type, allow_resource_ops, depth + 1, + lib_runtime)) { VLOG(2) << "Rejecting While " << while_node.name() << ": can't compile loop body: " << body_func; return false; @@ -127,7 +154,8 @@ bool IsCompilableWhile(const Node& while_node, // Every operator in the function must be compilable for a function to be // compilable. bool IsCompilableCall(const NodeDef& call_def, - const DeviceType& jit_device_type, int depth, + const DeviceType& jit_device_type, + bool allow_resource_ops, int depth, FunctionLibraryRuntime* lib_runtime) { if (depth > kMaxRecursionDepth) { VLOG(2) << "Rejecting " << call_def.op() @@ -167,12 +195,17 @@ bool IsCompilableCall(const NodeDef& call_def, if (node->type_string() == "_Arg" || node->type_string() == "_Retval") continue; if (node->type_string() == "While") { - // Handle functional While loop (not in open source build). - return IsCompilableWhile(*node, jit_device_type, depth + 1, lib_runtime); + // Handle functional While loop. + return IsCompilableWhile(*node, jit_device_type, allow_resource_ops, + depth + 1, lib_runtime); + } + if (!allow_resource_ops && + (HasResourceInput(*node) || HasResourceOutput(*node))) { + return false; } if (!HasXLAKernel(*node, jit_device_type) && - !IsCompilableCall(node->def(), jit_device_type, depth + 1, - lib_runtime)) { + !IsCompilableCall(node->def(), jit_device_type, allow_resource_ops, + depth + 1, lib_runtime)) { VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op " << node->name() << ": " << node->def().ShortDebugString(); return false; @@ -343,6 +376,10 @@ Status FindCompilationCandidates( flib_def, opts)); FunctionLibraryRuntime* lib_runtime = pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + std::vector compile_time_const_nodes(graph.num_node_ids(), false); + TF_RETURN_IF_ERROR( + BackwardsConstAnalysis(graph, /*compile_time_const_arg_indices=*/nullptr, + &compile_time_const_nodes)); int64& fuel = legacy_flags::GetMarkForCompilationPassFlags()->tf_xla_clustering_fuel; @@ -386,19 +423,46 @@ Status FindCompilationCandidates( XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)); DeviceType jit_device_type(registration->compilation_device_name); if (!HasXLAKernel(*node, jit_device_type) && - !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) { + !IsCompilableCall(node->def(), jit_device_type, + registration->compile_resource_ops, 0, lib_runtime)) { VLOG(2) << "Rejecting " << node->name() << ": unsupported op " << node->type_string(); continue; } if (!registration->compile_resource_ops && - HasResourceInputOrOutput(*node)) { - VLOG(2) << "Rejecting: " << node->name() << ": resource input/output " + (HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) { + // We don't have a way of returning values of type DT_RESOURCE from XLA + // computations so we avoid auto-clustering nodes producing DT_RESOURCE. + // XlaLaunchOp also cannot snapshot resources that are not resource + // variables so we avoid clustering resource operations that operate on + // non-resource variables. + VLOG(2) << "Rejecting: " << node->name() << ": resource output " << node->type_string(); continue; } + if (compile_time_const_nodes[node->id()] && + !registration->requires_compilation) { + const OpDef* op_def; + TF_RETURN_IF_ERROR( + OpRegistry::Global()->LookUpOpDef(node->type_string(), &op_def)); + if (op_def->is_stateful()) { + // We need to be able to constant fold the nodes in + // compile_time_const_nodes given constant inputs (required by XLA) and + // therefore can't auto-cluster stateful ops since these can never be + // constant folded. + VLOG(2) << "Rejecting " << node->name() + << ": must-be-constant stateful op"; + continue; + } + } + // We don't auto-cluster functional control flow nodes containing resource + // operations because safety checks are trickier in this case. + // registration->compile_resource_ops is true for XLA_CPU/XLA_GPU but not + // for CPU/GPU. if (node->type_string() == "While" && - !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) { + !IsCompilableWhile(*node, jit_device_type, + registration->compile_resource_ops, 0, + lib_runtime)) { continue; } // _Arg nodes in a top-level function represent feeds. @@ -457,7 +521,11 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), ®istration)); DeviceType jit_device_type(registration->compilation_device_name); - return IsCompilableCall(ndef, jit_device_type, 0, flr); + + // We can always *compile* resource operations, even if we are sometimes + // unable to auto-cluster them. + const bool compile_resource_ops = true; + return IsCompilableCall(ndef, jit_device_type, compile_resource_ops, 0, flr); } Status MarkForCompilationPass::Run( @@ -600,6 +668,82 @@ static void VLogClusteringSummary(const Graph& g) { VLOG(3) << " " << pair.first << ": " << pair.second << " instances"; } } + + struct EdgeInfo { + StringPiece node_name; + absl::optional cluster_name; + + StringPiece GetClusterName() const { + return cluster_name ? *cluster_name : "[none]"; + } + + std::pair> AsPair() const { + return {node_name, cluster_name}; + } + + bool operator<(const EdgeInfo& other) const { + return AsPair() < other.AsPair(); + } + }; + + using EdgeInfoMap = std::map>; + + EdgeInfoMap incoming_edge_infos; + EdgeInfoMap outgoing_edge_infos; + + std::set cluster_names_to_print; + + for (const Edge* e : g.edges()) { + const Node* from = e->src(); + absl::optional from_cluster_name = GetXlaClusterForNode(*from); + + const Node* to = e->dst(); + absl::optional to_cluster_name = GetXlaClusterForNode(*to); + + if (to_cluster_name == from_cluster_name) { + continue; + } + + if (to_cluster_name) { + incoming_edge_infos[*to_cluster_name] + [EdgeInfo{from->name(), from_cluster_name}]++; + cluster_names_to_print.insert(*to_cluster_name); + } + + if (from_cluster_name) { + outgoing_edge_infos[*from_cluster_name][{to->name(), to_cluster_name}]++; + cluster_names_to_print.insert(*from_cluster_name); + } + } + + VLOG(2) << "*** Inter-Cluster edges:"; + if (cluster_names_to_print.empty()) { + VLOG(2) << " [none]"; + } + + auto print_edge_info_set_for_cluster = [&](StringPiece cluster_name, + const EdgeInfoMap& edge_info_map, + StringPiece desc) { + auto it = edge_info_map.find(cluster_name); + if (it != edge_info_map.end()) { + VLOG(2) << " " << it->second.size() << " " << desc << " edges"; + for (const auto& edge_info_count_pair : it->second) { + VLOG(2) << " " << edge_info_count_pair.first.GetClusterName() << " " + << edge_info_count_pair.first.node_name << " # " + << edge_info_count_pair.second; + } + } else { + VLOG(2) << " No " << desc << " edges."; + } + }; + + for (StringPiece cluster_name : cluster_names_to_print) { + VLOG(2) << " ** Cluster " << cluster_name; + print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos, + "incoming"); + print_edge_info_set_for_cluster(cluster_name, outgoing_edge_infos, + "outgoing"); + } } // Is 'node' an operator that consumes only the shape of its input, not the @@ -609,6 +753,43 @@ static bool IsShapeConsumerOp(const Node& node) { node.type_string() == "Size"; } +static Status IgnoreResourceOpForSafetyAnalysis(const Node& n, bool* ignore) { + // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then + // ignore it during resource operation safety analysis. We need this hack + // because of two reasons: + // + // 1. Operations assigned to XLA_CPU and XLA_GPU have to always be compiled. + // 2. We don't support live-out values of type DT_RESOURCE and live-in values + // of type DT_RESOURCE that are not resource variables. + // + // Together these imply we cannot let resource variable safety analysis + // constrain e.g. a TensorArrayV3->TensorArrayAssignV3 edge to be in different + // clusters: both of them will have to be clustered because of (1) and we + // won't be able to keep the edge between the two as neither the input to the + // second XLA cluster nor the output from the first XLA cluster are supported + // because of (2). + // + // TODO(b/113100872): This can be fixed if the TensorFlow representation for + // TensorArray and Stack on the XLA_{C|G}PU devices were the same in XLA; then + // (2) would no longer hold. + + if (n.assigned_device_name().empty()) { + *ignore = false; + return Status::OK(); + } + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n.assigned_device_name(), &device_type)); + + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + *ignore = true; + } else { + *ignore = registration->compile_resource_ops; + } + return Status::OK(); +} + // Sequence number generator to ensure clusters have unique names. static std::atomic cluster_sequence_num; @@ -637,6 +818,8 @@ Status MarkForCompilationPass::RunImpl( GraphCycles cycles; TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles)); + TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( + graph, options.flib_def, IgnoreResourceOpForSafetyAnalysis, &cycles)); // Each compilation candidate belongs to a cluster. The cluster's // representative @@ -675,7 +858,7 @@ Status MarkForCompilationPass::RunImpl( string to_scope; for (int to : cycles.Successors(from)) { if (to >= graph->num_node_ids()) { - // Node is a "frame" node that is present only in the cycle detection + // Node is a fictitious node that is present only in the cycle detection // graph. No clustering is possible. continue; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 9d7ac0d609eea370b8100e1eb53b0b0b3d9f2382..807ab51fd3c133b95915ea88e0bf99dbb8661452 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -15,10 +15,12 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" @@ -26,11 +28,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -48,9 +50,35 @@ std::unordered_map GetClusters(const Graph& graph) { ids[node->name()] = cluster; } } + + if (VLOG_IS_ON(2)) { + VLOG(2) << "Clusters:"; + for (const auto& p : ids) { + VLOG(2) << " " << p.first << " -> " << p.second; + } + } return ids; } +gtl::FlatMap> GetClusterSets( + const Graph& g, std::vector* cluster_names = nullptr) { + CHECK(cluster_names == nullptr || cluster_names->empty()); + gtl::FlatMap> cluster_sets; + for (const auto& p : GetClusters(g)) { + cluster_sets[p.second].push_back(p.first); + } + for (auto& p : cluster_sets) { + if (cluster_names != nullptr) { + cluster_names->push_back(p.first); + } + std::sort(p.second.begin(), p.second.end()); + } + if (cluster_names != nullptr) { + std::sort(cluster_names->begin(), cluster_names->end()); + } + return cluster_sets; +} + TEST(XlaCompilationTest, Chains) { std::unique_ptr graph(new Graph(OpRegistry::Global())); GraphDef graphdef; @@ -501,38 +529,104 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { EXPECT_EQ(clusters["B"], clusters["C"]); } -REGISTER_OP("ResourceInput").Input("a: resource").Output("o: float"); -REGISTER_OP("ResourceOutput").Input("a: float").Output("o: resource"); - namespace { +Node* MakeRead(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output read = + ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT); + return read.node(); +} -class DummyOp : public XlaOpKernel { - using XlaOpKernel::XlaOpKernel; - void Compile(XlaOpKernelContext* ctx) override {} -}; - -REGISTER_XLA_OP(Name("ResourceInput"), DummyOp); -REGISTER_XLA_OP(Name("ResourceOutput"), DummyOp); +Node* MakeWrite(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output value_to_write = + ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f); + ops::AssignVariableOp assign_op(scope.WithOpName("Assignment" + id), + var_handle, value_to_write); + return assign_op.operation.node(); +} +Node* MakeNeutral(const Scope& scope, const string& id) { + return ops::Const(scope.WithOpName("Const" + id), 42.0f).node(); +} } // namespace -TEST(XlaCompilationTest, Resources) { +TEST(XlaCompilationTest, ResourcesClusteringAllowed) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(read, write); + + FixupSourceAndSinkEdges(root.graph()); std::unique_ptr graph(new Graph(OpRegistry::Global())); - GraphDef graphdef; - { - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - Node* a = - ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); - Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); - // We should not form clusters with resource ops by default. - Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C")); - Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D")); - ops::UnaryOp("Relu", d, builder.opts().WithName("E")); - TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); - } + TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - auto clusters = GetClusters(*graph); - EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. + gtl::FlatMap> cluster_sets = + GetClusterSets(*graph); + ASSERT_EQ(cluster_sets.size(), 1); + std::vector expected_clustered_nodes = {"AssignmentW", "ReadR", + "ValueToAssignW"}; + ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); +} + +TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, read); + + FixupSourceAndSinkEdges(root.graph()); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_EXPECT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + gtl::FlatMap> cluster_sets = + GetClusterSets(*graph); + ASSERT_EQ(cluster_sets.size(), 1); + std::vector expected_clustered_nodes = {"AssignmentW", + "ValueToAssignW"}; + ASSERT_EQ(cluster_sets.begin()->second, expected_clustered_nodes); +} + +TEST(XlaCompilationTest, ChainOfOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* neutral_0 = MakeNeutral(root, "N0"); + Node* read_0 = MakeRead(root, "R0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral_1 = MakeNeutral(root, "N1"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral_0); + root.graph()->AddControlEdge(neutral_0, read_0); + root.graph()->AddControlEdge(read_0, write_1); + root.graph()->AddControlEdge(write_1, neutral_1); + root.graph()->AddControlEdge(neutral_1, read_1); + + FixupSourceAndSinkEdges(root.graph()); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_EXPECT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::vector cluster_names; + gtl::FlatMap> cluster_sets = + GetClusterSets(*graph, &cluster_names); + + ASSERT_EQ(cluster_sets.size(), 2); + + std::vector expected_clustered_nodes_a = {"AssignmentW0", "ConstN0", + "ValueToAssignW0"}; + ASSERT_EQ(cluster_sets[cluster_names[0]], expected_clustered_nodes_a); + + std::vector expected_clustered_nodes_b = { + "AssignmentW1", "ConstN1", "ReadR0", "ValueToAssignW1"}; + ASSERT_EQ(cluster_sets[cluster_names[1]], expected_clustered_nodes_b); } TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { @@ -562,11 +656,11 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.ToString(), - "Edge from c to a would create a cycle.\n" - "+-> a\n" - "| b\n" - "+-- c\n")); + EXPECT_TRUE(absl::StrContains(status.ToString(), + "Edge from c to a would create a cycle.\n" + "+-> a\n" + "| b\n" + "+-- c\n")); } TEST(XlaCompilationTest, Retval) { @@ -731,5 +825,27 @@ TEST(XlaCompilationTest, ClusterControlTrigger) { EXPECT_EQ(clusters, expected_clusters); } +TEST(XlaCompilationTest, RandomShape) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output shape_shape = ops::Const(root.WithOpName("shape_shape"), {2}, {1}); + Output shape = + ops::RandomUniformInt(root.WithOpName("shape"), shape_shape, + ops::Const(root.WithOpName("minval"), 1), + ops::Const(root.WithOpName("maxval"), 20)); + Output reshape_input = + ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({500, 500}))); + Output reshape = + ops::Reshape(root.WithOpName("reshape"), reshape_input, shape); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["shape"], ""); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index c9e46bc1475aed0e35a48765ad70eef4362e8281..13804c6a0575b921839f99ef7d142e0871693b5a 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -10,10 +10,3 @@ cc_library( deps = ["//tensorflow/core:framework"], alwayslink = 1, ) - -cc_library( - name = "parallel_check_op", - srcs = ["parallel_check_op.cc"], - deps = ["//tensorflow/core:framework"], - alwayslink = 1, -) diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 08a956e4c6478ff76a0fe8f1f60d94824daf535c..f61a955c222dd7ce11a177cd54bb8851a5400496 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc new file mode 100644 index 0000000000000000000000000000000000000000..1ba4a5ef7399111e512da8c4966f5899ed828b17 --- /dev/null +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -0,0 +1,336 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// ALGORITHM OVERVIEW +// ================== +// +// An XLA cluster hoists all resource reads to be beginning of the cluster +// execution and all the resource writes to the end. This means it cannot +// enforce arbitrary ordering dependencies (via control or data edges) between +// resource operations. Since all resource reads happen before all resource +// writes, edges constraining resource reads to happen before resource writes +// are fine, but all other kinds of edges are problematic. This analysis +// computes the set of pairs of resource operations that cannot be put in the +// same cluster because XLA cannot respect the dependencies between them in the +// TensorFlow program. +// +// TODO(b/112856632): We can, in theory, support Read->Read and Write->Write +// dependencies. +// +// Specifically the result computed by this analysis contains the edge {W, R} +// iff all of these hold true: +// +// - In the graph (g - {edges from NextIteration to Merge}) there is a path +// from W to R. +// - IsEdgeSafe(W, R) == False [defined below] +// - W != R (note: some resource operations both read from and write to +// resource variables). +// +// The result is incorrect around loops because we ignore edges from +// NextIteration to Merge, but that should be fine because we don't cluster +// these edges. For instance, in: +// +// Init -----> Merge <-------+ +// | | +// v | +// Read | +// | | +// v | +// Write | +// | | +// v | +// NextIteration --+ +// +// we won't put (Read, Write) in the returned set. This is fine if +// auto-clustering can only cluster the Read->Write edge, but it is a problem if +// it clusters the Write->NextIteration->Merge->Read edges instead. The same +// problem is present for the functional version of the loop above. We rely on +// auto-clustering to not cluster control flow edges like NextIteration->Merge. +// This is enough to avoid the explicit-control-flow problem shown above. One +// way to think about this is that we only care about cases where two nodes, A +// and B, would normally have been put in the same cluster but cannot legally be +// in the same cluster because of resourcevar-dependencies. If A and B would +// normally have been put in the same cluster then all paths between A and B +// would have to be clusterable (otherwise we'd have introduced a cycle). Ergo +// there could not have been a NextIteration->Merge edge between A and B since +// we don't cluster these edges. +// +// We also rely on auto-clustering to not cluster functional control flow nodes +// that contain resource operations. +// +// IMPLEMENTATION +// -------------- +// +// We traverse the graph minus backedges in reverse post order, mapping each +// node to the set of resource operation reaching that node. Since we visit +// producers before consumers, we can construct the set of reaching operations +// by taking the union of the operations reaching the input nodes. These +// "reaching resource operations" can then be used to create the pairs of +// incompatible nodes using `IsEdgeSafe`. + +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace { +// Returns true if `n` may call a function. +Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def, + bool* out_result) { + if (flib_def->Contains(n.type_string())) { + *out_result = true; + } else { + *out_result = + std::any_of(n.def().attr().begin(), n.def().attr().end(), + [](const std::pair& name_attr_pair) { + return name_attr_pair.second.has_func(); + }); + } + + return Status::OK(); +} + +// Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is +// not a resource operation recognized by XLA then sets `out_resource_op_kind` +// to nullopt. +Status XlaResourceOpKindForNode( + const Node& n, const FunctionLibraryDefinition* flib_def, + const std::function& resource_ops_to_ignore, + absl::optional* out_resource_op_kind) { + bool should_ignore = false; + if (resource_ops_to_ignore) { + TF_RETURN_IF_ERROR(resource_ops_to_ignore(n, &should_ignore)); + } + if (should_ignore) { + *out_resource_op_kind = absl::nullopt; + return Status::OK(); + } + + const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string()); + if (op_info) { + *out_resource_op_kind = op_info->kind(); + return Status::OK(); + } + + // We conservatively assume that functions will both read and write resource + // variables. In the future we may consider doing some form of + // inter-procedural analysis. + bool may_call_function; + TF_RETURN_IF_ERROR(MayCallFunction(n, flib_def, &may_call_function)); + if (may_call_function) { + *out_resource_op_kind = XlaResourceOpKind::kReadWrite; + } else { + *out_resource_op_kind = absl::nullopt; + } + + return Status::OK(); +} + +// Returns true if a control or data dependence from a TensorFlow operation of +// resource op kind `from` to a TensorFlow operation of resource op kind `to` +// can be represented by an XLA cluster and needs no special handling around +// auto-jit. +bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) { + // XLA clusters forces all reads to happen before all writes, which means the + // kinds of edges it can faithfully represent are: Read->Write, Read->Modify, + // Modify->Write, Read->Read, Write->Write. + // + // TODO(b/112856632): We can, in theory, support Read->Read and Write->Write + // dependencies. + return from == XlaResourceOpKind::kRead && to == XlaResourceOpKind::kWrite; +} + +using ResourceOp = std::pair; + +string ResourceOpToString(const ResourceOp& resource_op) { + return strings::StrCat( + resource_op.first, ": ", + XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second)); +} + +// A copy-on-write set used to store the set of ResourceOps reaching a node in a +// TensorFlow graph. +// +// TODO(sanjoy): It may be useful to pull this out into its own header at some +// point. +class ResourceOpSet { + private: + using Impl = gtl::FlatSet; + + public: + ResourceOpSet() = default; + + // Adds all ResourceOp s in `other` to this set. + void Add(const ResourceOpSet& other) { + CHECK(!frozen_); + if (other.impl_ == impl_) { + other.frozen_ = true; + return; + } + + if (!impl_) { + other.frozen_ = true; + impl_ = other.impl_; + return; + } + + for (ResourceOp resource_op : other) { + Add(resource_op); + } + } + + void Add(const ResourceOp& resource_op) { + CHECK(!frozen_); + if (!IsCopy() && Contains(resource_op)) { + // We can avoid the copy if the item we want to insert already exists. + return; + } + + EnsureIsCopied(); + impl_->insert(resource_op); + } + + Impl::const_iterator begin() const { + return impl_ ? impl_->begin() : GetEmptyImpl()->begin(); + } + + Impl::const_iterator end() const { + return impl_ ? impl_->end() : GetEmptyImpl()->end(); + } + + bool Contains(const ResourceOp& resource_op) const { + return impl_ != nullptr && impl_->count(resource_op); + } + + private: + bool IsCopy() const { return storage_ != nullptr; } + + void EnsureIsCopied() { + if (storage_ == nullptr) { + storage_ = absl::make_unique(); + for (ResourceOp op : *this) { + storage_->insert(op); + } + impl_ = storage_.get(); + } + } + + static Impl* GetEmptyImpl() { + static Impl* empty_impl = new Impl; + return empty_impl; + } + + Impl* impl_ = nullptr; + std::unique_ptr storage_; + + // frozen_ is true if there is another set pointing to this set's impl_. We + // can no longer add elements to this set in that case since the sets pointing + // to this set expect the contents of this set to be stable. + mutable bool frozen_ = false; + + TF_DISALLOW_COPY_AND_ASSIGN(ResourceOpSet); +}; + +string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { + std::vector elements_debug_string; + std::transform(resource_op_set.begin(), resource_op_set.end(), + std::back_inserter(elements_debug_string), ResourceOpToString); + return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); +} + +string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { + return strings::StrCat( + "[", n.name(), ": ", n.type_string(), "(", + XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]"); +} +} // namespace + +Status ComputeIncompatibleResourceOperationPairs( + const Graph& g, const FunctionLibraryDefinition* flib_def, + const std::function& resource_ops_to_ignore, + std::vector>* result) { + CHECK(result->empty()); + + std::vector rpo; + GetReversePostOrder(g, &rpo, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/[](const Edge& edge) { + return !edge.src()->IsNextIteration(); + }); + + auto resource_op_set_for_node = + absl::make_unique(g.num_node_ids()); + + const bool vlog = VLOG_IS_ON(2); + + for (Node* n : rpo) { + absl::optional op_kind; + TF_RETURN_IF_ERROR(XlaResourceOpKindForNode( + *n, flib_def, resource_ops_to_ignore, &op_kind)); + + ResourceOpSet* resource_op_set = &resource_op_set_for_node[n->id()]; + + // Merge the reaching resource operations for all the incoming edges to + // create the set of all possible resource ops reaching `n`. + for (const Edge* e : n->in_edges()) { + if (n->IsMerge() && e->src()->IsNextIteration()) { + // Ignore back-edges (see file comment). + continue; + } + + const ResourceOpSet& incoming_op_set = + resource_op_set_for_node[e->src()->id()]; + resource_op_set->Add(incoming_op_set); + } + + // Add to the "incompatible resource ops" set if necessary. + if (op_kind) { + for (ResourceOp incoming_op : *resource_op_set) { + if (IsEdgeSafe(incoming_op.second, *op_kind)) { + continue; + } + + if (vlog) { + VLOG(2) << "Unsafe edge: " + << NodeToString(*g.FindNodeId(incoming_op.first), + incoming_op.second) + << " -> " << NodeToString(*n, *op_kind); + } + result->push_back({incoming_op.first, n->id()}); + } + + resource_op_set->Add({n->id(), *op_kind}); + } + + if (vlog) { + VLOG(3) << n->name() << " -> " << ResourceOpSetToString(*resource_op_set); + } + } + + std::sort(result->begin(), result->end()); + CHECK(std::unique(result->begin(), result->end()) == result->end()); + + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.h b/tensorflow/compiler/jit/resource_operation_safety_analysis.h new file mode 100644 index 0000000000000000000000000000000000000000..ae8cfeecad9b9cd631db3e9865bb3c3ff28a2e48 --- /dev/null +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ + +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +// An XLA cluster hoists all resource reads to be beginning of the cluster +// execution and all the resource writes to the end. This means it cannot +// enforce arbitrary ordering dependencies (via control or data edges) between +// resource operations. Since all resource reads happen before all resource +// writes, edges constraining resource reads to happen before resource writes +// are fine, but all other kinds of edges are problematic. This analysis +// returns the set of pairs of resource operations that cannot be put in the +// same cluster because XLA cannot respect the dependencies between them in the +// TensorFlow program. +// +// The restrictions are not transitive: it is fine to put A and C in the same +// cluster even if the returned set contains (A,B) and (B,C). +// +// In other words, if these pairs are seen as edges in an undirected graph of +// the nodes in `g` then auto-clustering is at least as constrained as the graph +// coloring problem on this graph. +// +// +// For instance if we auto-cluster all operations in this TensorFlow graph: +// +// ReadVariablepOp0 -> ReadVariableOp1 +// | +// v +// AssignVariableOp0 -> AssignVariableOp1 +// +// we will lose the ReadVariablepOp0 -> ReadVariableOp1 and the +// AssignVariableOp0 -> AssignVariableOp1 dependencies. I.e. it is possible for +// XlaLaunchOp to issue ReadVariableOp1 before ReadVariablepOp0 since it reads +// all the resource variables when the cluster starts executing without any +// particular ordering between them; same holds for the AssignVariableOp0 -> +// AssignVariableOp1 edge. The ReadVariableOp1 -> AssignVariableOp0 edge will +// be respected by XlaLaunchOp though because all reads happen before all +// writes. +// +// +// NB! The result computed by this analysis assumes that we don't auto-cluster +// back-edges (i.e. the edges from NextIteration to Merge). +// +// NB! The result computed by this analysis assumes that we don't auto-cluster +// functional control flow nodes containing resource operations. +// +// If `resource_ops_to_ignore` is set then nodes for which it returns true are +// ignored (we pretend these nodes are not resource operations). +Status ComputeIncompatibleResourceOperationPairs( + const Graph& g, const FunctionLibraryDefinition* flib_def, + const std::function& resource_ops_to_ignore, + std::vector>* result); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e54b547abcfea698fe79e81dce547ea7858ff829 --- /dev/null +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc @@ -0,0 +1,540 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +Node* MakeRead(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output read = + ops::ReadVariableOp(scope.WithOpName("Read" + id), var_handle, DT_FLOAT); + return read.node(); +} + +Node* MakeWrite(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output value_to_write = + ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f); + ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle, + value_to_write); + return assign_op.operation.node(); +} + +Node* MakeModify(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output value_to_write = ops::Const(scope.WithOpName("Increment" + id), 1.0f); + ops::AssignAddVariableOp assign_add_op(scope.WithOpName("Increment" + id), + var_handle, value_to_write); + return assign_add_op.operation.node(); +} + +Node* MakeNeutral(const Scope& scope, const string& id) { + return ops::Const(scope.WithOpName("Const" + id), 42.0f).node(); +} + +Status ComputeIncompatiblePairs(Graph* g, + std::vector>* result) { + FixupSourceAndSinkEdges(g); + return ComputeIncompatibleResourceOperationPairs(*g, &g->flib_def(), {}, + result); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, read); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair write_read_pair = {write->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], write_read_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(read, write); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 0); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadWriteNoEdges) { + Scope root = Scope::NewRootScope().ExitOnError(); + + MakeRead(root, "R"); + MakeWrite(root, "W"); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 0); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadModify) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + + root.graph()->AddControlEdge(read, modify); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 1); + std::pair read_modify_pair = {read->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], read_modify_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ModifyRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + + root.graph()->AddControlEdge(modify, read); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair modify_read_pair = {modify->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], modify_read_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ModifyWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(modify, write); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 1); + std::pair modify_write_pair = {modify->id(), write->id()}; + EXPECT_EQ(incompatible_pairs[0], modify_write_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteModify) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, modify); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair write_modify_pair = {write->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], write_modify_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadModifyWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(read, modify); + root.graph()->AddControlEdge(modify, write); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + EXPECT_EQ(incompatible_pairs.size(), 2); + std::pair modify_write_pair = {modify->id(), write->id()}; + std::pair read_modify_pair = {read->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], read_modify_pair); + EXPECT_EQ(incompatible_pairs[1], modify_write_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteModifyRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, modify); + root.graph()->AddControlEdge(modify, read); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 3); + + std::pair write_modify_pair = {write->id(), modify->id()}; + std::pair modify_read_pair = {modify->id(), read->id()}; + std::pair write_read_pair = {write->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], modify_read_pair); + EXPECT_EQ(incompatible_pairs[1], write_read_pair); + EXPECT_EQ(incompatible_pairs[2], write_modify_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteReadModify) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* read = MakeRead(root, "R"); + Node* modify = MakeModify(root, "M"); + Node* write = MakeWrite(root, "W"); + + root.graph()->AddControlEdge(write, read); + root.graph()->AddControlEdge(read, modify); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 3); + + std::pair write_modify_pair = {write->id(), modify->id()}; + std::pair write_read_pair = {write->id(), read->id()}; + std::pair read_modify_pair = {read->id(), modify->id()}; + EXPECT_EQ(incompatible_pairs[0], read_modify_pair); + EXPECT_EQ(incompatible_pairs[1], write_read_pair); + EXPECT_EQ(incompatible_pairs[2], write_modify_pair); +} + +FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { + FunctionDefLibrary flib_def; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"}, + /*attr_def*/ + {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)}, + /*ret_def=*/{{"out", "out:output:0"}}); + *flib_def.add_function() = std::move(func); + return flib_def; +} + +Node* MakeCall(Graph* graph, const string& callee_name, const string& node_name, + Status* status) { + NodeDef call_node; + call_node.set_name(node_name); + call_node.set_op(callee_name); + return graph->AddNode(call_node, status); +} + +TEST(ResourceOperationSafetyAnalysisTest, CallRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* read = MakeRead(root, "R"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(call, read); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair call_read_edge = {call->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], call_read_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, ReadCall) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* read = MakeRead(root, "R"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(read, call); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair read_call_edge = {read->id(), call->id()}; + EXPECT_EQ(incompatible_pairs[0], read_call_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, CallWrite) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* write = MakeWrite(root, "W"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(call, write); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair call_write_edge = {call->id(), write->id()}; + EXPECT_EQ(incompatible_pairs[0], call_write_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteCall) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* write = MakeWrite(root, "W"); + Status status; + Node* call = MakeCall(root.graph(), "Const_func", "C", &status); + TF_ASSERT_OK(status); + + root.graph()->AddControlEdge(write, call); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair write_call_edge = {write->id(), call->id()}; + EXPECT_EQ(incompatible_pairs[0], write_call_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, SymbolicGradientRead) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* read = MakeRead(root, "R"); + NameAttrList fn; + fn.set_name("Const_func"); + Node* symbolic_gradient = + ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)}, + /*Tout=*/{DT_FLOAT}, fn) + .output[0] + .node(); + + root.graph()->AddControlEdge(symbolic_gradient, read); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair symbolic_gradient_read_edge = {symbolic_gradient->id(), + read->id()}; + EXPECT_EQ(incompatible_pairs[0], symbolic_gradient_read_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, WriteSymbolicGradient) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("Const_func"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + + Node* write = MakeWrite(root, "W"); + NameAttrList fn; + fn.set_name("Const_func"); + Node* symbolic_gradient = + ops::SymbolicGradient(root, /*input=*/{ops::Const(root, 1.0f)}, + /*Tout=*/{DT_FLOAT}, fn) + .output[0] + .node(); + + root.graph()->AddControlEdge(write, symbolic_gradient); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + std::pair write_symbolic_gradient_edge = {write->id(), + symbolic_gradient->id()}; + EXPECT_EQ(incompatible_pairs[0], write_symbolic_gradient_edge); +} + +TEST(ResourceOperationSafetyAnalysisTest, ChainOfOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* neutral_0 = MakeNeutral(root, "N0"); + Node* read_0 = MakeRead(root, "R0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral_1 = MakeNeutral(root, "N1"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral_0); + root.graph()->AddControlEdge(neutral_0, read_0); + root.graph()->AddControlEdge(read_0, write_1); + root.graph()->AddControlEdge(write_1, neutral_1); + root.graph()->AddControlEdge(neutral_1, read_1); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 5); + std::pair write_0_read_0_pair = {write_0->id(), read_0->id()}; + std::pair write_0_read_1_pair = {write_0->id(), read_1->id()}; + std::pair write_1_read_1_pair = {write_1->id(), read_1->id()}; + std::pair write_0_write_1_pair = {write_0->id(), write_1->id()}; + std::pair read_0_read_1_pair = {read_0->id(), read_1->id()}; + + EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair); + EXPECT_EQ(incompatible_pairs[1], write_0_write_1_pair); + EXPECT_EQ(incompatible_pairs[2], write_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[3], read_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[4], write_1_read_1_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, DagOfOps) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral = MakeNeutral(root, "N"); + Node* read_0 = MakeRead(root, "R0"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral); + root.graph()->AddControlEdge(write_1, neutral); + root.graph()->AddControlEdge(neutral, read_0); + root.graph()->AddControlEdge(neutral, read_1); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 4); + std::pair write_0_read_0_pair = {write_0->id(), read_0->id()}; + std::pair write_0_read_1_pair = {write_0->id(), read_1->id()}; + std::pair write_1_read_0_pair = {write_1->id(), read_0->id()}; + std::pair write_1_read_1_pair = {write_1->id(), read_1->id()}; + + EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair); + EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair); + EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, DagOfOpsWithRepeatedPaths) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Node* write_0 = MakeWrite(root, "W0"); + Node* write_1 = MakeWrite(root, "W1"); + Node* neutral = MakeNeutral(root, "N"); + Node* read_0 = MakeRead(root, "R0"); + Node* read_1 = MakeRead(root, "R1"); + + root.graph()->AddControlEdge(write_0, neutral); + root.graph()->AddControlEdge(write_1, neutral); + root.graph()->AddControlEdge(neutral, read_0); + root.graph()->AddControlEdge(neutral, read_1); + root.graph()->AddControlEdge(write_1, read_1); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 4); + std::pair write_0_read_0_pair = {write_0->id(), read_0->id()}; + std::pair write_0_read_1_pair = {write_0->id(), read_1->id()}; + std::pair write_1_read_0_pair = {write_1->id(), read_0->id()}; + std::pair write_1_read_1_pair = {write_1->id(), read_1->id()}; + + EXPECT_EQ(incompatible_pairs[0], write_0_read_0_pair); + EXPECT_EQ(incompatible_pairs[1], write_0_read_1_pair); + EXPECT_EQ(incompatible_pairs[2], write_1_read_0_pair); + EXPECT_EQ(incompatible_pairs[3], write_1_read_1_pair); +} + +TEST(ResourceOperationSafetyAnalysisTest, Loop) { + Scope root = Scope::NewRootScope().ExitOnError(); + + Output init_value = ops::Placeholder(root.WithOpName("init"), DT_FLOAT); + Output loop_cond = ops::Placeholder(root.WithOpName("init"), DT_BOOL); + Output enter_value = + ops::internal::Enter(root.WithOpName("enter"), init_value, "fr"); + ops::Merge iv(root.WithOpName("iv"), {enter_value, enter_value}); + ops::Switch latch(root.WithOpName("latch"), iv.output, loop_cond); + ops::internal::Exit exit(root.WithOpName("exit"), iv.output); + Output next_iteration = + ops::NextIteration(root.WithOpName("next_iteration"), latch.output_true); + TF_ASSERT_OK( + root.graph()->UpdateEdge(next_iteration.node(), 0, iv.output.node(), 1)); + + Node* write = MakeWrite(root, "W"); + Node* read = MakeRead(root, "R"); + + root.graph()->AddControlEdge(iv.output.node(), write); + root.graph()->AddControlEdge(write, read); + root.graph()->AddControlEdge(read, next_iteration.node()); + + std::vector> incompatible_pairs; + TF_ASSERT_OK(ComputeIncompatiblePairs(root.graph(), &incompatible_pairs)); + + ASSERT_EQ(incompatible_pairs.size(), 1); + + std::pair write_read_pair = {write->id(), read->id()}; + EXPECT_EQ(incompatible_pairs[0], write_read_pair); +} + +bool IsResourceArgDef(const OpDef::ArgDef& arg_def) { + return arg_def.type() == DT_RESOURCE; +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 38adacd93bc43ef17734b909af862e063574e986..4f2fabd658330b8ab182e13e02ed0bca41641e46 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" @@ -207,4 +208,27 @@ bool HasResourceInputOrOutput(const Node& node) { void RemoveFromXlaCluster(NodeDef* node_def) { node_def->mutable_attr()->erase(kXlaClusterAttr); } + +Status AdjustCycleDetectionGraphForResourceOps( + const Graph* graph, const FunctionLibraryDefinition* flib_def, + const std::function& resource_ops_to_ignore, + GraphCycles* cycles) { + std::vector> unsafe_deps; + TF_RETURN_IF_ERROR(ComputeIncompatibleResourceOperationPairs( + *graph, flib_def, resource_ops_to_ignore, &unsafe_deps)); + + // An edge {P,Q} in `unsafe_deps` denotes that P and Q, both of which are + // operations that interact with resource variables, must not be put in the + // same cluster. We enforce this constraint by creating a phantom node, X, + // and adding edges P->X and X->Q. MarkForCompilation then cannot cluster P + // and Q together since that would create a cycle with X. + + for (std::pair unsafe_dep : unsafe_deps) { + int phantom_node_id = cycles->NewNode(); + CHECK(cycles->InsertEdge(unsafe_dep.first, phantom_node_id)); + CHECK(cycles->InsertEdge(phantom_node_id, unsafe_dep.second)); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index 662a53d89eb37128be54abc95e1748c6d5f9081f..b0439a63ca6476b6b1d63e65308712270381dd9f 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -55,6 +55,13 @@ void RemoveFromXlaCluster(NodeDef* node_def); // Returns true if `node` has a DT_RESOURCE typed input or output. bool HasResourceInputOrOutput(const Node& node); +// Adds edges to `cycles` to prevent clustering resource operations that cannot +// be legally clustered. +Status AdjustCycleDetectionGraphForResourceOps( + const Graph* graph, const FunctionLibraryDefinition* flib_def, + const std::function& resource_ops_to_ignore, + GraphCycles* cycles); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc index 2cb351e1ecdb4523a8652886af156540e4736b18..65bbf3efe85ba30f44531ff6d54b041786dca0a5 100644 --- a/tensorflow/compiler/jit/xla_cluster_util_test.cc +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 7140d47a9421ec73d0144e855b490f89569e6ae9..ef6b0e67d3c4007f86dc7eef89cacb4cea98fc15 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -230,7 +230,7 @@ Status XlaCompilationCache::Compile( const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options) { + const XlaCompiler::CompileOptions& compile_options) { return CompileImpl(options, function, constant_args, variable_args, ctx, compilation_result, executable, compile_options, false); } @@ -241,7 +241,7 @@ Status XlaCompilationCache::CompileSingleOp( const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options) { + const XlaCompiler::CompileOptions& compile_options) { const NodeDef& def = ctx->op_kernel().def(); NameAttrList name; name.set_name(def.op()); @@ -256,7 +256,7 @@ Status XlaCompilationCache::CompileImpl( const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options, + const XlaCompiler::CompileOptions& compile_options, bool compile_single_op) { CHECK_NE(executable, nullptr); VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); @@ -324,13 +324,12 @@ Status XlaCompilationCache::CompileImpl( entry->compiled = true; if (compile_single_op) { - entry->compilation_status = compiler.CompileSingleOp( - compile_options ? *compile_options : XlaCompiler::CompileOptions(), - signature.name, ctx, args, &entry->compilation_result); + entry->compilation_status = + compiler.CompileSingleOp(compile_options, signature.name, ctx, args, + &entry->compilation_result); } else { entry->compilation_status = compiler.CompileFunction( - compile_options ? *compile_options : XlaCompiler::CompileOptions(), - function, args, &entry->compilation_result); + compile_options, function, args, &entry->compilation_result); } TF_RETURN_IF_ERROR(entry->compilation_status); CHECK_EQ(entry->executable.get(), nullptr); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index fc5f008f4f52c32d97e680784082d0e7bcb7d8eb..10ad87e38cc4d614e869782329f84351bc3b1f0b 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -70,7 +70,7 @@ class XlaCompilationCache : public ResourceBase { OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options); + const XlaCompiler::CompileOptions& compile_options); // As above, but calls XlaCompiler::CompileSingleOp instead of // XlaCompiler::CompileFunction. @@ -80,7 +80,7 @@ class XlaCompilationCache : public ResourceBase { const std::map& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options); + const XlaCompiler::CompileOptions& compile_options); xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } @@ -96,7 +96,7 @@ class XlaCompilationCache : public ResourceBase { OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions* compile_options, + const XlaCompiler::CompileOptions& compile_options, bool compile_single_op); // Takes `result` which has been compiled from a Tensorflow subgraph to a diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index dd84fb34c171f8d2174444ddd3b3b476e7142718..3ba48e8c318f84a4691fb74434bc009fdd0d81bf 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -177,7 +177,7 @@ Status XlaCompileOnDemandOp::Compile( std::map variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - result, executable, &compile_options); + result, executable, compile_options); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 70e6d0be0f2cffe98fd77fddac5866789c411a51..50c902fdfc06e9fb2cbcd9dd44640a7d40d0fe81 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -365,11 +365,7 @@ Status XlaDevice::FillContextMap(const Graph* graph, void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":" << op_kernel->type_string(); - // When Xprof profiling is off (which is the default), constructing the - // activity is simple enough that its overhead is negligible. - tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), - op_kernel->IsExpensive()); - op_kernel->Compute(context); + TracingDevice::Compute(op_kernel, context); } void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 2027ec7737895557a46df38fe6d0ddb372e3bb67..ee07c5c9643ef1119b9077326c1cf7c83930e90c 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -184,18 +184,6 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, return; } status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); - if (status.ok()) { - xla_tensor->set_host_tensor(*cpu_tensor); - host_to_device_stream_->ThenDoHostCallback([this, done]() { - // We must not call the done closure directly from DoHostCallback - // to avoid a deadlock. If done() is the callback that ends an - // Executor's run, the Executor may call XlaDevice::Sync() inside the - // callback. This deadlocks, because XlaDevice::Sync() waits for all - // stream activity to complete. - thread_pool_->Schedule([done]() { done(Status::OK()); }); - }); - return; - } } else { se::DeviceMemoryBase dev_dst_ptr = XlaTensor::DeviceMemoryFromTensor(*device_tensor); @@ -208,8 +196,9 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, host_to_device_stream_.get(), block_status.error_message().c_str()); } } - xla_tensor->set_host_tensor(*cpu_tensor); - + if (status.ok()) { + xla_tensor->set_host_tensor(*cpu_tensor); + } done(status); } diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index 4b499b161371ecece14447b29fbf809b6e8857db..07cfab615157650aea0e15cdafa8c9b0925f9e5f 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -41,8 +41,8 @@ static bool IsShapeConsumerOp(const Node& node) { } // Returns true if the op can be decomposed into XLA ops for which -// there are fusable elemental implementations. -bool IsXlaFusable(const NodeDef& node) { +// there are fusible elemental implementations. +static bool IsXlaFusible(const NodeDef& node) { static const std::unordered_set* elementwise_ops = new std::unordered_set( {// tf2xla/kernels/aggregate_ops.cc @@ -176,9 +176,9 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type)); if (device_type.type_string().find("XLA") != string::npos) continue; - // Assume all fusable ops are registered. + // Assume all fusible ops are registered. // TODO(hpucha): Check for registration if possible. - if (!IsXlaFusable(node->def())) { + if (!IsXlaFusible(node->def())) { continue; } @@ -208,6 +208,8 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, GraphCycles cycles; TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles)); + TF_RETURN_IF_ERROR(AdjustCycleDetectionGraphForResourceOps( + &graph, &graph.flib_def(), /*resource_ops_to_ignore=*/{}, &cycles)); // TODO(hpucha): Make clustering more robust. There are two known issues that // we need to mitigate: (a) Non-resource variables can cause deadlocks diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc index 5736760a878dc857a8558093054d0adc0f727398..68e19c8a135735a79fcabf121e619157fa22b4d8 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/xla_fusion_optimizer.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/core/graph/graph_def_builder.h" @@ -71,7 +73,7 @@ TEST_F(XlaFusionOptimizerTest, Chains) { EXPECT_TRUE(clusters.find("D") == clusters.cend()); } -TEST_F(XlaFusionOptimizerTest, FusableOps) { +TEST_F(XlaFusionOptimizerTest, FusibleOps) { GraphDef graph; { GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); @@ -179,5 +181,28 @@ TEST_F(XlaFusionOptimizerTest, CompilableCycles) { EXPECT_EQ(clusters["A"], clusters["C"]); } +TEST_F(XlaFusionOptimizerTest, ResourcesClusteringDisallowed) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output var_handle = + ops::VarHandleOp(root.WithOpName("Var"), DT_FLOAT, TensorShape({})); + Output to_assign = ops::Const(root.WithOpName("Const"), 10.0f); + Output begin = ops::Const(root.WithOpName("begin"), 0); + Output end = ops::Const(root.WithOpName("end"), 1); + Output strides = ops::Const(root.WithOpName("strides"), 1); + ops::ResourceStridedSliceAssign assign_1( + root.WithOpName("assign_1"), var_handle, begin, end, strides, to_assign); + ops::ResourceStridedSliceAssign assign_2( + root.WithOpName("assign_2"), var_handle, begin, end, strides, to_assign); + root.graph()->AddControlEdge(assign_1.operation.node(), + assign_2.operation.node()); + grappler::GrapplerItem item; + root.graph()->ToGraphDef(&item.graph); + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_NE(clusters["assign_1"], clusters["assign_2"]); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 2ffce9298d99e1e136e15e9a4b0e3f5b26121bd5..affeab4a8c43b63ac0e2b8ef40de5223ce39d410 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -271,31 +271,36 @@ Status XlaComputationLaunchContext::PopulateOutputs( } } else { const TensorShape& shape = kernel->outputs[i].shape; - VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); - - se::DeviceMemoryBase buffer = output.buffer({output_num}); - if (allocate_xla_tensors_) { - Tensor* output_tensor; - TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); - XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); - if (xla_tensor) { - xla_tensor->set_shaped_buffer(ScopedShapedBuffer( - ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); - if (use_multiple_streams_) { - xla_tensor->SetDefinedOn(stream, definition_event); + const DataType& type = kernel->outputs[i].type; + VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " + << DataTypeString(type); + if (type == DT_RESOURCE) { + ctx->set_output(i, ctx->input(kernel->outputs[i].input_index)); + } else { + se::DeviceMemoryBase buffer = output.buffer({output_num}); + if (allocate_xla_tensors_) { + Tensor* output_tensor; + TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); + XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); + if (xla_tensor) { + xla_tensor->set_shaped_buffer(ScopedShapedBuffer( + ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); + if (use_multiple_streams_) { + xla_tensor->SetDefinedOn(stream, definition_event); + } + } else { + // xla_tensor wasn't valid, which must mean this is a zero-element + // tensor. + CHECK_EQ(output_tensor->TotalBytes(), 0); } } else { - // xla_tensor wasn't valid, which must mean this is a zero-element - // tensor. - CHECK_EQ(output_tensor->TotalBytes(), 0); + Tensor output_tensor = XlaTensorBuffer::MakeTensor( + ctx->expected_output_dtype(i), shape, buffer, allocator); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); + ctx->set_output(i, output_tensor); } - } else { - Tensor output_tensor = XlaTensorBuffer::MakeTensor( - ctx->expected_output_dtype(i), shape, buffer, allocator); - output.set_buffer(xla::OwningDeviceMemory(), {output_num}); - ctx->set_output(i, output_tensor); + ++output_num; } - ++output_num; } if (VLOG_IS_ON(3)) { diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 47311d2630175ee8c7eb52d587f138128bcad3df..cf02926e0675e94381462f9579c36909c3bf7de9 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -72,7 +72,7 @@ py_test( tf_xla_py_test( name = "adadelta_test", - size = "medium", + size = "large", srcs = ["adadelta_test.py"], deps = [ ":xla_test", @@ -728,6 +728,7 @@ tf_xla_py_test( "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", ], ) @@ -1190,3 +1191,19 @@ tf_xla_py_test( "//tensorflow/python:platform_test", ], ) + +tf_xla_py_test( + name = "xla_ops_test", + size = "small", + srcs = ["xla_ops_test.py"], + disabled_backends = ["cpu_ondemand"], + deps = [ + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "@absl_py//absl/testing:parameterized", + ], +) diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py index 3e3c09c66e72c4de141b64cea3c4693fabb7b2a2..b7b7fda293b69d6f0cec61d0d234277636a3670d 100644 --- a/tensorflow/compiler/tests/adadelta_test.py +++ b/tensorflow/compiler/tests/adadelta_test.py @@ -33,7 +33,7 @@ class AdadeltaOptimizerTest(xla_test.XLATestCase): def testBasic(self): num_updates = 4 # number of ADADELTA steps to perform for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): for grad in [0.2, 0.1, 0.01]: for lr in [1.0, 0.5, 0.1]: var0_init = [1.0, 2.0] diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py index dc1625793aa44b96d3b96e175237caf96e7d7e74..69fb3ec2964a09508e612515b9e291fc14121d68 100644 --- a/tensorflow/compiler/tests/adagrad_da_test.py +++ b/tensorflow/compiler/tests/adagrad_da_test.py @@ -33,7 +33,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAWithoutRegularizationBasic1(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) @@ -69,7 +69,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAwithoutRegularizationBasic2(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) @@ -100,7 +100,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAWithL1(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) @@ -131,7 +131,7 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): def testAdagradDAWithL1_L2(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): global_step = resource_variable_ops.ResourceVariable( 0, dtype=dtypes.int64) var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py index d775850a80e9f83f7b2c9f1cf8997dd50e229635..ab69319c59fb07e7ce56c3c287a50a6290effdfd 100644 --- a/tensorflow/compiler/tests/adagrad_test.py +++ b/tensorflow/compiler/tests/adagrad_test.py @@ -32,7 +32,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -57,7 +57,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -83,7 +83,7 @@ class AdagradOptimizerTest(xla_test.XLATestCase): def testSharing(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py index c4fdbc5974319db9243eb2c323746cbaaea795f6..3ed1d41b7121f44dd7470f61180f7a7055369174 100644 --- a/tensorflow/compiler/tests/adamax_test.py +++ b/tensorflow/compiler/tests/adamax_test.py @@ -49,7 +49,7 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): def testBasic(self): for i, dtype in enumerate(self.float_types): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 @@ -100,7 +100,7 @@ class AdaMaxOptimizerTest(xla_test.XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py index 9ec5a964cbb4dd98d2ef2d0b684872292118800f..1bc07ace23ccdc83103abe71ee11b72994c75a6d 100644 --- a/tensorflow/compiler/tests/addsign_test.py +++ b/tensorflow/compiler/tests/addsign_test.py @@ -63,7 +63,7 @@ class AddSignTest(xla_test.XLATestCase): alpha=1.0, beta=0.9): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): # Initialize variables for numpy implementation. m0, m1 = 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py index 9d3a889b1f54c813e881bb03b5275f809af1b3c8..4155342787fbbdeaf5c5958c44d007b1ea0660ed 100644 --- a/tensorflow/compiler/tests/argminmax_test.py +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -40,7 +40,7 @@ class ArgMinMaxTest(xla_test.XLATestCase): op_input: numpy input array to use as input to 'op'. expected: numpy array representing the expected output of 'op'. """ - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): pinp = array_ops.placeholder( dtypes.as_dtype(op_input.dtype), op_input.shape, name="a") diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 5b7001b5a463ae0bd4e8f07032256717aab70d49..17280e445b329d1541aaed78ec106f8f282cbc74 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -36,7 +36,7 @@ class BinaryOpsTest(xla_test.XLATestCase): """Test cases for binary operators.""" def _testBinary(self, op, a, b, expected, equality_test=None): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") @@ -1010,7 +1010,38 @@ class BinaryOpsTest(xla_test.XLATestCase): [7, 7, 7, 7, 7, 7]], dtype=dtype)) - def testMirrorPad(self): + def testSymmetricMirrorPad(self): + mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC") + for dtype in self.numeric_types: + self._testBinary( + mirror_pad, + np.array( + [ + [1, 2, 3], # + [4, 5, 6], # + ], + dtype=dtype), + np.array([[ + 2, + 2, + ], [3, 3]], dtype=np.int32), + expected=np.array( + [ + [6, 5, 4, 4, 5, 6, 6, 5, 4], # + [3, 2, 1, 1, 2, 3, 3, 2, 1], # + [3, 2, 1, 1, 2, 3, 3, 2, 1], # + [6, 5, 4, 4, 5, 6, 6, 5, 4], # + [6, 5, 4, 4, 5, 6, 6, 5, 4], # + [3, 2, 1, 1, 2, 3, 3, 2, 1], # + ], + dtype=dtype)) + self._testBinary( + mirror_pad, + np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype), + np.array([[0, 0], [0, 0]], dtype=np.int32), + expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)) + + def testReflectMirrorPad(self): mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT") for dtype in self.numeric_types: self._testBinary( @@ -1372,5 +1403,40 @@ class BinaryOpsTest(xla_test.XLATestCase): [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]], dtype=dtype)) + def testBroadcastTo(self): + for dtype in self.all_types: + x = np.random.randint(0, high=100, size=[2, 3]) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([2, 3], dtype=np.int32), + expected=x) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([6, 6], dtype=np.int32), + expected=np.tile(x, [3, 2])) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([7, 4, 3], dtype=np.int32), + expected=np.tile(x, [7, 2, 1])) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([7, 0, 3], dtype=np.int32), + expected=np.zeros([7, 0, 3], dtype=dtype)) + self._testBinary( + array_ops.broadcast_to, + x, + np.array([7, 1, 2, 9], dtype=np.int32), + expected=np.tile(x, [7, 1, 1, 3])) + self._testBinary( + array_ops.broadcast_to, + np.zeros([2, 0], dtype=dtype), + np.array([4, 0], dtype=np.int32), + expected=np.zeros([4, 0], dtype=dtype)) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py index ef4d5f6322b7ae79b051795b5af7e6f7f1e55550..5c24db539bce5df701d8229290ddb4c20997d40a 100644 --- a/tensorflow/compiler/tests/bucketize_op_test.py +++ b/tensorflow/compiler/tests/bucketize_op_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test class BucketizationOpTest(xla_test.XLATestCase): def testInt(self): - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.int32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) @@ -38,7 +38,7 @@ class BucketizationOpTest(xla_test.XLATestCase): sess.run(op, {p: [-5, 0, 2, 3, 5, 8, 10, 11, 12]})) def testFloat(self): - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.float32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0., 3., 8., 11.]) @@ -48,7 +48,7 @@ class BucketizationOpTest(xla_test.XLATestCase): sess.run(op, {p: [-5., 0., 2., 3., 5., 8., 10., 11., 12.]})) def test2DInput(self): - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.float32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) @@ -58,7 +58,7 @@ class BucketizationOpTest(xla_test.XLATestCase): {p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]})) def testInvalidBoundariesOrder(self): - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.int32) with self.test_scope(): op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11]) @@ -67,7 +67,7 @@ class BucketizationOpTest(xla_test.XLATestCase): sess.run(op, {p: [-5, 0]}) def testBoundariesNotList(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(TypeError, "Expected list.*"): p = array_ops.placeholder(dtypes.int32) with self.test_scope(): diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index a4e7f75081dfd07fd4b5c94c33908aab8e7d8aa9..a57d1dc81ea2c9c188b0a3005904738aa8156bf3 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -56,7 +56,7 @@ class CategoricalTest(xla_test.XLATestCase): Returns: Frequencies from sampled classes; shape [batch_size, num_classes]. """ - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): random_seed.set_random_seed(1618) op = random_ops.multinomial(logits, num_samples, output_dtype=dtypes.int32) @@ -79,7 +79,7 @@ class CategoricalTest(xla_test.XLATestCase): def _testRngIsNotConstant(self, rng, dtype, output_dtype): # Tests that 'rng' does not always return the same value. - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = rng(dtype, output_dtype) @@ -107,7 +107,7 @@ class CategoricalTest(xla_test.XLATestCase): def testCategoricalIsInRange(self): for dtype in self.float_types: for output_dtype in self.output_dtypes(): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = random_ops.multinomial( array_ops.ones(shape=[1, 20], dtype=dtype), 1000, diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py index ed532db0ee5553a275192e6cc3ebf394075fa0e1..d1896a50f7037f2972cba8a4fa16cc1e2cd4fe3e 100644 --- a/tensorflow/compiler/tests/cholesky_op_test.py +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -54,7 +54,7 @@ class CholeskyOpTest(xla_test.XLATestCase): def _verifyCholesky(self, x, atol=1e-6): # Verify that LL^T == x. - with self.test_session() as sess: + with self.cached_session() as sess: placeholder = array_ops.placeholder( dtypes.as_dtype(x.dtype), shape=x.shape) with self.test_scope(): diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py index e42ebf8f9e01dab13cde15979ffc42b7c0fbc57b..88bd58b2da6b2892f898ad10f3467d8ce39d6388 100644 --- a/tensorflow/compiler/tests/clustering_test.py +++ b/tensorflow/compiler/tests/clustering_test.py @@ -38,7 +38,7 @@ class ClusteringTest(xla_test.XLATestCase): val1 = np.array([4, 3, 2, 1], dtype=np.float32) val2 = np.array([5, 6, 7, 8], dtype=np.float32) expected = val1 + val2 - with self.test_session(): + with self.cached_session(): with self.test_scope(): input1 = constant_op.constant(val1, name="const1") input2 = constant_op.constant(val2, name="const2") @@ -50,7 +50,7 @@ class ClusteringTest(xla_test.XLATestCase): val1 = np.array([4, 3, 2, 1]).astype(np.float32) val2 = np.array([5, 6, 7, 8]).astype(np.float32) expected = val1 + val2 - with self.test_session(): + with self.cached_session(): with ops.device(CPU_DEVICE): input1 = constant_op.constant(val1, name="const1") input2 = constant_op.constant(val2, name="const2") @@ -68,7 +68,7 @@ class ClusteringTest(xla_test.XLATestCase): # where x and z are placed on the CPU and y and w are placed on the XLA # device. If y and w are clustered for compilation, then the graph will # deadlock since the clustered graph will contain a self-loop. - with self.test_session() as sess: + with self.cached_session() as sess: with ops.device(CPU_DEVICE): x = array_ops.placeholder(dtypes.float32, [2]) with self.test_scope(): @@ -81,7 +81,7 @@ class ClusteringTest(xla_test.XLATestCase): self.assertAllClose(result, [12., 2.], rtol=1e-3) def testHostMemory(self): - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(dtypes.int32) with self.test_scope(): y = x + 1 diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py index d9ad4281477e87f79f2ecb52989ae86a5030d0cc..37e5318bb54c5d8ecdedc7bb346e89765f2adf35 100644 --- a/tensorflow/compiler/tests/concat_ops_test.py +++ b/tensorflow/compiler/tests/concat_ops_test.py @@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest class ConcatTest(xla_test.XLATestCase): def testHStack(self): - with self.test_session(): + with self.cached_session(): p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) with self.test_scope(): @@ -49,7 +49,7 @@ class ConcatTest(xla_test.XLATestCase): self.assertAllEqual(result[4:, :], params[p2]) def testVStack(self): - with self.test_session(): + with self.cached_session(): p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4]) with self.test_scope(): @@ -65,7 +65,7 @@ class ConcatTest(xla_test.XLATestCase): self.assertAllEqual(result[:, 4:], params[p2]) def testInt32(self): - with self.test_session(): + with self.cached_session(): p1 = np.random.rand(2, 3).astype("i") p2 = np.random.rand(2, 3).astype("i") x1 = constant_op.constant(p1) @@ -88,7 +88,7 @@ class ConcatTest(xla_test.XLATestCase): dtype_feed = dtypes.float32 else: dtype_feed = dtype - with self.test_session(): + with self.cached_session(): p = [] for i in np.arange(num_tensors): input_shape = shape @@ -130,7 +130,7 @@ class ConcatTest(xla_test.XLATestCase): self._testRandom(dtypes.int32) def _testGradientsSimple(self): - with self.test_session(): + with self.cached_session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -157,7 +157,7 @@ class ConcatTest(xla_test.XLATestCase): self._testGradientsSimple() def _testGradientsFirstDim(self): - with self.test_session(): + with self.cached_session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -185,7 +185,7 @@ class ConcatTest(xla_test.XLATestCase): self._testGradientsFirstDim() def _testGradientsLastDim(self): - with self.test_session(): + with self.cached_session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -220,7 +220,7 @@ class ConcatTest(xla_test.XLATestCase): # Random dim to concat on concat_dim = np.random.randint(5) concat_dim_sizes = np.random.randint(1, 5, size=num_tensors) - with self.test_session(): + with self.cached_session(): inp = [] inp_tensors = [] with self.test_scope(): @@ -254,7 +254,7 @@ class ConcatTest(xla_test.XLATestCase): def DISABLED_testZeroSize(self): # Verify that concat doesn't crash and burn for zero size inputs np.random.seed(7) - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): for shape0 in (), (2,): axis = len(shape0) @@ -276,14 +276,14 @@ class ConcatTest(xla_test.XLATestCase): def testConcatTuple(self): c1 = np.random.rand(4, 4).astype(np.float32) c2 = np.random.rand(4, 4).astype(np.float32) - with self.test_session(): + with self.cached_session(): with self.test_scope(): concat_list_t = array_ops.concat([c1, c2], 0) concat_tuple_t = array_ops.concat((c1, c2), 0) self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval()) def testConcatNoScalars(self): - with self.test_session(): + with self.cached_session(): with self.test_scope(): scalar = constant_op.constant(7) dim = array_ops.placeholder(dtypes.int32) @@ -295,7 +295,7 @@ class ConcatTest(xla_test.XLATestCase): class ConcatOffsetTest(xla_test.XLATestCase): def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): cdim = constant_op.constant(1, dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32) @@ -309,7 +309,7 @@ class ConcatOffsetTest(xla_test.XLATestCase): class PackTest(xla_test.XLATestCase): def testBasic(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) @@ -319,7 +319,7 @@ class PackTest(xla_test.XLATestCase): self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]]) def testScalars(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): s0 = constant_op.constant(2, dtypes.int32) s1 = constant_op.constant(3, dtypes.int32) @@ -329,7 +329,7 @@ class PackTest(xla_test.XLATestCase): self.assertAllEqual(ans, [2, 3, 5]) def testEmpty(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): s0 = constant_op.constant([[]], dtypes.int32) s1 = constant_op.constant([[]], dtypes.int32) diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py index f9db103f6d0f9ea0e393a0971593552ec5c14079..af00ff287d43a8542b5a3d14eedc00c3d7aef1b7 100644 --- a/tensorflow/compiler/tests/conv2d_test.py +++ b/tensorflow/compiler/tests/conv2d_test.py @@ -87,7 +87,7 @@ class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase): dilations = test_utils.PermuteDimsBetweenDataFormats( dilations, data_format_src, data_format_dst) - with self.test_session() as sess: + with self.cached_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) with self.test_scope(): @@ -288,7 +288,7 @@ class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase): dilations = test_utils.PermuteDimsBetweenDataFormats( dilations, data_format_src, data_format_dst) - with self.test_session() as sess: + with self.cached_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) with self.test_scope(): @@ -586,7 +586,7 @@ class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase): dilations = test_utils.PermuteDimsBetweenDataFormats( dilations, data_format_src, data_format_dst) - with self.test_session() as sess: + with self.cached_session() as sess: t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) with self.test_scope(): diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py index 31ee41f04f27d387415e9fa2c4fa70b33cab7b04..33fd983b5485e503c2fcc96db2dfdecfc41e309f 100644 --- a/tensorflow/compiler/tests/conv3d_test.py +++ b/tensorflow/compiler/tests/conv3d_test.py @@ -36,7 +36,7 @@ from tensorflow.python.platform import googletest class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase): def testGradient(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): for padding in ["SAME", "VALID"]: for stride in [1, 2]: np.random.seed(1) @@ -69,7 +69,7 @@ class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase): class Conv3DTransposeTest(xla_test.XLATestCase): def testConv3DTransposeSingleStride(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): strides = [1, 1, 1, 1, 1] # Input, output: [batch, depth, height, width, channel] @@ -119,7 +119,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): self.assertAllClose(target, value[n, d, h, w, k]) def testConv3DTransposeSame(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): strides = [1, 2, 2, 2, 1] # Input, output: [batch, depth, height, width, depth] @@ -157,7 +157,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): self.assertAllClose(target, value[n, d, h, w, k]) def testConv3DTransposeValid(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): strides = [1, 2, 2, 2, 1] # Input, output: [batch, depth, height, width, depth] @@ -217,7 +217,7 @@ class Conv3DTransposeTest(xla_test.XLATestCase): np.random.seed(1) # Make it reproducible. x_val = np.random.random_sample(x_shape).astype(np.float64) f_val = np.random.random_sample(f_shape).astype(np.float64) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): x = constant_op.constant(x_val, name="x", dtype=dtypes.float32) f = constant_op.constant(f_val, name="f", dtype=dtypes.float32) output = nn_ops.conv3d_transpose( diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index 865f60ccab46ec6829e49409508303052944e13b..04f3b3ef4905984b0432a536c3b1c275738ede17 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -86,7 +86,7 @@ class DenseLayerTest(test.TestCase): XlaLaunch op by XLA. """ - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32) with jit_scope(): y = layers.dense(x, 3) @@ -113,7 +113,7 @@ class DenseLayerTest(test.TestCase): cluster, causing dense layer to be split into TWO XlaLaunch ops. """ - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) with jit_scope(): y = layers.dense(x, 3) diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 98dc73e189f99b7b811487756659d89dacb97d8a..6ef8a68ca5d35d3d2f78f0cb491e7bb98ff97ac9 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -151,7 +151,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): dtype=data_type).reshape(tensor_in_sizes) x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)], dtype=data_type).reshape(filter_in_sizes) - with self.test_session() as sess: + with self.cached_session() as sess: if data_type == np.float32: tolerance = 1e-4 else: @@ -247,7 +247,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): dtype=np.float32).reshape(tensor_in_sizes) x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)], dtype=np.float32).reshape(filter_in_sizes) - with self.test_session() as sess: + with self.cached_session() as sess: t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=np.float32) t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=np.float32) with self.test_scope(): @@ -321,7 +321,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): x2 = np.random.rand(*output_sizes).astype(np.float32) def _GetVal(use_xla): - with self.test_session(): + with self.cached_session(): t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)]) t1 = array_ops.placeholder(np.float32, shape=filter_sizes) t2 = array_ops.placeholder(np.float32, shape=output_sizes) @@ -356,7 +356,7 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): x2 = np.random.rand(*output_sizes).astype(np.float32) def _GetVal(use_xla): - with self.test_session(): + with self.cached_session(): t0 = array_ops.placeholder(np.float32, shape=input_sizes) t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)]) t2 = array_ops.placeholder(np.float32, shape=output_sizes) diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py index 154e36b10e6da409606ae6022aaf53e34c8e37cc..5f01e128f0b0fa725d99b00ba3406bd50a1b8962 100644 --- a/tensorflow/compiler/tests/dynamic_slice_ops_test.py +++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py @@ -30,7 +30,7 @@ from tensorflow.python.platform import test class DynamicUpdateSliceOpsTest(xla_test.XLATestCase): def _assertOpOutputMatchesExpected(self, op, args, expected): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py index edd78153b56bb5bf1c268936fb82a60581389733..50b04daa6b9f4159a3c4bdeecaf900a5b35a833c 100644 --- a/tensorflow/compiler/tests/dynamic_stitch_test.py +++ b/tensorflow/compiler/tests/dynamic_stitch_test.py @@ -30,7 +30,7 @@ from tensorflow.python.platform import googletest class DynamicStitchTest(xla_test.XLATestCase): def _AssertDynamicStitchResultIs(self, indices, data, expected): - with self.test_session() as session: + with self.cached_session() as session: index_placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices ] diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 3d21fb5864c22a6f449c54d03abc0f234e28dab1..63cee550fde9d9d4314b1541fba191df776a4da2 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -101,7 +101,7 @@ class EagerTest(xla_test.XLATestCase): self.assertAllEqual(15, product) # Run some ops graphly - with context.graph_mode(), self.test_session() as sess: + with context.graph_mode(), self.cached_session() as sess: with self.test_scope(): three = constant_op.constant(3) five = constant_op.constant(5) @@ -351,6 +351,38 @@ class EagerFunctionTest(xla_test.XLATestCase): var = f(v) self.assertEqual(2.0, var.numpy()) + def testReturnResourceHandle(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable([[1.0, 2.0], [3.0, 4.0]]) + + def f(v): + return v.handle + + f = function.defun(f) + handle = f(v) + self.assertAllEqual(v.numpy(), + resource_variable_ops.read_variable_op( + handle, dtypes.float32).numpy()) + + def testReturnMultipleResourceHandles(self): + with self.test_scope(): + v1 = resource_variable_ops.ResourceVariable(1.25) + v2 = resource_variable_ops.ResourceVariable(2.0) + + def f(v): + return v.handle, 3.0 * v, v2.handle, v + v2 + + f = function.defun(f) + v1_handle, v1_times_3, v2_handle, variable_sum = f(v1) + self.assertAllEqual(v1.numpy(), + resource_variable_ops.read_variable_op( + v1_handle, dtypes.float32).numpy()) + self.assertEqual(3.75, v1_times_3.numpy()) + self.assertAllEqual(v2.numpy(), + resource_variable_ops.read_variable_op( + v2_handle, dtypes.float32).numpy()) + self.assertEqual(3.25, variable_sum.numpy()) + def testAllArgumentKinds(self): """Test a complex function that takes different argument kinds. @@ -457,6 +489,72 @@ class EagerFunctionTest(xla_test.XLATestCase): y = two_x_plus_1(x) self.assertAllEqual([5, 7, 9], y.numpy()) + def testNestedDefunWithVariable(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun + def g(x): + x = v0 * x + return x + + @function.defun + def f(x): + x = g(v0 * x) + return x + + x = constant_op.constant(3.0) + y = f(x) + + self.assertEqual(75, y.numpy()) + + def testNestedDefunInGradientTape(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun + def g(x): + x = v0 * x + return x + + @function.defun + def f(x): + x = g(v0 * x) + return x + + x = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + y = f(x) + dy = tape.gradient(y, v0) + + self.assertEqual(75, y.numpy()) + self.assertEqual(30, dy.numpy()) + + def testNestedDefunInGradientTapeDifferentVars(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + v1 = resource_variable_ops.ResourceVariable(3.0) + + @function.defun + def g(x): + x = v1 * x + return x + + @function.defun + def f(x): + x = g(v0 * x) + return x + + x = constant_op.constant(3.0) + with backprop.GradientTape(persistent=True) as tape: + y = f(x) + dy_v0 = tape.gradient(y, v0) + dy_v1 = tape.gradient(y, v1) + + self.assertEqual(45, y.numpy()) + self.assertEqual(9, dy_v0.numpy()) + self.assertEqual(15, dy_v1.numpy()) + class ExcessivePaddingTest(xla_test.XLATestCase): """Test that eager execution works with TPU flattened tensors. diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py index 5529fdbb090315e1d7f47589777d8a538c90db2b..37061e91d161db352b388a965eb72c9c32d3d752 100644 --- a/tensorflow/compiler/tests/extract_image_patches_op_test.py +++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py @@ -44,7 +44,7 @@ class ExtractImagePatches(xla_test.XLATestCase): strides = [1] + strides + [1] rates = [1] + rates + [1] - with self.test_session(): + with self.cached_session(): image_placeholder = array_ops.placeholder(dtypes.float32) with self.test_scope(): out_tensor = array_ops.extract_image_patches( diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py index c48ab178bf53558084fb500b2811c6f0b77a7943..2178c4455609550226c89ceb185837768be1f622 100644 --- a/tensorflow/compiler/tests/fake_quant_ops_test.py +++ b/tensorflow/compiler/tests/fake_quant_ops_test.py @@ -107,7 +107,7 @@ class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase): ], dtype=np.float32) - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): input_placeholder = array_ops.placeholder( dtypes.float32, inputs.shape, name="inputs") @@ -198,7 +198,7 @@ class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase): [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0], dtype=np.float32) - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): gradient_placeholder = array_ops.placeholder( dtypes.float32, gradients.shape, name="gradients") @@ -306,7 +306,7 @@ class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase): ], dtype=np.float32) - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): input_placeholder = array_ops.placeholder( dtypes.float32, inputs.shape, name="inputs") @@ -406,7 +406,7 @@ class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase): expected_backprops_wrt_min = 1.0 + 2.0 expected_backprops_wrt_max = 10.0 + 11.0 - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): gradient_placeholder = array_ops.placeholder( dtypes.float32, gradients.shape, name="gradients") diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py index c64ea249ecb97991952a960a6d16e1bb3be35b17..b3e13fbaa6b33bdaa1be123be558059e96de282e 100644 --- a/tensorflow/compiler/tests/fft_test.py +++ b/tensorflow/compiler/tests/fft_test.py @@ -71,7 +71,7 @@ class FFTTest(xla_test.XLATestCase): data = np.reshape(data.astype(np.float32).view(np.complex64), shape) data = to_32bit(complex_to_input(data)) expected = to_32bit(input_to_expected(data)) - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): ph = array_ops.placeholder( dtypes.as_dtype(data.dtype), shape=data.shape) @@ -93,7 +93,7 @@ class FFTTest(xla_test.XLATestCase): data, nperseg=ws, noverlap=ws - hs, boundary=None, window=window)[2] expected = np.swapaxes(expected, -1, -2) expected *= window.sum() # scipy divides by window sum - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): ph = array_ops.placeholder( dtypes.as_dtype(data.dtype), shape=data.shape) diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py index 0f64cc87cde77fbbef6c4e570879e992bc34bafa..8c7edfd277c992c35a81dd5f261256a86352254e 100644 --- a/tensorflow/compiler/tests/fifo_queue_test.py +++ b/tensorflow/compiler/tests/fifo_queue_test.py @@ -31,13 +31,13 @@ from tensorflow.python.platform import test class FIFOQueueTest(xla_test.XLATestCase): def testEnqueue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) enqueue_op = q.enqueue((10.0,)) enqueue_op.run() def testEnqueueWithShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2)) enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],)) enqueue_correct_op.run() @@ -46,7 +46,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual(1, q.size().eval()) def testMultipleDequeues(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) self.evaluate(q.enqueue([1])) self.evaluate(q.enqueue([2])) @@ -55,7 +55,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertAllEqual(set([1, 2, 3]), set([a, b, c])) def testQueuesDontShare(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) self.evaluate(q.enqueue(1)) q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()]) @@ -64,13 +64,13 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertAllEqual(self.evaluate(q.dequeue()), 1) def testEnqueueDictWithoutNames(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) with self.assertRaisesRegexp(ValueError, "must have names"): q.enqueue({"a": 12.0}) def testParallelEnqueue(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -95,7 +95,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertItemsEqual(elems, results) def testParallelDequeue(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -119,7 +119,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertItemsEqual(elems, results) def testDequeue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) elems = [10.0, 20.0, 30.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -133,7 +133,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([elems[i]], vals) def testEnqueueAndBlockingDequeue(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32) elems = [10.0, 20.0, 30.0] enqueue_ops = [q.enqueue((x,)) for x in elems] @@ -163,7 +163,7 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([elem], result) def testMultiEnqueueAndDequeue(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32)) elems = [(5, 10.0), (10, 20.0), (15, 30.0)] enqueue_ops = [q.enqueue((x, y)) for x, y in elems] @@ -179,12 +179,12 @@ class FIFOQueueTest(xla_test.XLATestCase): self.assertEqual([y], y_val) def testQueueSizeEmpty(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) self.assertEqual([0], q.size().eval()) def testQueueSizeAfterEnqueueAndDequeue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32) enqueue_op = q.enqueue((10.0,)) dequeued_t = q.dequeue() diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 1da97fd51217a0f28d4b3ba2ccfae3f6b094e65b..7ca50b02d9bf3203cbd460c8de13a16defd974a3 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -112,7 +112,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testFtrlwithoutRegularization(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -146,7 +146,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testFtrlwithoutRegularization2(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -174,7 +174,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testFtrlWithL1(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -202,7 +202,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testFtrlWithL1_L2(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -236,7 +236,7 @@ class FtrlOptimizerTest(xla_test.XLATestCase): weights will tend to have smaller magnitudes with this parameter set. """ for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.2], dtype=dtype) @@ -273,9 +273,9 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testEquivAdagradwithoutRegularization(self): steps = 5 for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val0, val1 = self.equivAdagradTest_FtrlPart(steps, dtype) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype) self.assertAllCloseAccordingToType(val0, val2, rtol=1e-4, half_rtol=1e-2) @@ -284,9 +284,9 @@ class FtrlOptimizerTest(xla_test.XLATestCase): def testEquivGradientDescentwithoutRegularization(self): steps = 5 for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val0, val1 = self.equivGradientDescentTest_FtrlPart(steps, dtype) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val2, val3 = self.equivGradientDescentTest_GradientDescentPart( steps, dtype) diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index 04fba444460e714ce96205361ac02ed492206b04..b1891b918c6584abce9da382088ed0037f5319fb 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -40,7 +40,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) expected = APlus2B(aval, bval) - with self.test_session() as sess: + with self.cached_session() as sess: @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -66,7 +66,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) expected = APlus2B(aval, bval) - with self.test_session() as sess: + with self.cached_session() as sess: @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -90,7 +90,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32) expected = Func(aval, bval) - with self.test_session() as sess: + with self.cached_session() as sess: @function.Defun(dtypes.float32, dtypes.float32) def Foo(a, b): @@ -105,7 +105,7 @@ class FunctionTest(xla_test.XLATestCase): def testCompileTimeConstantsInDefun(self): """Tests that XLA handles compile-time constants in defuns.""" - with self.test_session() as sess: + with self.cached_session() as sess: @function.Defun(dtypes.float32, dtypes.int32, dtypes.int32) def Foo(a, c, d): @@ -140,7 +140,7 @@ class FunctionTest(xla_test.XLATestCase): bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32) expected = aval + bval * 2 - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): a = array_ops.placeholder(dtypes.float32, name="a") b = array_ops.placeholder(dtypes.float32, name="b") diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 132e42ac7a28d0769b0de12ea0cee6eae752b245..8c018cccb83a05babb0b7f73b80b4f9de7267c98 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -83,7 +83,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format_src) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): # To avoid constant folding x_val_converted = test_utils.ConvertBetweenDataFormats( x_val, data_format_src, data_format) @@ -126,7 +126,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format_src) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): # To avoid constant folding x_val_converted = test_utils.ConvertBetweenDataFormats( x_val, data_format_src, data_format) @@ -210,7 +210,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): grad_val_converted = test_utils.ConvertBetweenDataFormats( grad_val, data_format_src, data_format) x_val_converted = test_utils.ConvertBetweenDataFormats( @@ -260,7 +260,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): var_val = np.random.random_sample(scale_shape).astype(np.float32) data_format_src = "NHWC" - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): grad_val_converted = test_utils.ConvertBetweenDataFormats( grad_val, data_format_src, data_format) x_val_converted = test_utils.ConvertBetweenDataFormats( diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py index 23b0aed34fb460f50c241e5a920cb4f6f613b947..7161f4ab339b6f4069dd2b02ddbc6a89973e0074 100644 --- a/tensorflow/compiler/tests/gather_nd_op_test.py +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test class GatherNdTest(xla_test.XLATestCase): def _runGather(self, params, indices): - with self.test_session(): + with self.cached_session(): paramsp = array_ops.placeholder(params.dtype) indicesp = array_ops.placeholder(indices.dtype) with self.test_scope(): @@ -46,7 +46,7 @@ class GatherNdTest(xla_test.XLATestCase): np.array([[4], [4], [0]], np.int32))) def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self): - with self.test_session(): + with self.cached_session(): params = np.ones((3, 3), dtype=np.float32) indices_empty = np.empty((0, 2), dtype=np.int32) diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index e9c8ef7c91a728b7dfc948fd9b315e6c9102f6a3..089d95daab7e502b4ba13796fadc2ba3f209759b 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -42,7 +42,7 @@ class GatherTest(xla_test.XLATestCase): return data def testScalar1D(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): data = np.array([0, 1, 2, 3, 7, 5]) for dtype in self.all_tf_types: for indices in 4, [4], [1, 2, 2, 4, 5]: @@ -55,7 +55,7 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual(np_val, gather_val) def testScalar2D(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) for dtype in self.all_tf_types: @@ -69,7 +69,7 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual(expected, gather_val) def testSimpleTwoD32(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) for dtype in self.all_tf_types: @@ -87,7 +87,7 @@ class GatherTest(xla_test.XLATestCase): if np.int64 not in self.int_types: return - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) # The indices must be in bounds for any axis. @@ -114,7 +114,7 @@ class GatherTest(xla_test.XLATestCase): for axis in 0, 1, 2, 3, -1, -2: params = self._buildParams(np.random.randn(*shape), dtype) indices = np.random.randint(shape[axis], size=indices_shape) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): tf_params = array_ops.placeholder(dtype=dtype) tf_indices = constant_op.constant(indices, dtype=dtypes.int32) gather = array_ops.gather(tf_params, tf_indices, axis=axis) @@ -123,7 +123,7 @@ class GatherTest(xla_test.XLATestCase): self.assertAllEqual(gather_np, gather_value) def testIndicesWithDifferentDimensions(self): - with self.test_session(): + with self.cached_session(): for dtype in self.numeric_tf_types: params = array_ops.placeholder(dtype=dtype) indices = array_ops.placeholder(dtype=np.int32) @@ -137,7 +137,7 @@ class GatherTest(xla_test.XLATestCase): [[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]})) def testGatherPrecision(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0], [0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]]) indices = np.array([1, 2, 3, 1]) diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index bf986ade06b11358552ee92df3169f965ce3f534..6fe5a66e0e6717ec738dded9196eef6ba1e2114d 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -54,7 +54,7 @@ class RGBToHSVTest(xla_test.XLATestCase): inp = GenerateNumpyRandomRGB(shape).astype(nptype) # Convert to HSV and back, as a batch and individually - with self.test_session() as sess: + with self.cached_session() as sess: batch0 = array_ops.placeholder(nptype, shape=shape) with self.test_scope(): batch1 = image_ops.rgb_to_hsv(batch0) @@ -78,7 +78,7 @@ class RGBToHSVTest(xla_test.XLATestCase): data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] for nptype in self.float_types: rgb_np = np.array(data, dtype=nptype).reshape([2, 2, 3]) / 255. - with self.test_session(): + with self.cached_session(): placeholder = array_ops.placeholder(nptype) with self.test_scope(): hsv = image_ops.rgb_to_hsv(placeholder) @@ -97,7 +97,7 @@ class RGBToHSVTest(xla_test.XLATestCase): for r, g, b in rgb_flat ]) hsv_np = hsv_np.reshape(4, 4, 4, 3) - with self.test_session(): + with self.cached_session(): placeholder = array_ops.placeholder(nptype) with self.test_scope(): hsv_op = image_ops.rgb_to_hsv(placeholder) @@ -108,7 +108,7 @@ class RGBToHSVTest(xla_test.XLATestCase): class AdjustContrastTest(xla_test.XLATestCase): def _testContrast(self, x_np, y_np, contrast_factor): - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_np.shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -146,7 +146,7 @@ class AdjustContrastTest(xla_test.XLATestCase): return y_np def _adjustContrastTf(self, x_np, contrast_factor): - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(np.float32) with self.test_scope(): y = image_ops.adjust_contrast(x, contrast_factor) @@ -180,7 +180,7 @@ class AdjustHueTest(xla_test.XLATestCase): y_data = [0, 13, 1, 54, 226, 59, 8, 234, 150, 255, 39, 1] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -198,7 +198,7 @@ class AdjustHueTest(xla_test.XLATestCase): y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -216,7 +216,7 @@ class AdjustHueTest(xla_test.XLATestCase): y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) flt_x = image_ops.convert_image_dtype(x, dtypes.float32) with self.test_scope(): @@ -244,7 +244,7 @@ class AdjustHueTest(xla_test.XLATestCase): return y_v.reshape(x_np.shape) def _adjustHueTf(self, x_np, delta_h): - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(dtypes.float32) with self.test_scope(): y = gen_image_ops.adjust_hue(x, delta_h) @@ -324,7 +324,7 @@ class AdjustSaturationTest(xla_test.XLATestCase): y_rgb_data = [6, 9, 13, 140, 180, 226, 135, 121, 234, 172, 255, 128] y_np = np.array(y_rgb_data, dtype=np.uint8).reshape(x_shape) - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) y = self._adjust_saturation(x, saturation_factor) y_tf = y.eval({x: x_np}) @@ -339,7 +339,7 @@ class AdjustSaturationTest(xla_test.XLATestCase): y_data = [0, 5, 13, 0, 106, 226, 30, 0, 234, 89, 255, 0] y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) - with self.test_session(): + with self.cached_session(): x = array_ops.placeholder(x_np.dtype, shape=x_shape) y = self._adjust_saturation(x, saturation_factor) y_tf = y.eval({x: x_np}) @@ -378,7 +378,7 @@ class AdjustSaturationTest(xla_test.XLATestCase): "gb_same", "rgb_same", ] - with self.test_session(): + with self.cached_session(): for x_shape in x_shapes: for test_style in test_styles: x_np = np.random.rand(*x_shape) * 255. @@ -410,13 +410,14 @@ class ResizeBilinearTest(xla_test.XLATestCase): image_np, target_shape, expected=None, - large_tolerance=False): + large_tolerance=False, + align_corners=True): if expected is None: self.fail("expected must be specified") - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): image = array_ops.placeholder(image_np.dtype) resized = gen_image_ops.resize_bilinear( - image, target_shape, align_corners=True) + image, target_shape, align_corners=align_corners) out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) if large_tolerance: self.assertAllClose( @@ -433,7 +434,7 @@ class ResizeBilinearTest(xla_test.XLATestCase): self.fail("input_shape must be specified") if expected is None: self.fail("expected must be specified") - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): dtype = dtype or np.float32 grads = array_ops.placeholder(np.float32) resized = gen_image_ops.resize_bilinear_grad( @@ -579,6 +580,27 @@ class ResizeBilinearTest(xla_test.XLATestCase): dtype=np.float32)), large_tolerance=True) + def testNonAlignCorners3x2To6x4(self): + input_data = [[64, 32], [32, 64], [50, 100]] + expected_data = [[64.0, 48.0, 32.0, 32.0], [48.0, 48.0, 48.0, 48.0], + [32.0, 48.0, 64.0, 64.0], [41.0, 61.5, 82.0, 82.0], + [50.0, 75.0, 100.0, 100.0], [50.0, 75.0, 100.0, 100.0]] + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array(input_data, dtype=dtype), [6, 4], + expected=np.array(expected_data, dtype=np.float32), + align_corners=False) + + def testNonAlignCorners6x4To3x2(self): + input_data = [[127, 127, 64, 64], [127, 127, 64, 64], [64, 64, 127, 127], + [64, 64, 127, 127], [50, 50, 100, 100], [50, 50, 100, 100]] + expected_data = [[127, 64], [64, 127], [50, 100]] + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array(input_data, dtype=dtype), [3, 2], + expected=np.array(expected_data, dtype=dtype), + align_corners=False) + class NonMaxSuppressionTest(xla_test.XLATestCase): @@ -596,7 +618,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): iou_threshold_np = np.array(0.5, dtype=np.float32) score_threshold_np = np.array(0.0, dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, @@ -639,7 +661,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): iou_threshold_np = np.array(0.5, dtype=np.float32) score_threshold_np = np.array(0.0, dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, @@ -686,7 +708,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): iou_threshold_np = np.array(0.5, dtype=np.float32) score_threshold_np = np.array(0.4, dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py index 45a04f0cf56e88946b946bedacb25ce6da3121b4..58622114e4f552fb71db9b040a39b57d7da0037c 100644 --- a/tensorflow/compiler/tests/listdiff_op_test.py +++ b/tensorflow/compiler/tests/listdiff_op_test.py @@ -33,7 +33,7 @@ class ListDiffTest(xla_test.XLATestCase): def _testListDiff(self, x, y, out, idx): for dtype in [dtypes.int32, dtypes.int64]: for index_dtype in [dtypes.int32, dtypes.int64]: - with self.test_session() as sess: + with self.cached_session() as sess: x_tensor = ops.convert_to_tensor(x, dtype=dtype) y_tensor = ops.convert_to_tensor(y, dtype=dtype) with self.test_scope(): diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py index 253b45902fba2df64e5234f135b373cd2a0a7e2a..c6ad67993e8bc196a74c9a328df8c9200c92c575 100644 --- a/tensorflow/compiler/tests/lrn_ops_test.py +++ b/tensorflow/compiler/tests/lrn_ops_test.py @@ -58,7 +58,7 @@ class LRNTest(xla_test.XLATestCase): return output def _RunAndVerify(self, dtype): - with self.test_session(): + with self.cached_session(): # random shape shape = np.random.randint(1, 16, size=4) # Make depth at least 2 to make it meaningful @@ -110,7 +110,7 @@ class LRNTest(xla_test.XLATestCase): alpha = 1.0 * np.random.rand() beta = 1.0 * np.random.rand() - with self.test_session(): + with self.cached_session(): in_image = constant_op.constant(in_image_vals, shape=shape) out_image = constant_op.constant(out_image_vals, shape=shape) out_grads = constant_op.constant(out_grads_vals, shape=shape) diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py index 31093c65713df55390c3130b8654fdcb10fbc133..265c0b6d1412de7be3a5bf5e79129cb330ceb162 100644 --- a/tensorflow/compiler/tests/lstm_test.py +++ b/tensorflow/compiler/tests/lstm_test.py @@ -73,7 +73,7 @@ class LSTMTest(test.TestCase): def _RunLSTMCell(self, basename, init_weights, m_prev_scalar, c_prev_scalar, pad_scalar): - with self.test_session() as sess: + with self.cached_session() as sess: num_inputs = 1 num_nodes = 1 @@ -156,7 +156,7 @@ class LSTMTest(test.TestCase): def _RunLSTMLayer(self, basename, init_weights, m_init_scalar, c_init_scalar, pad_scalar): - with self.test_session() as sess: + with self.cached_session() as sess: num_inputs = 1 num_nodes = 1 seq_length = 3 diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py index 0d9f99f8a6803ecae5f9233518a1768109161ac0..9222db4b7ebf020c8cee1c0af81e05129fb33c4d 100644 --- a/tensorflow/compiler/tests/matrix_band_part_test.py +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test class MatrixBandPartTest(xla_test.XLATestCase): def _testMatrixBandPart(self, dtype, shape): - with self.test_session(): + with self.cached_session(): batch_shape = shape[:-2] mat = np.ones(shape).astype(dtype) batch_mat = np.tile(mat, batch_shape + [1, 1]) diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index 2bb8a97bdaf5836a05501ab9754433e29ae34675..94cd3eeb3179da9b920ea9f03216d602b042a639 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -54,7 +54,7 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol): clean_a = np.tril(a) if lower else np.triu(a) - with self.test_session() as sess: + with self.cached_session() as sess: placeholder_a = MakePlaceholder(a) placeholder_ca = MakePlaceholder(clean_a) placeholder_b = MakePlaceholder(b) diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index c2592c54cf83d41f0e3bdbc1f4dc9ff276ddb078..f77521a7c49dba39849869ddceb7c0e885147722 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -41,7 +41,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -95,7 +95,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase): def testNesterovMomentum(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.1, 0.2], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([0.3, 0.4], dtype=dtype) var0_np = np.array([0.1, 0.2], dtype=dtype) @@ -120,7 +120,7 @@ class MomentumOptimizerTest(xla_test.XLATestCase): def testTensorLearningRateAndMomentum(self): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py index da08225e9fc0d5a8ec21ee9961c4758fa38628b4..a1c07fce732d3b91a7c0550545a03fdab67644d3 100644 --- a/tensorflow/compiler/tests/nary_ops_test.py +++ b/tensorflow/compiler/tests/nary_ops_test.py @@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest class NAryOpsTest(xla_test.XLATestCase): def _testNAry(self, op, args, expected, equality_fn=None): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) @@ -126,7 +126,7 @@ class NAryOpsTest(xla_test.XLATestCase): [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32)) def testOneHot(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): indices = array_ops.constant(np.array([[2, 3], [0, 1]], dtype=np.int32)) op = array_ops.one_hot(indices, np.int32(4), @@ -148,7 +148,7 @@ class NAryOpsTest(xla_test.XLATestCase): self.assertAllEqual(output, expected) def testSplitV(self): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): output = session.run( array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]], diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py index 2f9122645d3c5ccabc8130ac30a3f09cf4bc2de7..f985c5d2d96e06fc0117f3935d61b19c9e8562b1 100644 --- a/tensorflow/compiler/tests/nullary_ops_test.py +++ b/tensorflow/compiler/tests/nullary_ops_test.py @@ -29,14 +29,14 @@ from tensorflow.python.platform import googletest class NullaryOpsTest(xla_test.XLATestCase): def _testNullary(self, op, expected): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): output = op() result = session.run(output) self.assertAllClose(result, expected, rtol=1e-3) def testNoOp(self): - with self.test_session(): + with self.cached_session(): with self.test_scope(): output = control_flow_ops.no_op() # This should not crash. diff --git a/tensorflow/compiler/tests/oom_test.py b/tensorflow/compiler/tests/oom_test.py index d68d32057a367776d5b70d5ac21d5618297c605d..7635f89249b7b71e5353e0b7cb1cea5c1f7bca1d 100644 --- a/tensorflow/compiler/tests/oom_test.py +++ b/tensorflow/compiler/tests/oom_test.py @@ -46,7 +46,7 @@ class OutOfMemoryTest(xla_test.XLATestCase): def test_loop(): size = int(2e8) while True: - with self.test_session(): + with self.cached_session(): # Force the compiled code to not be constant by feeding in a # parameter. p = array_ops.placeholder(dtypes.float32, shape=[2, 1, 1]) diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py index a75d99189b5b673261c9e48f1c5998ea0c575594..77bb839409f0c323ff6ed2c8d6bd105d3003b398 100644 --- a/tensorflow/compiler/tests/placeholder_test.py +++ b/tensorflow/compiler/tests/placeholder_test.py @@ -28,7 +28,7 @@ from tensorflow.python.platform import googletest class PlaceholderTest(xla_test.XLATestCase): def test_placeholder_with_default_default(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(4.0) ph = array_ops.placeholder_with_default(v, shape=[]) out = ph * 2 @@ -36,7 +36,7 @@ class PlaceholderTest(xla_test.XLATestCase): self.assertEqual(8.0, sess.run(out)) def test_placeholder_with_default_fed(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(4.0) ph = array_ops.placeholder_with_default(v, shape=[]) out = ph * 2 diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py index 17f860db61aeda98326a6820771d67ee948b6dda..b6cdd38345b9a9f6b03e8799587e3f6ffe07b407 100644 --- a/tensorflow/compiler/tests/pooling_ops_3d_test.py +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -62,7 +62,7 @@ class Pooling3DTest(xla_test.XLATestCase): # numbers from 1. x = np.arange(1.0, total_size + 1, dtype=np.float32) x = x.reshape(input_sizes) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): inputs = array_ops.placeholder(dtypes.float32) t = pool_func( inputs, @@ -210,7 +210,7 @@ class Pooling3DTest(xla_test.XLATestCase): strides = [1] + strides + [1] total_size = np.prod(input_sizes) x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes) - with self.test_session() as sess: + with self.cached_session() as sess: # Use the forward pool function to compute some corresponding outputs # (needed for the CPU device, and we need the shape in both cases). with ops.device("CPU"): diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index 9fc94752ea660f7fb8b2c792180f01485ad04419..d03bd4fdbb7694bc36291faf9b845ec48e26a386 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -89,7 +89,7 @@ class PoolingTest(xla_test.XLATestCase): # numbers from 1. x = np.array([f * 1.0 for f in range(1, total_size + 1)], dtype=np.float32) x = x.reshape(input_sizes) - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): inputs = array_ops.placeholder(dtypes.float32) t = inputs @@ -324,7 +324,7 @@ class PoolGradTest(xla_test.XLATestCase): # TODO(b/74222344): Fix nan handling for max pool grad. # x[np.random.choice(total_size)] = np.nan x = x.reshape(input_sizes) - with self.test_session() as sess: + with self.cached_session() as sess: # Use the forward pool function to compute some corresponding outputs # (needed for the CPU device, and we need the shape in both cases). with ops.device(self.CPU_DEVICE): diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py index 5fa7706d7294f2cffb7d24a56851be02d759335a..86536da7fed0e2309beb32fee9c7c605491592ed 100644 --- a/tensorflow/compiler/tests/powersign_test.py +++ b/tensorflow/compiler/tests/powersign_test.py @@ -64,7 +64,7 @@ class PowerSignTest(xla_test.XLATestCase): base=math.e, beta=0.9): for dtype in self.float_types: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): # Initialize variables for numpy implementation. m0, m1 = 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype) diff --git a/tensorflow/compiler/tests/proximal_adagrad_test.py b/tensorflow/compiler/tests/proximal_adagrad_test.py index cde87db63dbfd7c8d823c6fd0e41eee8b23735bb..c41b4171e26af4f7ad0237d7407a5b3691299595 100644 --- a/tensorflow/compiler/tests/proximal_adagrad_test.py +++ b/tensorflow/compiler/tests/proximal_adagrad_test.py @@ -32,7 +32,7 @@ from tensorflow.python.training import proximal_adagrad class ProximalAdagradOptimizerTest(xla_test.XLATestCase): def testResourceProximalAdagradwithoutRegularization(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -60,7 +60,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): self.assertEqual(2, len(opt_vars)) def testProximalAdagradwithoutRegularization2(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -84,7 +84,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([3.715679, 2.433051]), var1.eval()) def testProximalAdagradWithL1(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -108,7 +108,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([2.959304, 1.029232]), var1.eval()) def testProximalAdagradWithL1_L2(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -151,7 +151,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): return var0.eval(), var1.eval() def testEquivAdagradwithoutRegularization(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val0, val1 = self.applyOptimizer( proximal_adagrad.ProximalAdagradOptimizer( 3.0, @@ -159,7 +159,7 @@ class ProximalAdagradOptimizerTest(xla_test.XLATestCase): l1_regularization_strength=0.0, l2_regularization_strength=0.0)) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val2, val3 = self.applyOptimizer( adagrad.AdagradOptimizer( 3.0, initial_accumulator_value=0.1)) diff --git a/tensorflow/compiler/tests/proximal_gradient_descent_test.py b/tensorflow/compiler/tests/proximal_gradient_descent_test.py index 11eb76871133eba8fcd24621afb03e16614fb005..3d808e6b8a71ef9fa60b671d07bfd907e9f58efc 100644 --- a/tensorflow/compiler/tests/proximal_gradient_descent_test.py +++ b/tensorflow/compiler/tests/proximal_gradient_descent_test.py @@ -32,7 +32,7 @@ from tensorflow.python.training import proximal_gradient_descent class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): def testResourceProximalGradientDescentwithoutRegularization(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([0.0, 0.0]) var1 = resource_variable_ops.ResourceVariable([0.0, 0.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -53,7 +53,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([-0.09, -0.18]), var1.eval()) def testProximalGradientDescentwithoutRegularization2(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -75,7 +75,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([3.91, 2.82]), var1.eval()) def testProximalGradientDescentWithL1(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -97,7 +97,7 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): self.assertAllClose(np.array([3.67, 2.37]), var1.eval()) def testProximalGradientDescentWithL1_L2(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): var0 = resource_variable_ops.ResourceVariable([1.0, 2.0]) var1 = resource_variable_ops.ResourceVariable([4.0, 3.0]) grads0 = constant_op.constant([0.1, 0.2]) @@ -137,14 +137,14 @@ class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase): return var0.eval(), var1.eval() def testEquivGradientDescentwithoutRegularization(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val0, val1 = self.applyOptimizer( proximal_gradient_descent.ProximalGradientDescentOptimizer( 3.0, l1_regularization_strength=0.0, l2_regularization_strength=0.0)) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): val2, val3 = self.applyOptimizer( gradient_descent.GradientDescentOptimizer(3.0)) diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py index 1b969ee2b3886fca6ec9951d1621ca5af6a673d8..3a268978bfd72d08a7d3a7cc61a116dac543cda5 100644 --- a/tensorflow/compiler/tests/qr_op_test.py +++ b/tensorflow/compiler/tests/qr_op_test.py @@ -71,7 +71,7 @@ class QrOpTest(xla_test.XLATestCase, parameterized.TestCase): x_np = np.random.uniform( low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype) - with self.test_session() as sess: + with self.cached_session() as sess: x_tf = array_ops.placeholder(dtype) with self.test_scope(): q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices) diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 8c4e16e4e075726d741f6ff8cdfb6b1aad6cd33e..6e183441179ebf2e8c063b333f9328d6fa86cc88 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -39,7 +39,7 @@ class RandomOpsTest(xla_test.XLATestCase): def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = rng(dtype) @@ -79,7 +79,7 @@ class RandomOpsTest(xla_test.XLATestCase): if (self.device in ["XLA_GPU", "XLA_CPU" ]) and (dtype in [dtypes.bfloat16, dtypes.half]): continue - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = random_ops.random_uniform( shape=[1000], dtype=dtype, minval=-2, maxval=33) @@ -99,7 +99,7 @@ class RandomOpsTest(xla_test.XLATestCase): count = 10000000 # TODO(b/34339814): implement inverse erf support for non-F32 types. for dtype in [dtypes.float32]: - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = random_ops.truncated_normal(shape=[count], dtype=dtype) y = sess.run(x) @@ -147,7 +147,7 @@ class RandomOpsTest(xla_test.XLATestCase): # TODO(b/26783907): this test requires the CPU backend to implement sort. if self.device in ["XLA_CPU"]: return - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = math_ops.range(1 << 16) shuffle = random_ops.random_shuffle(x) @@ -158,7 +158,7 @@ class RandomOpsTest(xla_test.XLATestCase): self.assertAllEqual(set(result), set(expected)) def testShuffle2d(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = array_ops.diag(math_ops.range(20)) shuffle = random_ops.random_shuffle(x) diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index cea2ec816f85e88b11e6e80c91c14fca9015f45c..5ae5b1bc1df76e6d0267a9a9ac18e7bc4725ec7b 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import functools import itertools +from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test @@ -30,22 +31,24 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest -class ReduceOpsTest(xla_test.XLATestCase): - +@parameterized.named_parameters(('32_bit_index', dtypes.int32), + ('64_bit_index', dtypes.int64)) +class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase): def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs, + index_dtype, rtol=1e-4, atol=1e-4): """Tests that the output of 'tf_reduce_fn' matches numpy's output.""" for test_input in test_inputs: - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): a = array_ops.placeholder(dtype) - index = array_ops.placeholder(dtypes.int32) + index = array_ops.placeholder(index_dtype) out = tf_reduce_fn(a, index) result = sess.run(out, {a: test_input, index: [0]}) self.assertAllClose( @@ -89,22 +92,23 @@ class ReduceOpsTest(xla_test.XLATestCase): np.array([[False, True, False], [True, True, False]]), ] - def testReduceSumF32(self): - self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA) + def testReduceSumF32(self, index_dtype): + self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA, + index_dtype) - def testReduceSumC64(self): + def testReduceSumC64(self, index_dtype): self._testReduction(math_ops.reduce_sum, np.sum, np.complex64, - self.COMPLEX_DATA) + self.COMPLEX_DATA, index_dtype) - def testReduceProdF32(self): + def testReduceProdF32(self, index_dtype): self._testReduction(math_ops.reduce_prod, np.prod, np.float32, - self.REAL_DATA) + self.REAL_DATA, index_dtype) - def testReduceProdC64(self): + def testReduceProdC64(self, index_dtype): self._testReduction(math_ops.reduce_prod, np.prod, np.complex64, - self.COMPLEX_DATA) + self.COMPLEX_DATA, index_dtype) - def testReduceMin(self): + def testReduceMin(self, index_dtype): def reference_min(dtype, inp, axis): """Wrapper around np.amin that returns +infinity for an empty input.""" @@ -119,9 +123,9 @@ class ReduceOpsTest(xla_test.XLATestCase): [np.float32, np.int32, np.int64]): self._testReduction(math_ops.reduce_min, functools.partial(reference_min, dtype), dtype, - self.REAL_DATA) + self.REAL_DATA, index_dtype) - def testReduceMax(self): + def testReduceMax(self, index_dtype): def reference_max(dtype, inp, axis): """Wrapper around np.amax that returns -infinity for an empty input.""" @@ -137,23 +141,25 @@ class ReduceOpsTest(xla_test.XLATestCase): [np.float32, np.int32, np.int64]): self._testReduction(math_ops.reduce_max, functools.partial(reference_max, dtype), dtype, - self.REAL_DATA) + self.REAL_DATA, index_dtype) - def testReduceMeanF32(self): + def testReduceMeanF32(self, index_dtype): # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when # reducing across zero inputs. self._testReduction(math_ops.reduce_mean, np.mean, np.float32, - self.NONEMPTY_REAL_DATA) + self.NONEMPTY_REAL_DATA, index_dtype) - def testReduceMeanC64(self): + def testReduceMeanC64(self, index_dtype): self._testReduction(math_ops.reduce_mean, np.mean, np.complex64, - self.NONEMPTY_COMPLEX_DATA) + self.NONEMPTY_COMPLEX_DATA, index_dtype) - def testReduceAll(self): - self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA) + def testReduceAll(self, index_dtype): + self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA, + index_dtype) - def testReduceAny(self): - self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA) + def testReduceAny(self, index_dtype): + self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA, + index_dtype) class ReduceOpPrecisionTest(xla_test.XLATestCase): @@ -178,7 +184,7 @@ class ReduceOpPrecisionTest(xla_test.XLATestCase): """ for test_input in test_inputs: - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): a = array_ops.placeholder(dtype) index = array_ops.placeholder(dtypes.int32) diff --git a/tensorflow/compiler/tests/reduce_window_test.py b/tensorflow/compiler/tests/reduce_window_test.py index c69b6837b0f88ced844faf3713a29a1c14c8790d..ff20ea3f4287b4666684501fa4920435a77b4183 100644 --- a/tensorflow/compiler/tests/reduce_window_test.py +++ b/tensorflow/compiler/tests/reduce_window_test.py @@ -32,7 +32,7 @@ class ReduceWindowTest(xla_test.XLATestCase): """Test cases for xla.reduce_window.""" def _reduce_window(self, operand, init, reducer, **kwargs): - with self.test_session(): + with self.cached_session(): placeholder = array_ops.placeholder(operand.dtype) with self.test_scope(): output = xla.reduce_window(placeholder, init, reducer, **kwargs) diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py index 32ab5d08f0b925ee6b7b641ddba6b950149a6d20..392290fd92d0c7c928581422433892147374b2dd 100644 --- a/tensorflow/compiler/tests/reverse_ops_test.py +++ b/tensorflow/compiler/tests/reverse_ops_test.py @@ -51,7 +51,7 @@ class ReverseOpsTest(xla_test.XLATestCase): def _AssertReverseEqual(self, revdims, shape): np.random.seed(120) pval = np.random.randint(0, 100, size=shape).astype(float) - with self.test_session(): + with self.cached_session(): with self.test_scope(): p = array_ops.placeholder(dtypes.int32, shape=shape) axis = constant_op.constant( diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py index ccfa63001653537c4d1b7140e3d745c126f9034b..60c2337743b44e9bad61c4d65280eb2b1a1ad9ea 100644 --- a/tensorflow/compiler/tests/reverse_sequence_op_test.py +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -35,7 +35,7 @@ class ReverseSequenceTest(xla_test.XLATestCase): seq_lengths, truth, expected_err_re=None): - with self.test_session(): + with self.cached_session(): p = array_ops.placeholder(dtypes.as_dtype(x.dtype)) lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype)) with self.test_scope(): diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py index ff8bbac911abe73f946464663984ff1626302882..8840a1329a907bddc6ef1cb6dd1c2a6d234def5c 100644 --- a/tensorflow/compiler/tests/rmsprop_test.py +++ b/tensorflow/compiler/tests/rmsprop_test.py @@ -55,7 +55,7 @@ class RmspropTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: for centered in [False, True]: - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): # Initialize variables for numpy implementation. var0_np = np.array([1.0, 2.0], dtype=dtype) grads0_np = np.array([0.1, 0.1], dtype=dtype) diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 4292352e76ebcef7dbf41df7b857d2604a468117..897db384b7e8067b0460b5f344201f101a4d8479 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -78,7 +78,7 @@ class CumsumTest(xla_test.XLATestCase): def _compare(self, x, axis, exclusive, reverse): np_out = handle_options(np.cumsum, x, axis, exclusive, reverse) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): p = array_ops.placeholder(x.dtype) tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval( feed_dict={p: x}) @@ -100,7 +100,7 @@ class CumsumTest(xla_test.XLATestCase): for dtype in self.valid_dtypes: x = np.arange(1, 6).reshape([5]).astype(dtype) for axis_dtype in self.axis_dtypes(): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): p = array_ops.placeholder(x.dtype) axis = constant_op.constant(0, axis_dtype) math_ops.cumsum(p, axis).eval(feed_dict={p: x}) @@ -131,7 +131,7 @@ class CumsumTest(xla_test.XLATestCase): def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): input_tensor = ops.convert_to_tensor(x) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, @@ -156,7 +156,7 @@ class CumprodTest(xla_test.XLATestCase): def _compare(self, x, axis, exclusive, reverse): np_out = handle_options(np.cumprod, x, axis, exclusive, reverse) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): p = array_ops.placeholder(x.dtype) prod = math_ops.cumprod(p, axis, exclusive, reverse) tf_out = prod.eval(feed_dict={p: x}) @@ -178,7 +178,7 @@ class CumprodTest(xla_test.XLATestCase): for dtype in self.valid_dtypes: x = np.arange(1, 6).reshape([5]).astype(dtype) for axis_dtype in self.axis_dtypes(): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): p = array_ops.placeholder(x.dtype) axis = constant_op.constant(0, axis_dtype) math_ops.cumprod(x, axis).eval(feed_dict={p: x}) @@ -209,7 +209,7 @@ class CumprodTest(xla_test.XLATestCase): def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): input_tensor = ops.convert_to_tensor(x) with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index f606f88545d0b6f0b52cee9b93083a6bd91169bc..693f8513bc54e30060a2e963abd504768535a50a 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -119,7 +119,7 @@ class ScatterNdTest(xla_test.XLATestCase): self._VariableRankTest(np_scatter, tf_scatter, vtype, itype) def _runScatterNd(self, indices, updates, shape): - with self.test_session(): + with self.cached_session(): updates_placeholder = array_ops.placeholder(updates.dtype) indices_placeholder = array_ops.placeholder(indices.dtype) with self.test_scope(): diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 772c20fd424577c3e06eeae409f424b77b52aa8a..287bb0d84e24de3bdcde3aa4c61acee00626e88f 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -32,7 +32,7 @@ class SegmentReductionOpsTest(xla_test.XLATestCase): """Test cases for segment reduction ops.""" def _segmentReduction(self, op, data, indices, num_segments): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): d = array_ops.placeholder(data.dtype, shape=data.shape) if isinstance(indices, int): i = array_ops.placeholder(np.int32, shape=[]) diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index 6c4890565d2083a9493abc59bd563c4dd9fdb186..2c611a959e1d71c53e44bc92c31258153d01507d 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -29,7 +29,7 @@ class SliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[10]) with self.test_scope(): o = array_ops.slice(i, [2], [4]) @@ -40,9 +40,22 @@ class SliceTest(xla_test.XLATestCase): self.assertAllEqual([2, 3, 4, 5], result) + def testZeroSlice(self): + for dtype in self.numeric_types: + with self.cached_session(): + i = array_ops.placeholder(dtype, shape=[2]) + with self.test_scope(): + o = array_ops.slice(i, [0], [0]) + params = { + i: [0, 1], + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([], result) + def test3D(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) with self.test_scope(): o = array_ops.slice(i, [1, 2, 2], [1, 1, 4]) @@ -64,7 +77,7 @@ class SliceTest(xla_test.XLATestCase): def test3DWithDynamicBegin(self): """Tests a slice where the start offset is not known at compile time.""" for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) begin = array_ops.placeholder(dtypes.int32, shape=[3]) with self.test_scope(): @@ -88,7 +101,7 @@ class SliceTest(xla_test.XLATestCase): def test3DWithDynamicBeginAndNegativeSize(self): """Tests a slice where `begin` is fed dynamically and `size` contains -1.""" for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) begin = array_ops.placeholder(dtypes.int32, shape=[3]) with self.test_scope(): @@ -114,7 +127,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test1D(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[10]) with self.test_scope(): o = array_ops.strided_slice(i, [2], [6], [2]) @@ -127,7 +140,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test1DNegativeStride(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[10]) with self.test_scope(): o = array_ops.strided_slice(i, [6], [2], [-2]) @@ -140,7 +153,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test2DDegenerate(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[2, 3]) with self.test_scope(): o = array_ops.strided_slice(i, [-1, 0], [0, 3]) @@ -154,7 +167,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test2DDegenerateNegativeStride(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[2, 3]) with self.test_scope(): o = array_ops.strided_slice(i, [0, 0], [-1, 3], [-1, 1]) @@ -168,7 +181,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test3D(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[3, 3, 10]) with self.test_scope(): o = array_ops.strided_slice(i, [0, 2, 2], [2, 3, 6], [1, 1, 2]) @@ -189,7 +202,7 @@ class StridedSliceTest(xla_test.XLATestCase): def test3DNegativeStride(self): for dtype in self.numeric_types: - with self.test_session(): + with self.cached_session(): i = array_ops.placeholder(dtype, shape=[3, 4, 10]) with self.test_scope(): o = array_ops.strided_slice(i, [2, 2, 6], [0, 0, 2], [-1, -1, -2]) diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 7ff01be3cb4848d6bb85b8ab96b3ee1db6889791..51c04b5c4796474700a92a8b23a1cbdf533fcbb4 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -32,7 +32,7 @@ from tensorflow.python.platform import test class XlaSortOpTest(xla_test.XLATestCase): def _assertOpOutputMatchesExpected(self, op, args, expected): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): placeholders = [ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) @@ -131,7 +131,7 @@ class XlaSortOpTest(xla_test.XLATestCase): if bfloat16 not in self.numeric_types: return - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.bfloat16) with self.test_scope(): topk = nn_ops.top_k(p, k=4) @@ -153,7 +153,7 @@ class XlaSortOpTest(xla_test.XLATestCase): if bfloat16 not in self.numeric_types: return - with self.test_session() as sess: + with self.cached_session() as sess: p = array_ops.placeholder(dtypes.bfloat16) with self.test_scope(): topk = nn_ops.top_k(p, k=6) diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index c685bc548f9f6f8f7723c6f94dfd45f5420b4a67..33b84cec7188c85a3bacb20a6df29c73adbd107c 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -72,7 +72,7 @@ class SpaceToBatchTest(xla_test.XLATestCase): """Tests input-output pairs for the SpaceToBatch and BatchToSpace ops.""" def _testPad(self, inputs, paddings, block_size, outputs): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self.float_types: # outputs = space_to_batch(inputs) placeholder = array_ops.placeholder(dtype) @@ -155,7 +155,7 @@ class SpaceToBatchNDTest(xla_test.XLATestCase): def _testPad(self, inputs, block_shape, paddings, outputs): block_shape = np.array(block_shape) paddings = np.array(paddings).reshape((len(block_shape), 2)) - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self.float_types: # TODO(b/68813416): Skip bfloat16's as the input type for direct is # float32 and results in a mismatch, while making testDirect provide the diff --git a/tensorflow/compiler/tests/sparse_to_dense_op_test.py b/tensorflow/compiler/tests/sparse_to_dense_op_test.py index 3db8101c4bfbb1b53c7318a36519612984d6f179..07afd1ab3fb78d5accc52ee2382af0b9fb8079d3 100644 --- a/tensorflow/compiler/tests/sparse_to_dense_op_test.py +++ b/tensorflow/compiler/tests/sparse_to_dense_op_test.py @@ -45,32 +45,32 @@ def _SparseToDense(sparse_indices, class SparseToDenseTest(xla_test.XLATestCase): def testInt(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], 1, 0) np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32) self.assertAllClose(np_ans, tf_ans) def testFloat(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], 1.0, 0.0) np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32) self.assertAllClose(np_ans, tf_ans) def testSetValue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], [1, 2], -1) np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32) self.assertAllClose(np_ans, tf_ans) def testSetSingleValue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([1, 3], [5], 1, -1) np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32) self.assertAllClose(np_ans, tf_ans) def test2d(self): # pylint: disable=bad-whitespace - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([[1, 3], [2, 0]], [3, 4], 1, -1) np_ans = np.array([[-1, -1, -1, -1], [-1, -1, -1, 1], @@ -78,12 +78,12 @@ class SparseToDenseTest(xla_test.XLATestCase): self.assertAllClose(np_ans, tf_ans) def testZeroDefault(self): - with self.test_session(): + with self.cached_session(): x = sparse_ops.sparse_to_dense(2, [4], 7).eval() self.assertAllEqual(x, [0, 0, 7, 0]) def test3d(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1) np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1 np_ans[1, 3, 0] = 1 @@ -91,25 +91,25 @@ class SparseToDenseTest(xla_test.XLATestCase): self.assertAllClose(np_ans, tf_ans) def testBadShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"): _SparseToDense([1, 3], [[5], [3]], 1, -1) def testBadValue(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): with self.assertRaisesOpError( r"sparse_values has incorrect shape \[2,1\], " r"should be \[\] or \[2\]"): _SparseToDense([1, 3], [5], [[5], [3]], -1) def testBadNumValues(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): with self.assertRaisesOpError( r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"): _SparseToDense([1, 3], [5], [1, 2, 3], -1) def testBadDefault(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): with self.assertRaisesOpError("default_value should be a scalar"): _SparseToDense([1, 3], [5], [1, 2], [0]) diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py index b7dd787feff2b22a9cfb5d43a4ba6ceb6eb0b301..720595a159eea997be2246c4c7dad49612b257eb 100644 --- a/tensorflow/compiler/tests/stack_ops_test.py +++ b/tensorflow/compiler/tests/stack_ops_test.py @@ -31,7 +31,7 @@ from tensorflow.python.platform import test class StackOpTest(xla_test.XLATestCase): def testStackPushPop(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): size = array_ops.placeholder(dtypes.int32) v = array_ops.placeholder(dtypes.float32) h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo") @@ -41,7 +41,7 @@ class StackOpTest(xla_test.XLATestCase): self.assertAllClose([[4.0, 5.0]], c1.eval({size: 5, v: [[4.0, 5.0]]})) def testStackPushPopSwap(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): a = np.arange(2000) x = array_ops.placeholder(dtypes.float32) h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") @@ -51,7 +51,7 @@ class StackOpTest(xla_test.XLATestCase): self.assertAllClose(a, c1.eval({x: a})) def testMultiStack(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): v = array_ops.placeholder(dtypes.float32) h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") c1 = gen_data_flow_ops.stack_push_v2(h1, v) @@ -66,7 +66,7 @@ class StackOpTest(xla_test.XLATestCase): def testSameNameStacks(self): """Different stacks with the same name do not interfere.""" - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): v1 = array_ops.placeholder(dtypes.float32) v2 = array_ops.placeholder(dtypes.float32) h1 = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") @@ -84,14 +84,14 @@ class StackOpTest(xla_test.XLATestCase): self.assertAllClose(out2, 5.0) def testCloseStack(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): size = array_ops.placeholder(dtypes.int32) h = gen_data_flow_ops.stack_v2(size, dtypes.float32, stack_name="foo") c1 = gen_data_flow_ops.stack_close_v2(h) sess.run(c1, {size: 5}) def testPushCloseStack(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): v = array_ops.placeholder(dtypes.float32) h = gen_data_flow_ops.stack_v2(5, dtypes.float32, stack_name="foo") c = gen_data_flow_ops.stack_push_v2(h, v) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index d162675ef840131485128414b4a29e3cd89c8761..1bea7d9355e40c5a71f848dabc0fa7fa760429d2 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -38,7 +38,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): def testDeterminism(self): # Stateless values should be equal iff the seeds are equal (roughly) - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) seeds = [(x, y) for x in range(5) for y in range(5)] * 3 for stateless_op in [ @@ -55,7 +55,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): self.assertEqual(s0 == s1, np.all(v0 == v1)) def testRandomUniformIsInRange(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) x = stateless.stateless_random_uniform( @@ -74,7 +74,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): def testDistributionOfStatelessRandomUniform(self): """Use Pearson's Chi-squared test to test for uniformity.""" - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 1000 @@ -88,7 +88,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): self.assertTrue(self._chi_squared(y, 10) < 16.92) def testRandomNormalIsFinite(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) x = stateless.stateless_random_uniform( @@ -111,7 +111,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): def testDistributionOfStatelessRandomNormal(self): """Use Anderson-Darling test to test distribution appears normal.""" - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 1000 @@ -126,7 +126,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): def testTruncatedNormalIsInRange(self): # TODO(b/34339814): implement inverse erf support for non-F32 types. for dtype in [dtypes.float32]: - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 10000000 x = stateless.stateless_truncated_normal( diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index f332aa2e9b97e13654cf9b10588c18fed32f7ad4..78244d0b366d9128a4c59f786e4c5ac12e743b75 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -44,7 +44,7 @@ def _make_converter(dtype): class TensorArrayTest(xla_test.XLATestCase): def testTensorArrayWriteRead(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -66,7 +66,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([], flow_val.shape) def _testTensorArrayWritePack(self, tf_dtype): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) @@ -86,7 +86,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayWritePack(dtype) def testEmptyTensorArrayPack(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) @@ -100,7 +100,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([3, 0, 1], c0.eval().shape) def _testTensorArrayWriteConcat(self, tf_dtype): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) @@ -121,7 +121,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayWriteConcat(dtype) def _testTensorArrayUnpackRead(self, tf_dtype): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) @@ -176,7 +176,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayUnpackReadMaybeLegacy() def _testTensorArraySplitRead(self, tf_dtype): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=tf_dtype, tensor_array_name="foo", size=3) @@ -228,7 +228,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArraySplitRead(dtype) def testTensorGradArrayWriteRead(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -261,7 +261,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([[-2.0]], g_d2) def testTensorGradArrayDynamicWriteRead(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -300,7 +300,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(3, g_vs) def testTensorGradAccessTwiceReceiveSameObject(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3, element_shape=[1, 2]) @@ -317,7 +317,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([[4.0, 5.0]], d_r1_0) def testTensorArrayWriteWrongIndexOrDataTypeFails(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) @@ -331,7 +331,7 @@ class TensorArrayTest(xla_test.XLATestCase): # the first type, but try to read the other type. if len(self.float_types) > 1: dtype1, dtype2 = list(self.float_types)[:2] - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtype1, tensor_array_name="foo", size=3) @@ -347,7 +347,7 @@ class TensorArrayTest(xla_test.XLATestCase): w0.read(1) def testTensorArraySplitIncompatibleShapesFails(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -379,7 +379,7 @@ class TensorArrayTest(xla_test.XLATestCase): ta.split([1.0], [1]).flow.eval() def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtype, tensor_array_name="foo", size=3, infer_shape=False) @@ -410,7 +410,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayWriteGradientAddMultipleAdds(dtype) def testMultiTensorArray(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): h1 = tensor_array_ops.TensorArray( size=1, dtype=dtypes.float32, tensor_array_name="foo") w1 = h1.write(0, 4.0) @@ -425,7 +425,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllClose(9.0, r.eval()) def _testTensorArrayGradientWriteReadType(self, dtype): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.as_dtype(dtype), tensor_array_name="foo", @@ -478,7 +478,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayGradientWriteReadType(dtype) def _testTensorArrayGradientWritePackConcatAndRead(self): - with self.test_session() as sess, self.test_scope(): + with self.cached_session() as sess, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -513,7 +513,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayGradientWritePackConcatAndRead() def testTensorArrayReadTwice(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]]) ta_readtwice = tensor_array_ops.TensorArray( @@ -529,7 +529,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([1.0, -1.0], r1_readtwice.eval()) def _testTensorArrayGradientUnpackRead(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -557,7 +557,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayGradientUnpackRead() def testTensorArrayGradientSplitConcat(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=2) @@ -581,21 +581,21 @@ class TensorArrayTest(xla_test.XLATestCase): grad_vals[0]) def testCloseTensorArray(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) c1 = ta.close() session.run(c1) def testSizeTensorArray(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) s = ta.size() self.assertAllEqual(3, s.eval()) def testWriteCloseTensorArray(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -608,7 +608,7 @@ class TensorArrayTest(xla_test.XLATestCase): # TODO(phawkins): implement while loops. # def _testWhileLoopWritePackGradients(self, dynamic_size, dtype): # np_dtype = dtype.as_numpy_dtype - # with self.test_session() as session, self.test_scope(): + # with self.cached_session() as session, self.test_scope(): # v0 = array_ops.identity(np.arange(3 * 5, dtype=np_dtype).reshape(3, 5)) # var = variables.Variable(np.arange(100, 105, dtype=np_dtype)) # state0 = array_ops.identity(np.array([1] * 5, dtype=np_dtype)) @@ -692,7 +692,7 @@ class TensorArrayTest(xla_test.XLATestCase): # dynamic_size=True, dtype=dtypes.float32) # def testGradSerialTwoLoops(self): - # with self.test_session(), self.test_scope(): + # with self.cached_session(), self.test_scope(): # num_steps = 100 # acc = tensor_array_ops.TensorArray( # dtype=dtypes.float32, @@ -725,7 +725,7 @@ class TensorArrayTest(xla_test.XLATestCase): # self.assertAllClose(31.0, grad.eval()) def testSumOfTwoReadVariablesWithoutRepeatGrad(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): a = array_ops.identity( np.arange( 3 * 5, dtype=np.float32).reshape(3, 5) + 1) @@ -757,7 +757,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(joint_grad_b_t, g0) def testWriteShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) c0 = constant_op.constant([4.0, 5.0]) @@ -781,7 +781,7 @@ class TensorArrayTest(xla_test.XLATestCase): w0.write(0, c2) def testPartlyUnknownShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=6) @@ -821,7 +821,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list()) def _testUnpackShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -846,7 +846,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testUnpackShape() def testSplitShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -867,7 +867,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape()) def testWriteUnknownShape(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -879,7 +879,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(r0.get_shape(), tensor_shape.unknown_shape()) def _testGradientWhenNotAllComponentsRead(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2) x = constant_op.constant([2.0, 3.0]) w = ta.unstack(x) @@ -893,7 +893,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testGradientWhenNotAllComponentsRead() def _testTensorArrayEvalEmpty(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, size=0, infer_shape=False) with self.assertRaisesOpError( @@ -906,7 +906,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayEvalEmpty() def _testTensorArrayEvalEmptyWithDefault(self): - with self.test_session(), self.test_scope(): + with self.cached_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, size=0, infer_shape=True) self.assertEqual(0, ta.size().eval()) @@ -921,7 +921,7 @@ class TensorArrayTest(xla_test.XLATestCase): self._testTensorArrayEvalEmptyWithDefault() def testTensorArrayScatterReadAndGradients(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -946,7 +946,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0]) def testTensorArrayWriteGatherAndGradients(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", @@ -974,7 +974,7 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(expected_grad, grad_vals[0]) def testTensorArrayIdentity(self): - with self.test_session() as session, self.test_scope(): + with self.cached_session() as session, self.test_scope(): ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2, infer_shape=False) ta1 = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=4, diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index effa5a59fee7dda543b2c409dfaa27a972a55808..55a992195f2df72677b77757ae86171fa662439f 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -31,7 +31,7 @@ from tensorflow.python.platform import googletest class TernaryOpsTest(xla_test.XLATestCase): def _testTernary(self, op, a, b, c, expected): - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 124cf9da813861fb3774e3bb29ad947af1598059..5b0e57f83ff4b5a8d1891bef0675074bd67addce 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -65,7 +65,7 @@ class UnaryOpsTest(xla_test.XLATestCase): rtol: relative tolerance for equality test. atol: absolute tolerance for equality test. """ - with self.test_session() as session: + with self.cached_session() as session: with self.test_scope(): pinp = array_ops.placeholder( dtypes.as_dtype(inp.dtype), inp.shape, name="a") @@ -202,7 +202,7 @@ class UnaryOpsTest(xla_test.XLATestCase): # Disable float16 testing for now if dtype != np.float16: x = np.arange(-10, 10, 1).astype(dtype) - with self.test_session() as session: + with self.cached_session() as session: erf_x = session.run(math_ops.erf(x)) erfc_x = session.run(math_ops.erfc(x)) diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py index b637cf31cfc303ebe84ce8307ef4ad8b0b5cd720..4ee144beb7f3243be069d59ee4a613484fe183b3 100644 --- a/tensorflow/compiler/tests/while_test.py +++ b/tensorflow/compiler/tests/while_test.py @@ -43,7 +43,7 @@ class WhileTest(xla_test.XLATestCase): def loop_cond(step): return step < 10 - with self.test_session() as sess: + with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) with self.test_scope(): loop_outputs = xla.while_loop([init_index], loop_cond, loop_body) @@ -65,7 +65,7 @@ class WhileTest(xla_test.XLATestCase): del rsum return step < 10 - with self.test_session() as sess: + with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) init_sum = array_ops.placeholder(dtypes.float32, []) with self.test_scope(): @@ -91,7 +91,7 @@ class WhileTest(xla_test.XLATestCase): del rsum return step < 10 - with self.test_session() as sess: + with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) init_sum = array_ops.placeholder(dtypes.complex64, []) with self.test_scope(): @@ -117,7 +117,7 @@ class WhileTest(xla_test.XLATestCase): del x return step < 10 - with self.test_session() as sess: + with self.cached_session() as sess: init_index = array_ops.placeholder(dtypes.int32, []) with self.test_scope(): loop_outputs = xla.while_loop([init_index, 42], loop_cond, loop_body) diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index 85084bb1240cf05f6eabfbea772df113cabe613c..28d61fb07dcb665fa0dbe3f3e566e291e24fa662 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -37,7 +37,7 @@ class XlaDeviceTest(xla_test.XLATestCase): [16384, 1], [1, 16384], [1, 20000, 1, 1]] for dtype in self.numeric_types: for shape in shapes: - with self.test_session() as sess: + with self.cached_session() as sess: with ops.device("CPU"): x = array_ops.placeholder(dtype, shape) with self.test_scope(): @@ -58,7 +58,7 @@ class XlaDeviceTest(xla_test.XLATestCase): ]) shape = (10, 10) for unsupported_dtype in test_types - self.all_types: - with self.test_session() as sess: + with self.cached_session() as sess: with ops.device("CPU"): x = array_ops.placeholder(unsupported_dtype, shape) with self.test_scope(): @@ -78,7 +78,7 @@ class XlaDeviceTest(xla_test.XLATestCase): pass def testControlTrigger(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.test_scope(): x = gen_control_flow_ops.control_trigger() sess.run(x) diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f026df6c0c28fcbceaa0493871bc12c2d23b1f --- /dev/null +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -0,0 +1,301 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA op wrappers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): + + def _assertOpOutputMatchesExpected(self, op, args, expected, + equality_fn=None): + with self.test_session() as session: + with self.test_scope(): + placeholders = [ + array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape) + for arg in args + ] + feeds = {placeholders[i]: args[i] for i in range(0, len(args))} + output = op(*placeholders) + result = session.run(output, feeds) + if not equality_fn: + equality_fn = self.assertAllClose + equality_fn(result, expected, rtol=1e-3) + + def testAdd(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.add, + args=(np.array([1, 2, 3], dtype=dtype), + np.array([4, 5, 6], dtype=dtype)), + expected=np.array([5, 7, 9], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x, y: xla.add(x, y, broadcast_dims=(0,)), + args=(np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([7, 11], dtype=dtype)), + expected=np.array([[8, 9], [14, 15]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x, y: xla.add(x, y, broadcast_dims=(1,)), + args=(np.array([[1, 2], [3, 4]], dtype=dtype), + np.array([7, 11], dtype=dtype)), + expected=np.array([[8, 13], [10, 15]], dtype=dtype)) + + def testBroadcast(self): + for dtype in self.numeric_types: + v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]) + self._assertOpOutputMatchesExpected( + lambda x: xla.broadcast(x, (7, 42)), + args=(v,), + expected=np.tile(v, (7, 42, 1, 1))) + + def testShiftRightLogical(self): + self._assertOpOutputMatchesExpected( + xla.shift_right_logical, + args=(np.array([-1, 16], dtype=np.int32), np.int32(4)), + expected=np.array([0x0FFFFFFF, 1], dtype=np.int32)) + + self._assertOpOutputMatchesExpected( + xla.shift_right_logical, + args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), + expected=np.array([0x0FFFFFFF, 1], dtype=np.uint32)) + + def testShiftRightArithmetic(self): + self._assertOpOutputMatchesExpected( + xla.shift_right_arithmetic, + args=(np.array([-1, 16], dtype=np.int32), np.int32(4)), + expected=np.array([-1, 1], dtype=np.int32)) + + self._assertOpOutputMatchesExpected( + xla.shift_right_arithmetic, + args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), + expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32)) + + PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfigProto.DEFAULT, + xla_data_pb2.PrecisionConfigProto.HIGH, + xla_data_pb2.PrecisionConfigProto.HIGHEST) + + @parameterized.parameters(*PRECISION_VALUES) + def testConv(self, precision): + for dtype in set(self.float_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + def conv_1d_fn(lhs, rhs): + dnums = xla_data_pb2.ConvolutionDimensionNumbers() + num_spatial_dims = 1 + dnums.input_batch_dimension = 0 + dnums.input_feature_dimension = 1 + dnums.output_batch_dimension = 0 + dnums.output_feature_dimension = 1 + dnums.kernel_output_feature_dimension = 0 + dnums.kernel_input_feature_dimension = 1 + dnums.input_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + dnums.kernel_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) + precision_config = None + if precision: + precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config.operand_precision.extend([precision, precision]) + return xla.conv( + lhs, + rhs, + window_strides=(1,), + padding=((2, 1),), + lhs_dilation=(1,), + rhs_dilation=(2,), + dimension_numbers=dnums) + + self._assertOpOutputMatchesExpected( + conv_1d_fn, + args=( + np.array([[[3, 4, 5, 6]]], dtype=dtype), + np.array([[[-2, -3]]], dtype=dtype), + ), + expected=np.array([[[-9, -12, -21, -26, -10]]], dtype=dtype)) + + @parameterized.parameters(*PRECISION_VALUES) + def testDotGeneral(self, precision): + for dtype in self.float_types: + + def dot_fn(lhs, rhs): + dnums = xla_data_pb2.DotDimensionNumbers() + dnums.lhs_contracting_dimensions.append(2) + dnums.rhs_contracting_dimensions.append(1) + dnums.lhs_batch_dimensions.append(0) + dnums.rhs_batch_dimensions.append(0) + precision_config = None + if precision: + precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config.operand_precision.extend([precision, precision]) + return xla.dot_general( + lhs, + rhs, + dimension_numbers=dnums, + precision_config=precision_config) + + lhs = np.array( + [ + [[1, 2], [3, 4]], + [[5, 6], [7, 8]], + ], dtype=dtype) + rhs = np.array( + [ + [[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]], + ], dtype=dtype) + self._assertOpOutputMatchesExpected( + dot_fn, + args=(lhs, rhs), + expected=np.array( + [ + [[9, 12, 15], [19, 26, 33]], + [[95, 106, 117], [129, 144, 159]], + ], + dtype=dtype)) + + def testNeg(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + xla.neg, + args=(np.array([1, 2, 3], dtype=dtype),), + expected=np.array([-1, -2, -3], dtype=dtype)) + + def testPad(self): + for dtype in self.numeric_types: + + def pad_fn(x): + return xla.pad( + x, + padding_value=7, + padding_low=[2, 1], + padding_high=[1, 2], + padding_interior=[1, 0]) + + self._assertOpOutputMatchesExpected( + pad_fn, + args=(np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]),), + expected=np.array( + [[7, 7, 7, 7, 7], [7, 7, 7, 7, 7], [7, 0, 1, 7, 7], + [7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]], + dtype=dtype)) + + def testReduce(self): + for dtype in set(self.numeric_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + @function.Defun(dtype, dtype) + def sum_reducer(x, y): + return x + y + + def sum_reduction(dims): + + def fn(x): + return xla.reduce( + x, init_value=0, dimensions_to_reduce=dims, reducer=sum_reducer) + + return fn + + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4])) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[0]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([12, 15, 18, 21], dtype=dtype)) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[1]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([6, 22, 38], dtype=dtype)) + self._assertOpOutputMatchesExpected( + sum_reduction(dims=[0, 1]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=dtype(66)) + + @function.Defun(dtype, dtype) + def mul_reducer(x, y): + return x * y + + def mul_reduction(dims): + + def fn(x): + return xla.reduce( + x, init_value=1, dimensions_to_reduce=dims, reducer=mul_reducer) + + return fn + + self._assertOpOutputMatchesExpected( + mul_reduction(dims=[0]), + args=(np.arange(12, dtype=np.int32).astype(dtype).reshape([3, 4]),), + expected=np.array([0, 45, 120, 231], dtype=dtype)) + + def testSelectAndScatter(self): + for dtype in set(self.numeric_types).intersection( + set([dtypes.bfloat16.as_numpy_dtype, np.float32])): + + @function.Defun(dtype, dtype) + def add_scatter(x, y): + return x + y + + @function.Defun(dtype, dtype) + def ge_select(x, y): + return x >= y + + def test_fn(operand, source): + return xla.select_and_scatter( + operand, + window_dimensions=[2, 3, 1, 1], + window_strides=[2, 2, 1, 1], + padding=[[0, 0]] * 4, + source=source, + init_value=0, + select=ge_select, + scatter=add_scatter) + + self._assertOpOutputMatchesExpected( + test_fn, + args=(np.array( + [[7, 2, 5, 3, 8], [3, 8, 9, 3, 4], [1, 5, 7, 5, 6], + [0, 6, 2, 10, 2]], + dtype=dtype).reshape((4, 5, 1, 1)), + np.array([[2, 6], [3, 1]], dtype=dtype).reshape((2, 2, 1, 1))), + expected=np.array( + [[0, 0, 0, 0, 0], [0, 0, 8, 0, 0], [0, 0, 3, 0, 0], + [0, 0, 0, 1, 0]], + dtype=dtype).reshape((4, 5, 1, 1))) + + def testTranspose(self): + for dtype in self.numeric_types: + v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2]) + self._assertOpOutputMatchesExpected( + lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 85fd0c9217d8e56be564f915e5c950d4fadc4e59..92e577bb7b930f5b9139e361cafb8628daede455 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -39,6 +39,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -88,6 +89,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -221,13 +223,11 @@ cc_library( srcs = [ "literal_util.cc", "shape_util.cc", - "str_util.cc", "type_util.cc", ], hdrs = [ "literal_util.h", "shape_util.h", - "str_util.h", "type_util.h", ], visibility = [":friends"], @@ -256,6 +256,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -307,6 +308,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -374,19 +376,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - ], -) - -tf_cc_test( - name = "str_util_test", - srcs = [ - "str_util_test.cc", - ], - deps = [ - ":common", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -459,6 +449,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -482,6 +473,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:graph", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -609,3 +601,30 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "resource_operation_table", + srcs = ["resource_operation_table.cc"], + hdrs = ["resource_operation_table.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/algorithm:container", + ], +) + +tf_cc_test( + name = "resource_operation_table_test", + srcs = ["resource_operation_table_test.cc"], + deps = [ + ":resource_operation_table", + ":xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index de1008803d69fefa415c7bdbe6c27a62e625b417..e8673d77903bd5a1a85412e9dfa86437f73d56bc 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -23,11 +23,11 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" namespace tensorflow { - // Backwards dataflow analysis that finds arguments to a graph that must be // compile-time constants. Status BackwardsConstAnalysis(const Graph& g, - std::vector* compile_time_const_args) { + std::vector* compile_time_const_args, + std::vector* compile_time_const_nodes) { // Operators that don't look at the data of their inputs, just the shapes. const std::unordered_set metadata_ops = { "Rank", @@ -36,9 +36,16 @@ Status BackwardsConstAnalysis(const Graph& g, "Size", }; + std::vector compile_time_const_nodes_impl; + if (compile_time_const_nodes) { + CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids()); + } else { + compile_time_const_nodes_impl.resize(g.num_node_ids()); + compile_time_const_nodes = &compile_time_const_nodes_impl; + } + Status status; - std::unordered_set must_be_const; - auto visit = [&status, &metadata_ops, &must_be_const, + auto visit = [&status, &metadata_ops, compile_time_const_nodes, compile_time_const_args](Node* node) { if (!status.ok()) return; @@ -47,17 +54,19 @@ Status BackwardsConstAnalysis(const Graph& g, // If this node must be const, and it isn't a metadata op, then all of its // parents must be const. - if (must_be_const.find(node) != must_be_const.end()) { + if ((*compile_time_const_nodes)[node->id()]) { if (node->type_string() == "_Arg") { int index; status = GetNodeAttr(node->attrs(), "index", &index); if (!status.ok()) return; - compile_time_const_args->at(index) = true; + if (compile_time_const_args) { + (*compile_time_const_args)[index] = true; + } return; } for (const Edge* pred : node->in_edges()) { if (!pred->IsControlEdge()) { - must_be_const.insert(pred->src()); + (*compile_time_const_nodes)[pred->src()->id()] = true; } } return; @@ -80,7 +89,7 @@ Status BackwardsConstAnalysis(const Graph& g, for (Edge const* edge : node->in_edges()) { if (edge->dst_input() >= name_range->second.first && edge->dst_input() < name_range->second.second) { - must_be_const.insert(edge->src()); + (*compile_time_const_nodes)[edge->src()->id()] = true; } } } diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h index 634b97d7e3760c0344c948a56353ade243284aa6..af57e5a4033248e3fd32dabeda252c4ca0a44050 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.h +++ b/tensorflow/compiler/tf2xla/const_analysis.h @@ -23,10 +23,18 @@ limitations under the License. namespace tensorflow { -// Backwards dataflow analysis that finds arguments (_Arg nodes) to a graph that -// must be compile-time constants. +// Backwards dataflow analysis that finds nodes in a graph that must be +// compile-time constants for us to be able to lower the graph to XLA. +// +// The indices of the arguments to `graph` that must be constant are returned in +// `compile_time_const_arg_indices`, if `compile_time_const_arg_indices` is not +// null. +// +// The ids of the nodes in `graph` that must be constant are returned in +// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null. Status BackwardsConstAnalysis(const Graph& graph, - std::vector* compile_time_const_args); + std::vector* compile_time_const_arg_indices, + std::vector* compile_time_const_nodes); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc index 992b12c06db5efc0ae54284d0ea77017c1c79aca..56065be894697bc72ecc0089c665c19aafee7bf8 100644 --- a/tensorflow/compiler/tf2xla/const_analysis_test.cc +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -38,17 +39,23 @@ TEST(ConstAnalysisTest, Basics) { auto c = ops::Reshape(root, arg2, b); auto d = ops::Mul(root, c, ops::Sum(root, arg3, arg3)); - Graph graph(OpRegistry::Global()); - TF_ASSERT_OK(root.ToGraph(&graph)); + FixupSourceAndSinkEdges(root.graph()); std::vector const_args(4, false); - TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + std::vector const_nodes(root.graph()->num_node_ids(), false); + TF_ASSERT_OK( + BackwardsConstAnalysis(*root.graph(), &const_args, &const_nodes)); // Arg 0 doesn't need to be constant since the graph only uses its shape. // Arg 1 must be constant because it flows to the shape argument of a Reshape. // Arg 2 is used only as the value input to a Reshape and need not be const. // Arg 3 is used as the reduction-indices argument to Sum and must be const. EXPECT_EQ(const_args, std::vector({false, true, false, true})); + + EXPECT_FALSE(const_nodes[arg0.node()->id()]); + EXPECT_TRUE(const_nodes[arg1.node()->id()]); + EXPECT_FALSE(const_nodes[arg2.node()->id()]); + EXPECT_TRUE(const_nodes[arg3.node()->id()]); } // Regression test for a case where the backward const analysis did @@ -73,7 +80,8 @@ TEST(ConstAnalysisTest, TopologicalOrder) { TF_ASSERT_OK(root.ToGraph(&graph)); std::vector const_args(3, false); - TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, + /*compile_time_const_nodes=*/nullptr)); EXPECT_EQ(const_args, std::vector({true, true, false})); } @@ -93,7 +101,8 @@ TEST(ConstAnalysisTest, DontFollowControlDependencies) { TF_ASSERT_OK(root.ToGraph(&graph)); std::vector const_args(2, false); - TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args)); + TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, + /*compile_time_const_nodes=*/nullptr)); EXPECT_EQ(const_args, std::vector({false, true})); } diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index f14cfca4eaf654abc1d37c8abd34fbdae2bd24d7..b5667ca0d3ba35bea9da2d702b5b49fb38fe6f02 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" @@ -52,11 +53,10 @@ string DebugString(CondStateMap::CondId cond_state) { if (cond_state == nullptr || cond_state->empty()) return "[]"; return strings::StrCat( "[", - tensorflow::str_util::Join( - *cond_state, ", ", - [](string* output, const CondStateMap::CondNode& node) { - strings::StrAppend(output, node.ToString()); - }), + absl::StrJoin(*cond_state, ", ", + [](string* output, const CondStateMap::CondNode& node) { + strings::StrAppend(output, node.ToString()); + }), "]"); } @@ -169,10 +169,10 @@ using CondArgNodes = std::vector; string DebugString(const CondArgNodes& nodes) { return strings::StrCat( "[", - tensorflow::str_util::Join(nodes, ", ", - [](string* output, const CondArgNode& node) { - strings::StrAppend(output, node.ToString()); - }), + absl::StrJoin(nodes, ", ", + [](string* output, const CondArgNode& node) { + strings::StrAppend(output, node.ToString()); + }), "]"); } @@ -387,8 +387,9 @@ Status Conditional::BuildArgumentNodes() { } if (!has_input) { return errors::Internal( - "Failed to functionalize control flow with merge '", m->name(), - "' that doesn't have input on ", Branch_Name(branch), " branch."); + "Failed to functionalize control flow with merge ", + FormatNodeForError(*m), " that doesn't have input on ", + Branch_Name(branch), " branch."); } } } @@ -469,8 +470,8 @@ Status Conditional::ExtractBodies(Graph* graph) { // but revisit to improve the testing to enable making this an // error. LOG(WARNING) << errors::InvalidArgument( - "Graph contains node ", src->name(), " that feeds into node ", - dst->name(), + "Graph contains node ", FormatNodeForError(*src), + " that feeds into node ", FormatNodeForError(*dst), " but these nodes are in different control contexts (", DebugString(src_id), " vs ", DebugString(dst_id), " (detected during out edge testing)"); @@ -512,8 +513,8 @@ Status Conditional::ExtractBodies(Graph* graph) { node_map.at(src->id()) = output->CopyNode(src); } else { return errors::InvalidArgument( - "Graph contains node ", src->name(), " that feeds into node ", - dst->name(), + "Graph contains node ", FormatNodeForError(*src), + " that feeds into node ", FormatNodeForError(*dst), " but these nodes are in different control contexts (", DebugString(src_id), " vs ", DebugString(dst_id), " (detected during in edge testing)"); @@ -675,7 +676,8 @@ Status Conditional::AddOutputEdges(Graph* graph) { int dst_input = edge->dst_input(); if (edge->src_output() > 0) { return errors::Unimplemented("Output of index (", edge->src_output(), - ") of merge node ", node->name()); + ") of merge node ", + FormatNodeForError(*node)); } bool control_edge = edge->IsControlEdge(); @@ -1060,7 +1062,8 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { CondStateMap::CondId prop = StateAlongEdge(e); auto id_or = JoinCondStatesMerge(prop, cond_state_map_.LookupId(dst)); - TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", dst->name()); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst)); cond_state_map_.ResetId(dst, id_or.ValueOrDie()); } @@ -1090,7 +1093,8 @@ Status FunctionalizeCond::DetermineCondState(Node* dst) { // Joining the state between the current and propagated state. CondStateMap::CondId prop = StateAlongEdge(e); auto id_or = JoinCondStatesNonMerge(prop, cond_state_map_.LookupId(dst)); - TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", dst->name()); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst)); cond_state_map_.ResetId(dst, id_or.ValueOrDie()); } } @@ -1117,7 +1121,7 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { } if (non_dead_edge == nullptr) { - return errors::InvalidArgument("Merge node ", node->name(), + return errors::InvalidArgument("Merge node ", FormatNodeForError(*node), " has no non-dead inputs."); } cond_state_map_.MarkDead(node); @@ -1169,7 +1173,8 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { if (IsMerge(dst_node)) { auto id_or = JoinCondStatesMerge(dst_id, cond_state_map_.LookupId(dst_node)); - TF_RETURN_IF_ERROR(id_or.status()); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst_node)); cond_state_map_.ResetId(dst_node, id_or.ValueOrDie()); } else { auto id_or = diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h index a0544b69e9ea3a1bd16dcd08bc4b4638a8fc31fb..61940e3586c59ffc660eaac8f8d035fbbbdfeffd 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/graph/graph.h" @@ -43,11 +44,11 @@ xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index); template string NodesToString(const T& nodes) { return strings::StrCat("{", - str_util::Join(nodes, ",", - [](string* output, const Node* node) { - strings::StrAppend(output, - node->name()); - }), + absl::StrJoin(nodes, ",", + [](string* output, const Node* node) { + strings::StrAppend(output, + node->name()); + }), "}"); } diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index e4fdf0a6186eb69a2e3413838c91616b992ef2d6..1ed1fb3b021b27be00086b2e71cc9309e3d76049 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -57,7 +57,8 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, std::vector compile_time_constant_flags(expressions.size()); TF_RETURN_IF_ERROR( - BackwardsConstAnalysis(*graph, &compile_time_constant_flags)); + BackwardsConstAnalysis(*graph, &compile_time_constant_flags, + /*compile_time_const_nodes=*/nullptr)); args->resize(expressions.size()); for (int i = 0; i < args->size(); ++i) { @@ -145,6 +146,7 @@ Status GraphCompiler::Compile() { } OpKernelContext op_context(¶ms, n->num_outputs()); + VLOG(3) << "Translating " << params.op_kernel->name(); if (IsFunctional(n)) { TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context)); } else { diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index b1366e9e31e28406c5bf1a808b9c5670558ed9c7..c1438f893f6d3c46dd7f6c39b6aa3367a79789f0 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -22,6 +22,7 @@ tf_kernel_library( "bcast_ops.cc", "bias_ops.cc", "binary_ops.cc", + "broadcast_to_op.cc", "bucketize_op.cc", "cast_op.cc", "categorical_op.cc", @@ -100,6 +101,12 @@ tf_kernel_library( "unary_ops.cc", "unpack_op.cc", "variable_ops.cc", + "xla_broadcast_helper_op.cc", + "xla_conv_op.cc", + "xla_dot_op.cc", + "xla_pad_op.cc", + "xla_reduce_op.cc", + "xla_select_and_scatter_op.cc", ], hdrs = [ "index_ops.h", @@ -108,6 +115,8 @@ tf_kernel_library( deps = [ ":if_op", ":while_op", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index ba3b1c9dab79a387c48e8e25e4804917f328f8a0..2e383b1473590403823863f89264e5381d8e8806 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -16,6 +16,7 @@ limitations under the License. // XLA-specific Ops for broadcasting used in gradient // code. +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -51,8 +52,8 @@ class BCastArgsOp : public XlaOpKernel { BCast bcast(shapes[0], shapes[1]); OP_REQUIRES(ctx, bcast.IsValid(), errors::InvalidArgument( - "Incompatible shapes: [", str_util::Join(shapes[0], ","), - "] vs. [", str_util::Join(shapes[1], ","), "]")); + "Incompatible shapes: [", absl::StrJoin(shapes[0], ","), + "] vs. [", absl::StrJoin(shapes[1], ","), "]")); const int64 len = bcast.output_shape().size(); Tensor output(DT_INT32, TensorShape({len})); @@ -105,8 +106,8 @@ class BCastGradArgsOp : public XlaOpKernel { BCast bcast(shapes[0], shapes[1]); OP_REQUIRES(ctx, bcast.IsValid(), errors::InvalidArgument( - "Incompatible shapes: [", str_util::Join(shapes[0], ","), - "] vs. [", str_util::Join(shapes[1], ","), "]")); + "Incompatible shapes: [", absl::StrJoin(shapes[0], ","), + "] vs. [", absl::StrJoin(shapes[1], ","), "]")); Output(ctx, 0, bcast.grad_x_reduce_idx()); Output(ctx, 1, bcast.grad_y_reduce_idx()); } diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4bd7c74dca2a7cbb51f2a329ac575d635f314516 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { +namespace { + +class BroadcastToOp : public XlaOpKernel { + public: + explicit BroadcastToOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + TensorShape output_shape; + OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); + + OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(), + errors::InvalidArgument( + "Input rank (", input_shape.dims(), + ") must be less than or equal to the output rank (", + output_shape.dims(), ")")); + + auto input_dims = input_shape.dim_sizes(); + auto output_dims = output_shape.dim_sizes(); + + // Broadcasting is done right-to-left on right-aligned dimensions; reverse + // the two vectors so elements to be broadcast are aligned. + absl::c_reverse(input_dims); + absl::c_reverse(output_dims); + + std::vector broadcast_dims; + std::vector broadcast_shape; + for (int i = 0; i < output_shape.dims(); ++i) { + if (i < input_shape.dims()) { + OP_REQUIRES( + context, + (output_dims[i] == 0 && input_dims[i] == 0) || + (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0), + errors::InvalidArgument("invalid shape to broadcast from ", + input_shape.DebugString(), " to ", + output_shape.DebugString())); + + broadcast_dims.push_back(broadcast_shape.size()); + if (output_dims[i] == input_dims[i] || input_dims[i] == 1) { + broadcast_shape.push_back(output_dims[i]); + } + if (output_dims[i] != input_dims[i]) { + // Add dimensions [I, O/I], which we will later flatten to just + // [O]. We must do this in two phases since XLA broadcasting does not + // support tiling. + broadcast_shape.push_back(input_dims[i]); + broadcast_shape.push_back(output_dims[i] / input_dims[i]); + } + } else { + broadcast_shape.push_back(output_dims[i]); + } + } + absl::c_reverse(broadcast_dims); + int broadcast_shape_size = broadcast_shape.size(); + for (int64& broadcast_dim : broadcast_dims) { + broadcast_dim = broadcast_shape_size - broadcast_dim - 1; + } + absl::c_reverse(broadcast_shape); + xla::XlaOp output = xla::Reshape( + xla::BroadcastInDim(context->Input(0), + xla::ShapeUtil::MakeShape( + context->input_xla_type(0), broadcast_shape), + broadcast_dims), + output_shape.dim_sizes()); + context->SetOutput(0, output); + } +}; + +REGISTER_XLA_OP(Name("BroadcastTo").CompileTimeConstInput("shape"), + BroadcastToOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index ed44ad218b6dc073583ec339da082b6881ad672d..70c3eaf66bbd6470734d1e5fc9978510022ac7bc 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -178,7 +178,7 @@ class MatrixDiagOp : public XlaOpKernel { int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); tensorflow::gtl::ArraySlice other_dims(dims); - other_dims.pop_back(); + other_dims.remove_suffix(1); xla::XlaOp input = ctx->Input(0); xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims, diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 8d75624e74028ea083c3facc4f9578ec14c50e6d..8e071bf0b7ae638888818ea8cd5d63b5d543342e 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -32,13 +32,13 @@ namespace { // // 1. S := (N - 1) / gcd(N-1, R-1) // 2. k := (R - 1) / gcd(N-1, R-1) -// 3. Convolution(kxk, stride=S, lhs_dilation=k, padding=k-1) +// 3. Convolution((2k-1)x(2k-1), stride=S, lhs_dilation=k, padding=k-1) // // For example, to Scale from 7x7 -> 15x15: // // 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3 // 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7 -// 3. Convolution(7x7, stride=3, lhs_dilation=3, padding=2) +// 3. Convolution(15x15, stride=3, lhs_dilation=7, padding=2) // // // The 7x7 -> 15x15 case is much too large to write out in full as an @@ -65,6 +65,8 @@ namespace { // 1/9 * 3 6 9 6 3 // 2 4 6 4 2 // 1 2 3 2 1 +// Note that the convolution kernel matrix is separable and thus we can instead +// use 2 consecutive 1D kernel of the dimension 2k-1, along each axis. // Computes the size of the convolutional kernel and stride to use when resizing // from in_size to out_size. @@ -76,7 +78,8 @@ struct ResizeConvolutionDims { std::vector stride; }; ResizeConvolutionDims ComputeResizeConvolutionParameters( - gtl::ArraySlice in_size, gtl::ArraySlice out_size) { + gtl::ArraySlice in_size, gtl::ArraySlice out_size, + bool align_corners) { CHECK_EQ(in_size.size(), out_size.size()); int num_spatial_dims = in_size.size(); ResizeConvolutionDims dims; @@ -92,15 +95,32 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( // entry before resizing. dims.stride[i] = dims.kernel_size[i] = 1; } else { - int64 gcd = MathUtil::GCD(static_cast(in_size[i] - 1), - static_cast(out_size[i] - 1)); - dims.stride[i] = (in_size[i] - 1) / gcd; - dims.kernel_size[i] = (out_size[i] - 1) / gcd; + // The scaling factor changes depending on the alignment of corners. + const int64 in_size_factor = align_corners ? in_size[i] - 1 : in_size[i]; + const int64 out_size_factor = + align_corners ? out_size[i] - 1 : out_size[i]; + + int64 gcd = MathUtil::GCD(static_cast(in_size_factor), + static_cast(out_size_factor)); + dims.stride[i] = in_size_factor / gcd; + dims.kernel_size[i] = out_size_factor / gcd; } } return dims; } +// The upper padding of the input needed by ConvGeneralDilated calls is +// determined by solving two related relationships (assuming rhs_dilation == 0): +// 1. dilated_input_dim = lower_padding + upper_padding +// + lhs_dilation * (in_size - 1) + 1 +// 2. dilated_input_dim = (2 * dims.kernel-size - 1) +// + dims.stride * (out_size - 1) +int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, + int64 stride) { + return (2 * kernel_size - 1) + (out_size - 1) * stride - (kernel_size - 1) - + 1 - (kernel_size * (in_size - 1)); +} + // Form a 2D convolution kernel like: // 1 2 3 2 1 // 2 4 6 4 2 @@ -171,7 +191,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, const int num_spatial_dims, std::vector in_size, std::vector out_size, - const int64 channels) { + const int64 channels, + const bool align_corners) { // Picture for a 1x3 to 1x4 resize: // stride = 2, kernel size = 3 // Input: @@ -196,27 +217,82 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); ResizeConvolutionDims dims = - ComputeResizeConvolutionParameters(in_size, out_size); + ComputeResizeConvolutionParameters(in_size, out_size, align_corners); xla::XlaOp output; - // Split convolutions into independent dimensions if they wmuld be a very + + // Concatenation and padding below currently assumes num_spatial_dims is 2 to + // prevent needless code complexity. + CHECK_EQ(num_spatial_dims, 2) + << "ResizeUsingDilationAndConvolution pads only 2 dimensions currently."; + std::vector upper_padding(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + upper_padding[i] = dims.kernel_size[i] - 1; + } + xla::XlaOp input_data = input; + + if (!align_corners) { + // When Tensorflow does not align_corners, the resize indexing can access + // beyond the upper bound and is instead clamped to prevent out of bounds + // reads. This is conceptually the same as extending the edges of the input. + // We emulate this by copying the last row/column of the input. + // Calculate what padding would be needed then determine how far to extend + // the border before lhs dilation. + std::vector num_extended(num_spatial_dims); + upper_padding[0] = CalculateUpperPadding( + in_size[0], out_size[0], dims.kernel_size[0], dims.stride[0]); + upper_padding[1] = CalculateUpperPadding( + in_size[1], out_size[1], dims.kernel_size[1], dims.stride[1]); + num_extended[0] = upper_padding[0] / (dims.kernel_size[0]); + num_extended[1] = upper_padding[1] / (dims.kernel_size[1]); + + if (num_extended[0] > 0) { + auto slice = + xla::Slice(input_data, {0, in_size[0] - 1, 0, 0}, + {1, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); + for (int i = 0; i < num_extended[0]; i++) { + input_data = xla::ConcatInDim(builder, {input_data, slice}, 1); + } + } + + if (num_extended[1] > 0) { + auto slice = + xla::Slice(input_data, {0, 0, in_size[1] - 1, 0}, + {1, in_size[0] + num_extended[0], in_size[1], channels}, + {1, 1, 1, 1}); + for (int i = 0; i < num_extended[1]; i++) { + input_data = xla::ConcatInDim(builder, {input_data, slice}, 2); + } + } + + // Setting in_size to (in_size + num_extended) due to the above Slice and + // ConcatInDim. Recalculate needed padding after the above Slice/Concat. + upper_padding[0] = + CalculateUpperPadding(in_size[0] + num_extended[0], out_size[0], + dims.kernel_size[0], dims.stride[0]); + upper_padding[1] = + CalculateUpperPadding(in_size[1] + num_extended[1], out_size[1], + dims.kernel_size[1], dims.stride[1]); + } + + // Split convolutions into independent dimensions if they would be a very // large kernel. if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { xla::XlaOp kernel = MakeBilinearResizeKernel(builder, dims.kernel_size, channels); - output = xla::ConvGeneralDilated( - input, kernel, dims.stride, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.kernel_size, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + output = + xla::ConvGeneralDilated(input_data, kernel, dims.stride, + /*padding=*/ + {{dims.kernel_size[0] - 1, upper_padding[0]}, + {dims.kernel_size[1] - 1, upper_padding[1]}}, + /*lhs_dilation=*/dims.kernel_size, + /*rhs_dilation=*/{1, 1}, dimension_numbers); } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); output = xla::ConvGeneralDilated( - input, kernel0, {dims.stride[0], 1}, + input_data, kernel0, {dims.stride[0], 1}, /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}}, /*lhs_dilation=*/{dims.kernel_size[0], 1}, /*rhs_dilation=*/{1, 1}, dimension_numbers); xla::XlaOp kernel1 = @@ -224,7 +300,7 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, output = xla::ConvGeneralDilated( output, kernel1, {1, dims.stride[1]}, /*padding=*/ - {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}}, /*lhs_dilation=*/{1, dims.kernel_size[1]}, /*rhs_dilation=*/{1, 1}, dimension_numbers); } @@ -245,9 +321,10 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, const int num_spatial_dims, std::vector in_size, std::vector grad_size, - const int64 channels) { + const int64 channels, + const bool align_corners) { ResizeConvolutionDims dims = - ComputeResizeConvolutionParameters(in_size, grad_size); + ComputeResizeConvolutionParameters(in_size, grad_size, align_corners); // To form the backward convolution, we keep the kernel unchanged (it is // already symmetric) and swap the roles of strides and LHS dilation. @@ -341,10 +418,6 @@ class ResizeBilinearOp : public XlaOpKernel { public: explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); - OP_REQUIRES( - ctx, align_corners_ == true, - errors::Unimplemented( - "ResizeBilinear with align_corners=False is not yet implemented")); } void Compile(XlaOpKernelContext* ctx) override { @@ -377,20 +450,19 @@ class ResizeBilinearOp : public XlaOpKernel { // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in // dimension i. - std::vector slice_size = in_size; bool slice_input = false; for (int i = 0; i < num_spatial_dims; ++i) { if (in_size[i] > 1 && out_size[i] == 1) { // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first // entry before resizing. slice_input = true; - slice_size[i] = 1; + in_size[i] = 1; } } if (slice_input) { - input = xla::Slice(input, {0, 0, 0, 0}, - {batch, slice_size[0], slice_size[1], channels}, - {1, 1, 1, 1}); + input = + xla::Slice(input, {0, 0, 0, 0}, + {batch, in_size[0], in_size[1], channels}, {1, 1, 1, 1}); } // Output is always type float. @@ -406,6 +478,9 @@ class ResizeBilinearOp : public XlaOpKernel { // operations along different dimensions. // Given sufficient numerical stability and a cxd is same as resizing axb -> exf -> cxd. + // This does not work in the case of align_corners_=false because of special + // padding requirements that cause multiple resizes to be very different + // from a single resize. // // This makes the convolutions kernels smaller and the operation faster. xla::XlaOp output = input; @@ -415,21 +490,24 @@ class ResizeBilinearOp : public XlaOpKernel { (static_cast(out_size[0]) - 1) / ((in_size[0] - 1) * 2), (static_cast(out_size[1]) - 1) / ((in_size[1] - 1) * 2)}; if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && - k[0] > 1 && k[1] > 1) { + k[0] > 1 && k[1] > 1 && align_corners_) { std::vector next_out_size = {(in_size[0] - 1) * 2 + 1, (in_size[1] - 1) * 2 + 1}; - output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, next_out_size, channels); + output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, + in_size, next_out_size, + channels, align_corners_); input = output; in_size = next_out_size; } else { - output = ResizeUsingDilationAndConvolution( - b, input, num_spatial_dims, in_size, out_size, channels); + output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, + in_size, out_size, + channels, align_corners_); in_size = out_size; } } else { output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, - in_size, out_size, channels); + in_size, out_size, channels, + align_corners_); in_size = out_size; } } @@ -509,17 +587,20 @@ class ResizeBilinearGradOp : public XlaOpKernel { std::vector next_grad_size = {(in_size[0] - 1) * 2 + 1, (in_size[1] - 1) * 2 + 1}; output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, next_grad_size, channels); + b, grad, num_spatial_dims, in_size, next_grad_size, channels, + align_corners_); grad = output; in_size = next_grad_size; } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels); + b, grad, num_spatial_dims, in_size, grad_size, channels, + align_corners_); in_size = grad_size; } } else { output = ResizeUsingDilationAndConvolutionGradOp( - b, grad, num_spatial_dims, in_size, grad_size, channels); + b, grad, num_spatial_dims, in_size, grad_size, channels, + align_corners_); in_size = grad_size; } } diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index eedfc3c9140d7b1ccc1944611de98c1d49fbdaf2..2a42eeaf76ab3aa88ff3a93ef7eb7ab217964bb6 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -29,7 +29,14 @@ class MirrorPadOp : public XlaOpKernel { xla::StatusOr DoMirrorPad(const xla::XlaOp& t, const xla::Shape& original_shape, const xla::LiteralSlice& pad_literal, + const MirrorPadMode mode, xla::XlaBuilder* b) { + // The difference in the semantics of REFLECT and SYMMETRIC is that REFLECT + // will not mirror the border values while symmetric does. + // e.g. input is [1, 2, 3] and paddings is [0, 2], then the output is: + // - [1, 2, 3, 2, 1] in reflect mode + // - [1, 2, 3, 3, 2] in symmetric mode. + int64 excluded_edges = mode == MirrorPadMode::REFLECT ? 1 : 0; xla::XlaOp accum = t; for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; --dimno) { @@ -39,9 +46,19 @@ class MirrorPadOp : public XlaOpKernel { TF_ASSIGN_OR_RETURN(int64 rhs_padding, pad_literal.GetIntegralAsS64({dimno, 1})); int64 dim_size = original_shape.dimensions(dimno); - auto lhs_pad = xla::SliceInDim(t_rev, dim_size - 1 - lhs_padding, - dim_size - 1, 1, dimno); - auto rhs_pad = xla::SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno); + + // Padding amounts on each side must be no more than the size of the + // original shape. + TF_RET_CHECK(lhs_padding >= 0 && + lhs_padding <= dim_size - excluded_edges); + TF_RET_CHECK(rhs_padding >= 0 && + rhs_padding <= dim_size - excluded_edges); + + auto lhs_pad = + xla::SliceInDim(t_rev, dim_size - excluded_edges - lhs_padding, + dim_size - excluded_edges, 1, dimno); + auto rhs_pad = xla::SliceInDim(t_rev, excluded_edges, + excluded_edges + rhs_padding, 1, dimno); accum = xla::ConcatInDim(b, {lhs_pad, accum, rhs_pad}, dimno); } return accum; @@ -53,9 +70,10 @@ class MirrorPadOp : public XlaOpKernel { MirrorPadMode mode; OP_REQUIRES_OK(ctx, GetNodeAttr(def(), "mode", &mode)); - OP_REQUIRES(ctx, mode == MirrorPadMode::REFLECT, - xla::Unimplemented( - "Only REFLECT MirrorPad mode is currently supported")); + OP_REQUIRES( + ctx, mode == MirrorPadMode::REFLECT || mode == MirrorPadMode::SYMMETRIC, + xla::Unimplemented("Unsupported MirrorPad mode. Only SYMMETRIC and " + "REFLECT modes are currently supported")); const int dims = input_shape.dims(); OP_REQUIRES( @@ -83,7 +101,7 @@ class MirrorPadOp : public XlaOpKernel { xla::StatusOr in0_shape = b->GetShape(in0); OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status()); xla::StatusOr accum_status = - DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, b); + DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, mode, b); OP_REQUIRES_OK(ctx, accum_status.status()); diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index d4d180aff806f12875f0e43f111ee090f6607ef6..f6f158a73be42ea2602811ad64a2a2c655dab088 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -199,59 +199,6 @@ class MaxPool3DOp : public MaxPoolOp { }; REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); -// Divide each element of an image by the count of elements that contributed to -// that element during pooling. -static xla::XlaOp AvgPoolDivideByCount( - XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype, - const TensorShape& input_shape, xla::Padding padding, - const std::vector& ksize, const std::vector& stride, - int num_spatial_dims, TensorFormat data_format) { - if (padding == xla::Padding::kValid) { - // In VALID padding, all windows have the same number of elements - // contributing to each average. Divide by the window size everywhere to - // get the average. - int64 window_size = std::accumulate(ksize.begin(), ksize.end(), 1, - [](int64 a, int64 b) { return a * b; }); - - auto divisor = - XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size); - return xla::Div(output, divisor); - } else { - // For SAME padding, the padding shouldn't be included in the - // counts. We use another ReduceWindow to find the right counts. - - // TODO(phawkins): use a less brute-force way to compute this. Only - // the boundary regions will have interesting values here. - - std::vector input_dim_sizes(num_spatial_dims); - std::vector window_dims(num_spatial_dims); - std::vector window_ksize(num_spatial_dims); - std::vector window_stride(num_spatial_dims); - for (int i = 0; i < num_spatial_dims; ++i) { - int dim = GetTensorSpatialDimIndex(num_spatial_dims + 2, data_format, i); - input_dim_sizes[i] = input_shape.dim_size(dim); - window_dims[i] = dim; - window_ksize[i] = ksize[dim]; - window_stride[i] = stride[dim]; - } - - // Build a matrix of all 1s, with the same width/height as the input. - const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto ones = xla::Broadcast( - XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes); - - // Perform a ReduceWindow with the same window size, strides, and padding - // to count the number of contributions to each result element. - auto reduce = xla::ReduceWindow( - ones, XlaHelpers::Zero(ctx->builder(), accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride, - xla::Padding::kSame); - auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype); - - return xla::Div(output, counts, window_dims); - } -} - class AvgPoolOp : public PoolingOp { public: AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims) @@ -463,78 +410,31 @@ class AvgPoolGradOp : public XlaOpKernel { errors::InvalidArgument("out_backprop must be ", num_dims(), "-dimensional")); - int depth_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - int64 depth = out_backprop_shape.dim_size(depth_dim); - - // We can think of average-pooling as: - // * a convolution with a kernel consisting entirely of 1s, where the - // input feature and output feature are equal, and 0s everywhere else. - // * followed by dividing by the counts. - // - // This then gives us an algorithm to build the gradient: - // * divide out_backprop by the counts, followed by - // * Conv2DBackpropInput specialized for that kernel, which simplifies to - // a Pad and a ReduceWindow. - // - // For an explanation of backpropagation for convolution, see the comments - // in third_party/tensorflow/core/kernels/conv_grad_ops.h - - // TF filter shape is [ H, W, ..., inC, outC ] - std::vector filter_dims(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - filter_dims[i] = ksize_[dim]; - } - filter_dims[num_dims() - 2] = depth; - filter_dims[num_dims() - 1] = depth; - TensorShape filter_shape(filter_dims); - - // Reuse the logic from Conv2DBackpropInput to compute padding. - ConvBackpropDimensions dims; - OP_REQUIRES_OK( - ctx, ConvBackpropComputeDimensions( - type_string(), /*num_spatial_dims=*/num_spatial_dims_, - gradients_shape, filter_shape, out_backprop_shape, stride_, - padding_, data_format_, &dims)); - - // The input gradients are computed by a convolution of the output gradients - // and the filter, with some appropriate padding. See the comment at the top - // of conv_grad_ops.h for details. - xla::XlaBuilder* const b = ctx->builder(); auto out_backprop = ctx->Input(1); - auto dtype = input_type(1); + std::vector stride_int64s(stride_.begin(), stride_.end()); xla::Padding xla_padding = (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; - - // Divide the out_backprop values by the counts for each spatial position. - std::vector stride_int64s(stride_.begin(), stride_.end()); - auto out_backprop_div = AvgPoolDivideByCount( - ctx, out_backprop, dtype, gradients_shape, xla_padding, ksize_, - stride_int64s, num_spatial_dims_, data_format_); - - // Pad the gradients in the spatial dimensions. We use the same padding - // as Conv2DBackpropInput. - xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - auto* padding = padding_config.mutable_dimensions(dim); - padding->set_edge_padding_low(dims.spatial_dims[i].pad_before); - padding->set_edge_padding_high(dims.spatial_dims[i].pad_after); - padding->set_interior_padding(dims.spatial_dims[i].stride - 1); - } - - auto zero = XlaHelpers::Zero(b, dtype); - auto padded_gradients = xla::Pad(out_backprop_div, zero, padding_config); - - // in_backprop = padded_gradients ones - std::vector ones(num_dims(), 1LL); - auto accumulation_type = XlaHelpers::SumAccumulationType(dtype); - auto in_backprop = xla::ReduceWindow( - XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type), - XlaHelpers::Zero(b, accumulation_type), - *ctx->GetOrCreateAdd(accumulation_type), ksize_, - /* window_strides=*/ones, xla::Padding::kValid); - ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, in_backprop, dtype)); + xla::PrimitiveType xla_reduction_type; + auto reduction_type = XlaHelpers::SumAccumulationType(ctx->input_type(1)); + OP_REQUIRES_OK( + ctx, DataTypeToPrimitiveType(reduction_type, &xla_reduction_type)); + auto converted_out_backprop = + xla::ConvertElementType(out_backprop, xla_reduction_type); + auto xla_data_format = + XlaTensorFormat(data_format_, gradients_shape.dims() - 2); + auto padding_values = + MakeSpatialPadding(gradients_shape.dim_sizes(), ksize_, stride_int64s, + xla_padding, xla_data_format); + auto in_backprop = + xla::AvgPoolGrad(converted_out_backprop, gradients_shape.dim_sizes(), + ksize_, stride_int64s, padding_values, xla_data_format, + /*counts_include_padding=*/padding_ == VALID); + // Convert the pooling result back to the input type before returning it. + xla::PrimitiveType xla_out_backprop_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1), + &xla_out_backprop_type)); + ctx->SetOutput(0, + xla::ConvertElementType(in_backprop, xla_out_backprop_type)); } protected: diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index b11a4ce36da9907ce8fe377c075023a4540797fa..8102faad28db71075fb8da269c55edbdb667193e 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -32,41 +32,30 @@ class ReduceWindowOp : public XlaOpKernel { explicit ReduceWindowOp(OpKernelConstruction* context) : XlaOpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("computation", &computation_)); - OP_REQUIRES_OK(context, - context->GetAttr("window_dimensions", &window_dimensions_)); - OP_REQUIRES_OK(context, - context->GetAttr("window_strides", &window_strides_)); - OP_REQUIRES_OK(context, context->GetAttr("padding_low", &padding_low_)); - OP_REQUIRES_OK(context, context->GetAttr("padding_high", &padding_high_)); } void Compile(XlaOpKernelContext* context) override { const TensorShape input_shape = context->InputShape(0); const DataType dtype = context->input_type(0); + std::vector window_dimensions; + std::vector window_strides; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dimensions", &window_dimensions)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", + &window_strides)); + const int rank = input_shape.dims(); - OP_REQUIRES(context, rank == window_dimensions_.size(), + OP_REQUIRES(context, rank == window_dimensions.size(), errors::InvalidArgument( "The size of window_dimensions must be equal to the input " "rank (", - window_dimensions_.size(), " vs. ", rank, ")")); - OP_REQUIRES(context, rank == window_strides_.size(), + window_dimensions.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_strides.size(), errors::InvalidArgument( "The size of window_strides must be equal to the input " "rank (", - window_strides_.size(), " vs. ", rank, ")")); - OP_REQUIRES(context, rank == padding_low_.size(), - errors::InvalidArgument( - "The size of padding_low must be equal to the input " - "rank (", - padding_low_.size(), " vs. ", rank, ")")); - OP_REQUIRES(context, rank == padding_high_.size(), - errors::InvalidArgument( - "The size of padding_high must be equal to the input " - "rank (", - padding_high_.size(), " vs. ", rank, ")")); - - xla::XlaBuilder* builder = context->builder(); + window_strides.size(), " vs. ", rank, ")")); // Build the reducer function. XlaCompiler::Argument reducer_arg; @@ -78,6 +67,7 @@ class ReduceWindowOp : public XlaOpKernel { compile_options.use_tuple_arg = false; compile_options.resolve_compile_time_constants = false; compile_options.is_entry_computation = false; + compile_options.always_return_tuple = false; XlaCompiler::CompilationResult reducer; OP_REQUIRES_OK(context, context->compiler()->CompileFunction( compile_options, *computation_, @@ -86,51 +76,47 @@ class ReduceWindowOp : public XlaOpKernel { xla::Shape scalar_shape; OP_REQUIRES_OK(context, TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape), + errors::InvalidArgument( + "Invalid output shape of ReduceWindow reducer. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(reducer.xla_output_shape))); + + const TensorShape padding_shape = context->InputShape("padding"); OP_REQUIRES(context, - xla::ShapeUtil::Compatible( - reducer.xla_output_shape, - xla::ShapeUtil::MakeTupleShape({scalar_shape})), + TensorShapeUtils::IsMatrix(padding_shape) && + padding_shape.dim_size(1) == 2, errors::InvalidArgument( - "Invalid output shape of ReduceWindow reducer. Expected ", - xla::ShapeUtil::HumanString(scalar_shape), " got ", - xla::ShapeUtil::HumanString(reducer.xla_output_shape))); - - // Wraps the reducer in a computation that unpacks the output tuple. - xla::XlaComputation wrapper; - { - std::unique_ptr cb = - builder->CreateSubBuilder("wrapper"); - auto x = xla::Parameter(cb.get(), 0, scalar_shape, "x"); - auto y = xla::Parameter(cb.get(), 1, scalar_shape, "y"); - auto outputs = xla::Call(cb.get(), *reducer.computation, {x, y}); - xla::GetTupleElement(outputs, 0); - xla::StatusOr result = cb->Build(); - OP_REQUIRES_OK(context, result.status()); - wrapper = std::move(result.ValueOrDie()); - } - - std::vector> padding(rank); - for (int i = 0; i < rank; ++i) { - padding[i] = {padding_low_[i], padding_high_[i]}; + "padding must be a matrix with minor dimension 2, got ", + padding_shape.DebugString())); + xla::Literal padding_literal; + OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( + "padding", &padding_literal)); + std::vector> padding(padding_shape.dim_size(0)); + for (int i = 0; i < padding.size(); ++i) { + padding[i] = {padding_literal.Get({i, 0}), + padding_literal.Get({i, 1})}; } xla::XlaOp output = xla::ReduceWindowWithGeneralPadding( - context->Input(0), context->Input(1), wrapper, window_dimensions_, - window_strides_, padding); + context->Input(0), context->Input(1), *reducer.computation, + window_dimensions, window_strides, padding); context->SetOutput(0, output); } private: const NameAttrList* computation_; - std::vector window_dimensions_; - std::vector window_strides_; - std::vector padding_low_; - std::vector padding_high_; TF_DISALLOW_COPY_AND_ASSIGN(ReduceWindowOp); }; -REGISTER_XLA_OP(Name("XlaReduceWindow"), ReduceWindowOp); +REGISTER_XLA_OP(Name("XlaReduceWindow") + .CompileTimeConstInput("window_dimensions") + .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("padding"), + ReduceWindowOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index b52f0a0ab6290f2019bb58120be5c2364ec15bb6..598248563bb93146e6dea3016822d26b8bf368e7 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific reduction Ops. +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -29,9 +30,6 @@ namespace tensorflow { XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, DataType reduction_type) : XlaOpKernel(ctx), reduction_type_(reduction_type) { - const DataType dt = BaseType(input_type(0)); - OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt})); - OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); OP_REQUIRES_OK( ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_)); @@ -58,20 +56,24 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { return; } + OP_REQUIRES(ctx, axes_tensor_shape.dims() <= 1, + errors::InvalidArgument( + "Expected scalar or vector as index argument, got ", + axes_tensor_shape.DebugString())); + // Evaluate the constant, reshaping to a 1-vector if it is a scalar. + std::vector axes; xla::Literal axes_literal; - OP_REQUIRES_OK( - ctx, ctx->ConstantInputReshaped(1, {axes_tensor_shape.num_elements()}, - &axes_literal)); + OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector(1, &axes)); VLOG(1) << "data shape: " << data_shape.DebugString(); - VLOG(1) << "axes : " << axes_literal.ToString(); + VLOG(1) << "axes : " << absl::StrJoin(axes, ","); gtl::InlinedVector bitmap(data_shape.dims(), false); std::vector xla_axes; int64 num_elements_reduced = 1LL; for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) { - int32 index = axes_literal.Get({i}); + int64 index = axes[i]; OP_REQUIRES(ctx, !(index < -data_shape.dims() || index >= data_shape.dims()), errors::InvalidArgument("Invalid reduction dimension (", index, diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 64900e4709fd3e16d21096b0cfff8922906cb0d4..e172c649325adb6f7761ce0be141f21e8d545bc1 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -48,6 +48,15 @@ class RetvalOp : public XlaOpKernel { } else { xla::XlaOp input = ctx->Input(0); const TensorShape input_shape = ctx->InputShape(0); + DataType input_type = ctx->input_type(0); + XlaContext& tc = XlaContext::Get(ctx); + + if (input_type == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + ctx->SetStatus(tc.AddResourceRetval(index_, resource)); + return; + } auto is_constant = ctx->builder()->IsConstant(input); if (!is_constant.ok()) { @@ -55,7 +64,6 @@ class RetvalOp : public XlaOpKernel { return; } - XlaContext& tc = XlaContext::Get(ctx); if (tc.resolve_compile_time_constants() && (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { xla::Literal literal; @@ -104,7 +112,8 @@ class RetvalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); }; -REGISTER_XLA_OP(Name("_Retval").CompilationOnly(), RetvalOp); +REGISTER_XLA_OP(Name("_Retval").AllowResourceTypes().CompilationOnly(), + RetvalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 6ce50efb4aa6e3434a7c6009cf9f52f6cff9cc9f..d9578eca5bf11110e9770b66a4dab82c597da6ee 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -67,7 +67,7 @@ class SelectOp : public XlaOpKernel { // to get the dimensions in the right order. const auto dim_sizes = then_shape.dim_sizes(); gtl::ArraySlice bdims = dim_sizes; - bdims.pop_front(); + bdims.remove_prefix(1); cond_handle = xla::Broadcast(cond_handle, bdims); std::vector dim_order(then_shape.dims()); diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 025ba827410f1a9f993a8a1855558a2daa86609b..d6bd927135c013ac1ec3f6547aef358dc2741896 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Ops for softmax. +#include "absl/strings/match.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace { @@ -33,7 +33,7 @@ namespace { class SoftmaxOp : public XlaOpKernel { public: explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - log_ = str_util::StartsWith(type_string(), "Log"); + log_ = absl::StartsWith(type_string(), "Log"); } void Compile(XlaOpKernelContext* ctx) override { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..412afeaaad96842521fbd306f5b666e837e675fd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc @@ -0,0 +1,115 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +class XlaBroadcastHelperOp : public XlaOpKernel { + public: + explicit XlaBroadcastHelperOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaOp lhs = context->Input(0); + xla::XlaOp rhs = context->Input(1); + const TensorShape lhs_shape = context->InputShape(0); + const TensorShape rhs_shape = context->InputShape(1); + + const bool broadcast_lhs = lhs_shape.dims() < rhs_shape.dims(); + const TensorShape* min_rank_shape = broadcast_lhs ? &lhs_shape : &rhs_shape; + const TensorShape* max_rank_shape = broadcast_lhs ? &rhs_shape : &lhs_shape; + + std::vector broadcast_dims; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("broadcast_dims", + &broadcast_dims)); + if (broadcast_dims.empty()) { + OP_REQUIRES( + context, + lhs_shape.dims() == rhs_shape.dims() || lhs_shape.dims() == 0 || + rhs_shape.dims() == 0, + errors::InvalidArgument( + "If broadcast_dims is empty, both " + "arguments must have equal rank; " + "argument shapes, or at least one argument must be a scalar: ", + lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); + context->SetOutput(0, lhs); + context->SetOutput(1, rhs); + return; + } + + OP_REQUIRES( + context, broadcast_dims.size() == min_rank_shape->dims(), + errors::InvalidArgument( + "broadcast_dims must have size equal to the smaller argument rank; " + "broadcast_dims: [", + absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ", + lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); + std::vector sorted_broadcast_dims = broadcast_dims; + absl::c_sort(sorted_broadcast_dims); + std::set dims_set(broadcast_dims.begin(), broadcast_dims.end()); + OP_REQUIRES(context, + dims_set.size() == broadcast_dims.size() && + broadcast_dims == sorted_broadcast_dims, + errors::InvalidArgument( + "Duplicate or nonmonotonic dimension in broadcast_dims; " + "broadcast_dims: [", + absl::StrJoin(broadcast_dims, ","), "]")); + + std::vector broadcast_shape(max_rank_shape->dims(), 1LL); + for (int i = 0; i < broadcast_dims.size(); ++i) { + const int dim = broadcast_dims[i]; + OP_REQUIRES( + context, dim >= 0 && dim < broadcast_shape.size(), + errors::InvalidArgument( + "Invalid broadcast dimension (", dim, "); broadcast_dims: [", + absl::StrJoin(broadcast_dims, ","), "]; argument shapes: ", + lhs_shape.DebugString(), " and ", rhs_shape.DebugString())); + broadcast_shape[dim] = min_rank_shape->dim_size(i); + } + xla::PrimitiveType type = context->input_xla_type(0); + xla::Shape broadcast_xla_shape = + xla::ShapeUtil::MakeShape(type, broadcast_shape); + if (broadcast_lhs) { + lhs = xla::BroadcastInDim(lhs, broadcast_xla_shape, broadcast_dims); + } else { + rhs = xla::BroadcastInDim(rhs, broadcast_xla_shape, broadcast_dims); + } + context->SetOutput(0, lhs); + context->SetOutput(1, rhs); + } + + private: + xla::DotDimensionNumbers dnums_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaBroadcastHelperOp); +}; + +REGISTER_XLA_OP( + Name("XlaBroadcastHelper").CompileTimeConstInput("broadcast_dims"), + XlaBroadcastHelperOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8848623868091f8d19b1622f23ba23c68689d90d --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -0,0 +1,101 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaConvOp : public XlaOpKernel { + public: + explicit XlaConvOp(OpKernelConstruction* context) : XlaOpKernel(context) { + string dnums_attr; + OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); + OP_REQUIRES( + context, dnums_.ParsePartialFromString(dnums_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + string precision_config_attr; + OP_REQUIRES_OK( + context, context->GetAttr("precision_config", &precision_config_attr)); + OP_REQUIRES( + context, + precision_config_.ParsePartialFromString(precision_config_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape lhs_shape = context->InputShape(0); + const TensorShape rhs_shape = context->InputShape(1); + const TensorShape padding_shape = context->InputShape("padding"); + std::vector window_strides; + std::vector lhs_dilation; + std::vector rhs_dilation; + int64 feature_group_count; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", + &window_strides)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("lhs_dilation", + &lhs_dilation)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("rhs_dilation", + &rhs_dilation)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar( + "feature_group_count", &feature_group_count)); + + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(padding_shape) && + padding_shape.dim_size(1) == 2, + errors::InvalidArgument( + "padding must be a matrix with minor dimension 2, got ", + padding_shape.DebugString())); + xla::Literal padding_literal; + OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( + "padding", &padding_literal)); + std::vector> padding(padding_shape.dim_size(0)); + for (int i = 0; i < padding.size(); ++i) { + padding[i] = {padding_literal.Get({i, 0}), + padding_literal.Get({i, 1})}; + } + + // We do only minimal checking, relying on XLA to check the shape + // invariants. + xla::XlaOp output = xla::ConvGeneralDilated( + context->Input(0), context->Input(1), window_strides, padding, + lhs_dilation, rhs_dilation, dnums_, feature_group_count, + &precision_config_); + context->SetOutput(0, output); + } + + private: + xla::ConvolutionDimensionNumbers dnums_; + xla::PrecisionConfigProto precision_config_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaConvOp); +}; + +REGISTER_XLA_OP(Name("XlaConv") + .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("lhs_dilation") + .CompileTimeConstInput("rhs_dilation") + .CompileTimeConstInput("feature_group_count") + .CompileTimeConstInput("padding"), + XlaConvOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2fed53e5c072e1a50e0f07f45357ee86c90f986f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaDotOp : public XlaOpKernel { + public: + explicit XlaDotOp(OpKernelConstruction* context) : XlaOpKernel(context) { + string dnums_attr; + OP_REQUIRES_OK(context, context->GetAttr("dimension_numbers", &dnums_attr)); + OP_REQUIRES( + context, dnums_.ParsePartialFromString(dnums_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + string precision_config_attr; + OP_REQUIRES_OK( + context, context->GetAttr("precision_config", &precision_config_attr)); + OP_REQUIRES( + context, + precision_config_.ParsePartialFromString(precision_config_attr), + errors::InvalidArgument("Error parsing convolution dimension numbers")); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape lhs_shape = context->InputShape(0); + const TensorShape rhs_shape = context->InputShape(1); + + // We do only minimal checking, relying on XLA to check the shape + // invariants. + xla::XlaOp output = xla::DotGeneral(context->Input(0), context->Input(1), + dnums_, &precision_config_); + context->SetOutput(0, output); + } + + private: + xla::DotDimensionNumbers dnums_; + xla::PrecisionConfigProto precision_config_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp); +}; + +REGISTER_XLA_OP(Name("XlaDot"), XlaDotOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..59502d83c7338bd1b05b3323a97761fff2da186a --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc @@ -0,0 +1,105 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaPadOp : public XlaOpKernel { + public: + explicit XlaPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape("input"); + const TensorShape padding_value_shape = + context->InputShape("padding_value"); + + std::vector padding_low; + std::vector padding_high; + std::vector padding_interior; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_low", + &padding_low)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("padding_high", + &padding_high)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "padding_interior", &padding_interior)); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(padding_value_shape), + errors::InvalidArgument("padding_value must be a scalar")); + const int rank = input_shape.dims(); + OP_REQUIRES(context, rank == padding_low.size(), + errors::InvalidArgument( + "The size of padding_low must be equal to the input " + "rank (", + padding_low.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == padding_high.size(), + errors::InvalidArgument( + "The size of padding_high must be equal to the input " + "rank (", + padding_high.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == padding_interior.size(), + errors::InvalidArgument( + "The size of padding_interior must be equal to the input " + "rank (", + padding_interior.size(), " vs. ", rank, ")")); + + auto non_negative = [](int64 x) { return x >= 0; }; + OP_REQUIRES( + context, absl::c_all_of(padding_low, non_negative), + errors::InvalidArgument("padding_low must be non-negative, got [", + absl::StrJoin(padding_low, ","), "]")); + OP_REQUIRES( + context, absl::c_all_of(padding_high, non_negative), + errors::InvalidArgument("padding_high must be non-negative, got [", + absl::StrJoin(padding_high, ","), "]")); + OP_REQUIRES( + context, absl::c_all_of(padding_interior, non_negative), + errors::InvalidArgument("padding_interior must be non-negative, got [", + absl::StrJoin(padding_interior, ","), "]")); + + xla::PaddingConfig padding_config; + for (int i = 0; i < rank; ++i) { + auto* dim = padding_config.add_dimensions(); + dim->set_edge_padding_low(padding_low[i]); + dim->set_edge_padding_high(padding_high[i]); + dim->set_interior_padding(padding_interior[i]); + } + + xla::XlaOp output = + xla::Pad(context->Input("input"), context->Input("padding_value"), + padding_config); + context->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaPadOp); +}; + +REGISTER_XLA_OP(Name("XlaPad") + .CompileTimeConstInput("padding_low") + .CompileTimeConstInput("padding_high") + .CompileTimeConstInput("padding_interior"), + XlaPadOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc2425f37bfa793ce3a106b635c9dffd15b975ff --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_reduce_op.cc @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaReduceOp : public XlaOpKernel { + public: + explicit XlaReduceOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("reducer", &reducer_)); + OP_REQUIRES_OK(context, context->GetAttr("dimensions_to_reduce", + &dimensions_to_reduce_)); + std::set dims_set(dimensions_to_reduce_.begin(), + dimensions_to_reduce_.end()); + OP_REQUIRES( + context, dims_set.size() == dimensions_to_reduce_.size(), + errors::InvalidArgument("Duplicate dimension in dimensions_to_reduce " + "argument to XlaReduce")); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape("input"); + const TensorShape init_value_shape = context->InputShape("init_value"); + const DataType dtype = context->input_type(0); + + const int rank = input_shape.dims(); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(init_value_shape), + errors::InvalidArgument("init_value must be a scalar")); + + auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; }; + OP_REQUIRES(context, + rank >= dimensions_to_reduce_.size() && + absl::c_all_of(dimensions_to_reduce_, dim_in_range), + errors::InvalidArgument( + "Invalid dimensions_to_reduce argument to XlaReduce")); + + // Build the reducer function. + XlaCompiler::Argument reducer_arg; + reducer_arg.kind = XlaCompiler::Argument::kParameter; + reducer_arg.type = dtype; + reducer_arg.shape = TensorShape(); + + XlaCompiler::CompileOptions compile_options; + compile_options.use_tuple_arg = false; + compile_options.always_return_tuple = false; + compile_options.resolve_compile_time_constants = false; + compile_options.is_entry_computation = false; + XlaCompiler::CompilationResult reducer; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *reducer_, + {reducer_arg, reducer_arg}, &reducer)); + + xla::Shape scalar_shape; + OP_REQUIRES_OK(context, + TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(reducer.xla_output_shape, scalar_shape), + errors::InvalidArgument( + "Invalid output shape of XlaReduce reducer. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(reducer.xla_output_shape))); + + xla::XlaOp output = + xla::Reduce(context->Input("input"), context->Input("init_value"), + *reducer.computation, dimensions_to_reduce_); + context->SetOutput(0, output); + } + + private: + const NameAttrList* reducer_; + std::vector dimensions_to_reduce_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaReduceOp); +}; + +REGISTER_XLA_OP(Name("XlaReduce"), XlaReduceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..089776fcf74fcf6b363dfff5de8d86d7449eacd6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/xla_select_and_scatter_op.cc @@ -0,0 +1,147 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/while_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class XlaSelectAndScatterOp : public XlaOpKernel { + public: + explicit XlaSelectAndScatterOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("select", &select_computation_)); + OP_REQUIRES_OK(context, context->GetAttr("scatter", &scatter_computation_)); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const DataType dtype = context->input_type(0); + + std::vector window_dimensions; + std::vector window_strides; + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dimensions", &window_dimensions)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", + &window_strides)); + + const int rank = input_shape.dims(); + OP_REQUIRES(context, rank == window_dimensions.size(), + errors::InvalidArgument( + "The size of window_dimensions must be equal to the input " + "rank (", + window_dimensions.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_strides.size(), + errors::InvalidArgument( + "The size of window_strides must be equal to the input " + "rank (", + window_strides.size(), " vs. ", rank, ")")); + + XlaCompiler::CompileOptions compile_options; + compile_options.use_tuple_arg = false; + compile_options.resolve_compile_time_constants = false; + compile_options.is_entry_computation = false; + compile_options.always_return_tuple = false; + + // Build the select function. + XlaCompiler::Argument select_arg; + select_arg.kind = XlaCompiler::Argument::kParameter; + select_arg.type = dtype; + select_arg.shape = TensorShape(); + + XlaCompiler::CompilationResult select; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *select_computation_, + {select_arg, select_arg}, &select)); + + xla::Shape select_output_shape = xla::ShapeUtil::MakeShape(xla::PRED, {}); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(select.xla_output_shape, + select_output_shape), + errors::InvalidArgument( + "Invalid output shape of XlaSelectAndScatter select. Expected ", + xla::ShapeUtil::HumanString(select_output_shape), " got ", + xla::ShapeUtil::HumanString(select.xla_output_shape))); + + // Build the scatter function. + XlaCompiler::Argument scatter_arg; + scatter_arg.kind = XlaCompiler::Argument::kParameter; + scatter_arg.type = dtype; + scatter_arg.shape = TensorShape(); + + XlaCompiler::CompilationResult scatter; + OP_REQUIRES_OK(context, context->compiler()->CompileFunction( + compile_options, *scatter_computation_, + {scatter_arg, scatter_arg}, &scatter)); + + xla::Shape scalar_shape; + OP_REQUIRES_OK(context, + TensorShapeToXLAShape(dtype, TensorShape(), &scalar_shape)); + OP_REQUIRES( + context, + xla::ShapeUtil::Compatible(scatter.xla_output_shape, scalar_shape), + errors::InvalidArgument( + "Invalid output shape of scatter. Expected ", + xla::ShapeUtil::HumanString(scalar_shape), " got ", + xla::ShapeUtil::HumanString(scatter.xla_output_shape))); + + const TensorShape padding_shape = context->InputShape("padding"); + OP_REQUIRES(context, + TensorShapeUtils::IsMatrix(padding_shape) && + padding_shape.dim_size(1) == 2, + errors::InvalidArgument( + "padding must be a matrix with minor dimension 2, got ", + padding_shape.DebugString())); + xla::Literal padding_literal; + OP_REQUIRES_OK(context, context->ConstantInputAsInt64Literal( + "padding", &padding_literal)); + std::vector> padding(padding_shape.dim_size(0)); + for (int i = 0; i < padding.size(); ++i) { + padding[i] = {padding_literal.Get({i, 0}), + padding_literal.Get({i, 1})}; + } + + xla::XlaOp output = xla::SelectAndScatterWithGeneralPadding( + context->Input("operand"), *select.computation, window_dimensions, + window_strides, padding, context->Input("source"), + context->Input("init_value"), *scatter.computation); + context->SetOutput(0, output); + } + + private: + const NameAttrList* select_computation_; + const NameAttrList* scatter_computation_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaSelectAndScatterOp); +}; + +REGISTER_XLA_OP(Name("XlaSelectAndScatter") + .CompileTimeConstInput("window_dimensions") + .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("padding"), + XlaSelectAndScatterOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index cb7a40e23d539f758d963791f1c2b4d37374ade5..99511e991422014c877fb5f6b7fb6a914e730f40 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -25,8 +25,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", ], ) @@ -44,8 +44,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/core:lib", ], @@ -78,8 +78,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:math", @@ -119,6 +119,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:constants", diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index f666d22ea44216beef74608bb4d9f33fb2fe82c6..d8c050d09e871c80e128989c9fbdb57c266b19ed 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -27,7 +27,8 @@ limitations under the License. namespace tensorflow { xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, - bool transpose_y, bool conjugate_x, bool conjugate_y) { + bool transpose_y, bool conjugate_x, bool conjugate_y, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); @@ -95,6 +96,10 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, y = xla::Conj(y); } + xla::PrecisionConfigProto precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + // If there are no batch dimensions, use a regular Dot. // TODO(b/69062148) Remove this code when Dot emitters can be passed // dimensions to transpose directly (i.e. without requiring a Transpose @@ -102,7 +107,7 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, if (batch_dimension_numbers.empty()) { auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x; auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y; - return xla::Dot(lhs, rhs); + return xla::Dot(lhs, rhs, &precision_proto); } xla::DotDimensionNumbers dot_dnums; @@ -112,7 +117,8 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); } - return xla::DotGeneral(x, y, dot_dnums); + + return xla::DotGeneral(x, y, dot_dnums, &precision_proto); }); } diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index 8757b16a1ca6a8cec5e3c801c885e7bbbb2f2c76..6cfccd55530ff40a309673d57d1fe61fc8264316 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -45,7 +45,9 @@ namespace tensorflow { // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, bool transpose_y = false, bool conjugate_x = false, - bool conjugate_y = false); + bool conjugate_y = false, + xla::PrecisionConfigProto::Precision precision = + xla::PrecisionConfigProto::DEFAULT); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 87d73eb3f07ebd7dfa4fef50ebe76cad0c4ed117..67fb56510cbd0677a2b78e2090f98b602539c6bd 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -49,7 +49,8 @@ namespace { // l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / // l[..., j, j] // return l -xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { +xla::XlaOp CholeskyUnblocked(xla::XlaOp a, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); @@ -101,7 +102,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { // np.dot(row, np.swapaxes(row, -1, -2)) auto diag_dot = BatchDot(row, row, /*transpose_x=*/false, - /*transpose_y=*/true); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, // np.swapaxes(row, -1, -2))) auto l_ii = @@ -121,7 +123,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { // r.T) auto dot = BatchDot(body_l, row, /*transpose_x=*/false, - /*transpose_y=*/true); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); // np.dot(l[..., i+1:, :i], r.T) auto dot_ip1 = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot); @@ -145,7 +148,8 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a) { } // namespace -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) { +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); @@ -181,14 +185,15 @@ xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) { auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); a = UpdateSliceInMinorDims(a, before - delta, {i, i}); } // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k}); - auto factorized = CholeskyUnblocked(x); + auto factorized = CholeskyUnblocked(x, precision); l = UpdateSliceInMinorDims(l, factorized, {i, i}); if (i + k < n) { diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 1bef9bb166c576ec665bb48265b4da200ddca2a0..60cd7ded53fe862f29ca2bb68b175fcd1c89b70c 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -30,7 +30,9 @@ namespace tensorflow { // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. // TODO(znado): handle the complex Hermitian case -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256); +xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256, + xla::PrecisionConfigProto::Precision precision = + xla::PrecisionConfigProto::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index fc0c1ee838190b1f1a7ca5b901c97e0a35232a97..b6f30d8d49bf05813fa6fccc4544b0631f866490 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -149,7 +149,8 @@ struct QRBlockResult { xla::XlaOp taus; // Shape: [..., n] xla::XlaOp vs; // Shape: [..., m, n] }; -xla::StatusOr QRBlock(xla::XlaOp a) { +xla::StatusOr QRBlock( + xla::XlaOp a, xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); @@ -190,8 +191,12 @@ xla::StatusOr QRBlock(xla::XlaOp a) { auto v_broadcast = xla::Reshape(v, shape); // a[:, :] -= tau * np.dot(v[:, np.newaxis], // np.dot(v[np.newaxis, :], a[:, :])) - auto vva = BatchDot(v_broadcast, a); - vva = BatchDot(v_broadcast, vva, /*transpose_x=*/true); + auto vva = + BatchDot(v_broadcast, a, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + vva = + BatchDot(v_broadcast, vva, /*transpose_x=*/true, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); a = a - xla::Mul(tau, vva, /*broadcast_dimensions=*/batch_dim_indices); @@ -251,7 +256,8 @@ xla::StatusOr QRBlock(xla::XlaOp a) { // vs. xla::StatusOr ComputeWYRepresentation( xla::PrimitiveType type, gtl::ArraySlice batch_dims, xla::XlaOp vs, - xla::XlaOp taus, int64 m, int64 n) { + xla::XlaOp taus, int64 m, int64 n, + xla::PrecisionConfigProto::Precision precision) { std::vector batch_dim_indices(batch_dims.size()); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); int64 n_index = batch_dims.size() + 1; @@ -272,9 +278,12 @@ xla::StatusOr ComputeWYRepresentation( auto beta = DynamicSliceInMinorDims(taus, {j}, {1}); // yv has shape [..., n, 1] - auto yv = BatchDot(y, v, /*transpose_x=*/true); + auto yv = BatchDot(y, v, /*transpose_x=*/true, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); // wyv has shape [..., m, 1] - auto wyv = BatchDot(w, yv); + auto wyv = + BatchDot(w, yv, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); auto z = xla::Mul( -beta, v + wyv, @@ -321,8 +330,9 @@ xla::StatusOr ComputeWYRepresentation( // return (q, a) // TODO(phawkins): consider using UT transformations (in the form I - V U V') // rather than WY transformations. -xla::StatusOr QRDecomposition(xla::XlaOp a, - int64 block_size) { +xla::StatusOr QRDecomposition( + xla::XlaOp a, int64 block_size, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); @@ -352,29 +362,36 @@ xla::StatusOr QRDecomposition(xla::XlaOp a, int64 k = std::min(block_size, p - i); auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k}); - TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block)); + TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block, precision)); a = UpdateSliceInMinorDims(a, qr_block.r, {i, i}); // Compute the I-WY block representation of a product of Householder // matrices. - TF_ASSIGN_OR_RETURN(auto w, - ComputeWYRepresentation(type, batch_dims, qr_block.vs, - qr_block.taus, m - i, k)); + TF_ASSIGN_OR_RETURN( + auto w, ComputeWYRepresentation(type, batch_dims, qr_block.vs, + qr_block.taus, m - i, k, precision)); auto y = qr_block.vs; // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); - auto a_update = BatchDot(w, a_panel, /*transpose_x=*/true); - a_update = BatchDot(y, a_update); + auto a_update = + BatchDot(w, a_panel, /*transpose_x=*/true, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + a_update = + BatchDot(y, a_update, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); a_panel = a_panel + a_update; a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); - auto q_update = BatchDot(q_panel, w); - q_update = - BatchDot(q_update, y, /*transpose_x=*/false, /*transpose_y=*/true); + auto q_update = + BatchDot(q_panel, w, /*transpose_x=*/false, /*transpose_y=*/false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); + q_update = BatchDot(q_update, y, /*transpose_x=*/false, + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); q_panel = q_panel + q_update; q = UpdateSliceInMinorDims(q, q_panel, {0, i}); } diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h index abd2316ac961f583dd29f90f43cf6209de30bd6a..05565477b6062618a75f929b69c38938ddfd7a5a 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/tf2xla/lib/qr.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -32,8 +33,10 @@ struct QRDecompositionResult { xla::XlaOp r; }; -xla::StatusOr QRDecomposition(xla::XlaOp a, - int64 block_size = 128); +xla::StatusOr QRDecomposition( + xla::XlaOp a, int64 block_size = 128, + xla::PrecisionConfigProto::Precision precision = + xla::PrecisionConfigProto::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index ba22eff73abab11abeb57283c63318b2e50a9ca1..bafe5099f2d494fd3549fae41397ffc5a22f5cb7 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -58,7 +58,7 @@ xla::StatusOr XlaScatter( ") must be <= the rank of the buffer (shape: ", xla::ShapeUtil::HumanString(buffer_shape), ")"); } - indices_dims.pop_back(); + indices_dims.remove_suffix(1); } int64 num_indices = 1; diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index febb638e5e8a87d78919f1eaa556d9c05ee40112..37b2240b45b4ae6a587c827cfdfa1096b4e1737e 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -110,8 +110,9 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { }); } -xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, - bool transpose_a, bool conjugate_a) { +xla::XlaOp InvertDiagonalBlocks( + xla::XlaOp diag_blocks, bool lower, bool transpose_a, bool conjugate_a, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = diag_blocks.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { // Input is a batch of square lower triangular square matrices. Its shape is @@ -215,7 +216,10 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, dnums.add_rhs_batch_dimensions(0); dnums.add_lhs_contracting_dimensions(2); dnums.add_rhs_contracting_dimensions(1); - auto update = -DotGeneral(input_row, body_out, dnums); + xla::PrecisionConfigProto precision_proto; + precision_proto.add_operand_precision(precision); + precision_proto.add_operand_precision(precision); + auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); body_out = DynamicUpdateSlice(body_out, update, start_indices); @@ -238,10 +242,10 @@ xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, }); } -xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, - xla::XlaOp inv_diag_blocks, - bool left_side, bool lower, - bool transpose_a, bool conjugate_a) { +xla::XlaOp SolveWithInvertedDiagonalBlocks( + xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side, + bool lower, bool transpose_a, bool conjugate_a, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, @@ -307,9 +311,13 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, auto a_row = MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a); if (left_side) { - remainder = b_row - BatchDot(a_row, x, transpose_a, false); + remainder = b_row - BatchDot(a_row, x, transpose_a, false, + /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); } else { - remainder = b_row - BatchDot(x, a_row, false, transpose_a); + remainder = b_row - BatchDot(x, a_row, false, transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/false, precision); } } @@ -319,9 +327,13 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, xla::ConstantR0WithType(builder, xla::S32, j * block_size); std::vector update_starts = {start_index, zero}; if (left_side) { - x_update = BatchDot(inv_block, remainder, transpose_a, false); + x_update = + BatchDot(inv_block, remainder, transpose_a, false, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); } else { - x_update = BatchDot(remainder, inv_block, false, transpose_a); + x_update = + BatchDot(remainder, inv_block, false, transpose_a, + /*conjugate_x=*/false, /*conjugate_y=*/false, precision); std::swap(update_starts[0], update_starts[1]); } x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); @@ -333,7 +345,8 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks(xla::XlaOp a, xla::XlaOp b, xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, - int64 block_size) { + int64 block_size, + xla::PrecisionConfigProto::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); @@ -388,12 +401,13 @@ xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, auto diag_blocks = DiagonalBlocks(a, block_size); // We invert these blocks in parallel using batched matrix-vector products - auto inv_diag_blocks = - InvertDiagonalBlocks(diag_blocks, lower, transpose_a, conjugate_a); + auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a, + conjugate_a, precision); // We now find the solution using GEMMs - auto x = SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, - lower, transpose_a, conjugate_a); + auto x = + SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower, + transpose_a, conjugate_a, precision); return x; }); diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index 555760b7efabddfb25c9135b109a1c48b487415e..ac42a4835295b7cb52697710d738f4728d3983d1 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace tensorflow { @@ -59,7 +59,9 @@ namespace tensorflow { // blocking is used. xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, - int64 block_size = 128); + int64 block_size = 128, + xla::PrecisionConfigProto::Precision precision = + xla::PrecisionConfigProto::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD index ace6fd1d8eeaf439509a7b75d8d986997c392e73..4dce0a2102cf9c782850ccc7af4f14b59bd51e53 100644 --- a/tensorflow/compiler/tf2xla/ops/BUILD +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -11,6 +11,8 @@ cc_library( srcs = ["xla_ops.cc"], deps = [ "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index a59c77f5c3a309abe8f6fbab1e48455d54e8fae5..2cd9ae799f06afdcbae5429ef8caffd3b4d29c29 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -13,11 +13,97 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/algorithm/container.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { +namespace { + +// Helper shape function for operators that return an output with the same rank +// as their first input. +Status UnchangedRank(shape_inference::InferenceContext* c) { + if (c->RankKnown(c->input(0))) { + c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0)))); + } else { + c->set_output(0, c->input(0)); + } + return Status::OK(); +} + +REGISTER_OP("XlaBroadcastHelper") + .Input("lhs: T") + .Input("rhs: T") + .Input("broadcast_dims: Tindices") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Output("lhs_output: T") + .Output("rhs_output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Helper operator for performing XLA-style broadcasts + +Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to +whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules +for binary operators. + +lhs: the LHS input tensor +rhs: the RHS input tensor +broadcast_dims: an XLA-style broadcast dimension specification +lhs_output: the broadcasted LHS tensor +rhs_output: the broadcasted RHS tensor +)doc"); + +REGISTER_OP("XlaConv") + .Input("lhs: T") + .Input("rhs: T") + .Input("window_strides: Tindices") + .Input("padding: Tindices") + .Input("lhs_dilation: Tindices") + .Input("rhs_dilation: Tindices") + .Input("feature_group_count: Tindices") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("dimension_numbers: string") + .Attr("precision_config: string") + .Output("output: T") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA ConvGeneralDilated operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution +. + +lhs: the input tensor +rhs: the kernel tensor +window_strides: the inter-window strides +padding: the padding to apply at the start and end of each input dimensions +lhs_dilation: dilation to apply between input elements +rhs_dilation: dilation to apply between kernel elements +feature_group_count: number of feature groups for grouped convolution. +dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto. +precision_config: a serialized xla::PrecisionConfigProto proto. +)doc"); + +REGISTER_OP("XlaDot") + .Input("lhs: T") + .Input("rhs: T") + .Attr("T: numbertype") + .Attr("dimension_numbers: string") + .Attr("precision_config: string") + .Output("output: T") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Wraps the XLA ConvGeneralDilated operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral +. + +lhs: the LHS tensor +rhs: the RHS tensor +dimension_numbers: a serialized xla::DotDimensionNumbers proto. +precision_config: a serialized xla::PrecisionConfigProto proto. +)doc"); REGISTER_OP("XlaDynamicUpdateSlice") .Input("input: T") @@ -73,6 +159,29 @@ else_branch: A function takes 'inputs' and returns a list of tensors. whose types are the same as what then_branch returns. )doc"); +REGISTER_OP("XlaPad") + .Input("input: T") + .Input("padding_value: T") + .Input("padding_low: Tindices") + .Input("padding_high: Tindices") + .Input("padding_interior: Tindices") + .Output("output: T") + .Attr("T: type") + .Attr("Tindices: {int32, int64}") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA Pad operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#pad +. + +input: A `Tensor` of type T. +padding_value: A scalar `Tensor` of type T. +padding_low: the padding to apply at the start of each input dimensions +padding_high: the padding to apply at the end of each input dimension. +padding_interior: the padding to apply between each input element. +output: A `Tensor` of type T. +)doc"); + REGISTER_OP("XlaRecv") .Output("tensor: dtype") .Attr("dtype: type") @@ -98,17 +207,58 @@ tensor_name: A string key that identifies the channel. shape: The shape of the tensor. )doc"); +REGISTER_OP("XlaReduce") + .Input("input: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("dimensions_to_reduce: list(int)") + .Attr("reducer: func") + .Output("output: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + if (c->RankKnown(c->input(0))) { + int rank = c->Rank(c->input(0)); + std::vector dimensions_to_reduce; + TF_RETURN_IF_ERROR( + c->GetAttr("dimensions_to_reduce", &dimensions_to_reduce)); + std::set dims_set(dimensions_to_reduce.begin(), + dimensions_to_reduce.end()); + auto dim_in_range = [rank](int64 dim) { + return dim >= 0 && dim < rank; + }; + if (rank < dimensions_to_reduce.size() || + dims_set.size() != dimensions_to_reduce.size() || + !absl::c_all_of(dimensions_to_reduce, dim_in_range)) { + return errors::InvalidArgument( + "Invalid dimensions_to_reduce argument to XlaReduce"); + } + c->set_output( + 0, c->UnknownShapeOfRank(rank - dimensions_to_reduce.size())); + } else { + c->set_output(0, c->input(0)); + } + return Status::OK(); + }) + .Doc(R"doc( +Wraps the XLA Reduce operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#reduce . + +input: the input tensor +init_value: a scalar representing the initial value for the reduction +reducer: a reducer function to apply +dimensions_to_reduce: dimension numbers over which to reduce +)doc"); + REGISTER_OP("XlaReduceWindow") .Input("input: T") .Input("init_value: T") + .Input("window_dimensions: Tindices") + .Input("window_strides: Tindices") + .Input("padding: Tindices") .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") .Attr("computation: func") - .Attr("window_dimensions: list(int)") - .Attr("window_strides: list(int)") - .Attr("padding_low: list(int)") - .Attr("padding_high: list(int)") .Output("output: T") - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn(UnchangedRank) .Doc(R"doc( Wraps the XLA ReduceWindow operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . @@ -118,8 +268,35 @@ init_value: a scalar representing the initial value for the reduction computation: a reducer function to apply window_dimensions: the shape of the window window_strides: the inter-window strides -padding_low: the padding to apply at the start of each input dimensions -padding_high: the padding to apply at the end of each input dimension. +padding: the padding to apply at the start and end of each input dimensions +)doc"); + +REGISTER_OP("XlaSelectAndScatter") + .Input("operand: T") + .Input("window_dimensions: Tindices") + .Input("window_strides: Tindices") + .Input("padding: Tindices") + .Input("source: T") + .Input("init_value: T") + .Attr("T: numbertype") + .Attr("Tindices: {int32, int64}") + .Attr("select: func") + .Attr("scatter: func") + .Output("output: T") + .SetShapeFn(UnchangedRank) + .Doc(R"doc( +Wraps the XLA SelectAndScatter operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter +. + +operand: the input tensor +window_dimensions: the shape of the window +window_strides: the inter-window strides +padding: the padding to apply at the start and end of each input dimensions +source: a tensor of values to scatter +init_value: a scalar representing the initial value for the output tensor +select: a selection function to apply +scatter: a scatter function to apply )doc"); REGISTER_OP("XlaSend") @@ -179,4 +356,5 @@ body: A function that takes a list of tensors and returns another list of tensors. Both lists have the same types as specified by T. )doc"); +} // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index 42b6292f79ffddd155c05758a1420a2a583eb0c6..69ca39436013ec5cf09ba502a1540d5df322e213 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -28,5 +28,6 @@ py_library( srcs = ["xla.py"], deps = [ "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", + "//tensorflow/compiler/xla:xla_data_proto_py", ], ) diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 2fc47dffb8f5f16f24e3beb1ff75aeed3e857c58..3626de375ea9ac12e40ea5b5b591bb6d5262adbc 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -15,11 +15,12 @@ """Experimental library that exposes XLA operations directly in TensorFlow. It is sometimes useful to be able to build HLO programs directly from -TensorFlow. This file provides Tensorflow operators that map as closely as -possible to HLO operators. +TensorFlow. This file provides Tensorflow operators that mirror the semantics of +HLO operators as closely as possible. -There is no promise of backward or forward compatibility for operators defined -in this module. +Note: There is no promise of backward or forward compatibility for operators +defined in this module. This is primarily because the underlying HLO operators +do not promise backward or forward compatibility. """ from __future__ import absolute_import @@ -27,11 +28,298 @@ from __future__ import division from __future__ import print_function from tensorflow.compiler.tf2xla.ops import gen_xla_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops + +# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing +# ops include: +# infeed/outfeed (available via tf.contrib.tpu) +# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu) +# conditional +# gather/scatter +# collapse + +# This file reuses builtin names (following XLA's names, so we can call things +# like xla.max), so we capture the builtin versions here. +# pylint: disable=redefined-builtin +_max = max +_min = min +_slice = slice # pylint: disable=invalid-name + +constant = constant_op.constant + +# Unary operators. + +# For most arithmetic operators there is a TensorFlow operator +# that exactly corresponds to each XLA operator. Rather than defining +# XLA-specific variants, we reuse the corresponding TensorFlow operator. +# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1 +# wrap every HLO operator, because that would allow us to be confident that the +# semantics match. + + +def _unary_op(fn): + """Wrapper that restricts `fn` to have the correct signature.""" + + def unary_op_wrapper(x, name=None): + return fn(x, name=name) + + return unary_op_wrapper + + +abs = _unary_op(math_ops.abs) +# TODO(phawkins): implement clz. +conj = _unary_op(math_ops.conj) +cos = _unary_op(math_ops.cos) +ceil = _unary_op(math_ops.ceil) +digamma = _unary_op(math_ops.digamma) +erf = _unary_op(math_ops.erf) +erfc = _unary_op(math_ops.erfc) +# TODO(phawkins): implement erfinv +exp = _unary_op(math_ops.exp) +expm1 = _unary_op(math_ops.expm1) +floor = _unary_op(math_ops.floor) +imag = _unary_op(math_ops.imag) +is_finite = _unary_op(math_ops.is_finite) +lgamma = _unary_op(math_ops.lgamma) +log = _unary_op(math_ops.log) +log1p = _unary_op(math_ops.log1p) +logical_not = _unary_op(math_ops.logical_not) +neg = _unary_op(math_ops.neg) +real = _unary_op(math_ops.real) +# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for +# numbers halfway between two integers. +round = _unary_op(math_ops.round) +sin = _unary_op(math_ops.sin) +sign = _unary_op(math_ops.sign) +tanh = _unary_op(math_ops.tanh) + +# Binary operators + +# The main difference between TensorFlow and XLA binary ops is the broadcasting +# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA +# requires an explicit specification of which dimensions to broadcast if the +# arguments have different ranks. + + +def _broadcasting_binary_op(fn): + """Wraps a binary Tensorflow operator and performs XLA-style broadcasting.""" + + def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None): + """Inner wrapper function.""" + broadcast_dims = broadcast_dims or [] + broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64) + # Rather than relying on having static shape information in the TensorFlow + # graph, we use an XlaBroadcastHelper op that can compute the correct shapes + # at JIT compilation time. + x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims) + return fn(x, y, name=name) + + return broadcasting_binary_op_wrapper + + +# Map from TF signed types to TF unsigned types. +_SIGNED_TO_UNSIGNED_TABLE = { + dtypes.int8: dtypes.uint8, + dtypes.int16: dtypes.uint16, + dtypes.int32: dtypes.uint32, + dtypes.int64: dtypes.uint64, +} + +# Map from TF unsigned types to TF signed types. +_UNSIGNED_TO_SIGNED_TABLE = { + dtypes.uint8: dtypes.int8, + dtypes.uint16: dtypes.int16, + dtypes.uint32: dtypes.int32, + dtypes.uint64: dtypes.int64, +} + + +def _shift_right_logical_helper(x, y, name=None): + """Performs an integer right logical shift irrespective of input type.""" + assert y.dtype == x.dtype + dtype = x.dtype + signed = dtype in _SIGNED_TO_UNSIGNED_TABLE + if signed: + unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype] + x = math_ops.cast(x, unsigned_dtype) + y = math_ops.cast(y, unsigned_dtype) + output = bitwise_ops.right_shift(x, y, name=name) + if signed: + output = math_ops.cast(output, dtype) + return output + + +def _shift_right_arithmetic_helper(x, y, name=None): + """Performs an integer right arithmetic shift irrespective of input type.""" + assert y.dtype == x.dtype + dtype = x.dtype + unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE + if unsigned: + signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype] + x = math_ops.cast(x, signed_dtype) + y = math_ops.cast(y, signed_dtype) + output = bitwise_ops.right_shift(x, y, name=name) + if unsigned: + output = math_ops.cast(output, dtype) + return output + + +add = _broadcasting_binary_op(math_ops.add) +sub = _broadcasting_binary_op(math_ops.sub) +mul = _broadcasting_binary_op(math_ops.mul) +div = _broadcasting_binary_op(math_ops.div) +rem = _broadcasting_binary_op(gen_math_ops.mod) +max = _broadcasting_binary_op(math_ops.maximum) +min = _broadcasting_binary_op(math_ops.minimum) +atan2 = _broadcasting_binary_op(math_ops.atan2) +complex = _broadcasting_binary_op(math_ops.complex) +logical_and = _broadcasting_binary_op(math_ops.logical_and) +logical_or = _broadcasting_binary_op(math_ops.logical_or) +logical_xor = _broadcasting_binary_op(math_ops.logical_xor) +eq = _broadcasting_binary_op(math_ops.equal) +ne = _broadcasting_binary_op(math_ops.not_equal) +ge = _broadcasting_binary_op(math_ops.greater_equal) +gt = _broadcasting_binary_op(math_ops.greater) +le = _broadcasting_binary_op(math_ops.less_equal) +lt = _broadcasting_binary_op(math_ops.less) +pow = _broadcasting_binary_op(math_ops.pow) +shift_left = _broadcasting_binary_op(bitwise_ops.left_shift) +shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper) +shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper) + + +def _binary_op(fn): + """Wrapper that restricts `fn` to have the correct signature.""" + + def binary_op_wrapper(x, y, name=None): + return fn(x, y, name=name) + + return binary_op_wrapper + + +transpose = _binary_op(array_ops.transpose) +rev = _binary_op(array_ops.reverse) + +bitcast_convert_type = array_ops.bitcast + + +def broadcast(x, dims, name=None): + x = ops.convert_to_tensor(x) + shape = array_ops.concat( + [constant_op.constant(dims), + array_ops.shape(x)], axis=0) + return array_ops.broadcast_to(x, shape, name=name) + + +def clamp(a, x, b, name=None): + return min(max(a, x, name=name), b, name=name) + + +concatenate = array_ops.concat + + +def conv(lhs, + rhs, + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers, + feature_group_count=1, + precision_config=None, + name=None): + """Wraps the XLA ConvGeneralDilated operator. + + ConvGeneralDilated is the most general form of XLA convolution and is + documented at + https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution + + Args: + lhs: the input tensor + rhs: the kernel tensor + window_strides: the inter-window strides + padding: the padding to apply at the start and end of each input dimensions + lhs_dilation: dilation to apply between input elements + rhs_dilation: dilation to apply between kernel elements + dimension_numbers: a `ConvolutionDimensionNumbers` proto. + feature_group_count: number of feature groups for grouped convolution. + precision_config: a `PrecisionConfigProto` proto. + name: an optional name for the operator + + Returns: + A tensor representing the output of the convolution. + """ + precision_config_proto = "" + if precision_config: + precision_config_proto = precision_config.SerializeToString() + return gen_xla_ops.xla_conv( + lhs, + rhs, + window_strides=window_strides, + padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + feature_group_count=feature_group_count, + dimension_numbers=dimension_numbers.SerializeToString(), + precision_config=precision_config_proto, + name=name) + + +convert_element_type = math_ops.cast + + +def dot(lhs, rhs, name=None): + return math_ops.tensordot(lhs, rhs, axes=1, name=name) + + +def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): + precision_config_proto = "" + if precision_config: + precision_config_proto = precision_config.SerializeToString() + return gen_xla_ops.xla_dot( + lhs, + rhs, + dimension_numbers=dimension_numbers.SerializeToString(), + precision_config=precision_config_proto, + name=name) + + +def dynamic_slice(x, starts, sizes, name=None): + # TODO(phawkins): the Slice operator lowers to DynamicSlice if `starts` is not + # a compile-time constant. This doesn't exactly mimic the semantics of dynamic + # slice if the slice is out of bounds. + return array_ops.slice(x, starts, sizes, name=name) -# TODO(phawkins): provide wrappers for all XLA operators. dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice +# TODO(phawkins): generalize tf.pad to support interior padding, and then remove +# the XLA-specific pad operator. +pad = gen_xla_ops.xla_pad + + +def random_normal(mu, sigma, dims, name=None): + mu = ops.convert_to_tensor(mu) + return random_ops.random_normal( + dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name) + + +def random_uniform(minval, maxval, dims, name=None): + minval = ops.convert_to_tensor(minval) + return random_ops.random_uniform( + dims, minval, maxval, dtype=minval.dtype, name=name) + + +recv = gen_xla_ops.xla_recv +reduce = gen_xla_ops.xla_reduce + def reduce_window(operand, init, @@ -61,22 +349,38 @@ def reduce_window(operand, """ window_strides = window_strides or [1] * len(window_dimensions) padding = padding or [(0, 0)] * len(window_dimensions) - padding_low = [x for (x, _) in padding] - padding_high = [y for (_, y) in padding] return gen_xla_ops.xla_reduce_window( - operand, - init, - reducer, - window_dimensions, - window_strides, - padding_low, - padding_high, + input=operand, + init_value=init, + window_dimensions=window_dimensions, + window_strides=window_strides, + padding=padding, + computation=reducer, name=name) -recv = gen_xla_ops.xla_recv +def reshape(x, new_sizes, dimensions=None, name=None): + if dimensions is not None: + x = array_ops.transpose(x, dimensions) + x = array_ops.reshape(x, new_sizes, name=name) + return x + + +def select(condition, x, y, name=None): + return array_ops.where(condition, x, y, name) + + +select_and_scatter = gen_xla_ops.xla_select_and_scatter send = gen_xla_ops.xla_send -sort = gen_xla_ops.xla_sort +def slice(x, start_dims, limit_dims, strides): + spec = [ + _slice(start, limit, stride) + for (start, limit, stride) in zip(start_dims, limit_dims, strides) + ] + return x[tuple(spec)] + + +sort = gen_xla_ops.xla_sort while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc new file mode 100644 index 0000000000000000000000000000000000000000..32ba6df2e6daa2add468a1bc0559d42606d1a9a6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "absl/algorithm/container.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { +/*static*/ StringPiece XlaResourceOpInfo::XlaResourceOpKindToString( + XlaResourceOpKind op_kind) { + switch (op_kind) { + case XlaResourceOpKind::kRead: + return "Read"; + case XlaResourceOpKind::kWrite: + return "Write"; + case XlaResourceOpKind::kReadWrite: + return "Modify"; + } +} + +static gtl::FlatMap* CreateResourceOpInfoMap() { + gtl::FlatMap* result = + new gtl::FlatMap; + + auto add = [&](StringPiece op, XlaResourceOpKind op_kind, + XlaResourceKind resource_kind) { + auto insert_result = + result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)}); + CHECK(insert_result.second); + }; + + auto kRead = XlaResourceOpKind::kRead; + auto kWrite = XlaResourceOpKind::kWrite; + auto kReadWrite = XlaResourceOpKind::kReadWrite; + + auto kVariable = XlaResourceKind::kVariable; + auto kStack = XlaResourceKind::kStack; + auto kTensorArray = XlaResourceKind::kTensorArray; + + // clang-format off + add("AssignAddVariableOp" , kReadWrite, kVariable); + add("AssignSubVariableOp" , kReadWrite, kVariable); + add("AssignVariableOp" , kWrite, kVariable); + add("ReadVariableOp" , kRead, kVariable); + add("ResourceApplyAdaMax" , kReadWrite, kVariable); + add("ResourceApplyAdadelta" , kReadWrite, kVariable); + add("ResourceApplyAdagrad" , kReadWrite, kVariable); + add("ResourceApplyAdagradDA" , kReadWrite, kVariable); + add("ResourceApplyAdam" , kReadWrite, kVariable); + add("ResourceApplyAddSign" , kReadWrite, kVariable); + add("ResourceApplyCenteredRMSProp" , kReadWrite, kVariable); + add("ResourceApplyFtrl" , kReadWrite, kVariable); + add("ResourceApplyFtrlV2" , kReadWrite, kVariable); + add("ResourceApplyGradientDescent" , kReadWrite, kVariable); + add("ResourceApplyMomentum" , kReadWrite, kVariable); + add("ResourceApplyPowerSign" , kReadWrite, kVariable); + add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable); + add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable); + add("ResourceApplyRMSProp" , kReadWrite, kVariable); + add("ResourceGather" , kRead, kVariable); + add("ResourceScatterAdd" , kReadWrite, kVariable); + add("ResourceScatterDiv" , kReadWrite, kVariable); + add("ResourceScatterMax" , kReadWrite, kVariable); + add("ResourceScatterMin" , kReadWrite, kVariable); + add("ResourceScatterMul" , kReadWrite, kVariable); + add("ResourceScatterNdAdd" , kReadWrite, kVariable); + add("ResourceScatterNdUpdate" , kReadWrite, kVariable); + add("ResourceScatterSub" , kReadWrite, kVariable); + add("ResourceScatterUpdate" , kReadWrite, kVariable); + add("ResourceStridedSliceAssign" , kReadWrite, kVariable); + add("VarIsInitializedOp" , kRead, kVariable); + add("VariableShape" , kRead, kVariable); + + add("StackV2" , kWrite, kStack); + add("StackCloseV2" , kRead, kStack); + add("StackPopV2" , kReadWrite, kStack); + add("StackPushV2" , kReadWrite, kStack); + + add("TensorArrayV3" , kWrite, kTensorArray); + add("TensorArrayConcatV3" , kRead, kTensorArray); + add("TensorArrayGatherV3" , kRead, kTensorArray); + add("TensorArrayScatterV3" , kWrite, kTensorArray); + add("TensorArrayGradV3" , kRead, kTensorArray); + add("TensorArrayCloseV3" , kRead, kTensorArray); + add("TensorArrayReadV3" , kRead, kTensorArray); + add("TensorArraySizeV3" , kRead, kTensorArray); + add("TensorArraySplitV3" , kWrite, kTensorArray); + add("TensorArrayWriteV3" , kWrite, kTensorArray); + // clang-format on + + return result; +} + +static const gtl::FlatMap& +GetStaticResourceOpInfoMap() { + static gtl::FlatMap* op_info_map = + CreateResourceOpInfoMap(); + return *op_info_map; +} + +const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op) { + const gtl::FlatMap& op_infos = + GetStaticResourceOpInfoMap(); + auto it = op_infos.find(op); + return it == op_infos.end() ? nullptr : &it->second; +} + +namespace resource_op_table_internal { +std::vector GetKnownResourceOps() { + std::vector result; + for (const auto& p : GetStaticResourceOpInfoMap()) { + result.push_back(p.first); + } + absl::c_sort(result); + return result; +} +} // namespace resource_op_table_internal +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.h b/tensorflow/compiler/tf2xla/resource_operation_table.h new file mode 100644 index 0000000000000000000000000000000000000000..7f627a64c6e8298a427cd87d25d4ba24835bf542 --- /dev/null +++ b/tensorflow/compiler/tf2xla/resource_operation_table.h @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ + +#include +#include + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" + +// Exposes information about the resource operations supported by tf2xla in a +// structured form. + +namespace tensorflow { +enum class XlaResourceOpKind { + kRead, // Only reads from resources. + kWrite, // Only writes to resources. + kReadWrite // Reads from and writes to resources. +}; + +enum class XlaResourceKind { + kVariable, // Operates on resource variables. + kStack, // Operates on stacks. + kTensorArray // Operates on tensor arrays. +}; + +class XlaResourceOpInfo { + public: + explicit XlaResourceOpInfo(XlaResourceOpKind op_kind, + XlaResourceKind resource_kind) + : op_kind_(op_kind), resource_kind_(resource_kind) {} + + XlaResourceOpKind kind() const { return op_kind_; } + XlaResourceKind resource_kind() const { return resource_kind_; } + + static StringPiece XlaResourceOpKindToString(XlaResourceOpKind op_kind); + + private: + XlaResourceOpKind op_kind_; + XlaResourceKind resource_kind_; +}; + +// Returns a XlaResourceOpInfo describing `op` if it is a resource operation +// supported by tf2xla, otherwise returns null (i.e. if this returns null then +// `op` is either not a resource operation or is unsupported by XLA). +const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op); + +namespace resource_op_table_internal { +// NB! Implementation detail exposed for unit testing, do not use. +// +// Returns the set of resource operations known by this module. +std::vector GetKnownResourceOps(); +} // namespace resource_op_table_internal + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0343f80de9fed114a0097b981233277c3e12b378 --- /dev/null +++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +bool IsResourceArgDef(const OpDef::ArgDef& arg_def) { + return arg_def.type() == DT_RESOURCE; +} + +bool HasResourceInputOrOutput(const OpDef& op_def) { + return absl::c_any_of(op_def.input_arg(), IsResourceArgDef) || + absl::c_any_of(op_def.output_arg(), IsResourceArgDef); +} + +TEST(ResourceOperationTableTest, HaveAllResourceOps) { + gtl::FlatMap known_resource_ops; + for (StringPiece known_resource_op : + resource_op_table_internal::GetKnownResourceOps()) { + ASSERT_TRUE( + known_resource_ops.insert({string(known_resource_op), false}).second); + } + + std::vector xla_op_names = XlaOpRegistry::GetAllRegisteredOps(); + for (const string& xla_op_name : xla_op_names) { + const OpDef* op_def; + TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef(xla_op_name, &op_def)); + if (HasResourceInputOrOutput(*op_def)) { + EXPECT_EQ(known_resource_ops.count(xla_op_name), 1) + << "Unknown resource op " << xla_op_name; + known_resource_ops[xla_op_name] = true; + } + } + + std::vector unnecessary_resource_ops; + for (const auto& pair : known_resource_ops) { + if (!pair.second) { + unnecessary_resource_ops.push_back(pair.first); + } + } + + EXPECT_TRUE(unnecessary_resource_ops.empty()) + << "Stale resource ops:\n" + << absl::StrJoin(unnecessary_resource_ops, "\n"); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 66835e69b23a9bf58c2212abcf6b532a2696bc10..2d7eb8b915b8245ba6573c30b2eb15b12fc3a1b4 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "absl/strings/match.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" @@ -65,8 +65,8 @@ xla::StatusOr> ParseShardingFromDevice( if (explicit_sharding.has_value()) { return explicit_sharding; } else if (!parsed_device.has_type || !parsed_device.has_id || - !str_util::StrContains(parsed_device.type, - kDeviceSuffixReplicatedCore)) { + !absl::StrContains(parsed_device.type, + kDeviceSuffixReplicatedCore)) { return absl::optional(); } else { const int core = parsed_device.id; diff --git a/tensorflow/compiler/tf2xla/str_util.cc b/tensorflow/compiler/tf2xla/str_util.cc deleted file mode 100644 index 2b0834fe7b6c4d2199267dbe0ec1f7c2785aa9c7..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/str_util.cc +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/str_util.h" - -#include -#include -#include - -namespace tensorflow { -namespace str_util { - -static void ReplaceAll(string* text, StringPiece from, StringPiece to) { - size_t pos = 0; - while ((pos = text->find(from.data(), pos, from.size())) != string::npos) { - text->replace(pos, from.size(), to.data(), to.size()); - pos += to.size(); - if (from.empty()) { - pos++; // Match at the beginning of the text and after every byte - } - } -} - -void ReplaceAllPairs(string* text, - const std::vector>& replace) { - for (const std::pair& from_to : replace) { - ReplaceAll(text, from_to.first, from_to.second); - } -} - -} // namespace str_util -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/str_util.h b/tensorflow/compiler/tf2xla/str_util.h deleted file mode 100644 index 51f25009d7003db0d72296619a469ecbbbb1808d..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/str_util.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// String utilities that are esoteric enough that they don't belong in -// third_party/tensorflow/core/lib/strings/str_util.h, but are still generally -// useful under xla. - -#ifndef TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_ -#define TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_ - -#include -#include -#include - -#include "tensorflow/core/lib/core/stringpiece.h" - -namespace tensorflow { -namespace str_util { - -// Replace all non-overlapping occurrences of the given (from,to) pairs in-place -// in text. If from is empty, it matches at the beginning of the text and after -// every byte. Each (from,to) replacement pair is processed in the order it is -// given. -void ReplaceAllPairs(string* text, - const std::vector>& replace); - -} // namespace str_util -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/str_util_test.cc b/tensorflow/compiler/tf2xla/str_util_test.cc deleted file mode 100644 index 8817f6902a8e58e796ca5240a9a24d7506d38793..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/str_util_test.cc +++ /dev/null @@ -1,60 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/str_util.h" - -#include -#include -#include - -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace str_util { - -class ReplaceAllPairsTest : public ::testing::Test { - protected: - void ExpectReplaceAllPairs( - string text, const std::vector>& replace, - StringPiece want) { - ReplaceAllPairs(&text, replace); - EXPECT_EQ(text, want); - } -}; - -TEST_F(ReplaceAllPairsTest, Simple) { - ExpectReplaceAllPairs("", {}, ""); - ExpectReplaceAllPairs("", {{"", ""}}, ""); - ExpectReplaceAllPairs("", {{"", "X"}}, "X"); - ExpectReplaceAllPairs("", {{"", "XYZ"}}, "XYZ"); - ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}}, "_X_Y_Z_"); - ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}, {"_Y_", "a"}}, "_XaZ_"); - ExpectReplaceAllPairs("banana", {}, "banana"); - ExpectReplaceAllPairs("banana", {{"", ""}}, "banana"); - ExpectReplaceAllPairs("banana", {{"", "_"}}, "_b_a_n_a_n_a_"); - ExpectReplaceAllPairs("banana", {{"", "__"}}, "__b__a__n__a__n__a__"); - ExpectReplaceAllPairs("banana", {{"a", "a"}}, "banana"); - ExpectReplaceAllPairs("banana", {{"a", ""}}, "bnn"); - ExpectReplaceAllPairs("banana", {{"a", "X"}}, "bXnXnX"); - ExpectReplaceAllPairs("banana", {{"a", "XX"}}, "bXXnXXnXX"); - ExpectReplaceAllPairs("banana", {{"a", "XX"}, {"XnX", "z"}}, "bXzzX"); - ExpectReplaceAllPairs("a{{foo}}b{{bar}}c{{foo}}", - {{"{{foo}}", "0"}, {"{{bar}}", "123456789"}}, - "a0b123456789c0"); -} - -} // namespace str_util -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 48568c825b7a0f13011d3d6e8e62ec5db026760f..f34af2d67debe8bfa4abcad19e42c55ea40c4e82 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -40,7 +41,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -197,8 +197,8 @@ Status RewriteAndPruneGraph( if (!missing_feeds.empty() || !missing_fetches.empty()) { return errors::Aborted( "Post graph-pruning", - ", missing feeds: ", str_util::Join(missing_feeds, ", "), - ", missing fetches: ", str_util::Join(missing_fetches, ", ")); + ", missing feeds: ", absl::StrJoin(missing_feeds, ", "), + ", missing fetches: ", absl::StrJoin(missing_fetches, ", ")); } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc index 7aca889a266439538c4cd1c153460e6cc871b246..567d212b5eee493d29a1817987cbd7759575386e 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc @@ -20,11 +20,11 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -54,10 +54,10 @@ void PrintSupportedOps(const string& device, const string& regen_run) { } std::sort(types.begin(), types.end()); constraints.push_back("`" + constraint.name() + "={" + - str_util::Join(types, ",") + "}`"); + absl::StrJoin(types, ",") + "}`"); } std::cout << "`" << kdef->op() << "` | " - << str_util::Join(constraints, "
") << std::endl; + << absl::StrJoin(constraints, "
") << std::endl; } std::cout << "\nTo regenerate this table, run:\n\n```shell\n" @@ -76,7 +76,7 @@ void SupportedOpsMain(int argc, char** argv, const char* regen_run) { {"device", &device, "Name of the compilation device for which to print supported ops, " "one of: " + - str_util::Join(device_names, ",")}, + absl::StrJoin(device_names, ",")}, }; string usage = Flags::Usage(argv[0], flag_list); bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index ebdf2fd741a49c5eb578e733218bd332ee480522..e284e0b191ac09f9491973166c80b731c8ea51a5 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -233,7 +233,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, // Push input nodes of the currently visited node to name_queue. for (const string& in_edge : map_entry.second->input()) { auto id = ParseTensorName(in_edge); - const string node_name = std::string(id.first); + const string node_name = string(id.first); if (feed_tensors.find(std::make_pair(node_name, id.second)) == feed_tensors.end()) { name_queue.push(node_name); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index ae51446204baf14dc03fc6305641048dbf3872b0..2b1f724dc7b2e2bb6d06115827f92bf0670955b3 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" @@ -25,16 +26,15 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace { -void ExpectErrorContains(const Status& status, StringPiece str) { +void ExpectErrorContains(const Status& status, absl::string_view str) { EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(str_util::StrContains(status.error_message(), str)) + EXPECT_TRUE(absl::StrContains(status.error_message(), str)) << "expected error: " << status.error_message() << " to contain: " << str; } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 8e7aad26865eb458a4f530133347dc909b0895f7..aa2a521d984b4f7169980241c71018afc86cb430 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -361,6 +361,9 @@ Status BuildComputation( if (retval.has_constant_value()) { output.is_constant = true; output.constant_value = retval.constant_value(); + } else if (retval.resource() != nullptr) { + output.is_constant = false; + output.input_index = retval.resource()->arg_num(); } else { output.is_constant = false; elems.push_back(retval.handle()); @@ -465,8 +468,6 @@ Status XlaCompiler::BuildArguments( // XLA computation as runtime parameters. input_mapping->clear(); input_mapping->reserve(args.size()); - std::vector resources; - resources.reserve(args.size()); // Fills in constant arguments, and computes non-constant argument order. for (std::vector::size_type i = 0; i < args.size(); @@ -485,8 +486,9 @@ Status XlaCompiler::BuildArguments( /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); arg_expression.set_resource(resource); if (arg.initialized) { - resources.push_back(i); + input_mapping->push_back(i); } + break; case XlaCompiler::Argument::kParameter: { input_mapping->push_back(i); @@ -496,14 +498,11 @@ Status XlaCompiler::BuildArguments( arg_expression.set_constant_value(arg.constant_value); break; case XlaCompiler::Argument::kInvalid: - return errors::Internal("Unreachable case in BuildArguments()"); + return errors::Internal( + "Unreachable case in BuildArguments() while filling constant args"); } } - // Append parameters containing variable values after the other runtime - // parameters. - input_mapping->insert(input_mapping->end(), resources.begin(), - resources.end()); if (input_mapping->empty()) { return Status::OK(); } @@ -620,7 +619,8 @@ Status XlaCompiler::BuildArguments( break; case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kInvalid: - return errors::Internal("Unreachable case in BuildArguments()"); + return errors::Internal( + "Unreachable case in BuildArguments() while filling handles"); } } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index fde47dbdec8161b4563645fc7386b985e1fee9d2..9e2c64fd4210b56b591e11bc3113d8b52c1d50fd 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -183,6 +183,8 @@ class XlaCompiler { struct OutputDescription { // Type and shape of the output. The shape is the unflattened shape. + // When `type` is DT_RESOURCE, `shape` is the shape of the resource + // variable's value. DataType type; TensorShape shape; @@ -190,6 +192,10 @@ class XlaCompiler { // 'Tensor' is in host memory. bool is_constant = false; Tensor constant_value; + + // When this output is a resource, i.e. `type == DT_RESOURCE`, this is + // the index of the input that contains the resource. + int input_index; }; // Describes a variable write side effect of the computation. @@ -212,9 +218,9 @@ class XlaCompiler { struct CompilationResult { // Vector that maps from the parameters of the XLA computation to their - // original argument positions. To handle compile-time constant inputs and - // resources, the parameters to the XLA computation may be a subset of the - // original arguments, and are not necessarily in the same order.) + // original argument positions. To handle compile-time constant inputs, the + // parameters to the XLA computation may be a subset of the original + // arguments. The relative ordering of parameters are maintained. std::vector input_mapping; // Input shapes of the computation. If we are flattening inputs, these are diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 7227df96499f6e8f1b5f09ad5e27aa5f7b63e8c8..be3c93ae47bf16a67ed4fac34a99997cc7888559 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/version.h" @@ -280,6 +280,54 @@ TEST_F(XlaCompilerTest, OutOfOrderGraph) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(*param0_literal, *actual_literal)); } +// Tests that the compiler doesn't reorder the parameters. +TEST_F(XlaCompilerTest, MixedOrderArguments) { + for (bool swap_order : {false, true}) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto var = + ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, swap_order ? 0 : 1); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, swap_order ? 1 : 0); + // Adds an identity op around the resource to make sure identity ops + // propagate resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + if (swap_order) { + // Even after swapping arguments, the compiler should maintain the new + // ordering of parameters. + std::swap(args[0], args[1]); + } + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompileOptions compile_options; + compile_options.always_return_tuple = false; + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1)); + } +} + TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { // Builds a graph that adds reshapes a tensor, but with the shape not // statically known. @@ -309,10 +357,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { std::move(graph), args, &result); EXPECT_FALSE(status.ok()); EXPECT_TRUE( - str_util::StrContains(status.error_message(), "depends on a parameter")) + absl::StrContains(status.error_message(), "depends on a parameter")) << status.error_message(); EXPECT_TRUE( - str_util::StrContains(status.error_message(), "[[{{node C}} = Reshape")) + absl::StrContains(status.error_message(), "[[{{node C}} = Reshape")) << status.error_message(); } @@ -727,8 +775,7 @@ TEST_F(XlaCompilerTest, UndefinedFunctionFails) { compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, /*args=*/{}, &result); EXPECT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), - "is not defined.")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined.")) << status.error_message(); } @@ -807,15 +854,40 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { ASSERT_FALSE(status.ok()); // Flib lookup failure. - EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), - "is not defined.")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined.")) << status.error_message(); // Local flib lookup failure. - EXPECT_TRUE(str_util::StrContains(StringPiece(status.error_message()), - "Attr T is not found")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "Attr T is not found")) << status.error_message(); } +void RunAndCheckVariablesComputation( + xla::Client* client, const XlaCompiler::CompilationResult& result) { + std::unique_ptr param0_literal = + xla::LiteralUtil::CreateR1({7, 42}); + std::unique_ptr param1_literal = + xla::LiteralUtil::CreateR1({-3, 101}); + std::unique_ptr param0_data = + client->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr actual_literal = + client->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr expected0 = + xla::LiteralUtil::CreateR1({5, 144}); + std::unique_ptr expected1 = + xla::LiteralUtil::CreateR1({4, 143}); + std::unique_ptr expected_literal = + xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); +} + // Tests a simple graph that reads and writes a variable. TEST_F(XlaCompilerTest, Variables) { Scope scope = Scope::NewRootScope().ExitOnError(); @@ -847,36 +919,90 @@ TEST_F(XlaCompilerTest, Variables) { // Compiles the graph. XlaCompiler compiler(DefaultOptions()); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + RunAndCheckVariablesComputation(client_, result); +} + +// Tests a simple graph that reads and writes a variable. +TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0); + auto d = ops::_Retval(scope.WithOpName("D"), var, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kResource; + args[0].resource_kind = XlaResource::kVariable; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + XlaCompiler::CompilationResult result; TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", std::move(graph), args, &result)); // Tests that the generated computation works. - std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); std::unique_ptr param1_literal = xla::LiteralUtil::CreateR1({-3, 101}); - std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); std::unique_ptr actual = - client_ - ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + client_->Execute(*result.computation, {param1_data.get()}) .ConsumeValueOrDie(); std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR1({5, 144}); - std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({4, 143}); std::unique_ptr expected_literal = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralUtil::MakeTuple({}); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } +TEST_F(XlaCompilerTest, ReturnResourceHandle) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + // Adds an identity op around the resource to make sure identity ops propagate + // resources correctly. + auto identity = ops::Identity(scope.WithOpName("VIdentity"), var); + auto write = ops::AssignAddVariableOp(scope, identity, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto r = ops::_Retval(scope.WithOpName("R"), var, 0); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 1); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + RunAndCheckVariablesComputation(client_, result); +} + xla::StatusOr> BuildTestGraph() { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); @@ -1078,9 +1204,9 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", std::move(graph), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp")) << status.error_message(); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node fill_fn}}")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node fill_fn}}")) << status.error_message(); } @@ -1103,10 +1229,10 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", std::move(graph), args, &result); ASSERT_FALSE(status.ok()); - EXPECT_TRUE(str_util::StrContains(status.error_message(), - "is not in the list of allowed values")) + EXPECT_TRUE(absl::StrContains(status.error_message(), + "is not in the list of allowed values")) << status.error_message(); - EXPECT_TRUE(str_util::StrContains(status.error_message(), "{{node Shape}}")) + EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Shape}}")) << status.error_message(); } @@ -1130,9 +1256,9 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { std::move(graph_copy), args, &result); ASSERT_FALSE(status.ok()); EXPECT_TRUE( - str_util::StrContains(status.error_message(), - "The following nodes are unreachable " - "from the source in the graph: {{node NoOp}}")) + absl::StrContains(status.error_message(), + "The following nodes are unreachable " + "from the source in the graph: {{node NoOp}}")) << status.error_message(); } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index b24e3aabbe6ba858a8bfb4dd435726984cc7b0f5..e36039ada5f5a655ccecc8a2c15bd9824b70518c 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -107,6 +107,19 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, return Status::OK(); } +Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) { + VLOG(1) << "Adding retval index " << retval_index << " with resource " + << resource->name() << ":" << resource->shape().DebugString() + << " to XLA computation"; + if (retvals_.size() <= retval_index) { + retvals_.resize(retval_index + 1); + } + XlaExpression e; + e.set_resource(resource); + retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e}; + return Status::OK(); +} + xla::XlaBuilder* XlaContext::builder() { return builder_; } Status XlaContext::CreateResource( diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 3db37afdba71342cfb20af8841a40cb54709ca73..4da891634e97dd67af0ef09ef33dbc7a4d19743b 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -86,6 +86,9 @@ class XlaContext : public ResourceBase { Status AddConstRetval(int retval_index, DataType dtype, const xla::LiteralSlice& literal); + // As for Retval, but for return values that are resource handles. + Status AddResourceRetval(int retval_index, XlaResource* resource); + // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` // constructor for a description of the remaining arguments. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 82028c8b9ca9f65a73f8b50edc0a47c7068aba9a..9e8f5f2a1adc4dd0dadf6c8f88c5e18dd0d1dc00 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -99,6 +99,25 @@ Status XlaOpKernelContext::ConstantInput(int index, index, context_->input(index).shape().dim_sizes(), constant_literal); } +static xla::StatusOr InputIndex(XlaOpKernelContext* context, + StringPiece name) { + int start, stop; + TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop)); + if (stop != start + 1) { + return errors::InvalidArgument("OpKernel used list-valued input name '", + name, + "' when single-valued input was " + "expected"); + } + return start; +} + +Status XlaOpKernelContext::ConstantInput(StringPiece name, + xla::Literal* constant_literal) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInput(index, constant_literal); +} + Status XlaOpKernelContext::ConstantInputReshaped( int index, gtl::ArraySlice new_dims, xla::Literal* constant_literal) { @@ -246,6 +265,12 @@ Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) { return LiteralToInt64Scalar(literal, out); } +Status XlaOpKernelContext::ConstantInputAsIntScalar(StringPiece name, + int64* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInputAsIntScalar(index, out); +} + Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { xla::Literal literal; TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); @@ -280,6 +305,20 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index, return LiteralToInt64Vector(literal, out); } +Status XlaOpKernelContext::ConstantInputAsIntVector(StringPiece name, + std::vector* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInputAsIntVector(index, out); +} + +Status XlaOpKernelContext::ConstantInputReshapedToIntVector( + int index, std::vector* out) { + xla::Literal literal; + TF_RETURN_IF_ERROR(ConstantInputReshaped( + index, {InputShape(index).num_elements()}, &literal)); + return LiteralToInt64Vector(literal, out); +} + Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, xla::Literal* out) { xla::Literal literal; @@ -305,6 +344,12 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, } } +Status XlaOpKernelContext::ConstantInputAsInt64Literal(StringPiece name, + xla::Literal* out) { + TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); + return ConstantInputAsInt64Literal(index, out); +} + // TODO(phawkins): validate that the dimensions form a valid shape, fail // gracefully if they do not. Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index ac9dfe3369078df7392a4ef04679f7d7beacf8bb..3e26ba4f015ee81d1e880f9c4ee1e1a3665af452 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -106,6 +106,7 @@ class XlaOpKernelContext { // expression cannot be evaluated, e.g., because it depends on unbound // parameters, returns a non-OK status. Status ConstantInput(int index, xla::Literal* constant_literal); + Status ConstantInput(StringPiece name, xla::Literal* constant_literal); // Evaluates input `index`, reshapes it to `new_shape` if new_shape != // InputShape(index), and stores it in `*constant_literal`. If the input @@ -117,15 +118,22 @@ class XlaOpKernelContext { // Converts a constant scalar int32 or int64 tensor into an int64. Status ConstantInputAsIntScalar(int index, int64* out); + Status ConstantInputAsIntScalar(StringPiece name, int64* out); // Converts a constant scalar float32 or float64 tensor into a float64. Status ConstantInputAsFloatScalar(int index, double* out); // Converts a constant 1D int32 or int64 tensor into a vector of int64s. Status ConstantInputAsIntVector(int index, std::vector* out); + Status ConstantInputAsIntVector(StringPiece name, std::vector* out); + + // Reshapes and converts a constant int32 or int64 tensor into a vector of + // int64s. + Status ConstantInputReshapedToIntVector(int index, std::vector* out); // Converts a constant int32 or int64 Tensor into an xla int64 Literal. Status ConstantInputAsInt64Literal(int index, xla::Literal* out); + Status ConstantInputAsInt64Literal(StringPiece name, xla::Literal* out); // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 46785bc1f0a1279bfd67a55844fe238d9797382b..2f3a4cd3b57fd4a1dd8959f78fb51cc3c16db1ac 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -325,6 +325,17 @@ std::vector XlaOpRegistry::DeviceKernels( return kernels; } +/*static*/ std::vector XlaOpRegistry::GetAllRegisteredOps() { + std::vector ops; + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + for (const auto& pair : registry.ops_) { + ops.push_back(pair.first); + } + std::sort(ops.begin(), ops.end()); + return ops; +} + /* static */ const std::unordered_set* XlaOpRegistry::CompileTimeConstantInputs(const string& op) { XlaOpRegistry& registry = Instance(); @@ -362,7 +373,7 @@ XlaOpRegistry& XlaOpRegistry::Instance() { XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) { registration_.reset(new XlaOpRegistry::OpRegistration); - registration_->name = std::string(name); + registration_->name = string(name); } XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) { @@ -374,14 +385,14 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( gtl::ArraySlice devices) { registration_->has_device_whitelist = true; for (StringPiece device : devices) { - registration_->device_whitelist.insert(std::string(device)); + registration_->device_whitelist.emplace(device); } return *this; } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) { registration_->has_device_whitelist = true; - registration_->device_whitelist.insert(std::string(device)); + registration_->device_whitelist.emplace(device); return *this; } @@ -398,7 +409,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() { XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( StringPiece attr_name, DataType allowed) { std::set& types = - registration_->type_constraints[std::string(attr_name)]; + registration_->type_constraints[string(attr_name)]; types.insert(allowed); return *this; } @@ -406,7 +417,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( StringPiece attr_name, gtl::ArraySlice allowed) { std::set& types = - registration_->type_constraints[std::string(attr_name)]; + registration_->type_constraints[string(attr_name)]; for (DataType t : allowed) { types.insert(t); } @@ -415,7 +426,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( StringPiece input_name) { - registration_->compile_time_constant_inputs.insert(std::string(input_name)); + registration_->compile_time_constant_inputs.emplace(input_name); return *this; } @@ -444,7 +455,7 @@ XlaBackendRegistrar::XlaBackendRegistrar( StringPiece name, gtl::ArraySlice types, XlaOpRegistry::BackendOpFilter op_filter) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); - registry.RegisterBackend(std::string(name), types, op_filter); + registry.RegisterBackend(string(name), types, op_filter); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index fc14834ca6441ea785eacc57e1f502086f36657e..6ce0e2580b1a9b75fe72fba931d80c96b3870fce 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -128,6 +128,9 @@ class XlaOpRegistry { const string& compilation_device_name, bool include_compilation_only_kernels); + // Returns all operations for which there are XLA kernels on any device. + static std::vector GetAllRegisteredOps(); + // Returns the set of compile-time constant inputs to 'op'. Returns nullptr // if the op is not registered. static const std::unordered_set* CompileTimeConstantInputs( diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 1a8fa627a02ec737b941ca9d7f5c6f46e78834d9..ddeba1d91d0872a95bf8af252e43180ca19c0567 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -113,6 +113,7 @@ cc_library( ":statusor", ":types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -173,6 +174,8 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -237,11 +240,11 @@ cc_library( ":types", ":util", ":xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -259,6 +262,7 @@ tf_cc_test( ":xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -301,6 +305,8 @@ cc_library( ":xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -320,6 +326,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -341,6 +348,7 @@ cc_library( ":xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -359,6 +367,8 @@ cc_library( ":literal_util", ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -370,6 +380,8 @@ cc_library( deps = [ ":util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -379,8 +391,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":types", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -391,6 +403,7 @@ cc_library( ":status", ":types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -413,6 +426,7 @@ cc_library( ":types", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -457,6 +471,7 @@ cc_library( ":array2d", ":types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -510,6 +525,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -529,6 +545,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -559,6 +576,7 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -629,6 +647,7 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 2d5d078aa77423cc18bab053b80a7576acbd849e..c8e483712efb48e49135f8775ef079497f68776f 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -27,12 +27,12 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -507,9 +507,7 @@ class Array { } } - pieces.push_back( - tensorflow::strings::AlphaNum(values_[calculate_index(index)]) - .data()); + pieces.push_back(absl::StrCat(values_[calculate_index(index)])); // Emit comma if it isn't the last element if (index.back() != sizes_.back() - 1) { @@ -527,7 +525,7 @@ class Array { } } } while (next_index(&index)); - return tensorflow::str_util::Join(pieces, ""); + return absl::StrJoin(pieces, ""); } private: diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index 340f94fab72a24fb39cf1dfc1d722e2ee6c3685a..782c966b4c57672d137569a318fb20ace14d493b 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -25,11 +25,10 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index a75fffc605aa0df3e1e2eeb6d3129718cbbba0e4..8557bb8fe47c8e633a59f3b802b964a45aff8823 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -26,13 +26,11 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index c8b2a1ac730f79d882e15ac8e84b20ee8a95bc68..2638dea1bdbf6554802f99491b81037a8c82b421 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -77,6 +77,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -90,6 +91,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", ], ) @@ -216,6 +219,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 25608d6616f687825db0fb3d739e52f1ade9ce52..1fdf8f6260d3f00db43647a4d4de2842d69bf833 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -400,7 +400,7 @@ StatusOr Client::ExecutionStatsAsString( int64 nanoseconds = profile.compute_time_ns(); int64 cycle_count = profile.compute_cycle_count(); double gflops = total_flops / nanoseconds; - return tensorflow::strings::StrCat( + return absl::StrCat( "[Execution Statistics] flop count: ", computation_stats.flop_count(), ", transcendental count: ", computation_stats.transcendental_count(), ", compute execution time: ", nanoseconds, " nsec", diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index b6012a0352069917063084c5c5f022ef3e8c27a1..040344c9a65de122a21831b0eb79504ab4401772 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -41,7 +41,7 @@ CompileOnlyClient::CompileAheadOfTime( metadata); } -int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) { +int64 CompileOnlyClient::PointerSizeForTriple(absl::string_view triple) { llvm::Triple llvm_triple( llvm::Triple::normalize(llvm::StringRef(triple.data(), triple.size()))); if (llvm_triple.isArch64Bit()) { diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index a551edeab0943ec5213c5cb035644c02c3cf54d7..d0c83cbfccb99755f8f5b7fa2e179f25fb73d3d1 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -57,7 +57,7 @@ class CompileOnlyClient : public Client { std::unique_ptr* metadata = nullptr); // Returns the size of a pointer in bytes for a given triple. - static int64 PointerSizeForTriple(tensorflow::StringPiece triple); + static int64 PointerSizeForTriple(absl::string_view triple); private: CompileOnlyService* compiler_service_; diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 75610a381811b2bf0f6849e0d4c39c6132105ce6..0f1745366b7c33e573aff2e66d85431b01488c49 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -59,10 +59,10 @@ string ExecutableBuildOptions::ToString() const { if (generate_hlo_graph_.has_value()) { generate_hlo_graph = generate_hlo_graph_.value(); } - return tensorflow::strings::Printf( + return absl::StrFormat( "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, " "generate_hlo_graph=%s}", - device_ordinal_, result_layout.c_str(), generate_hlo_graph.c_str()); + device_ordinal_, result_layout, generate_hlo_graph); } ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph( @@ -77,8 +77,8 @@ const absl::optional& ExecutableBuildOptions::generate_hlo_graph() } ExecutableBuildOptions& ExecutableBuildOptions::set_dump_optimized_hlo_proto_to( - tensorflow::StringPiece dirpath) { - dump_optimized_hlo_proto_to_ = dirpath.ToString(); + absl::string_view dirpath) { + dump_optimized_hlo_proto_to_ = string(dirpath); return *this; } @@ -89,8 +89,8 @@ ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { ExecutableBuildOptions& ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to( - tensorflow::StringPiece dirpath) { - dump_unoptimized_hlo_proto_to_ = dirpath.ToString(); + absl::string_view dirpath) { + dump_unoptimized_hlo_proto_to_ = string(dirpath); return *this; } @@ -100,8 +100,8 @@ ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const { } ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( - tensorflow::StringPiece dirpath) { - dump_per_pass_hlo_proto_to_ = dirpath.ToString(); + absl::string_view dirpath) { + dump_per_pass_hlo_proto_to_ = string(dirpath); return *this; } diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 904d230981c9c31177f619f7ca0c444364504b18..888d2f28ebb2cfc73a58ba07d58d10405fb76832 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -62,19 +62,19 @@ class ExecutableBuildOptions { // If set, specifies a dirpath to dump the end-of-optimization-pipeline HLO // protobuf to (as in DebugOptions). ExecutableBuildOptions& set_dump_optimized_hlo_proto_to( - tensorflow::StringPiece dirpath); + absl::string_view dirpath); const absl::optional& dump_optimized_hlo_proto_to() const; // If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO // protobuf to (as in DebugOptions). ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to( - tensorflow::StringPiece dirpath); + absl::string_view dirpath); const absl::optional& dump_unoptimized_hlo_proto_to() const; // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs // to (as in DebugOptions). ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( - tensorflow::StringPiece dirpath); + absl::string_view dirpath); const absl::optional& dump_per_pass_hlo_proto_to() const; // If true, specifies that we should record an HLO profile during execution @@ -83,7 +83,7 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_hlo_profile(bool enabled); absl::optional hlo_profile() const; - void add_disabled_hlo_pass(tensorflow::StringPiece pass_name) { + void add_disabled_hlo_pass(absl::string_view pass_name) { disabled_hlo_passes_.push_back(std::string(pass_name)); } const tensorflow::gtl::ArraySlice disabled_hlo_passes() const { diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 4d233741bd2a26fa3f275a2043c2c2a53016bed6..8736f18dcfa678f35ba9c749d373d2d4ad6a9bd6 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -31,7 +31,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -221,5 +221,6 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 9225b1acd69c214d6f08a45372a8082ed789c18c..e86c10f030f3990d67e5a6638100640f73c82307 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { @@ -39,7 +39,7 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, b = builder->CreateSubBuilder(name); } else { b = builder->CreateSubBuilder( - tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type))); + absl::StrCat(name, "_", PrimitiveType_Name(type))); } const Shape scalar = ShapeUtil::MakeShape(type, {}); diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc index 031d62e4ffef188082303a28866bbc72a154e9b1..1ada7b4a964ccf7ca400b937abbe425bef083468 100644 --- a/tensorflow/compiler/xla/client/lib/constants.cc +++ b/tensorflow/compiler/xla/client/lib/constants.cc @@ -56,7 +56,7 @@ XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) { std::numeric_limits::epsilon()); default: return builder->ReportError(InvalidArgument( - "Invalid type for Epsilon (%s).", PrimitiveType_Name(type).c_str())); + "Invalid type for Epsilon (%s).", PrimitiveType_Name(type))); } } diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h index 0c8a9b8cc02ba0c1ebdf6a060d4b99262dceb178..81624614c1e3599dfe116eb61d9e2edcd5230684 100644 --- a/tensorflow/compiler/xla/client/lib/constants.h +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -37,13 +37,13 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { primitive_util::IsComplexType(type))) { return builder->ReportError(InvalidArgument( "Invalid cast from floating point type to %s in ConstantR0WithType.", - PrimitiveType_Name(type).c_str())); + PrimitiveType_Name(type))); } if (std::is_same::value && !primitive_util::IsComplexType(type)) { return builder->ReportError(InvalidArgument( "Invalid cast from complex type to %s in ConstantR0WithType.", - PrimitiveType_Name(type).c_str())); + PrimitiveType_Name(type))); } switch (type) { case F16: @@ -71,7 +71,7 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { default: return builder->ReportError( InvalidArgument("Invalid type for ConstantR0WithType (%s).", - PrimitiveType_Name(type).c_str())); + PrimitiveType_Name(type))); } } diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc index 1c91237ae1574f92cda78c9bddc6f4ac1d68f47c..02bed8016213a12300af3183a911bb6d41c85db1 100644 --- a/tensorflow/compiler/xla/client/lib/numeric.cc +++ b/tensorflow/compiler/xla/client/lib/numeric.cc @@ -65,9 +65,8 @@ XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) { case C64: return MakeIota(builder, size); default: - return builder->ReportError( - InvalidArgument("Unimplemented type for Iota: %s.", - PrimitiveType_Name(type).c_str())); + return builder->ReportError(InvalidArgument( + "Unimplemented type for Iota: %s.", PrimitiveType_Name(type))); } } diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 081fec7ad92958aa285e4be41394d7b1876e0815..6861521acc0db1d640666a6793b898a183ab6a17 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -61,8 +61,7 @@ XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, Client* client) { - XlaBuilder b( - tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); + XlaBuilder b(absl::StrCat("make_fake_", ShapeUtil::HumanString(shape))); BuildFakeDataOpOnDevice(shape, &b); XlaComputation computation = b.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 1cd3e9b22f9cf3383cfcbc19c79acba0e5938190..db7a8fc04751bdbb4f4414948627617641f5bd90 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -59,7 +59,7 @@ Status LocalExecutable::ValidateExecutionOptions( // Check argument number, shapes, and layouts. if (arguments.size() != computation_layout.parameter_count()) { return InvalidArgument( - "invalid number of arguments for computation: expected %d, got %zu", + "invalid number of arguments for computation: expected %d, got %u", computation_layout.parameter_count(), arguments.size()); } for (int i = 0; i < arguments.size(); ++i) { @@ -71,9 +71,9 @@ Status LocalExecutable::ValidateExecutionOptions( "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) - .c_str(), - ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); + ShapeUtil::HumanString( + computation_layout.parameter_layout(i).shape()), + ShapeUtil::HumanString(arguments[i]->on_host_shape())); } } @@ -88,8 +88,7 @@ Status LocalExecutable::ValidateExecutionOptions( if (stream_platform != backend_->platform()) { return InvalidArgument( "stream is for platform %s, but service targets platform %s", - stream_platform->Name().c_str(), - backend_->platform()->Name().c_str()); + stream_platform->Name(), backend_->platform()->Name()); } // Cannot specify device_ordinal with a stream. The stream determines these @@ -120,10 +119,10 @@ Status LocalExecutable::ValidateExecutionOptions( return InvalidArgument( "executable is built for device %s of type \"%s\"; cannot run it on " "device %s of type \"%s\"", - backend_->device_name(build_device_ordinal()).c_str(), - build_executor->GetDeviceDescription().name().c_str(), - backend_->device_name(run_device_ordinal).c_str(), - run_executor->GetDeviceDescription().name().c_str()); + backend_->device_name(build_device_ordinal()), + build_executor->GetDeviceDescription().name(), + backend_->device_name(run_device_ordinal), + run_executor->GetDeviceDescription().name()); } if (!run_options.allocator()) { @@ -133,8 +132,8 @@ Status LocalExecutable::ValidateExecutionOptions( if (run_options.allocator()->platform() != backend.platform()) { return InvalidArgument( "allocator platform (%s) does not match service platform (%s)", - run_options.allocator()->platform()->Name().c_str(), - backend.platform()->Name().c_str()); + run_options.allocator()->platform()->Name(), + backend.platform()->Name()); } return Status::OK(); diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc index 6a9cf466ac0a43ce214ef0e6aae9e6295f137b0f..ed4dc8e9f6d0861adcf2fd3b45ab16a43abf56e9 100644 --- a/tensorflow/compiler/xla/client/padding.cc +++ b/tensorflow/compiler/xla/client/padding.cc @@ -31,8 +31,8 @@ Status ValidatePaddingValues( input_dimensions.size() == window_strides.size(); if (!ok) { return InvalidArgument( - "Want input dimensions size %zu = window dimensions size %zu = window " - "strides size %zu", + "Want input dimensions size %u = window dimensions size %u = window " + "strides size %u", input_dimensions.size(), window_dimensions.size(), window_strides.size()); } diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 4e7ef66dc596f1f2c53302ecc3b08a7c110a97c1..819d3249276e984329ba8b449fd07a42fe4b3123 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -23,6 +23,9 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" @@ -31,12 +34,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" namespace xla { -using tensorflow::strings::StrCat; +using absl::StrCat; namespace { @@ -70,7 +72,7 @@ XlaOp operator>>(const XlaOp& x, const XlaOp& y) { if (!ShapeUtil::ElementIsIntegral(shape)) { return InvalidArgument( "Argument to >> operator does not have an integral type (%s).", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } if (ShapeUtil::ElementIsSigned(shape)) { return ShiftRightArithmetic(x, y); @@ -223,8 +225,7 @@ XlaComputation XlaBuilder::BuildAndNoteError() { auto build_status = Build(); if (!build_status.ok()) { parent_builder_->ReportError( - AddStatus(build_status.status(), - tensorflow::strings::StrCat("error from: ", name_))); + AddStatus(build_status.status(), absl::StrCat("error from: ", name_))); return {}; } return build_status.ConsumeValueOrDie(); @@ -465,6 +466,19 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { }); } +XlaOp XlaBuilder::IotaGen(const Shape& shape, int64 iota_dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + *instr.mutable_shape() = shape; + instr.add_dimensions(iota_dimension); + return AddInstruction(std::move(instr), HloOpcode::kIota); + }); +} + +XlaOp XlaBuilder::IotaGen(PrimitiveType type, int64 size) { + return IotaGen(ShapeUtil::MakeShape(type, {size}), /*iota_dimension=*/0); +} + XlaOp XlaBuilder::Call(const XlaComputation& computation, tensorflow::gtl::ArraySlice operands) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -491,7 +505,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (!parameter_numbers_.insert(parameter_number).second) { - return InvalidArgument("parameter %lld already registered", + return InvalidArgument("parameter %d already registered", parameter_number); } instr.set_parameter_number(parameter_number); @@ -705,8 +719,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, TF_ASSIGN_OR_RETURN(const Shape& original_shape, GetShape(operand)); VLOG(3) << "original shape: " << ShapeUtil::HumanString(original_shape); - VLOG(3) << "dims to collapse: " - << tensorflow::str_util::Join(dimensions, ","); + VLOG(3) << "dims to collapse: " << absl::StrJoin(dimensions, ","); std::vector new_sizes; for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) { @@ -717,8 +730,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, } } - VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") - << "]"; + VLOG(3) << "new sizes: [" << absl::StrJoin(new_sizes, ",") << "]"; return Reshape(operand, new_sizes); }); @@ -767,7 +779,7 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { if (!ShapeUtil::IsTuple(tuple_shape)) { return InvalidArgument( "Operand to GetTupleElement() is not a tuple; got %s", - ShapeUtil::HumanString(tuple_shape).c_str()); + ShapeUtil::HumanString(tuple_shape)); } *instr.mutable_shape() = ShapeUtil::GetTupleElementShape(tuple_shape, index); @@ -848,16 +860,14 @@ Status XlaBuilder::VerifyConvolution( return InvalidArgument( "Convolution arguments must have same number of " "dimensions. Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str()); + ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape)); } int num_dims = ShapeUtil::Rank(lhs_shape); if (num_dims < 2) { return InvalidArgument( "Convolution expects argument arrays with >= 3 dimensions. " "Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str()); + ShapeUtil::HumanString(lhs_shape), ShapeUtil::HumanString(rhs_shape)); } int num_spatial_dims = num_dims - 2; @@ -871,7 +881,7 @@ Status XlaBuilder::VerifyConvolution( } for (int i = 0; i < numbers.size(); ++i) { if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) { - return InvalidArgument("Convolution %s[%d] is out of bounds: %lld", + return InvalidArgument("Convolution %s[%d] is out of bounds: %d", field_name, i, numbers.Get(i)); } } @@ -1013,12 +1023,11 @@ StatusOr XlaBuilder::MakeWindow( return Status::OK(); } else { return InvalidArgument( - "%s", tensorflow::strings::StrCat( + "%s", absl::StrCat( "Window has different number of window dimensions than of ", x_name, "\nNumber of window dimensions: ", window_dimensions.size(), - "\nNumber of ", x_name, ": ", x, "\n") - .c_str()); + "\nNumber of ", x_name, ": ", x, "\n")); } }; TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides")); @@ -1194,8 +1203,8 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout, if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { return InvalidArgument( "Outfeed shape %s must be compatible with operand shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), - ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + ShapeUtil::HumanStringWithLayout(shape_with_layout), + ShapeUtil::HumanStringWithLayout(operand_shape)); } *instr.mutable_outfeed_shape() = shape_with_layout; @@ -1247,8 +1256,8 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token, if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { return InvalidArgument( "Outfeed shape %s must be compatible with operand shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), - ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + ShapeUtil::HumanStringWithLayout(shape_with_layout), + ShapeUtil::HumanStringWithLayout(operand_shape)); } *instr.mutable_outfeed_shape() = shape_with_layout; @@ -1283,11 +1292,11 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, const Shape& shape) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - if (tensorflow::str_util::StartsWith(call_target_name, "$")) { + if (absl::StartsWith(call_target_name, "$")) { return InvalidArgument( "Invalid custom_call_target \"%s\": Call targets that start with '$' " "are reserved for internal use.", - call_target_name.c_str()); + call_target_name); } *instr.mutable_shape() = shape; instr.set_custom_call_target(call_target_name); @@ -1591,7 +1600,7 @@ XlaOp XlaBuilder::RngOp(RandomDistribution distribution, if (parameters.size() != 2) { return InvalidArgument( "RNG distribution (%s) expects 2 parameters, but got %ld", - RandomDistribution_Name(distribution).c_str(), parameters.size()); + RandomDistribution_Name(distribution), parameters.size()); } break; default: @@ -1975,6 +1984,27 @@ XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension, }); } +XlaOp XlaBuilder::CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferCollectivePermuteShape(operand_shape)); + + for (const auto& pair : source_target_pairs) { + auto* proto_pair = instr.add_source_target_pairs(); + proto_pair->set_source(pair.first); + proto_pair->set_target(pair.second); + } + + return AddInstruction(std::move(instr), HloOpcode::kCollectivePermute, + {operand}); + }); +} + XlaOp XlaBuilder::SelectAndScatter( const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice window_dimensions, @@ -2141,13 +2171,13 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token, if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) { return InvalidArgument( "SendToHost shape %s must be compatible with operand shape %s", - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(), - ShapeUtil::HumanStringWithLayout(operand_shape).c_str()); + ShapeUtil::HumanStringWithLayout(shape_with_layout), + ShapeUtil::HumanStringWithLayout(operand_shape)); } // TODO(b/111544877): Support tuple shapes. if (!ShapeUtil::IsArray(operand_shape)) { return InvalidArgument("SendToHost only supports array shapes, shape: %s", - ShapeUtil::HumanString(operand_shape).c_str()); + ShapeUtil::HumanString(operand_shape)); } if (handle.type() != ChannelHandle::DEVICE_TO_HOST) { @@ -2186,7 +2216,7 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape, if (!ShapeUtil::IsArray(shape)) { return InvalidArgument( "RecvFromHost only supports array shapes, shape: %s", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } if (handle.type() != ChannelHandle::HOST_TO_DEVICE) { @@ -2241,7 +2271,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( "of being evaluated at XLA compile time.\n\n" "Please file a usability bug with the framework being used (e.g. " "TensorFlow).", - op_string.c_str()); + op_string); } TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, @@ -2349,8 +2379,8 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)}) .size() != 4) { return FailedPrecondition( - "dimension numbers for the input are not unique: (%lld, %lld, %lld, " - "%lld)", + "dimension numbers for the input are not unique: (%d, %d, %d, " + "%d)", dnum.input_batch_dimension(), dnum.input_feature_dimension(), dnum.input_spatial_dimensions(0), dnum.input_spatial_dimensions(1)); } @@ -2360,8 +2390,8 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { dnum.kernel_spatial_dimensions(1)}) .size() != 4) { return FailedPrecondition( - "dimension numbers for the weight are not unique: (%lld, %lld, %lld, " - "%lld)", + "dimension numbers for the weight are not unique: (%d, %d, %d, " + "%d)", dnum.kernel_output_feature_dimension(), dnum.kernel_input_feature_dimension(), dnum.kernel_spatial_dimensions(0), dnum.kernel_spatial_dimensions(1)); @@ -2372,8 +2402,8 @@ XlaBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { dnum.output_spatial_dimensions(1)}) .size() != 4) { return FailedPrecondition( - "dimension numbers for the output are not unique: (%lld, %lld, %lld, " - "%lld)", + "dimension numbers for the output are not unique: (%d, %d, %d, " + "%d)", dnum.output_batch_dimension(), dnum.output_feature_dimension(), dnum.output_spatial_dimensions(0), dnum.output_spatial_dimensions(1)); } @@ -2393,13 +2423,11 @@ StatusOr XlaBuilder::AddInstruction( } for (const auto& operand : operands) { if (operand.builder_ == nullptr) { - return InvalidArgument("invalid XlaOp with handle %lld", - operand.handle()); + return InvalidArgument("invalid XlaOp with handle %d", operand.handle()); } if (operand.builder_ != this) { return InvalidArgument("Do not add XlaOp from builder %s to builder %s", - operand.builder_->name().c_str(), - this->name().c_str()); + operand.builder_->name(), this->name()); } instr.add_operand_ids(operand.handle()); } @@ -2429,18 +2457,18 @@ StatusOr XlaBuilder::LookUpInstruction( if (op.builder_ == nullptr) { return InvalidArgument( - "invalid XlaOp with handle %lld; the builder of this op is freed", + "invalid XlaOp with handle %d; the builder of this op is freed", op.handle()); } if (op.builder_ != this) { return InvalidArgument( - "XlaOp with handle %lld is built by builder '%s', but is trying to use " + "XlaOp with handle %d is built by builder '%s', but is trying to use " "it in builder '%s'", - op.handle(), op.builder_->name().c_str(), this->name().c_str()); + op.handle(), op.builder_->name(), this->name()); } if (op.handle() >= instructions_.size() || op.handle() < 0) { - return InvalidArgument("no XlaOp value %lld", op.handle()); + return InvalidArgument("no XlaOp value %d", op.handle()); } return &instructions_[op.handle()]; } @@ -2788,6 +2816,12 @@ XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, split_count, replica_groups); } +XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs) { + return operand.builder()->CollectivePermute(operand, source_target_pairs); +} + XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, @@ -3002,10 +3036,11 @@ XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale, } XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size) { - HloInstructionProto instr; - *instr.mutable_shape() = ShapeUtil::MakeShape(type, {size}); - return builder->ReportErrorOrReturn( - builder->AddInstruction(std::move(instr), HloOpcode::kIota)); + return builder->IotaGen(type, size); +} + +XlaOp IotaGen(XlaBuilder* builder, const Shape& shape, int64 iota_dimension) { + return builder->IotaGen(shape, iota_dimension); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index adb62f5f02a11f7c432d342d7db76d7ddc793df7..193d8ed07198f0785cad4b2008b72e173f41643f 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" @@ -711,12 +711,16 @@ class XlaBuilder { const absl::optional& channel_id = absl::nullopt); // Enqueues an operation that do an Alltoall of the operand cross cores. - // - // TODO(b/110096724): This is NOT YET ready to use. XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups); + // Enqueues an operation that do an CollectivePermute of the operand cross + // cores. + XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, @@ -796,6 +800,12 @@ class XlaBuilder { // entry was NaN. XlaOp IsFinite(const XlaOp& operand); + // Enqueues an iota operation onto the computation. + XlaOp IotaGen(const Shape& shape, int64 iota_dimension); + + // Enqueues a rank-1 iota operation onto the computation. + XlaOp IotaGen(PrimitiveType type, int64 size); + // Enqueues a convert instruction onto the computation that changes the // element type of the operand array to primitive_type. XlaOp ConvertElementType(const XlaOp& operand, @@ -1262,6 +1272,9 @@ class XlaBuilder { friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups); + friend XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs); friend XlaOp SelectAndScatter( const XlaOp& operand, const XlaComputation& select, tensorflow::gtl::ArraySlice window_dimensions, @@ -1297,6 +1310,8 @@ class XlaBuilder { friend XlaOp IsFinite(const XlaOp& operand); // TODO(b/64798317): Finish CPU & GPU implementation, then replace xla::Iota // in xla/client/lib/numeric.h with this (renamed to xla::Iota). + friend XlaOp IotaGen(XlaBuilder* builder, const Shape& shape, + int64 iota_dimension); friend XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size); friend XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); @@ -1859,12 +1874,22 @@ XlaOp CrossReplicaSum( const absl::optional& channel_id = absl::nullopt); // Enqueues an operation that do an Alltoall of the operand cross cores. -// -// TODO(b/110096724): This is NOT YET ready to use. XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, const std::vector& replica_groups = {}); +// Enqueues an collective operation that sends and receives data cross replicas. +// +// - `source_target_pair`: a list of (source_replica_id, target_replica_id) +// pairs. For each pair, the operand is sent from source replica to target +// replica. Note that, 1) any two pairs should not have the same target replica +// id, and they should not have the same source replica id; 2) if a replica id +// is not a target in any pair, then the output on that replica is a tensor +// consists of 0(s) with the same shape as the input. +XlaOp CollectivePermute( + const XlaOp& operand, + const std::vector>& source_target_pairs); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, @@ -1943,6 +1968,12 @@ XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs, // entry was NaN. XlaOp IsFinite(const XlaOp& operand); +// Enqueues an iota operation onto the computation. +XlaOp IotaGen(XlaBuilder* builder, const Shape& shape, int64 iota_dimension); + +// Enqueues a rank-1 iota operation onto the computation. +XlaOp IotaGen(XlaBuilder* builder, PrimitiveType type, int64 size); + // Enqueues a convert instruction onto the computation that changes the // element type of the operand array to primitive_type. XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type); diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 49a15ec3b449bdec07aa6ecfbc40b7b9f62c3f4e..7c37ed00cd3dcc214fb0b36c0161d3c39a5bf8c8 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -320,6 +320,15 @@ TEST_F(XlaBuilderTest, AllToAll) { ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8}))); } +TEST_F(XlaBuilderTest, CollectivePermute) { + XlaBuilder b(TestName()); + auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); + CollectivePermute(x, {{0, 1}, {1, 2}, {2, 3}}); + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute); +} + TEST_F(XlaBuilderTest, ReportError) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x"); diff --git a/tensorflow/compiler/xla/device_util.h b/tensorflow/compiler/xla/device_util.h index 1a51fdee680721a4a03fa5de79a81746d92af76b..6d51126d882f87a84b054e9db599b995868824bf 100644 --- a/tensorflow/compiler/xla/device_util.h +++ b/tensorflow/compiler/xla/device_util.h @@ -21,8 +21,8 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -30,8 +30,8 @@ namespace xla { // Returns a string that represents the device in terms of platform and ordinal; // e.g. the first CUDA device will be "cuda:0" string DeviceIdentifier(se::StreamExecutor* stream_exec) { - return tensorflow::strings::StrCat(stream_exec->platform()->Name(), ":", - stream_exec->device_ordinal()); + return absl::StrCat(stream_exec->platform()->Name(), ":", + stream_exec->device_ordinal()); } } // namespace xla diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index ffd1fb79e986f82e1c2721f0eefbf3b4c0838e41..693dcb3a3eef37f92533f1add850395e51d4b910 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -36,7 +36,7 @@ namespace xla { DCHECK_GE(multi_index[i], 0); DCHECK_LT(multi_index[i], shape.dimensions(i)) << "indexing beyond extent in dimension " << i << ":" - << "\n\tindex: " << tensorflow::str_util::Join(multi_index, ",") + << "\n\tindex: " << absl::StrJoin(multi_index, ",") << "\n\tshape: " << ShapeUtil::HumanString(shape); } diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index b72d190d54591384392e79e73e90cf52df04a902..cce1838ef35865bc54d2d01365949dfd6b6f3a54 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -31,8 +33,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -169,7 +169,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } else if (ShapeUtil::IsArray(shape)) { if (!shape.has_layout()) { return InvalidArgument("shape %s does not have a layout", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } return ValidateLayoutForShape(shape.layout(), shape); } else { @@ -177,7 +177,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { if (shape.has_layout()) { return InvalidArgument( "shape of primitive type %s should not have a layout", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return Status::OK(); } @@ -194,7 +194,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { layout.padded_dimensions_size() != 0) { return InvalidArgument( "shape of primitive type %s should not have a non-trivial layout", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return Status::OK(); } @@ -202,17 +202,17 @@ Layout CreateDefaultLayoutForRank(int64 rank) { if (layout.format() == INVALID_FORMAT) { return InvalidArgument( "Layout does not have a valid format: layout {%s}, shape {%s}", - layout.ShortDebugString().c_str(), shape.ShortDebugString().c_str()); + layout.ShortDebugString(), shape.ShortDebugString()); } if (layout.format() == DENSE) { if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) { return InvalidArgument( "layout minor_to_major field contains %d elements, " - "but shape is rank %lld: {%s}; shape: %s", + "but shape is rank %d: {%s}; shape: %s", layout.minor_to_major_size(), ShapeUtil::Rank(shape), - tensorflow::str_util::Join(layout.minor_to_major(), ", ").c_str(), - shape.ShortDebugString().c_str()); + absl::StrJoin(layout.minor_to_major(), ", "), + shape.ShortDebugString()); } std::vector dimensions_in_layout(ShapeUtil::Rank(shape), false); @@ -221,12 +221,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) { if (dim < 0 || dim >= ShapeUtil::Rank(shape)) { return InvalidArgument( "layout minor_to_major field has out-of-bounds value: %s", - HumanString(layout).c_str()); + HumanString(layout)); } if (dimensions_in_layout[dim]) { return InvalidArgument( "layout minor_to_major field has duplicate values: {%s}", - HumanString(layout).c_str()); + HumanString(layout)); } dimensions_in_layout[dim] = true; } @@ -234,14 +234,14 @@ Layout CreateDefaultLayoutForRank(int64 rank) { if (layout.padded_dimensions_size() > 0) { if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) { return InvalidArgument( - "layout has %d padded dimensions, but shape is rank %lld", + "layout has %d padded dimensions, but shape is rank %d", layout.padded_dimensions_size(), ShapeUtil::Rank(shape)); } for (int i = 0; i < layout.padded_dimensions_size(); ++i) { if (layout.padded_dimensions(i) < shape.dimensions(i)) { return InvalidArgument( - "for dimension %d, dimension padding (%lld) is smaller than " - "the dimension size (%lld) of the shape", + "for dimension %d, dimension padding (%d) is smaller than " + "the dimension size (%d) of the shape", i, layout.padded_dimensions(i), shape.dimensions(i)); } } @@ -403,12 +403,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ string LayoutUtil::HumanString(const Layout& layout) { if (IsSparse(layout)) { - return tensorflow::strings::StrCat("sparse{", layout.max_sparse_elements(), - "}"); + return absl::StrCat("sparse{", layout.max_sparse_elements(), "}"); } CHECK(IsDense(layout)); - return tensorflow::strings::StrCat( - "{", tensorflow::str_util::Join(layout.minor_to_major(), ","), "}"); + return absl::StrCat("{", absl::StrJoin(layout.minor_to_major(), ","), "}"); } namespace { diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD index 89353448e29ec3d97275dac288e23aa8e96e31b2..3e79129aafd234e5eab05d205f2017b54057795e 100644 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -26,6 +26,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -39,6 +40,7 @@ tf_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings:str_format", ], ) @@ -56,6 +58,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -73,5 +76,7 @@ tf_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 5d27e4a46b57242c96ee84d37466ffb7d613a974..0d3136b0cc6a3a695eacb98c16200e46a144c571 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -17,9 +17,9 @@ limitations under the License. #include // NOLINT(build/c++11): only using std::call_once, not mutex. #include +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h" #include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace legacy_flags { @@ -87,7 +87,7 @@ void AllocateFlags() { // Custom "sub-parser" lambda for xla_disable_hlo_passes. auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) { std::vector disabled_passes = - tensorflow::str_util::Split(comma_separated_values, ','); + absl::StrSplit(comma_separated_values, ','); for (const auto& passname : disabled_passes) { flag_values->add_xla_disable_hlo_passes(passname); } diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h index e9cf435d83d8345e974d83f8e5340dafeba8e3b2..ee7eb019c07cf898e48886955b18710146644cac 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h @@ -17,10 +17,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ #include +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace legacy_flags { @@ -30,7 +30,7 @@ template void parse_xla_backend_extra_options(T* extra_options_map, string comma_separated_values) { std::vector extra_options_parts = - tensorflow::str_util::Split(comma_separated_values, ','); + absl::StrSplit(comma_separated_values, ','); // The flag contains a comma-separated list of options; some options // have arguments following "=", some don't. @@ -59,8 +59,7 @@ void parse_xla_backend_extra_options(T* extra_options_map, inline bool parse_xla_reduce_precision_option( HloReducePrecisionOptions* options, string option_string) { // Split off "LOCATION" from remainder of string. - std::vector eq_split = - tensorflow::str_util::Split(option_string, '='); + std::vector eq_split = absl::StrSplit(option_string, '='); if (eq_split.size() != 2) { return false; } @@ -80,26 +79,25 @@ inline bool parse_xla_reduce_precision_option( } // Split off "E,M" from remainder of string. - std::vector colon_split = - tensorflow::str_util::Split(eq_split[1], ':'); + std::vector colon_split = absl::StrSplit(eq_split[1], ':'); if (colon_split.size() != 2) { return false; } // Split E and M, and parse. std::vector bitsizes; - if (!tensorflow::str_util::SplitAndParseAsInts(colon_split[0], ',', - &bitsizes) || - bitsizes.size() != 2) { - return false; + for (const auto& s : absl::StrSplit(colon_split[0], ',')) { + bitsizes.emplace_back(); + if (!absl::SimpleAtoi(s, &bitsizes.back())) { + return false; + } } options->set_exponent_bits(bitsizes[0]); options->set_mantissa_bits(bitsizes[1]); // Split off OPS comma-separated list from remainder of string, if the // remainder exists. - std::vector semicolon_split = - tensorflow::str_util::Split(colon_split[1], ';'); + std::vector semicolon_split = absl::StrSplit(colon_split[1], ';'); if (semicolon_split.size() > 2) { return false; } @@ -113,8 +111,7 @@ inline bool parse_xla_reduce_precision_option( options->add_opcodes_to_suffix(i); } } else { - std::vector opcodes = - tensorflow::str_util::Split(opcode_string, ','); + std::vector opcodes = absl::StrSplit(opcode_string, ','); for (const string& opcode : opcodes) { bool found = false; for (int i = 0; i < HloOpcodeCount(); i++) { @@ -132,8 +129,7 @@ inline bool parse_xla_reduce_precision_option( // Process the NAMES string, if it exists. if (semicolon_split.size() == 2) { - std::vector opnames = - tensorflow::str_util::Split(semicolon_split[1], ','); + std::vector opnames = absl::StrSplit(semicolon_split[1], ','); for (const string& opname : opnames) { if (opname.length() > 0) { options->add_opname_substrings_to_suffix(opname); diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc index 0ed788a9676fe9b1bd06fb3ceabf627c108a2c70..6f197aec53c7596e84437a03affa9118f22f5a1d 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace xla { diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc index 7b6ae311c1099dccb8dceb2f49743c1b185cd5ab..138c0c852e2bb0527d171f25b4d96cedc5671516 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/test.h" @@ -106,8 +106,8 @@ TEST(ParseFlagsFromEnv, File) { if (tmp_dir == nullptr) { tmp_dir = kTempDir; } - string tmp_file = tensorflow::strings::Printf("%s/parse_flags_from_env.%d", - tmp_dir, getpid()); + string tmp_file = + absl::StrFormat("%s/parse_flags_from_env.%d", tmp_dir, getpid()); FILE* fp = fopen(tmp_file.c_str(), "w"); CHECK_NE(fp, nullptr) << "can't write to " << tmp_file; for (int i = 0; kTestFlagString[i] != '\0'; i++) { diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index d54f051a1a959488fe716e17b69ba087e4020ae3..93e808469af9b3d2bee9c3aed33cb15996f2a07e 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -23,6 +23,9 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -31,19 +34,15 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -using tensorflow::strings::Printf; -using tensorflow::strings::StrCat; - namespace xla { - namespace { +using absl::StrCat; +using absl::StrFormat; + constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; // Converts between little and big endian. @@ -304,7 +303,7 @@ MutableLiteralBase::CreateFromProto(const LiteralProto& proto) { if (proto_element->tuple_literals_size() != ShapeUtil::TupleElementCount(piece->subshape())) { return InvalidArgument( - "Expected %lld tuple elements in LiteralProto, has %d", + "Expected %d tuple elements in LiteralProto, has %d", ShapeUtil::TupleElementCount(piece->subshape()), proto_element->tuple_literals_size()); } @@ -405,7 +404,7 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { default: return Unimplemented( "Copying a Literal object with element type %s is not implemented.", - PrimitiveType_Name(subshape().element_type()).c_str()); + PrimitiveType_Name(subshape().element_type())); } } return Status::OK(); @@ -421,8 +420,8 @@ Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal, if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) { return InvalidArgument( "Destination subshape incompatible with source subshape: %s vs %s", - ShapeUtil::HumanString(dest_subshape).c_str(), - ShapeUtil::HumanString(src_subshape).c_str()); + ShapeUtil::HumanString(dest_subshape), + ShapeUtil::HumanString(src_subshape)); } return root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { @@ -459,8 +458,8 @@ Status Literal::MoveFrom(Literal&& src_literal, if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) { return InvalidArgument( "Destination subshape not equal to source shape: %s vs %s", - ShapeUtil::HumanString(dest_subshape).c_str(), - ShapeUtil::HumanString(src_literal.shape()).c_str()); + ShapeUtil::HumanString(dest_subshape), + ShapeUtil::HumanString(src_literal.shape())); } src_literal.root_piece_->ForEachSubpiece( @@ -655,8 +654,8 @@ StatusOr> LiteralBase::Reshape( return InvalidArgument( "Shapes before and after Literal::Reshape have different numbers " "of elements: %s vs %s.", - ShapeUtil::HumanString(shape()).c_str(), - ShapeUtil::HumanString(output->shape()).c_str()); + ShapeUtil::HumanString(shape()), + ShapeUtil::HumanString(output->shape())); } return std::move(output); } @@ -875,9 +874,8 @@ StatusOr LiteralBase::GetIntegralAsS64( case U64: return Get(multi_index); default: - return FailedPrecondition( - "Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type()).c_str()); + return FailedPrecondition("Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type())); } } @@ -925,9 +923,8 @@ Status MutableLiteralBase::SetIntegralAsS64( Set(multi_index, value); break; default: - return FailedPrecondition( - "Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type()).c_str()); + return FailedPrecondition("Array element type is not integral: %s", + PrimitiveType_Name(shape().element_type())); } return Status::OK(); } @@ -1030,9 +1027,9 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, element_index.push_back(i); std::vector element_pieces; ToStringHelper(literal, element_index, print_layout, &element_pieces); - tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, "")); + tuple_pieces.push_back(absl::StrJoin(element_pieces, "")); } - pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n")); + pieces->push_back(absl::StrJoin(tuple_pieces, ",\n")); pieces->push_back("\n)"); return; } @@ -1056,8 +1053,7 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(": "); } else { pieces->push_back("["); - pieces->push_back( - tensorflow::str_util::Join(literal.GetSparseIndex(i), ", ")); + pieces->push_back(absl::StrJoin(literal.GetSparseIndex(i), ", ")); pieces->push_back("]: "); } pieces->push_back(literal.GetSparseElementAsString(i)); @@ -1118,9 +1114,9 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(shape_to_string(subshape)); pieces->push_back(" {\n"); for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { pieces->push_back(" {"); for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { @@ -1138,11 +1134,11 @@ void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, pieces->push_back(shape_to_string(subshape)); pieces->push_back(" {\n"); for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { - pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + pieces->push_back(StrFormat(" { /*i0=%d*/\n", i0)); for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { - pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + pieces->push_back(StrFormat(" { /*i1=%d*/\n", i1)); for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { - pieces->push_back(Printf(" { /*i2=%lld*/\n", i2)); + pieces->push_back(StrFormat(" { /*i2=%d*/\n", i2)); for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { pieces->push_back(" {"); for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { @@ -1183,7 +1179,7 @@ string LiteralBase::ToString(bool print_layout) const { std::vector pieces; CHECK(LayoutUtil::HasLayout(this->shape())); ToStringHelper(*this, {}, print_layout, &pieces); - return tensorflow::str_util::Join(pieces, ""); + return absl::StrJoin(pieces, ""); } void LiteralBase::EachCellAsString( @@ -1314,10 +1310,9 @@ StatusOr> ConvertIfDestTypeMatches( default: break; } - return Unimplemented( - "Converting from type %s to type %s is not implemented.", - PrimitiveType_Name(src_literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); + return Unimplemented("Converting from type %s to type %s is not implemented.", + PrimitiveType_Name(src_literal.shape().element_type()), + PrimitiveType_Name(primitive_dest_type)); } StatusOr> ConvertSwitch( @@ -1346,11 +1341,10 @@ StatusOr> ConvertSwitch( #undef CONVERT_IF_DEST_TYPE_MATCHES // Other types are not yet supported. default: - return Unimplemented( - "%s from type %s to type %s is not implemented.", - (bitcast ? "Bitcast converting" : "Converting"), - PrimitiveType_Name(literal.shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str()); + return Unimplemented("%s from type %s to type %s is not implemented.", + (bitcast ? "Bitcast converting" : "Converting"), + PrimitiveType_Name(literal.shape().element_type()), + PrimitiveType_Name(primitive_dest_type)); } } @@ -1368,8 +1362,8 @@ StatusOr> LiteralBase::BitcastConvert( return InvalidArgument( "Cannot bitcast convert from %s to %s, bit widths are different: %d != " "%d", - PrimitiveType_Name(shape().element_type()).c_str(), - PrimitiveType_Name(primitive_dest_type).c_str(), + PrimitiveType_Name(shape().element_type()), + PrimitiveType_Name(primitive_dest_type), primitive_util::BitWidth(shape().element_type()), primitive_util::BitWidth(primitive_dest_type)); } @@ -1436,6 +1430,12 @@ bool LiteralBase::Piece::EqualElementsInternal( bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); + if (ShapeUtil::Equal(subshape(), other.subshape()) && + LayoutUtil::IsDenseArray(subshape())) { + CHECK_EQ(size_bytes(), other.size_bytes()); + return memcmp(buffer(), other.buffer(), size_bytes()) == 0; + } + std::vector multi_index; switch (subshape().element_type()) { case PRED: diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index ed9de652994bc948efe38a8fcc3ba9bed36c9f3a..aad435ed5b288176ebada8d1bcf1cd0239e0de68 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -26,6 +26,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -40,7 +41,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 6883a6bbab4de252ba47c6d34bcecd2e75c80818..43388ac9d1be508a1400c23014077ac893b5fdcb 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -19,16 +19,16 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/casts.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" -using tensorflow::strings::Appendf; -using tensorflow::strings::Printf; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrAppendFormat; +using absl::StrCat; namespace xla { namespace literal_comparison { @@ -47,10 +47,10 @@ Status CompareFloatsBitwiseEqual( if (ulhs != urhs) { return InvalidArgument( "floating values are not bitwise-equal; and equality testing " - "was requested: %s=%g=%a vs %s=%g=%a at index %s", - StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double, - StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double, - LiteralUtil::MultiIndexAsString(multi_index).c_str()); + "was requested: %s=%g=%a vs %s=%g=%a at array index %s", + StrCat(absl::Hex(ulhs)), lhs_double, lhs_double, + StrCat(absl::Hex(urhs)), rhs_double, rhs_double, + LiteralUtil::MultiIndexAsString(multi_index)); } return Status::OK(); } @@ -65,9 +65,9 @@ Status CompareEqual(NativeT lhs, NativeT rhs, return Status::OK(); } return InvalidArgument( - "Expected equality of these values:\n %s\n %s\nat index %s", - StrCat(lhs).c_str(), StrCat(rhs).c_str(), - LiteralUtil::MultiIndexAsString(multi_index).c_str()); + "first mismatch at array index %s:\n expected value: %s\n actual " + "value: %s", + LiteralUtil::MultiIndexAsString(multi_index), StrCat(lhs), StrCat(rhs)); } // Specializations for floating types that do bitwise comparisons when equality @@ -119,7 +119,8 @@ Status Equal(LiteralSlice expected, LiteralSlice actual, Status result; for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { multi_index[dimension] = i; - result.Update(Equal(expected, actual, multi_index, dimension + 1)); + TF_RETURN_IF_ERROR( + Equal(expected, actual, multi_index, dimension + 1)); } return result; } @@ -166,12 +167,12 @@ bool NanMismatch(half expected, half actual, bool relaxed_nans) { // Converts the given floating-point value to a string. template string FpValueToString(NativeT value) { - return Printf("%8.4g", static_cast(value)); + return absl::StrFormat("%8.4g", static_cast(value)); } template <> string FpValueToString(complex64 value) { - return Printf("%8.4g + %8.4fi", value.real(), value.imag()); + return absl::StrFormat("%8.4g + %8.4fi", value.real(), value.imag()); } // Returns the absolute value of the given floating point value. This function @@ -226,13 +227,12 @@ class NearComparator { } string ToString(const Shape& shape) const { - return Printf( + return absl::StrFormat( "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", - FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), + FpValueToString(actual), FpValueToString(expected), LiteralUtil::MultiIndexAsString( IndexUtil::LinearIndexToMultidimensionalIndex(shape, - linear_index)) - .c_str(), + linear_index)), rel_error, abs_error); } }; @@ -251,17 +251,12 @@ class NearComparator { // Runs the comparison between expected and actual literals. Status Run() { - VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, ToStringTruncated(expected_)); - VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, ToStringTruncated(actual_)); - // If the shapes mismatch, we simply fail the expectation instead of // printing out data, as it's a type error rather than a value error. TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape())); if (!ShapeUtil::IsArray(expected_.shape())) { return InvalidArgument("Expected array shape; got %s.", - ShapeUtil::HumanString(expected_.shape()).c_str()); + ShapeUtil::HumanString(expected_.shape())); } mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED)); @@ -274,7 +269,7 @@ class NearComparator { } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) { miscompare_callback_(expected_, actual_, mismatches_); } - return InvalidArgument("%s", ErrorMessage().c_str()); + return InvalidArgument("%s", ErrorMessage()); } // Insert the given absolute value into the absolute value bucket vector. The @@ -413,23 +408,23 @@ class NearComparator { auto percent_string = [](float a, float b) { float pct = b == 0.0 ? 0.0 : 100.0 * a / b; - return Printf("%0.4f%%", pct); + return absl::StrFormat("%0.4f%%", pct); }; - Appendf(&out, - "\nMismatch count %lld (%s) in shape %s (%lld elements), abs bound " - "%g, rel bound %g\n", - num_mismatches_, - percent_string(num_mismatches_, element_count).c_str(), - ShapeUtil::HumanString(actual_.shape()).c_str(), - ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); + StrAppendFormat( + &out, + "\nMismatch count %d (%s) in shape %s (%d elements), abs bound " + "%g, rel bound %g\n", + num_mismatches_, percent_string(num_mismatches_, element_count), + ShapeUtil::HumanString(actual_.shape()), + ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); if (num_nan_mismatches_ > 0) { StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n"); } - Appendf(&out, "Top relative error mismatches:\n"); + StrAppendFormat(&out, "Top relative error mismatches:\n"); for (auto it = top_rel_mismatches_.rbegin(); it != top_rel_mismatches_.rend(); ++it) { - StrAppend(&out, " ", it->ToString(actual_.shape()).c_str(), "\n"); + StrAppend(&out, " ", it->ToString(actual_.shape()), "\n"); } if (!detailed_message_) { @@ -441,36 +436,37 @@ class NearComparator { for (int i = 0; i < abs_value_buckets_.size(); ++i) { const int64 bucket_size = abs_value_buckets_[i].first; const int64 bucket_mismatches = abs_value_buckets_[i].second; - string mismatch_str = bucket_mismatches > 0 - ? Printf(", mismatches %lld", bucket_mismatches) - : ""; - Appendf(&out, " %-6g <= x < %-6g : %7lld (%9s)%s\n", - kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], - bucket_size, percent_string(bucket_size, element_count).c_str(), - mismatch_str.c_str()); + string mismatch_str = + bucket_mismatches > 0 + ? absl::StrFormat(", mismatches %d", bucket_mismatches) + : ""; + StrAppendFormat(&out, " %-6g <= x < %-6g : %7d (%9s)%s\n", + kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], + bucket_size, percent_string(bucket_size, element_count), + mismatch_str); } auto print_accum_buckets = [&](const string& header, int64 total, tensorflow::gtl::ArraySlice buckets) { StrAppend(&out, header, ":\n"); - Appendf(&out, " < %-6g : %7lld (%s)\n", kErrorBucketBounds[0], - total - buckets[0], - percent_string(total - buckets[0], total).c_str()); + StrAppendFormat(&out, " < %-6g : %7d (%s)\n", kErrorBucketBounds[0], + total - buckets[0], + percent_string(total - buckets[0], total)); CHECK_EQ(buckets.size(), kErrorBucketBounds.size()); for (int i = 0; i < kErrorBucketBounds.size(); ++i) { - Appendf(&out, " >= %-6g : %7lld (%s)\n", kErrorBucketBounds[i], - buckets[i], percent_string(buckets[i], total).c_str()); + StrAppendFormat(&out, " >= %-6g : %7d (%s)\n", kErrorBucketBounds[i], + buckets[i], percent_string(buckets[i], total)); } }; - Appendf(&out, "Elements exceeding abs error bound %g: %lld (%s)\n", - error_.abs, num_abs_mismatches_, - percent_string(num_abs_mismatches_, element_count).c_str()); + StrAppendFormat(&out, "Elements exceeding abs error bound %g: %d (%s)\n", + error_.abs, num_abs_mismatches_, + percent_string(num_abs_mismatches_, element_count)); print_accum_buckets( "Relative error breakdown of elements exceeding abs error bound", num_abs_mismatches_, rel_error_buckets_); - Appendf(&out, "Elements exceeding rel error bound %g: %lld (%s)\n", - error_.rel, num_rel_mismatches_, - percent_string(num_rel_mismatches_, element_count).c_str()); + StrAppendFormat(&out, "Elements exceeding rel error bound %g: %d (%s)\n", + error_.rel, num_rel_mismatches_, + percent_string(num_rel_mismatches_, element_count)); print_accum_buckets( "Absolute error breakdown of elements exceeding rel error bound", num_rel_mismatches_, abs_error_buckets_); @@ -539,6 +535,62 @@ constexpr std::array NearComparator::kAbsValueBucketBounds; template constexpr std::array NearComparator::kErrorBucketBounds; +Status EqualHelper(const LiteralSlice& expected, const LiteralSlice& actual) { + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + std::vector multi_index(expected.shape().dimensions_size(), 0); + Status result; + switch (expected.shape().element_type()) { + case PRED: + result = Equal(expected, actual, &multi_index, 0); + break; + case U8: + result = Equal(expected, actual, &multi_index, 0); + break; + case S32: + result = Equal(expected, actual, &multi_index, 0); + break; + case S64: + result = Equal(expected, actual, &multi_index, 0); + break; + case U32: + result = Equal(expected, actual, &multi_index, 0); + break; + case U64: + result = Equal(expected, actual, &multi_index, 0); + break; + case BF16: + result = Equal(expected, actual, &multi_index, 0); + break; + case F16: + result = Equal(expected, actual, &multi_index, 0); + break; + case F32: + result = Equal(expected, actual, &multi_index, 0); + break; + case F64: + result = Equal(expected, actual, &multi_index, 0); + break; + case C64: + result = Equal(expected, actual, &multi_index, 0); + break; + case TUPLE: { + for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + result.Update(EqualHelper(LiteralSlice(expected, {i}), + LiteralSlice(actual, {i}))); + } + break; + } + case TOKEN: + // Tokens have no on-device representation and are trivially equal. + return Status::OK(); + default: + LOG(FATAL) << "Unsupported primitive type: " + << PrimitiveType_Name(expected.shape().element_type()); + } + + return result; +} + // Helper function for comparing two literals for nearness. Handles tuple-shapes // via recursion. shape_index is the ShapeIndex of expected (or actual) // currently being compared. @@ -555,17 +607,18 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, const auto actual_element = LiteralSlice(actual, {i}); ShapeIndex element_index = shape_index; element_index.push_back(i); - Status res = + Status element_result = NearHelper(expected_element, actual_element, error, detailed_message, miscompare_callback, element_index); - if (!res.ok()) { - string err_message = Printf("\nArray at shape index %s%s", - element_index.ToString().c_str(), - res.error_message().c_str()); + if (!element_result.ok()) { + element_result = InvalidArgument("Array at shape index %s, %s", + element_index.ToString(), + element_result.error_message()); if (return_status.ok()) { - return_status = res; + return_status = element_result; } else { - return_status = AppendStatus(return_status, res.error_message()); + return_status = + AppendStatus(return_status, element_result.error_message()); } } } @@ -573,10 +626,10 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, // Emit a top-level error message containing the top-level shape in case // of mismatch. int64 total_elements = RecursiveElementCount(actual.shape()); - return_status = InvalidArgument( - "\nMismatches in shape %s (%lld elements):\n%s", - ShapeUtil::HumanString(actual.shape()).c_str(), total_elements, - return_status.error_message().c_str()); + return_status = + InvalidArgument("\nMismatches in shape %s (%d elements):\n%s", + ShapeUtil::HumanString(actual.shape()), + total_elements, return_status.error_message()); } return return_status; } @@ -611,8 +664,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, } } - // Non-floating point literal. - return literal_comparison::Equal(expected, actual); + // Non-floating point, non-tuple literal. + return EqualHelper(expected, actual); } } // namespace @@ -620,14 +673,14 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, Status EqualShapes(const Shape& expected, const Shape& actual) { if (expected.element_type() != actual.element_type()) { return InvalidArgument("element type mismatch, want: %s got %s", - ShapeUtil::HumanString(expected).c_str(), - ShapeUtil::HumanString(actual).c_str()); + ShapeUtil::HumanString(expected), + ShapeUtil::HumanString(actual)); } if (ShapeUtil::IsTuple(expected)) { if (ShapeUtil::TupleElementCount(expected) != ShapeUtil::TupleElementCount(actual)) { return InvalidArgument( - "want tuple element count: %lld got tuple element count: %lld", + "want tuple element count: %d got tuple element count: %d", ShapeUtil::TupleElementCount(expected), ShapeUtil::TupleElementCount(actual)); } @@ -641,14 +694,13 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { } else if (ShapeUtil::IsArray(expected)) { if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { return InvalidArgument("want rank of %s got rank of %s", - ShapeUtil::HumanString(expected).c_str(), - ShapeUtil::HumanString(actual).c_str()); + ShapeUtil::HumanString(expected), + ShapeUtil::HumanString(actual)); } if (expected.element_type() != actual.element_type()) { - return InvalidArgument( - "mismatch in primitive type %s vs %s", - PrimitiveType_Name(expected.element_type()).c_str(), - PrimitiveType_Name(actual.element_type()).c_str()); + return InvalidArgument("mismatch in primitive type %s vs %s", + PrimitiveType_Name(expected.element_type()), + PrimitiveType_Name(actual.element_type())); } if (expected.dimensions_size() != actual.dimensions_size()) { return InvalidArgument("want dimensions_size %d got dimensions_size %d", @@ -659,8 +711,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { if (expected.dimensions(i) != actual.dimensions(i)) { return InvalidArgument( "mismatch in dimension #%d expected: %s actual: %s", i, - ShapeUtil::HumanString(expected).c_str(), - ShapeUtil::HumanString(actual).c_str()); + ShapeUtil::HumanString(expected), ShapeUtil::HumanString(actual)); } } } @@ -668,81 +719,43 @@ Status EqualShapes(const Shape& expected, const Shape& actual) { return Status::OK(); } +namespace { + +// If result is an error, extend the error message with the expected and actual +// literals. +Status EmitLiteralsInErrorMessage(const Status& result, + const LiteralSlice& expected, + const LiteralSlice& actual) { + if (result.ok()) { + return result; + } + return InvalidArgument("%s\n\nExpected literal:\n%s\n\nActual literal:\n%s", + result.error_message(), ToStringTruncated(expected), + ToStringTruncated(actual)); +} + +} // namespace + Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { VLOG(1) << "expected:"; XLA_VLOG_LINES(1, expected.ToString()); VLOG(1) << "actual:"; XLA_VLOG_LINES(1, actual.ToString()); - - TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); - std::vector multi_index(expected.shape().dimensions_size(), 0); - Status result; - switch (expected.shape().element_type()) { - case PRED: - result = Equal(expected, actual, &multi_index, 0); - break; - case U8: - result = Equal(expected, actual, &multi_index, 0); - break; - case S32: - result = Equal(expected, actual, &multi_index, 0); - break; - case S64: - result = Equal(expected, actual, &multi_index, 0); - break; - case U32: - result = Equal(expected, actual, &multi_index, 0); - break; - case U64: - result = Equal(expected, actual, &multi_index, 0); - break; - case BF16: - result = Equal(expected, actual, &multi_index, 0); - break; - case F16: - result = Equal(expected, actual, &multi_index, 0); - break; - case F32: - result = Equal(expected, actual, &multi_index, 0); - break; - case F64: - result = Equal(expected, actual, &multi_index, 0); - break; - case C64: - result = Equal(expected, actual, &multi_index, 0); - break; - case TUPLE: { - for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - result.Update( - Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}))); - } - break; - } - case TOKEN: - // Tokens have no on-device representation and are trivially equal. - return Status::OK(); - default: - LOG(FATAL) - << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " - << PrimitiveType_Name(expected.shape().element_type()); - } - - if (result.ok()) { - return Status::OK(); - } - - return AppendStatus( - result, tensorflow::strings::Printf("\nexpected: %s\nactual: %s", - ToStringTruncated(expected).c_str(), - ToStringTruncated(actual).c_str())); + Status result = EqualHelper(expected, actual); + return EmitLiteralsInErrorMessage(result, expected, actual); } Status Near(const LiteralSlice& expected, const LiteralSlice& actual, const ErrorSpec& error, bool detailed_message, const MiscompareCallback& miscompare_callback) { - return NearHelper(expected, actual, error, detailed_message, - miscompare_callback, - /*shape_index=*/{}); + VLOG(1) << "Expected literal:"; + XLA_VLOG_LINES(1, expected.ToString()); + VLOG(1) << "Actual literal:"; + XLA_VLOG_LINES(1, actual.ToString()); + Status result = + NearHelper(expected, actual, error, detailed_message, miscompare_callback, + /*shape_index=*/{}); + return EmitLiteralsInErrorMessage(result, expected, actual); } string ToStringTruncated(const LiteralSlice& literal) { diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index c5d0c2c267e06f7d10651f57496c4d1dd76eff52..e08a9d6e415d14896804371da19b891062c2ec81 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -97,42 +99,42 @@ class LiteralUtilTest : public ::testing::Test { TEST_F(LiteralUtilTest, LiteralScalarToString) { auto true_lit = LiteralUtil::CreateR0(true); - ASSERT_EQ("true", true_lit->ToString()); + EXPECT_EQ("true", true_lit->ToString()); auto false_lit = LiteralUtil::CreateR0(false); - ASSERT_EQ("false", false_lit->ToString()); + EXPECT_EQ("false", false_lit->ToString()); auto u32_lit = LiteralUtil::CreateR0(42); - ASSERT_EQ("42", u32_lit->ToString()); + EXPECT_EQ("42", u32_lit->ToString()); auto s32_lit = LiteralUtil::CreateR0(-999); - ASSERT_EQ("-999", s32_lit->ToString()); + EXPECT_EQ("-999", s32_lit->ToString()); auto f32_lit = LiteralUtil::CreateR0(3.14f); - ASSERT_EQ("3.14", f32_lit->ToString()); + EXPECT_EQ("3.14", f32_lit->ToString()); auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - ASSERT_EQ("0.5", f16_lit->ToString()); + EXPECT_EQ("0.5", f16_lit->ToString()); auto c64_lit = LiteralUtil::CreateR0({3.14f, 2.78f}); - ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); + EXPECT_EQ("(3.14, 2.78)", c64_lit->ToString()); auto bf16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - ASSERT_EQ("0.5", bf16_lit->ToString()); + EXPECT_EQ("0.5", bf16_lit->ToString()); // 3.14 will be truncated to 3.125 in bfloat16 format. auto bf16_lit_truncated = LiteralUtil::CreateR0(static_cast(3.14f)); - ASSERT_EQ("3.125", bf16_lit_truncated->ToString()); + EXPECT_EQ("3.125", bf16_lit_truncated->ToString()); auto bf16_lit_truncated2 = LiteralUtil::CreateR0(static_cast(9.001f)); - ASSERT_EQ("9", bf16_lit_truncated2->ToString()); + EXPECT_EQ("9", bf16_lit_truncated2->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - ASSERT_EQ("{101}", pred_vec->ToString()); + EXPECT_EQ("{101}", pred_vec->ToString()); } TEST_F(LiteralUtilTest, R2ToString) { @@ -142,7 +144,7 @@ TEST_F(LiteralUtilTest, R2ToString) { { 3, 4 }, { 5, 6 } })"; - ASSERT_EQ(expected, literal->ToString()); + EXPECT_EQ(expected, literal->ToString()); } TEST_F(LiteralUtilTest, R3ToString) { @@ -156,7 +158,7 @@ TEST_F(LiteralUtilTest, R3ToString) { { { 5 }, { 6 } } })"; - ASSERT_EQ(expected, literal->ToString()); + EXPECT_EQ(expected, literal->ToString()); } TEST_F(LiteralUtilTest, TupleToString) { @@ -170,7 +172,7 @@ f32[2,2] { { 3, 4 } } ))"; - ASSERT_EQ(expected, tuple->ToString()); + EXPECT_EQ(expected, tuple->ToString()); } TEST_F(LiteralUtilTest, CreateR3FromArray3d) { @@ -196,7 +198,7 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { { 9, 10 }, { 11, 12 } } })"; - ASSERT_EQ(expected, result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, CreateSparse) { @@ -249,7 +251,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { } } })"; - ASSERT_EQ(expected, result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { @@ -282,7 +284,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { } } })"; - ASSERT_EQ(expected, result); + EXPECT_EQ(expected, result); } TEST_F(LiteralUtilTest, EachCellR2F32) { @@ -1037,7 +1039,7 @@ TEST_F(LiteralUtilTest, CopyFromDifferentShapes) { auto vector = LiteralUtil::CreateR1({5.0, 7.0}); Status status = matrix->CopyFrom(*vector); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Destination subshape incompatible")); } @@ -1324,8 +1326,8 @@ TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) { auto literal = LiteralUtil::CreateR0(1234); Status status = literal->BitcastConvert(F64).status(); EXPECT_NE(Status::OK(), status); - EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(), - "bit widths are different")); + EXPECT_TRUE( + absl::StrContains(status.error_message(), "bit widths are different")); } TEST_F(LiteralUtilTest, CopyFromProto_Bool) { @@ -1392,10 +1394,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { Literal::CreateFromProto(p)); auto r = literal->data(); ASSERT_EQ(4, r.size()); - ASSERT_EQ(h1, r[0]); - ASSERT_EQ(h2, r[1]); - ASSERT_EQ(h2, r[2]); - ASSERT_EQ(h1, r[3]); + EXPECT_EQ(h1, r[0]); + EXPECT_EQ(h2, r[1]); + EXPECT_EQ(h2, r[2]); + EXPECT_EQ(h1, r[3]); } TEST_F(LiteralUtilTest, LiteralSliceTest) { @@ -1578,7 +1580,7 @@ TEST_F(LiteralUtilTest, MoveIntoTuple) { TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) { Literal literal = Literal::MoveIntoTuple({}); ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); - ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0); + EXPECT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0); } TEST_F(LiteralUtilTest, LiteralMoveAssignment) { @@ -1691,7 +1693,7 @@ TEST_F(LiteralUtilTest, InvalidProtoNoValues) { *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 3 elements in LiteralProto")); } @@ -1703,7 +1705,7 @@ TEST_F(LiteralUtilTest, InvalidProtoNoShape) { proto.add_preds(false); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape")); + EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape")); } TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { @@ -1715,7 +1717,7 @@ TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { proto.add_preds(false); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 3 elements in LiteralProto")); } @@ -1728,7 +1730,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) { proto.add_f32s(3.0); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 84 elements in LiteralProto")); } @@ -1741,7 +1743,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) { proto.add_s32s(100); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), + EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 elements in LiteralProto")); } @@ -1756,7 +1758,7 @@ TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) { proto.add_preds(false); Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout")); + EXPECT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout")); } TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { @@ -1772,7 +1774,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); + EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); } TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { @@ -1795,7 +1797,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { Status status = Literal::CreateFromProto(proto).status(); ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); + EXPECT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); } TEST_F(LiteralUtilTest, SortSparseElements) { @@ -1805,7 +1807,7 @@ TEST_F(LiteralUtilTest, SortSparseElements) { literal->AppendSparseElement({3, 4, 5}, 3.0); literal->AppendSparseElement({1, 2, 3}, 1.0); literal->SortSparseElements(); - ASSERT_EQ(literal->ToString(false), + EXPECT_EQ(literal->ToString(false), "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); } @@ -1813,27 +1815,26 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { std::vector dimensions = {10, 10, 10}; SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}}); - ASSERT_EQ( + EXPECT_EQ( LiteralUtil::CreateSparse(dimensions, indices, {true, false, true}) ->GetSparseElementAsString(1), "false"); - ASSERT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {1, 2, 3}) + EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {1, 2, 3}) ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(int64{2})); - ASSERT_EQ( + absl::StrCat(int64{2})); + EXPECT_EQ( LiteralUtil::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(double{2.0})); - ASSERT_EQ(LiteralUtil::CreateSparse(dimensions, indices, + absl::StrCat(double{2.0})); + EXPECT_EQ(LiteralUtil::CreateSparse(dimensions, indices, {half{1.0}, half{2.0}, half{3.0}}) ->GetSparseElementAsString(1), - tensorflow::strings::StrCat(static_cast(half{2.0}))); - ASSERT_EQ( - LiteralUtil::CreateSparse( - dimensions, indices, - std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) - ->GetSparseElementAsString(1), - tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); + absl::StrCat(static_cast(half{2.0}))); + EXPECT_EQ(LiteralUtil::CreateSparse( + dimensions, indices, + std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) + ->GetSparseElementAsString(1), + absl::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); } TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index d4c7b76b2819d8b6b07297351d7cd9180e764c25..931d2c631bc40c7da08c5076b2b224c5ebbe6ee6 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -31,19 +33,15 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/types.h" -using tensorflow::strings::StrCat; - namespace xla { - namespace { +using absl::StrCat; + // Return a literal with all arrays of type FromNativeT converted to type // ToNativeT in the given literal. template @@ -287,7 +285,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { } /* static */ std::unique_ptr LiteralUtil::CreateR1U8( - tensorflow::StringPiece value) { + absl::string_view value) { auto literal = absl::make_unique( ShapeUtil::MakeShape(U8, {static_cast(value.size())})); for (int i = 0; i < value.size(); ++i) { @@ -477,7 +475,7 @@ std::unique_ptr ConvertType(LiteralSlice literal) { /* static */ string LiteralUtil::MultiIndexAsString( tensorflow::gtl::ArraySlice multi_index) { - return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); + return StrCat("{", absl::StrJoin(multi_index, ","), "}"); } } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 1109021ea892a38c1134b3fee6c608c25167c675..3d28c070f29052f2686cf605e068deadd998719c 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -28,6 +28,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -43,7 +44,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -187,7 +187,7 @@ class LiteralUtil { const Array4D& values, const Layout& layout); // Creates a new vector of U8s literal value from a string. - static std::unique_ptr CreateR1U8(tensorflow::StringPiece value); + static std::unique_ptr CreateR1U8(absl::string_view value); // Creates a linspace-populated literal with the given number of rows and // columns. diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h index 3c74e070da529b7f1431e01fbaf31932f582db44..fcff48b6b18ba115a67f3141a9aea4ca461be55d 100644 --- a/tensorflow/compiler/xla/map_util.h +++ b/tensorflow/compiler/xla/map_util.h @@ -60,7 +60,7 @@ MaybeFind(const Collection& collection, if (it == collection.end()) { std::ostringstream os; os << key; - return NotFound("key not found: %s", os.str().c_str()); + return NotFound("key not found: %s", os.str()); } return {it->second}; } diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index 69ef4f7a2f3ea559a334a11cbe8392b610742bab..4eab4fa4290c270697c00be20840cf4e85459183 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -18,7 +18,8 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/strings/stringprintf.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -84,7 +85,7 @@ void MetricTableReport::WriteReportToInfoLog(double expected_metric_sum) { if (end_of_line == string::npos) { end_of_line = report.size(); } - tensorflow::StringPiece line(report.data() + pos, end_of_line - pos); + absl::string_view line(report.data() + pos, end_of_line - pos); // TODO(b/34779244): Figure out how to do this without the verbose log-line // prefix. The usual way didn't compile on open source. @@ -152,8 +153,8 @@ void MetricTableReport::AppendCategoryTable() { if (text.empty()) { text = "[no category]"; } - tensorflow::strings::StrAppend(&text, " (", category.entries.size(), " ", - entry_name_, ")"); + absl::StrAppend(&text, " (", category.entries.size(), " ", entry_name_, + ")"); AppendTableRow(text, category.metric_sum, metric_sum); // Show the top entries in the category. @@ -177,9 +178,9 @@ void MetricTableReport::AppendCategoryTable() { } const int64 remaining_categories = categories.size() - categories_shown; if (remaining_categories > 0) { - AppendTableRow(tensorflow::strings::StrCat("... (", remaining_categories, - " more categories)"), - expected_metric_sum_ - metric_sum, expected_metric_sum_); + AppendTableRow( + absl::StrCat("... (", remaining_categories, " more categories)"), + expected_metric_sum_ - metric_sum, expected_metric_sum_); } } @@ -206,9 +207,9 @@ void MetricTableReport::AppendEntryTable() { } const int64 remaining_entries = entries_.size() - entries_shown; if (remaining_entries > 0) { - AppendTableRow(tensorflow::strings::StrCat("... (", remaining_entries, - " more ", entry_name_, ")"), - expected_metric_sum_ - metric_sum, expected_metric_sum_); + AppendTableRow( + absl::StrCat("... (", remaining_entries, " more ", entry_name_, ")"), + expected_metric_sum_ - metric_sum, expected_metric_sum_); } } @@ -241,10 +242,10 @@ double MetricTableReport::UnaccountedMetric() { string MetricTableReport::MetricString(double metric) { // Round to integer and stringify. - string s1 = tensorflow::strings::StrCat(std::llround(metric)); + string s1 = absl::StrCat(std::llround(metric)); // Code below commafies the string, e.g. "1234" becomes "1,234". - tensorflow::StringPiece sp1(s1); + absl::string_view sp1(s1); string output; // Copy leading non-digit characters unconditionally. // This picks up the leading sign. @@ -263,8 +264,7 @@ string MetricTableReport::MetricString(double metric) { } string MetricTableReport::MetricPercent(double metric) { - return tensorflow::strings::Printf("%5.2f%%", - metric / expected_metric_sum_ * 100.0); + return absl::StrFormat("%5.2f%%", metric / expected_metric_sum_ * 100.0); } } // namespace xla diff --git a/tensorflow/compiler/xla/metric_table_report.h b/tensorflow/compiler/xla/metric_table_report.h index 818fb1d3fe0b8bbe1a8eba363ff6445e2f3df9d2..062d8ed99b213535ad39d840aaaf10a6fe0da84c 100644 --- a/tensorflow/compiler/xla/metric_table_report.h +++ b/tensorflow/compiler/xla/metric_table_report.h @@ -18,9 +18,8 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -108,7 +107,7 @@ class MetricTableReport { // Append all parameters to the report. template void AppendLine(Args... args) { - tensorflow::strings::StrAppend(&report_, std::forward(args)..., "\n"); + absl::StrAppend(&report_, std::forward(args)..., "\n"); } // Represents a set of entries with the same category_text. diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 55c4a80e29b7d493e676e412dfd259677169b417..6e42775f6fb08cc00d42411e7feae077f2356dd2 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -54,7 +54,7 @@ StatusOr> PackedLiteralReader::Read( if (shape.element_type() != F32) { return Unimplemented( "not yet implemented element type for packed literal reading: %s", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } auto result = absl::make_unique(literal_shape); @@ -64,7 +64,7 @@ StatusOr> PackedLiteralReader::Read( tensorflow::gtl::ArraySlice field = result->data(); char* data = tensorflow::bit_cast(field.data()); uint64 bytes = elements * sizeof(float); - tensorflow::StringPiece sp; + tensorflow::StringPiece sp; // non-absl OK auto s = file_->Read(offset_, bytes, &sp, data); offset_ += sp.size(); if (!s.ok()) { @@ -85,7 +85,7 @@ bool PackedLiteralReader::IsExhausted() const { // Try to read a single byte from offset_. If we can't, we've // exhausted the data. char single_byte[1]; - tensorflow::StringPiece sp; + tensorflow::StringPiece sp; // non-absl OK auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte); return !s.ok(); } diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index a91336c3ac920bc1f28a17e2b9835eba81c94d75..fe91dc06185d6035c3f3f46ea601b5f45b288ec3 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -39,6 +39,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/python:numpy_lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 00e36c3c86a8b46b8479ac8245405459c3cfdd81..b5fd747cfab18e58781c1f7bfbd9905f46f11926 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -251,7 +251,7 @@ StatusOr> CompiledLocalComputation::Execute( return InternalError( "Failed running replica %d (other replicas may have failed as well): " "%s.", - replica, statusor.status().ToString().c_str()); + replica, statusor.status().ToString()); } } @@ -696,8 +696,7 @@ StatusOr DestructureLocalShapedBufferTuple( "Attemped to destructure a LocalShapedBuffer that did not have a tuple " "shape; shape: %s", ShapeUtil::HumanString( - local_shaped_buffer->shaped_buffer()->on_device_shape()) - .c_str()); + local_shaped_buffer->shaped_buffer()->on_device_shape())); } DeviceMemoryAllocator* allocator = diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index e1060d54e260cfecb283da1c75f26e59c5c1d870..f6169ebf19041b4fd35a9842ba5d6ceb90d70270 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -109,6 +109,8 @@ limitations under the License. // Must be included first #include "tensorflow/python/lib/core/numpy.h" +#include "third_party/absl/strings/str_cat.h" +#include "third_party/absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -154,8 +156,8 @@ bool HandleStringAttribute(PyObject* o, return true; // The attribute is None, which we consider ok. } if (!PyString_Check(attr)) { - string message = tensorflow::strings::Printf("%s must be a string or none; got %s", - attr_name, numpy::PyObjectCppRepr(attr).c_str()); + string message = absl::StrFormat("%s must be a string or none; got %s", + attr_name, numpy::PyObjectCppRepr(attr)); PyErr_SetString(PyExc_TypeError, message.c_str()); Py_DECREF(attr); return false; // Type error, not ok. @@ -896,7 +898,7 @@ tensorflow::ImportNumpy(); if (o != Py_None) { StatusOr statusor = numpy::XlaShapeFromPyShape(o); if (!statusor.ok()) { - PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); + PyErr_SetString(PyExc_TypeError, absl::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); Py_DECREF(o); SWIG_fail; } diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 4b9970eadcb7edec90468647ab93ccb9d26236da..fc6511bef566cb6f4e0d4e52972954de0792e959 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/python/numpy_bridge.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" @@ -149,9 +151,7 @@ static int NumpyTypenum(PyObject* o) { // // NOTE: this is an internal helper for conversion to a C++, and so decrefs r. static string ExtractStringAndDecref(PyObject* r) { - auto error = [r] { - return tensorflow::strings::Printf("", r); - }; + auto error = [r] { return absl::StrFormat("", r); }; if (r == nullptr) { return error(); } @@ -191,8 +191,8 @@ StatusOr XlaShapeFromPyShape(PyObject* o) { PyObject* result = PyObject_CallMethod(o, const_cast(method.c_str()), nullptr); if (result == nullptr) { - return error(tensorflow::strings::StrCat( - "Failed to call method of shape object:", method)); + return error( + absl::StrCat("Failed to call method of shape object:", method)); } return result; }; diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 44b22a5586dee3f7dd8ea0edbf9deb2090986ac8..97fcd37f6b89d6dd737c233ef19f55a8faa1b624 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -43,6 +43,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", ], ) @@ -62,6 +63,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index 67886761813f0bb45a600661b017be91ffeade73..43fd8fe1bd0f41eb2ac5c42021a8ca4f63282646 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -23,12 +23,12 @@ limitations under the License. #include "grpcpp/create_channel.h" #include "grpcpp/security/credentials.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_stub.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/net.h" #include "tensorflow/core/platform/subprocess.h" @@ -46,7 +46,7 @@ class GRPCClientTestBase : public ::testing::Test { int port = tensorflow::internal::PickUnusedPortOrDie(); subprocess_.SetProgram( service_main_path, - {service_main_path, tensorflow::strings::Printf("--port=%d", port)}); + {service_main_path, absl::StrFormat("--port=%d", port)}); subprocess_.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_DUPPARENT); subprocess_.SetChannelAction(tensorflow::CHAN_STDERR, @@ -54,9 +54,8 @@ class GRPCClientTestBase : public ::testing::Test { CHECK(subprocess_.Start()); LOG(INFO) << "Launched subprocess"; - auto channel = - ::grpc::CreateChannel(tensorflow::strings::Printf("localhost:%d", port), - ::grpc::InsecureChannelCredentials()); + auto channel = ::grpc::CreateChannel(absl::StrFormat("localhost:%d", port), + ::grpc::InsecureChannelCredentials()); channel->WaitForConnected(gpr_time_add( gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(10, GPR_TIMESPAN))); LOG(INFO) << "Channel to server is connected on port " << port; diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc index c68c857c304138ff4318e243f66547c6acce1005..d6b5149a24c491d1e9d7cd9119b36d7eb2ad65d3 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc @@ -18,8 +18,8 @@ limitations under the License. #include "grpcpp/security/server_credentials.h" #include "grpcpp/server.h" #include "grpcpp/server_builder.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/rpc/grpc_service.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" @@ -44,7 +44,7 @@ int RealMain(int argc, char** argv) { xla::GRPCService::NewService().ConsumeValueOrDie(); ::grpc::ServerBuilder builder; - string server_address(tensorflow::strings::Printf("localhost:%d", port)); + string server_address(absl::StrFormat("localhost:%d", port)); builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials()); builder.RegisterService(service.get()); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 32723849a655f2ce64288074e755a6c254a0be0d..b68785949ca3f160fe211c689528120c8c8dd818 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -99,6 +99,7 @@ cc_library( ":bfloat16_support", ":hlo", ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", @@ -176,6 +177,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -227,6 +230,7 @@ cc_library( hdrs = ["hlo_evaluator.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_query", ":shape_inference", "//tensorflow/compiler/xla:literal", @@ -241,6 +245,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -320,6 +325,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -346,7 +352,7 @@ cc_library( deps = [ ":hlo", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -398,7 +404,7 @@ cc_library( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -460,6 +466,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -564,6 +572,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -587,6 +596,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -629,6 +639,8 @@ cc_library( "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], alwayslink = 1, ) @@ -662,6 +674,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -684,6 +698,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -735,6 +750,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -784,6 +801,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -832,6 +850,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -851,6 +870,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -886,6 +906,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -896,6 +917,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -930,6 +952,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -939,6 +963,7 @@ tf_cc_test( deps = [ ":buffer_liveness", ":hlo", + ":hlo_dataflow_analysis", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -974,6 +999,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1021,6 +1048,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1086,6 +1115,7 @@ cc_library( deps = [ ":hlo", ":hlo_casting_utils", + ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -1113,6 +1143,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1132,6 +1163,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ], ) @@ -1139,17 +1171,18 @@ tf_cc_test( name = "hlo_scheduling_test", srcs = ["hlo_scheduling_test.cc"], deps = [ - ":buffer_value", ":heap_simulator", ":hlo", + ":hlo_dce", ":hlo_ordering", + ":hlo_parser", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -1199,6 +1232,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1215,6 +1249,7 @@ cc_library( "//tensorflow/compiler/xla:util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1330,6 +1365,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -1355,6 +1391,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1419,6 +1456,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1457,6 +1495,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1497,6 +1536,7 @@ cc_library( ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -1511,6 +1551,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1701,6 +1742,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], alwayslink = True, # Contains per-platform computation placer registration ) @@ -1714,6 +1756,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1807,6 +1851,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1839,6 +1884,7 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1871,6 +1917,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/strings", ], ) @@ -1898,6 +1945,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/strings", ], ) @@ -1916,6 +1964,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1936,6 +1985,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1978,6 +2028,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -2014,6 +2065,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2034,6 +2086,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2093,6 +2146,8 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -2144,6 +2199,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -2166,6 +2223,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2233,8 +2291,10 @@ cc_library( ":hlo_pass", ":shape_inference", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -2277,7 +2337,10 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -2400,6 +2463,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -2619,6 +2684,7 @@ cc_library( hdrs = ["elemental_ir_emitter.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_module_config", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -2627,12 +2693,14 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:transform_utils", ], @@ -2666,8 +2734,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -2681,6 +2749,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2717,8 +2786,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:framework", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -2752,6 +2821,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", ], alwayslink = 1, @@ -2769,6 +2840,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2997,8 +3069,8 @@ cc_library( ":hlo_creation_utils", ":tuple_util", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -3091,13 +3163,13 @@ cc_library( cc_library( name = "source_map_util", - srcs = ["source_map_util.cc"], + srcs = [], hdrs = ["source_map_util.h"], deps = [ ":executable", "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", ], ) @@ -3114,6 +3186,7 @@ cc_library( "//tensorflow/core:ptr_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -3146,10 +3219,11 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -3164,6 +3238,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", # fixdeps: keep + "@com_google_absl//absl/strings", ], ) @@ -3182,6 +3257,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index b86b7d2e71e4d0fa6edcfffffdbfdc911ad2d90e..c236453fc77c4082be295156889e7be22f55152e 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -1989,9 +1990,9 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( VLOG(10) << "Considering folding Pad: " << pad->ToString() << "\ninto reduce-window: " << reduce_window->ToString() - << (convert != nullptr ? tensorflow::strings::StrCat( - "\nvia convert: ", convert->ToString()) - : ""); + << (convert != nullptr + ? absl::StrCat("\nvia convert: ", convert->ToString()) + : ""); // Do not fold interior padding into ReduceWindow since the backends do not // support it. diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index c48196e861a559a5abfa360841ec70b39356fa2b..b864c372fa5877ca329d2efbbf7d747c763ae2c0 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -47,7 +47,7 @@ class AlgebraicSimplifier : public HloPassInterface { enable_dot_strength_reduction_(enable_dot_strength_reduction), enable_conv_simplification_(enable_conv_simplification) {} ~AlgebraicSimplifier() override = default; - tensorflow::StringPiece name() const override { return "algsimp"; } + absl::string_view name() const override { return "algsimp"; } // Run algebraic simplification on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 427069af5f49866d4e7c818696a6912302643b54..bb63ea26d453e52a6f39551a83a36eabe9709438 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -34,13 +36,12 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" - -using ::testing::ElementsAre; namespace xla { namespace { +using ::testing::ElementsAre; + namespace op = xla::testing::opcode_matchers; AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() { @@ -51,7 +52,12 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -class AlgebraicSimplifierTest : public HloVerifiedTestBase {}; +class AlgebraicSimplifierTest : public HloVerifiedTestBase { + public: + AlgebraicSimplifierTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // Test that A + 0 is simplified to A TEST_F(AlgebraicSimplifierTest, AddZero) { @@ -2143,9 +2149,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { root->operand(0)->opcode() == HloOpcode::kDot) { auto lhs_shape = root->operand(0)->operand(0)->shape(); auto rhs_shape = root->operand(0)->operand(1)->shape(); - return tensorflow::strings::StrCat( - tensorflow::str_util::Join(lhs_shape.dimensions(), "x"), " DOT ", - tensorflow::str_util::Join(rhs_shape.dimensions(), "x")); + return absl::StrCat(absl::StrJoin(lhs_shape.dimensions(), "x"), " DOT ", + absl::StrJoin(rhs_shape.dimensions(), "x")); } return "UNEXPECTED CHANGE"; }; @@ -2660,11 +2665,10 @@ struct PadReduceWindowEffectiveBroadcastCase { bool should_become_broadcast; string ToTestCaseName() const { - return tensorflow::strings::StrCat( - tensorflow::str_util::Join(input_spatials, ","), ";", - tensorflow::str_util::Join(symmetric_pad_spatials, ","), ";", - tensorflow::str_util::Join(reduce_window_spatials, ","), ";", prepend_a, - ";", should_become_broadcast); + return absl::StrCat(absl::StrJoin(input_spatials, ","), ";", + absl::StrJoin(symmetric_pad_spatials, ","), ";", + absl::StrJoin(reduce_window_spatials, ","), ";", + prepend_a, ";", should_become_broadcast); } }; @@ -2852,7 +2856,12 @@ struct DotOfConcatTestSpec { class DotOfConcatSimplificationTest : public HloVerifiedTestBase, - public ::testing::WithParamInterface {}; + public ::testing::WithParamInterface { + public: + DotOfConcatSimplificationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // Test that we transform // dot(const, concat(A, B, C)) @@ -3025,7 +3034,12 @@ struct DotOfGatherTestSpec { class DotOfGatherSimplificationTest : public HloVerifiedTestBase, - public ::testing::WithParamInterface {}; + public ::testing::WithParamInterface { + public: + DotOfGatherSimplificationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // input: dot(DS(ctA), ctB)) // where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}. diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index d0806d24a22ce57af3116b9aaddb487ec24bfbae..1ed6142dcecdc830cb7b8386e0cc20a2ea54aa7f 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -69,8 +69,7 @@ StatusOr AllocationTracker::RegisterInternal( return InvalidArgument( "AllocationTracker for platform %s cannot register buffer from " "platform %s", - backend_->platform()->Name().c_str(), - shaped_buffer.platform()->Name().c_str()); + backend_->platform()->Name(), shaped_buffer.platform()->Name()); } } @@ -125,7 +124,7 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) { // "handle does not exist". auto it = handle_to_shaped_buffers_.find(data.handle()); if (it == handle_to_shaped_buffers_.end()) { - return NotFound("no allocation record for global data handle: %lld", + return NotFound("no allocation record for global data handle: %d", data.handle()); } for (auto& shaped_buffer : it->second) { @@ -144,7 +143,7 @@ StatusOr> AllocationTracker::DeconstructTuple( // the same for all buffers across replicas. const ShapedBuffer* shaped_buffer = replicated_buffers[0]; if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { - return InvalidArgument("global data handle %lld is not a tuple", + return InvalidArgument("global data handle %d is not a tuple", data.handle()); } // If the on-host representation is a tuple, then the on-device one should be @@ -201,14 +200,14 @@ StatusOr> AllocationTracker::ResolveInternal( VLOG(2) << "resolve:" << data.handle(); auto it = handle_to_shaped_buffers_.find(data.handle()); if (it == handle_to_shaped_buffers_.end()) { - return NotFound("no allocation record for global data handle: %lld", + return NotFound("no allocation record for global data handle: %d", data.handle()); } std::vector replicated_buffers; for (const auto& shaped_buffer : it->second) { if (shaped_buffer == nullptr) { - return InvalidArgument( - "global data handle %lld was previously deallocated", data.handle()); + return InvalidArgument("global data handle %d was previously deallocated", + data.handle()); } replicated_buffers.push_back(shaped_buffer.get()); } diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 841d0fa85bb9c548cd737e21bb988886f43378bd..a6889cb171b91de3182bc2c25bd3145d6916dc38 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -177,7 +177,7 @@ StatusOr Backend::stream_executor( } } return InvalidArgument("device %s not supported by XLA service", - device_name(device_ordinal).c_str()); + device_name(device_ordinal)); } StatusOr Backend::devices_equivalent(int device_ordinal_a, diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 1bc3796fa48c1627538474d04ef5358ba64dfce9..4a6a78daf07256684402f448725b219d5983ed9e 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -130,7 +130,7 @@ class Backend { // Return a string identifier for the given device, eg: "GPU:3". string device_name(int device_ordinal) const { - return tensorflow::strings::StrCat(platform_->Name(), ":", device_ordinal); + return absl::StrCat(platform_->Name(), ":", device_ordinal); } // Returns true if the devices with the given ordinals are equivalent from diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index be6fbcc9e361c7a07e953054ca456dbe35445f37..a16b85a0a5e3f72f54e9733bb974b01377e0c358 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -78,7 +78,7 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( return true; } -tensorflow::StringPiece BatchDotSimplification::name() const { +absl::string_view BatchDotSimplification::name() const { return "batch-dot-simplification"; } diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h index c0ca8d8ebac1a3b218e7bd4d6db02b69cfb6916f..79d37f08d3553321ebbabc44c8f2488b194954d5 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.h +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h @@ -28,7 +28,7 @@ namespace xla { class BatchDotSimplification : public HloPassInterface { public: StatusOr Run(HloModule* module) override; - tensorflow::StringPiece name() const override; + absl::string_view name() const override; private: StatusOr ElideDegenerateBatchDimensionFromBatchDot( diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc index 38f1a5d3a645f98220ec445bb9bbdf2b9b842109..b342acb0259498c2255f55da1cb7a3da700bdca4 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -24,7 +24,12 @@ namespace { namespace op = xla::testing::opcode_matchers; -class BatchDotSimplificationTest : public HloVerifiedTestBase {}; +class BatchDotSimplificationTest : public HloVerifiedTestBase { + public: + BatchDotSimplificationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; TEST_F(BatchDotSimplificationTest, ElideSingleDegenerateBatchDotDim_VectorVector) { diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 7ae202c583516443a6263403fb5460d1adbabd97..76e32174f3ee7d319df6f1f465e19d265d5330f2 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -36,7 +36,7 @@ class BatchNormExpander : public HloPassInterface { rewrite_inference_op_(rewrite_inference_op), rewrite_grad_op_(rewrite_grad_op) {} ~BatchNormExpander() = default; - tensorflow::StringPiece name() const override { return "batchnorm_expander"; } + absl::string_view name() const override { return "batchnorm_expander"; } // Run operation expander on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index f62ab12319bf2cf6d37a5133b8e07dc4052179d0..aba0d9bb5b977d89656580df46838eefb8cd6662 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h index c9398387098fad84ba28735c30e426fedd9b0cb0..5dcd31b83d24f836d31f44181f39cb8371ca1033 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h @@ -37,7 +37,7 @@ class BFloat16ConversionFolding : public HloPassInterface { : bfloat16_support_(bfloat16_support) {} ~BFloat16ConversionFolding() override = default; - tensorflow::StringPiece name() const override { return "bfloat16-fold"; } + absl::string_view name() const override { return "bfloat16-fold"; } // Run BF16 conversion folding on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index 16e99b57220cc185fbfaa75d30a0de709cf61ee7..32573ed3555204c059d092ef65b18b38b19f9ea5 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -34,11 +35,6 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; - // Special handling for cross-replica-sum and sort which can have a tuple - // output. - Status HandleCrossReplicaSum(HloInstruction* crs) override; - Status HandleSort(HloInstruction* sort) override; - static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { BFloat16NormalizationVisitor visitor(computation, bfloat16_support); @@ -150,23 +146,6 @@ Status BFloat16NormalizationVisitor::ConvertCalledComputations( return Status::OK(); } -Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( - HloInstruction* crs) { - if (!ShapeUtil::IsTuple(crs->shape())) { - return HandleInstruction(crs); - } else { - return HandleMultipleOutputs(crs); - } -} - -Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) { - if (!ShapeUtil::IsTuple(sort->shape())) { - return HandleInstruction(sort); - } else { - return HandleMultipleOutputs(sort); - } -} - Status BFloat16NormalizationVisitor::HandleMultipleOutputs( HloInstruction* hlo) { std::vector operand_types(hlo->operand_count()); @@ -380,6 +359,11 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kConditional) { return Status::OK(); } + if ((hlo->opcode() == HloOpcode::kSort || + hlo->opcode() == HloOpcode::kCrossReplicaSum) && + ShapeUtil::IsTuple(hlo->shape())) { + return HandleMultipleOutputs(hlo); + } return HandleInstruction(hlo); } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h index 2a60fe0af3218484acb95e6c69815d551350764c..30b6346312790f0a199f96f1956ba9ce3e617f72 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.h +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h @@ -31,7 +31,7 @@ class BFloat16Normalization : public HloPassInterface { : bfloat16_support_(bfloat16_support) {} ~BFloat16Normalization() override = default; - tensorflow::StringPiece name() const override { return "bf16-normalization"; } + absl::string_view name() const override { return "bf16-normalization"; } // Run BF16 normalization on the given computation. Returns whether the // computation was changed. @@ -54,7 +54,7 @@ class BFloat16MixedPrecisionRemoval : public HloPassInterface { ~BFloat16MixedPrecisionRemoval() override = default; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "bf16-mixed-precision-removal"; } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 49ae5320b0f5aca452e5b2a8f98a5f8ce80fe081..b08705d4c2b644fe1a7ba9994876fd6397f8a5df 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -76,7 +76,8 @@ class BFloat16NormalizationTest : public HloTestBase { StatusOr result = normalization.Run(module); EXPECT_IS_OK(result.status()); - HloVerifier verifier(/*allow_mixed_precision=*/true); + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true); EXPECT_IS_OK(verifier.Run(module).status()); return result.ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 02b8cad089dd8465b7af5c1014e37b77ded6949d..1ee64971ab53e1775294afde1c779369a838008a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -64,9 +64,7 @@ class BFloat16Propagation : public HloPassInterface { ~BFloat16Propagation() override = default; - tensorflow::StringPiece name() const override { - return "bfloat16-propagation"; - } + absl::string_view name() const override { return "bfloat16-propagation"; } // Runs the pass on the given module. Returns whether the module was changed // (precision reductions were added). diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 0f08e7c52b29963a55f460c578e6d3d1591a520d..b11f15ec7bdce021879c85602c6c5b05a5f3fd52 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" @@ -36,20 +38,15 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { +namespace { +using absl::StrAppend; +using absl::StrAppendFormat; using ::tensorflow::gtl::FlatMap; using ::tensorflow::gtl::FlatSet; -using ::tensorflow::strings::Appendf; using ::tensorflow::strings::HumanReadableNumBytes; -using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrAppend; - -namespace { template string ColocatedBufferSetsToString(const T& container, const char* title) { @@ -107,7 +104,7 @@ Status GatherComputationsByAllocationType( return InvalidArgument( "computation %s has conflicting allocation requirements (global " "and thread-local)", - computation->name().c_str()); + computation->name()); } if (is_thread_local) { @@ -130,7 +127,7 @@ Status GatherComputationsByAllocationType( return InvalidArgument( "computation %s cannot contain call/while op because it " "requires thread-local buffer allocations", - computation->name().c_str()); + computation->name()); } worklist.push_back(std::make_pair(subcomputation, false)); // Not thread local. @@ -147,9 +144,8 @@ Status GatherComputationsByAllocationType( true)); // Thread local. break; default: - return InternalError( - "Unexpected calling opcode: %s", - HloOpcodeString(instruction->opcode()).c_str()); + return InternalError("Unexpected calling opcode: %s", + HloOpcodeString(instruction->opcode())); } } } @@ -236,8 +232,8 @@ size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { } string BufferAllocation::Slice::ToString() const { - return tensorflow::strings::StrCat("{index:", index(), ", offset:", offset_, - ", size:", size_, "}"); + return absl::StrCat("{index:", index(), ", offset:", offset_, + ", size:", size_, "}"); } BufferAllocation::Slice BufferAllocation::GetSlice( @@ -298,7 +294,7 @@ BufferAllocationProto BufferAllocation::ToProto() const { string BufferAllocation::ToString() const { string output; - Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size()); + StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size()); if (color().value() != 0) { StrAppend(&output, ", color ", color().value()); } @@ -330,11 +326,10 @@ string BufferAllocation::ToString() const { }); for (const LogicalBuffer* buffer : sorted_buffers) { const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer); - StrAppend(&output, - tensorflow::strings::Printf( - " %s [%lld,%lld]: %s\n", buffer->ToString().c_str(), - offset_size.offset, offset_size.size, - ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str())); + StrAppend(&output, absl::StrFormat( + " %s [%d,%d]: %s\n", buffer->ToString(), + offset_size.offset, offset_size.size, + ShapeUtil::HumanStringWithLayout(buffer->shape()))); } return output; } @@ -427,7 +422,7 @@ StatusOr BufferAssignment::GetUniqueSlice( return FailedPrecondition( "BufferAllocation::Slice for instruction %s at index %s cannot " "be determined at compile-time.", - instruction->name().c_str(), index.ToString().c_str()); + instruction->name(), index.ToString()); } } else { VLOG(3) << "No allocation"; @@ -436,7 +431,7 @@ StatusOr BufferAssignment::GetUniqueSlice( if (result.allocation() == nullptr) { return FailedPrecondition( "BufferAllocation::Slice not assigned for instruction %s at index %s", - instruction->name().c_str(), index.ToString().c_str()); + instruction->name(), index.ToString()); } return result; } @@ -648,39 +643,38 @@ Status BufferAssignment::ComputeSummaryStats() { string BufferAssignment::Stats::ToString() const { string s; - Appendf(&s, "BufferAssignment stats:\n"); - Appendf(&s, " parameter allocation: %10s\n", - HumanReadableNumBytes(parameter_allocation_bytes).c_str()); - Appendf(&s, " constant allocation: %10s\n", - HumanReadableNumBytes(constant_allocation_bytes).c_str()); - Appendf(&s, " maybe_live_out allocation: %10s\n", - HumanReadableNumBytes(maybe_live_out_allocation_bytes).c_str()); - Appendf(&s, " preallocated temp allocation: %10s\n", - HumanReadableNumBytes(preallocated_temp_allocation_bytes).c_str()); + StrAppendFormat(&s, "BufferAssignment stats:\n"); + StrAppendFormat(&s, " parameter allocation: %10s\n", + HumanReadableNumBytes(parameter_allocation_bytes)); + StrAppendFormat(&s, " constant allocation: %10s\n", + HumanReadableNumBytes(constant_allocation_bytes)); + StrAppendFormat(&s, " maybe_live_out allocation: %10s\n", + HumanReadableNumBytes(maybe_live_out_allocation_bytes)); + StrAppendFormat(&s, " preallocated temp allocation: %10s\n", + HumanReadableNumBytes(preallocated_temp_allocation_bytes)); if (preallocated_temp_fragmentation_bytes >= 0) { const double percent = 100. * preallocated_temp_fragmentation_bytes / preallocated_temp_allocation_bytes; - Appendf( + StrAppendFormat( &s, " preallocated temp fragmentation: %10s (%.2f%%)\n", - HumanReadableNumBytes(preallocated_temp_fragmentation_bytes).c_str(), - percent); + HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent); } - Appendf(&s, " total allocation: %10s\n", - HumanReadableNumBytes(total_allocation_bytes).c_str()); + StrAppendFormat(&s, " total allocation: %10s\n", + HumanReadableNumBytes(total_allocation_bytes)); if (total_fragmentation_bytes >= 0) { const double percent = 100. * total_fragmentation_bytes / total_allocation_bytes; - Appendf(&s, " total fragmentation: %10s (%.2f%%)\n", - HumanReadableNumBytes(total_fragmentation_bytes).c_str(), percent); + StrAppendFormat(&s, " total fragmentation: %10s (%.2f%%)\n", + HumanReadableNumBytes(total_fragmentation_bytes), percent); } return s; } string BufferAssignment::ToString() const { string output; - tensorflow::strings::StrAppend(&output, "BufferAssignment:\n"); + absl::StrAppend(&output, "BufferAssignment:\n"); for (auto& allocation : allocations_) { - tensorflow::strings::StrAppend(&output, allocation.ToString()); + absl::StrAppend(&output, allocation.ToString()); } return output; } diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 810d597e730c1823668c81598df6138655e58b55..9b2783a214a686f3148723d19bbc94421fc8b4e4 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -28,8 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -75,27 +75,25 @@ Status BufferLiveness::Analyze() { string BufferLiveness::ToString() const { std::vector pieces; - pieces.push_back(tensorflow::strings::Printf("BufferLiveness(module=%s):", - module_->name().c_str())); + pieces.push_back( + absl::StrFormat("BufferLiveness(module=%s):", module_->name())); pieces.push_back("HloOrdering:"); pieces.push_back(hlo_ordering_->ToString()); - pieces.push_back(tensorflow::strings::Printf("Aliased buffers:")); + pieces.push_back("Aliased buffers:"); for (const LogicalBuffer* buffer : aliased_buffers_) { - pieces.push_back( - tensorflow::strings::Printf(" %s", buffer->ToString().c_str())); + pieces.push_back(absl::StrFormat(" %s", buffer->ToString())); } - pieces.push_back(tensorflow::strings::Printf("Live out buffers:")); + pieces.push_back("Live out buffers:"); for (const LogicalBuffer* buffer : maybe_live_out_buffers_) { - pieces.push_back( - tensorflow::strings::Printf(" %s", buffer->ToString().c_str())); + pieces.push_back(absl::StrFormat(" %s", buffer->ToString())); } - return tensorflow::str_util::Join(pieces, "\n"); + return absl::StrJoin(pieces, "\n"); } bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, const LogicalBuffer& b) const { - TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a)); - TF_CHECK_OK(points_to_analysis_->VerifyBuffer(b)); + TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(a)); + TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(b)); if (!hlo_ordering_->ExecutesBefore(a.instruction(), b.instruction())) { return false; diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 3ffb7de65fb63b24e8be4978063d3f9f78f3e9ac..26e26e316d6281a97f8317f8ed1d7a6f21b0d374 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -610,11 +611,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { protected: // Builds and runs a computation (see test case computation graphs below). - // Runs BufferLiveness on this computation. - // Returns whether buffer interference is detected between tuple-shaped - // parameter and root instructions at tuple element 1. - bool Run(const bool update_uses_tuple_element1, - const bool fuse_gte0 = false) { + std::unique_ptr BuildModule(const bool update_uses_tuple_element1, + const bool fuse_gte0) { auto builder = HloComputation::Builder(TestName()); // Create param0 Tuple. Shape data_shape = ShapeUtil::MakeShape(F32, {8}); @@ -645,12 +643,12 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); // Create output tuple. - auto tuple_root = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. auto module = CreateNewModule(); - module->AddEntryComputation(BuildDummyComputation()); - auto* computation = module->AddEmbeddedComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); + auto* computation = module->entry_computation(); // Create fusion instruction based on number of tuple element 1 users. if (update_uses_tuple_element1) { computation->CreateFusionInstruction( @@ -666,7 +664,14 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { computation->CreateFusionInstruction({gte0}, HloInstruction::FusionKind::kLoop); } + return module; + } + // Returns whether buffer interference is detected between tuple-shaped + // parameter and root instructions at tuple element 1. + bool Run(const bool update_uses_tuple_element1, + const bool fuse_gte0 = false) { + auto module = BuildModule(update_uses_tuple_element1, fuse_gte0); // Run BufferLiveness on 'module'. auto liveness = BufferLiveness::Run( module.get(), @@ -674,8 +679,24 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. + auto tuple_param0 = FindInstruction(module.get(), "param0"); + auto tuple_root = module->entry_computation()->root_instruction(); return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); } + bool RunWithHloDataflowAnalysis(const bool update_uses_tuple_element1, + const bool fuse_gte0 = false) { + auto module = BuildModule(update_uses_tuple_element1, fuse_gte0); + // Run BufferLiveness on 'module'. + auto dataflow = HloDataflowAnalysis::Run(*module).ConsumeValueOrDie(); + auto hlo_ordering = absl::make_unique(module.get()); + // Return whether or not buffers interference is detected between + // 'tuple_param0' and 'tuple_root' at shape index '{1}'. + auto tuple_param0 = FindInstruction(module.get(), "param0"); + auto tuple_root = module->entry_computation()->root_instruction(); + return hlo_ordering->MayInterfere( + dataflow->GetUniqueValueAt(tuple_param0, {1}), + dataflow->GetUniqueValueAt(tuple_root, {1}), *dataflow); + } }; // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion) @@ -693,6 +714,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { // TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) { EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false)); + EXPECT_FALSE( + RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false)); } // Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases @@ -712,6 +735,8 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) { // TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) { EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true)); + EXPECT_FALSE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false, + /*fuse_gte0=*/true)); } // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion) @@ -736,6 +761,7 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) { // TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) { EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true)); + EXPECT_TRUE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/true)); } class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc index 2bc556a9e270136f5f3eaf2433f8c96eeeaea0a2..fdf822c666b15afbc7553ca89d4f92ab08201869 100644 --- a/tensorflow/compiler/xla/service/buffer_value.cc +++ b/tensorflow/compiler/xla/service/buffer_value.cc @@ -17,11 +17,10 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index d6efef5f12f62733ddd3a5314249ee9262571f97..23b2a327096dfdb3c756a4acc5476ec01dcac1b3 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -18,20 +18,20 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/types.h" namespace xla { -using ::tensorflow::strings::Appendf; -using ::tensorflow::strings::StrCat; +using absl::StrAppendFormat; +using absl::StrCat; string CallContextToString(CallContext context) { switch (context) { @@ -71,10 +71,10 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { } string CallSite::ToString() const { - return StrCat(instruction()->name(), " calls in context ", - CallContextToString(context()), ": ", - tensorflow::str_util::Join( - called_computations(), ", ", + return StrCat( + instruction()->name(), " calls in context ", + CallContextToString(context()), ": ", + absl::StrJoin(called_computations(), ", ", [](string* out, const HloComputation* computation) { out->append(computation->name()); })); @@ -356,20 +356,20 @@ CallGraph::NearestAncestorsInSameComputation(HloInstruction* a, string CallGraph::ToString() const { string out; - Appendf(&out, "Call graph for module %s:\n", module_->name().c_str()); + StrAppendFormat(&out, "Call graph for module %s:\n", module_->name()); for (const CallGraphNode& node : nodes()) { - Appendf(&out, "Computation %s:\n", node.computation()->name().c_str()); - Appendf(&out, " calls:\n"); + StrAppendFormat(&out, "Computation %s:\n", node.computation()->name()); + StrAppendFormat(&out, " calls:\n"); for (const HloComputation* callee : node.callees()) { - Appendf(&out, " %s\n", callee->name().c_str()); + StrAppendFormat(&out, " %s\n", callee->name()); } - Appendf(&out, " called by:\n"); + StrAppendFormat(&out, " called by:\n"); for (const HloComputation* caller : node.callers()) { - Appendf(&out, " %s\n", caller->name().c_str()); + StrAppendFormat(&out, " %s\n", caller->name()); } - Appendf(&out, " callsites:\n"); + StrAppendFormat(&out, " callsites:\n"); for (const CallSite& callsite : node.callsites()) { - Appendf(&out, " %s\n", callsite.ToString().c_str()); + StrAppendFormat(&out, " %s\n", callsite.ToString()); } } return out; diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 256d05a73e0bf61d959d21795c106286b52d0b19..1d4214044409ae06239506e610000c839450a030 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -96,7 +96,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { if (it == subcomputation_hlo_to_new_hlo_.end()) { return NotFound( "Could not find mapping from subcomputation HLO %s to a cloned HLO.", - subcomputation_hlo->ToString().c_str()); + subcomputation_hlo->ToString()); } return it->second; } diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index c0e95e1578bcf587647aa75bd68e9f9ca0c4b816..c5cd88b9ea2a9c308786d4d7476316b1e592d40a 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -35,7 +35,7 @@ class CallInliner : public HloPassInterface { static StatusOr Inline(HloInstruction* call); ~CallInliner() override = default; - tensorflow::StringPiece name() const override { return "CallInliner"; } + absl::string_view name() const override { return "CallInliner"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index e75f6f146d7c5896cfe6566fdec212a60e9f8457..5d85a3f173d50a964420e720f5c9b416731d948c 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace op = xla::testing::opcode_matchers; diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index 9c9e373821d7f84f3468ef6c6a4f7dae9715b9f8..3c2d1ae6d82ebc6c10d52194fd1cec5e291025f7 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -16,13 +16,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/channel_tracker.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" @@ -73,20 +73,20 @@ ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) { Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { if (opaque_to_channel_.count(handle.handle()) == 0) { - return NotFound("channel handle not found: %lld", handle.handle()); + return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; if (channel.type == ChannelHandle::HOST_TO_DEVICE) { return FailedPrecondition( "host-to-device channels cannot be used with a Send operation; " - "channel handle: %lld", + "channel handle: %d", handle.handle()); } if (channel.has_sender) { return FailedPrecondition( "when registering send, passed a channel handle that is already used " - "by a sender: %lld", + "by a sender: %d", handle.handle()); } channel.has_sender = true; @@ -95,13 +95,13 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { if (opaque_to_channel_.count(handle.handle()) == 0) { - return NotFound("channel handle not found: %lld", handle.handle()); + return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; if (channel.type == ChannelHandle::DEVICE_TO_HOST) { return FailedPrecondition( "device-to-host channels cannot be used with a Recv operation; " - "channel handle: %lld", + "channel handle: %d", handle.handle()); } @@ -109,7 +109,7 @@ Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { if (channel.receiver_count >= 1) { return FailedPrecondition( "when registering recv, passed a channel handle that is already used " - "by a receiver: %lld", + "by a receiver: %d", handle.handle()); } channel.receiver_count += 1; diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 7426672a7a2a9102bd5ea98bd51092982e1e09b4..3079695e9674f4000fdf4c54ac1e78c98968aa27 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -76,9 +76,9 @@ CompileOnlyService::CompileAheadOfTime( if (!directory_path.empty()) { HloSnapshot hlo_snapshot; *hlo_snapshot.mutable_hlo()->mutable_hlo_module() = instance.computation; - string filename = tensorflow::strings::StrCat( - "computation_", instance.computation.id(), "__", - instance.computation.entry_computation_name()); + string filename = + absl::StrCat("computation_", instance.computation.id(), "__", + instance.computation.entry_computation_name()); const string& per_host_path = tensorflow::io::JoinPath( directory_path, tensorflow::port::Hostname()); diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 6b3b9820f09803c8a04504e6c35c22de51abf04b..687ecafe0c308ecc22857fae650c6998677f605d 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -101,7 +101,7 @@ Compiler::GetPlatformCompilers() { return NotFound( "could not find registered compiler for platform %s -- check " "target linkage", - platform->Name().c_str()); + platform->Name()); } // And then we invoke the factory, placing the result into the mapping. diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index cb61f3da39fb8eef69fd81066d87a1da91a62935..af8f7f1027a40703137d6880a9865449c560a47b 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -17,9 +17,9 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -52,9 +52,8 @@ string ComputationLayout::ToString() const { for (auto& param_layout : parameter_layouts_) { params.push_back(param_layout.ToString()); } - return tensorflow::strings::StrCat("(", - tensorflow::str_util::Join(params, ", "), - ") => ", result_layout_.ToString()); + return absl::StrCat("(", absl::StrJoin(params, ", "), ") => ", + result_layout_.ToString()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index afbbea35b893b8c14dbc0454e0a01fcb451cb709..2210a8578ad73efb27dc9c230b142c55228d2af5 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" @@ -29,12 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; namespace xla { @@ -132,7 +132,7 @@ StatusOr ComputationPlacer::AssignDevices( return NotFound( "could not find registered computation placer for platform %s -- check " "target linkage", - platform->Name().c_str()); + platform->Name()); } if (it->second.placer == nullptr) { diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index b7be3ba605a89a736b032eaab5a5085ac64fc549..4ea3a13f2835c5fef99c274f14d7d683c9ff5fc8 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -28,8 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h index 063261e26d06e21a297e8e3c405898a17221b7ca..3de50cbd7ff752e8722a103b68f75144c6c889cd 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.h +++ b/tensorflow/compiler/xla/service/conditional_simplifier.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -27,9 +27,7 @@ namespace xla { // with their true or false computation as appropriate. class ConditionalSimplifier : public HloPassInterface { public: - tensorflow::StringPiece name() const override { - return "simplify-conditional"; - } + absl::string_view name() const override { return "simplify-conditional"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index c43a31b167d47af3c92ed35fa52594fa5da1e4af..6c477da03820681e381dd64978d30edf27e2c422 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -39,6 +39,10 @@ namespace op = xla::testing::opcode_matchers; class ConditionalSimplifierTest : public HloVerifiedTestBase { public: + ConditionalSimplifierTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + // Makes a computation that contains a conditional with constant predicate. HloComputation* MakeConditional(HloModule* module); }; diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h index f213cc870918d476e839f97ae067504038f8cacc..498894737fa37a6d8cca6ead2a86c72eb84ababd 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -29,7 +29,7 @@ class ConvolutionFeatureGroupConverter : public HloPassInterface { public: ConvolutionFeatureGroupConverter() {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "convolution-feature-group-converter"; } diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 3e39c1bab1e07d192a8c145be5103085fd3c189b..1b7a7b36eac31f972e1166e17859cc0c64265538 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -31,18 +33,13 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { - -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace { +using absl::StrAppend; + bool IsEntryParameterValue(const HloValue& value) { const HloComputation* computation = value.defining_instruction()->parent(); return value.defining_instruction()->opcode() == HloOpcode::kParameter && @@ -381,7 +378,7 @@ class CopyRemover { } string ToString() const { - string out = StrCat("CopyRemover, module ", module_->name(), "\n"); + string out = absl::StrCat("CopyRemover, module ", module_->name(), "\n"); StrAppend(&out, " Buffer values, in dependency order:\n"); for (const HloBuffer& buffer : alias_analysis_.buffers()) { StrAppend(&out, " HloBuffer ", buffer.id(), ":\n"); @@ -863,16 +860,16 @@ class CopyRemover { for (const ValueNode* p = head; p != nullptr; p = Next(*p)) { values.push_back(p->value); } - return StrCat("{", - Join(values, ", ", - [](string* s, const HloValue* value) { - StrAppend(s, value->ToShortString()); - }), - "}"); + return absl::StrCat("{", + absl::StrJoin(values, ", ", + [](string* s, const HloValue* value) { + StrAppend(s, value->ToShortString()); + }), + "}"); } string ToString() const { - string out = StrCat("BufferValueTracker:\n"); + string out = absl::StrCat("BufferValueTracker:\n"); StrAppend(&out, " Def-use chains in each buffer:\n"); for (const ValueNode* head : value_lists_) { StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), @@ -880,10 +877,10 @@ class CopyRemover { const ValueNode* p = head; do { StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", - Join(p->uses, "; ", - [](string* s, const HloUse* use) { - StrAppend(s, use->ToString()); - }), + absl::StrJoin(p->uses, "; ", + [](string* s, const HloUse* use) { + StrAppend(s, use->ToString()); + }), "\n"); p = p->next; @@ -960,16 +957,11 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { return Status::OK(); } -// Add copies to address special constraints on the roots of computations not -// related to live range interference: -// -// (1) Entry computation root must be unambiguous and distinct. -// -// (2) Any computation called by a kCall instruction must have an -// unambiguous root. -// -// (3) Constants and parameters cannot be live out of the entry computation -// +Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) { + std::unique_ptr call_graph = CallGraph::Build(module); + return AddSpecialCaseCopies(*call_graph, module); +} + Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, @@ -1065,15 +1057,6 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, for (HloInstruction* user : users) { TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); } - // Special case copies are not eligible for later copy elision passes. - indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) { - if (has_copy) { - HloInstruction* copy = *copies_added.mutable_element(index); - if (copy != nullptr) { - copy->SetCopyElisionAllowed(false); - } - } - }); if (instruction == instruction->parent()->root_instruction()) { instruction->parent()->set_root_instruction(deep_copy); } @@ -1081,10 +1064,10 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, return Status::OK(); } -Status CopyInsertion::VerifyNoLiveRangeInterference(HloModule* module) { +Status CopyInsertion::VerifyNoLiveRangeInterference(const HloOrdering& ordering, + HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); - DependencyHloOrdering ordering(module); TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering)); return Status::OK(); } @@ -1101,8 +1084,7 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, std::unique_ptr call_graph = CallGraph::Build(module); for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy && - instruction->CopyElisionAllowed()) { + if (instruction->opcode() == HloOpcode::kCopy) { TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); } } @@ -1168,10 +1150,10 @@ StatusOr CopyInsertion::Run(HloModule* module) { TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); - TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + DependencyHloOrdering dep_ordering(module); + TF_DCHECK_OK(VerifyNoLiveRangeInterference(dep_ordering, module)); - DependencyHloOrdering ordering(module); - TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module)); + TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(dep_ordering, module)); TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); @@ -1179,7 +1161,8 @@ StatusOr CopyInsertion::Run(HloModule* module) { TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); - TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + TF_DCHECK_OK( + VerifyNoLiveRangeInterference(DependencyHloOrdering(module), module)); MaybeDumpModule("after copy insertion", *module); diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 5ba64b78a3c9aff5f323691df2ece9b5e6bf3232..d308f6bc84670b78b9cab476f2893bce267df2cf 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -45,7 +45,7 @@ namespace xla { // InstructionAliasSet::IsDistinct return true. class CopyInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } + absl::string_view name() const override { return "copy-insertion"; } // fusion_can_share_buffer: backend specific function that decides whether a // fusion can share buffer with its operand. @@ -77,15 +77,29 @@ class CopyInsertion : public HloPassInterface { Status RemoveUnnecessaryCopies(const HloOrdering& ordering, HloModule* module); - private: - // Verifies that no HLO values have interfering live ranged assuming the - // ordering used by copy insertion. - Status VerifyNoLiveRangeInterference(HloModule* module); + // Add copies to address special constraints on the roots of computations not + // related to live range interference: + // + // (1) Entry computation root must be unambiguous and distinct. + // + // (2) Any computation called by a kCall instruction must have an + // unambiguous root. + // + // (3) Constants and parameters cannot be live out of the entry computation + // + Status AddSpecialCaseCopies(HloModule* module); - Status AddCopiesToResolveInterference(HloModule* module); + // Verifies that no HLO values have interfering live ranges using the given + // ordering. + Status VerifyNoLiveRangeInterference(const HloOrdering& ordering, + HloModule* module); + private: + // Override which requires the caller to pass in a call graph. Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module); + Status AddCopiesToResolveInterference(HloModule* module); + // Backend specific function that decides whether a fusion can share buffer // with its operand. HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer_; diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 850948b54b8c8ef7ac4e5da4c64e7ce018e31624..4cd192873f0c5fed884871ec3313f715f70210cc 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -87,6 +87,8 @@ cc_library( ":parallel_task_assignment", ":simple_orc_jit", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ":target_machine_features", "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", @@ -232,6 +234,8 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@llvm//:orc_jit", ], ) @@ -274,11 +278,14 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@llvm//:code_gen", "@llvm//:core", "@llvm//:support", @@ -323,6 +330,7 @@ cc_library( "//tensorflow/compiler/xla/service/cpu:cpu_runtime", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -333,12 +341,12 @@ cc_library( hdrs = ["parallel_loop_emitter.h"], deps = [ ":ir_emission_utils", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", "@llvm//:core", ], ) @@ -365,6 +373,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -385,6 +394,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", ], ) @@ -398,6 +408,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", "@llvm//:mc", "@llvm//:mc_disassembler", "@llvm//:object", @@ -639,6 +650,7 @@ tf_cc_test( "//tensorflow/core:test", "//third_party/eigen3", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -653,6 +665,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -816,6 +829,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -852,6 +866,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h index e6fd1499edd0095395194200a5b444ad61e7e39d..59437e88af27528654a0af86baf69ec7a1e91d60 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h @@ -38,7 +38,7 @@ class ConvCanonicalization : public HloPassInterface { : target_machine_features_(*target_machine_features) {} ~ConvCanonicalization() override {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "convolution-canonicalization"; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 5116f926f50bf0344951ebb67def7eddd0919f2b..6420180b1307ae7a41a0ac8539a525f7e4ea11e3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -27,6 +27,7 @@ limitations under the License. // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc" // IWYU pragma: no_include "llvm/Config/Targets.def.inc" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/Function.h" @@ -101,8 +102,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace cpu { @@ -235,15 +234,15 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { std::unordered_map* hlo_to_profile_idx_; const std::unordered_map& assigned_indices_; }; -} // namespace -Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, - llvm::TargetMachine* target_machine) { - LLVMTargetMachineFeatures target_machine_features(target_machine); +} // namespace - // Optimization pipeline. - HloPassPipeline pipeline("CPU"); - pipeline.AddInvariantChecker(); +Status CpuCompiler::RunHloPassesThroughLayoutAssn( + HloModule* module, bool /*is_aot_compile*/, + LLVMTargetMachineFeatures* target_machine_features) { + HloPassPipeline pipeline("HLO passes through layout assignment"); + pipeline.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( @@ -260,11 +259,12 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(&target_machine_features); + pipeline.AddPass(target_machine_features); { auto& pass = pipeline.AddPass>("simplification"); - pass.AddInvariantChecker(); + pass.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pass.AddPass( /*rewrite_training_op=*/true, @@ -291,10 +291,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, } pipeline.AddPass(); pipeline.AddPass( - [&target_machine_features]( - const HloInstruction& dot, + [&](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { - return PotentiallyImplementedAsEigenDot(dot, target_machine_features) + return PotentiallyImplementedAsEigenDot(dot, *target_machine_features) ? candidate_operands : TransposeFolding::OperandIndices{}; }, @@ -309,12 +308,28 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->mutable_entry_computation_layout(), &target_machine_features); + module->mutable_entry_computation_layout(), target_machine_features); + return pipeline.Run(module).status(); +} + +Status CpuCompiler::RunHloPassesAfterLayoutAssn( + HloModule* module, bool is_aot_compile, + LLVMTargetMachineFeatures* target_machine_features) { + HloPassPipeline pipeline("HLO passes after layout assignment"); + // After layout assignment, use a layout-sensitive verifier. + auto& after_layout_assn = + pipeline.AddPass("after layout assignment"); + after_layout_assn.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); + // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. { auto& pass = pipeline.AddPass>( - "after layout assignement"); + "simplification after layout assignement"); + pass.AddInvariantChecker(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); pass.AddPass>( /*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return true; }, @@ -322,7 +337,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pass.AddPass(); pass.AddPass(/*is_layout_sensitive=*/true); } + pipeline.AddPass(BF16, F32); + // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = module->config().intra_op_parallelism_threads() > 0 @@ -335,14 +352,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, // binary size (and most AOT applications are single-threaded). // TODO(b/29630486) Support multi-threaded AOT. pipeline.AddPass( - max_parallelism, ShapeSizeBytesFunction(), &target_machine_features); + max_parallelism, ShapeSizeBytesFunction(), target_machine_features); } - // Copy insertion should be performed immediately before IR emission to avoid - // inserting unnecessary copies (later pass adds an instruction which - // materializes the value) or missing a necessary copy (later pass removes an - // instruction which materializes a value). DCE must be run immediately before - // (and sometime after) copy insertion, to avoid dead code from interfering - // with the rewrites. + // Copy insertion should be performed immediately before IR emission to + // avoid inserting unnecessary copies (later pass adds an instruction which + // materializes the value) or missing a necessary copy (later pass removes + // an instruction which materializes a value). DCE must be run immediately + // before (and sometime after) copy insertion, to avoid dead code from + // interfering with the rewrites. pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); @@ -350,6 +367,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, return pipeline.Run(module).status(); } +Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine) { + LLVMTargetMachineFeatures target_machine_features(target_machine); + TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile, + &target_machine_features)); + return RunHloPassesAfterLayoutAssn(module, is_aot_compile, + &target_machine_features); +} + namespace { // Align buffers to 16-byte boundaries. @@ -679,8 +705,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, const llvm::Target* target = llvm::TargetRegistry::lookupTarget(triple.getTriple(), error); if (target == nullptr) { - return InternalError("TargetRegistry::lookupTarget failed: %s", - error.c_str()); + return InternalError("TargetRegistry::lookupTarget failed: %s", error); } llvm::Reloc::Model reloc_model = llvm::Reloc::Static; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 04e1c48872ed55ca7f2aa3bec08c44a1666b90f1..47b5edabff79d1df23cbeae0823536bbdcd07aaa 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" @@ -157,6 +158,16 @@ class CpuCompiler : public LLVMCompiler { Status RunHloPasses(HloModule* module, bool is_aot_compile, llvm::TargetMachine* target_machine); + // Runs HLO passes up to and including layout assignment. + Status RunHloPassesThroughLayoutAssn( + HloModule* module, bool /*is_aot_compile*/, + LLVMTargetMachineFeatures* target_machine_features); + + // Runs HLO passes after layout assignment. + Status RunHloPassesAfterLayoutAssn( + HloModule* module, bool is_aot_compile, + LLVMTargetMachineFeatures* target_machine_features); + TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h index 6398d8c98d0b4fec98519a53452effcface7e4a4..d49f7d7cc2d9b1d00847feda62fa62dd740820d8 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h @@ -32,7 +32,7 @@ namespace xla { // (module-scoped). class CpuCopyInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } + absl::string_view name() const override { return "copy-insertion"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index c376864c3e1f882e11bc05f8cf93f2fb1c88e4ec..08773693fba766bec78839d1557a587a832da95f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -22,6 +22,9 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -35,9 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" @@ -171,20 +171,18 @@ Status CpuExecutable::ExecuteComputeFunction( void* result_buffer = buffer_pointers[result_slice.index()]; if (VLOG_IS_ON(3)) { VLOG(3) << "Executing compute function:"; - VLOG(3) << tensorflow::strings::Printf( - " func(void* result, void* params[null], void* temps[%zu], " - "uint64 profile_counters[%zu])", + VLOG(3) << absl::StrFormat( + " func(void* result, void* params[null], void* temps[%u], " + "uint64 profile_counters[%u])", buffer_pointers.size(), profile_counters_size); - VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); + VLOG(3) << absl::StrFormat(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { - tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); + absl::StrAppend(out, absl::StrFormat("%p", p)); }; VLOG(3) << " params = nullptr"; - VLOG(3) << tensorflow::strings::Printf( - " temps = [%s]", - tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); - VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p", - profile_counters); + VLOG(3) << absl::StrFormat( + " temps = [%s]", absl::StrJoin(buffer_pointers, ", ", ptr_printer)); + VLOG(3) << absl::StrFormat(" profile_counters = %p", profile_counters); } compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc index 7bd4741a04b1135d9780e0cf765b7b33378526e1..7fbe0fa157c57eb0c274662a1de95cf5328ccfa8 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc @@ -34,9 +34,8 @@ StatusOr CpuHloSupportChecker::Run(HloModule* module) { return xla::Unimplemented( "CPU backend does not support HLO instruction %s with shape " "containing a sparse layout: %s", - instruction->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(instruction->shape()) - .c_str()); + instruction->ToString(), + ShapeUtil::HumanStringWithLayout(instruction->shape())); } return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h index 2924b6365943f0a3ec998d7a77767a76cbb576ae..6af724b2a5d71b9c30f3485ffb7e51d1d201cb6b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h @@ -28,9 +28,7 @@ class CpuHloSupportChecker : public HloPassInterface { CpuHloSupportChecker() = default; ~CpuHloSupportChecker() override = default; - tensorflow::StringPiece name() const override { - return "cpu_hlo_support_checker"; - } + absl::string_view name() const override { return "cpu_hlo_support_checker"; } // Note: always returns false (no instructions are ever modified by this // pass). diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index b40d264c03aba6e9308e8a621ae86e180e33c335..7f867fa1495b5bfa492a12e312980cbad2670b9b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -78,7 +78,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } if (!CanBeLoopFused(*producer)) { - VLOG(2) << "Producer is not fusile."; + VLOG(2) << "Producer is not fusible."; return false; } @@ -140,7 +140,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } if (CanBeLoopFused(*consumer)) { - VLOG(2) << "Fusing: consumer is elementwise or fusile."; + VLOG(2) << "Fusing: consumer is elementwise or fusible."; return true; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index e6130c7d76e0383d03fe56d19aee239c5992309d..28aaa28cdb54b6ded6e9a1229169a085d85be786 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" @@ -566,7 +567,7 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { HloOpcode::kParameter, HloOpcode::kParameter}); } -TEST_F(OpcodeFusionTest, MessOfFusileNodes) { +TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); @@ -773,8 +774,8 @@ class GatherLoopFusionTest TEST_P(GatherLoopFusionTest, GatherLoopFusion) { const GatherLoopFusionTestSpec& spec = GetParam(); - string hlo_string = tensorflow::strings::StrCat( - "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text); + string hlo_string = absl::StrCat("HloModule ", spec.test_name, "\n\n", + spec.hlo_computation_text); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseHloString(hlo_string)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index 69acca86bffdaa9427c2fff03a36ea057be6bafe..bfecbd6e017893e4f6d3dcbc01d46c899e6060fa 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -34,8 +34,8 @@ namespace cpu { // instruction stream. namespace { -using ::absl::nullopt; -using ::absl::optional; +using absl::nullopt; +using absl::optional; using ShouldMakeOperandColMajorCache = tensorflow::gtl::FlatMap; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index b6039b465ed6deb90be94e74a364db62d4f447c7..b8ace5702688096822573c7afae234cbcbe77b28 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace { @@ -51,7 +52,7 @@ absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config) { auto it = extra_options_map.find(kLlvmIrDotTilingFactor); int64 tiling_factor; if (it != extra_options_map.end() && - tensorflow::strings::safe_strto64(it->second, &tiling_factor)) { + absl::SimpleAtoi(it->second, &tiling_factor)) { return tiling_factor; } return absl::nullopt; @@ -63,8 +64,8 @@ bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; } -static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str, - tensorflow::StringPiece suffix) { +static absl::string_view RemoveSuffix(absl::string_view str, + absl::string_view suffix) { CHECK_GE(str.size(), suffix.size()); CHECK_EQ(str.substr(str.size() - suffix.size()), suffix); return str.substr(0, str.size() - suffix.size()); @@ -79,22 +80,21 @@ absl::optional> LlvmIrGemmTileSize( return absl::nullopt; } - std::vector tile_components = - tensorflow::str_util::Split(it->second, ':'); + std::vector tile_components = absl::StrSplit(it->second, ':'); CHECK_EQ(tile_components.size(), 3); int64 tile_size_m; int64 tile_size_k; int64 tile_size_n_in_vector_width; - CHECK(tensorflow::strings::safe_strto64(tile_components[0], &tile_size_m)); - CHECK(tensorflow::strings::safe_strto64(tile_components[1], &tile_size_k)); + CHECK(absl::SimpleAtoi(tile_components[0], &tile_size_m)); + CHECK(absl::SimpleAtoi(tile_components[1], &tile_size_k)); - tensorflow::StringPiece tile_size_n_in_vector_width_str = + absl::string_view tile_size_n_in_vector_width_str = RemoveSuffix(tile_components[2], "*vectwidth"); - CHECK(tensorflow::strings::safe_strto64(tile_size_n_in_vector_width_str, - &tile_size_n_in_vector_width)); + CHECK(absl::SimpleAtoi(tile_size_n_in_vector_width_str, + &tile_size_n_in_vector_width)); return std::tuple(tile_size_m, tile_size_k, tile_size_n_in_vector_width); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index bc4cfc099965e2ab12212f55e62bdf79c0cfb739..1ae3aa57111e3a3b7ac18b4907c5c282edf89b7e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -142,10 +142,10 @@ class EigenMatMulTest : public CpuRuntimeTest, bool transpose_rhs = std::get<2>(info.param); bool single_threaded = std::get<3>(info.param); - return tensorflow::strings::Printf( - "EigenMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, - transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", - single_threaded ? "single" : "multi"); + return absl::StrFormat("EigenMatMul_%d_%d_%d_%s%s%s_threaded", shape.m, + shape.k, shape.n, transpose_lhs ? "Tlhs_" : "", + transpose_rhs ? "Trhs_" : "", + single_threaded ? "single" : "multi"); } }; @@ -178,10 +178,10 @@ class MKLMatMulTest : public CpuRuntimeTest, bool transpose_rhs = std::get<2>(info.param); bool single_threaded = std::get<3>(info.param); - return tensorflow::strings::Printf( - "MKLMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, - transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", - single_threaded ? "single" : "multi"); + return absl::StrFormat("MKLMatMul_%d_%d_%d_%s%s%s_threaded", shape.m, + shape.k, shape.n, transpose_lhs ? "Tlhs_" : "", + transpose_rhs ? "Trhs_" : "", + single_threaded ? "single" : "multi"); } }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index b07cd675ffc4dbd0c7d56da715b29014bb12ce88..0df2abf0012db169d01e6d9bb19430db1ac80c14 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -104,7 +104,7 @@ Status CpuTransferManager::TransferLiteralToInfeed( if (ShapeUtil::IsNestedTuple(shape)) { return Unimplemented( "Infeed with a nested tuple shape is not supported: %s", - ShapeUtil::HumanString(literal.shape()).c_str()); + ShapeUtil::HumanString(literal.shape())); } // For a tuple, we transfer each of its elements to the device and @@ -152,11 +152,11 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { - return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); + return InvalidArgument("Infeed shape is too large: needs %d bytes", size); } if (size <= 0) { - return InvalidArgument("Infeed shape must have positive size; got %lld", + return InvalidArgument("Infeed shape must have positive size; got %d", size); } @@ -244,12 +244,12 @@ StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( for (auto b : buffer_data) { int64 size = b.second; if (size > std::numeric_limits::max()) { - return InvalidArgument("Outfeed shape is too large: needs %lld bytes", + return InvalidArgument("Outfeed shape is too large: needs %d bytes", size); } if (size <= 0) { - return InvalidArgument("Outfeed shape must have positive size; got %lld", + return InvalidArgument("Outfeed shape must have positive size; got %d", size); } diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.cc b/tensorflow/compiler/xla/service/cpu/disassembler.cc index e4c674e227ffc6725ca929f720b9aa7cf7c4c032..3ae64142cd7e32d3aa8d50870efaf94698c06440 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.cc +++ b/tensorflow/compiler/xla/service/cpu/disassembler.cc @@ -21,13 +21,13 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "llvm/MC/MCInst.h" #include "llvm/Support/TargetRegistry.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -151,7 +151,7 @@ StatusOr Disassembler::DisassembleObjectFile( size = 1; } - ostream << tensorflow::strings::Printf("0x%08lx", index) << " "; + ostream << absl::StrFormat("0x%08lx", index) << " "; if (decode_status == llvm::MCDisassembler::Success) { // For branches, try to determine the actual address and emit it as an @@ -163,7 +163,7 @@ StatusOr Disassembler::DisassembleObjectFile( uint64_t target; if (inst_analysis_->evaluateBranch( instruction, section_address + index, size, target)) { - annotation = tensorflow::strings::Printf("[0x%08lx]", target); + annotation = absl::StrFormat("[0x%08lx]", target); } } inst_printer_->printInst(&instruction, ostream, annotation.c_str(), diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 797392c26575d57b02e97e26f4cdb0d715c251b5..dd060f54a29d9872bc086ff6718c46b25142a83e 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -146,9 +147,9 @@ class GemvConfig { bool has_addend() const { return has_addend_; } string GetCacheKey() const { - return tensorflow::strings::StrCat( - name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_", - tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : ""); + return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_", + tile_rows(), "_", tile_cols(), "_", m(), "_", k(), + has_addend() ? "_with_addend" : ""); } protected: @@ -642,9 +643,7 @@ class TiledSmallGemmEmitter { int64 k() const { return k_; } int64 n() const { return n_; } - string ToString() const { - return tensorflow::strings::StrCat(m(), "x", k(), "x", n()); - } + string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); } private: const int64 m_; @@ -687,10 +686,10 @@ class TiledSmallGemmEmitter { tile_size_k_(tile_size_k) {} string GetCacheKey() const { - return tensorflow::strings::StrCat( - "gemm_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(), - "_", max_vectorization_width(), "_", min_vectorization_width(), "_", - tile_size_m(), "_", tile_size_k()); + return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_", + dims().ToString(), "_", max_vectorization_width(), + "_", min_vectorization_width(), "_", tile_size_m(), + "_", tile_size_k()); } PrimitiveType scalar_type() const { return scalar_type_; } @@ -1468,7 +1467,7 @@ Status DotOpEmitter::EmitCallToRuntime() { break; default: return Unimplemented("Invalid type %s for dot operation", - PrimitiveType_Name(type).c_str()); + PrimitiveType_Name(type)); } llvm::Type* float_ptr_type = float_type->getPointerTo(); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 05322faa75c78f350b540e14c218eac47c60e62c..4c2041b556aa8bf8fe8fb8e0674c0f4f04f0acae 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ +#include "absl/strings/string_view.h" #include "llvm/IR/IRBuilder.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index db54454707983ade31594119b2e868fa168d4cc2..c8312d80bd5012e5bcb42a410db18a7fa77a2eb6 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -30,15 +30,16 @@ limitations under the License. namespace xla { namespace cpu { -StatusOr CpuElementalIrEmitter::EmitAtan2( - PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { +StatusOr CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) { string function_name; bool cast_result_to_fp16 = false; switch (prim_type) { case F16: cast_result_to_fp16 = true; - lhs = b_->CreateFPCast(lhs, b_->getFloatTy()); - rhs = b_->CreateFPCast(rhs, b_->getFloatTy()); + lhs = FPCast(lhs, b_->getFloatTy()); + rhs = FPCast(rhs, b_->getFloatTy()); TF_FALLTHROUGH_INTENDED; case F32: function_name = "atan2f"; @@ -58,21 +59,21 @@ StatusOr CpuElementalIrEmitter::EmitAtan2( function->setDoesNotThrow(); function->setDoesNotAccessMemory(); // Create an instruction to call the function. - llvm::Value* result = b_->CreateCall(function, {lhs, rhs}); + llvm::Value* result = Call(function, {lhs, rhs}); if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); + result = FPCast(result, b_->getHalfTy()); } return result; } -StatusOr CpuElementalIrEmitter::EmitTanh( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, + llvm::Value* value) { bool cast_result_to_fp16 = false; string function_name; switch (prim_type) { case F16: cast_result_to_fp16 = true; - value = b_->CreateFPCast(value, b_->getFloatTy()); + value = FPCast(value, b_->getFloatTy()); TF_FALLTHROUGH_INTENDED; case F32: function_name = "tanhf"; @@ -91,16 +92,16 @@ StatusOr CpuElementalIrEmitter::EmitTanh( function->setDoesNotThrow(); function->setDoesNotAccessMemory(); // Create an instruction to call the function. - llvm::Value* result = b_->CreateCall(function, value); + llvm::Value* result = Call(function, value); if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); + result = FPCast(result, b_->getHalfTy()); } return result; } llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const { + const HloToElementGeneratorMap& operand_to_generator) { if (hlo->opcode() == HloOpcode::kMap) { return [this, hlo, &operand_to_generator]( const llvm_ir::IrArray::Index& index) -> StatusOr { diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 76833e765d05f2477961cd06cead66797c5be623..e3fba9306b72904803259047fafea245a8e183db 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -36,13 +36,13 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const override; + const HloToElementGeneratorMap& operand_to_generator) override; protected: StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const override; + llvm::Value* rhs) override; StatusOr EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 6f433b4f30372da9cf4503396dbb60172cfc0cb0..460363e18fd6505fb09167542ae65c274d467a27 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/BasicBlock.h" @@ -67,8 +69,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -170,9 +170,9 @@ IrEmitter::~IrEmitter() {} Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { VLOG(2) << "HandleBitcast: " << bitcast->ToString(); emitted_value_[bitcast] = - b_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)), - IrShapeType(bitcast->shape())->getPointerTo(), - AsStringRef(IrName(bitcast))); + BitCast(GetEmittedValueFor(bitcast->operand(0)), + IrShapeType(bitcast->shape())->getPointerTo(), + AsStringRef(IrName(bitcast))); return Status::OK(); } @@ -230,9 +230,8 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { // Use the elemental emitter for array shapes. return DefaultAction(copy); } - return Unimplemented( - "unsupported operand type %s for copy instruction", - PrimitiveType_Name(copy->shape().element_type()).c_str()); + return Unimplemented("unsupported operand type %s for copy instruction", + PrimitiveType_Name(copy->shape().element_type())); } // Calculate the alignment of a buffer allocated for a given primitive type. @@ -389,7 +388,7 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, int64 length = ByteSizeOf(shape); if (length <= 0 || length > std::numeric_limits::max()) { return InvalidArgument( - "xfeed (infeed or outfeed) buffer length %lld is outside the valid " + "xfeed (infeed or outfeed) buffer length %d is outside the valid " "size range", length); } @@ -440,22 +439,22 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, // of size exactly 'length_32', and the runtime is responsible for // check-failing the process if there is a mismatch, versus passing us back a // buffer that we might overrun. - llvm::Value* acquired_pointer = b_.CreateCall( - acquire_func, - {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)}); + llvm::Value* acquired_pointer = + Call(acquire_func, + {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)}); if (kind == XfeedKind::kInfeed) { // Copy to the program buffer address from the acquired buffer. - b_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer, - /*SrcAlign=*/1, length_32); + MemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer, + /*SrcAlign=*/1, length_32); } else { // Outfeed -- copy from the in-program address to the acquired buffer. - b_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address, - /*SrcAlign=*/1, length_32); + MemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address, + /*SrcAlign=*/1, length_32); } - b_.CreateCall(release_func, {b_.getInt32(length_32), acquired_pointer, - shape_ptr, b_.getInt32(shape_length)}); + Call(release_func, {b_.getInt32(length_32), acquired_pointer, shape_ptr, + b_.getInt32(shape_length)}); return Status::OK(); } @@ -502,7 +501,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { llvm::Value* IrEmitter::EmitElementalMap( const HloMapInstruction& map_instr, tensorflow::gtl::ArraySlice elemental_operands, - tensorflow::StringPiece name) { + absl::string_view name) { return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); } @@ -519,8 +518,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), "reduce_window_accumulator_address", &b_, MinimumAlignmentForPrimitiveType(operand_element_type)); - b_.CreateStore(b_.CreateLoad(GetEmittedValueFor(reduce_window->operand(1))), - accumulator_address); + Store(Load(GetEmittedValueFor(reduce_window->operand(1))), + accumulator_address); llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_); std::vector window_size; @@ -537,22 +536,21 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( llvm::Value* in_bounds_condition = nullptr; for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = - b_.CreateNSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); - input_index[i] = - b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); + NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); + input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), + b_.getInt64(window.dimensions(i).padding_low())); // We need to check if 0 <= input_index[i] < bound, as otherwise we are in // the padding so that we can skip the computation. That is equivalent to // input_index[i] < bound as an *unsigned* comparison, since a negative // value will wrap to a large positive value. - llvm::Value* index_condition = b_.CreateICmpULT( - input_index[i], - b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + llvm::Value* index_condition = + ICmpULT(input_index[i], + b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); if (in_bounds_condition == nullptr) { in_bounds_condition = index_condition; } else { - in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); + in_bounds_condition = And(in_bounds_condition, index_condition); } } CHECK(in_bounds_condition != nullptr); @@ -565,12 +563,12 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduceWindow( llvm_ir::IrArray input_array(GetIrArrayFor(operand)); llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_); llvm::Value* result = EmitThreadLocalCall( - *reduce_window->to_apply(), - {b_.CreateLoad(accumulator_address), input_value}, "reducer_function"); - b_.CreateStore(result, accumulator_address); + *reduce_window->to_apply(), {Load(accumulator_address), input_value}, + "reducer_function"); + Store(result, accumulator_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(accumulator_address); + return Load(accumulator_address); } Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { @@ -647,7 +645,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"), [this, init_value](const llvm_ir::IrArray::Index& target_index) { llvm::Value* init_value_addr = GetEmittedValueFor(init_value); - return b_.CreateLoad(init_value_addr); + return Load(init_value_addr); })); // Create a loop to iterate over the source array to scatter to the output. @@ -667,7 +665,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( b_.getInt1Ty(), "initialized_flag_address", &b_); - b_.CreateStore(b_.getInt1(false), initialized_flag_address); + Store(b_.getInt1(false), initialized_flag_address); // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_); @@ -685,15 +683,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm_ir::IrArray::Index operand_index(b_.getInt64Ty(), source_index.size()); llvm::Value* in_bounds_condition = b_.getTrue(); for (int64 i = 0; i < rank; ++i) { - llvm::Value* strided_index = b_.CreateNSWMul( - source_index[i], b_.getInt64(window.dimensions(i).stride())); - operand_index[i] = - b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); - llvm::Value* index_condition = b_.CreateICmpULT( - operand_index[i], - b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); - in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); + llvm::Value* strided_index = + NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride())); + operand_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), + b_.getInt64(window.dimensions(i).padding_low())); + llvm::Value* index_condition = + ICmpULT(operand_index[i], + b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + in_bounds_condition = And(in_bounds_condition, index_condition); } CHECK(in_bounds_condition != nullptr); @@ -703,7 +700,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); SetToFirstInsertPoint(if_in_bounds.true_block, &b_); llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( - b_.CreateLoad(initialized_flag_address), "initialized", &b_); + Load(initialized_flag_address), "initialized", &b_); // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. @@ -712,38 +709,37 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { [&](const llvm_ir::IrArray::Index& operand_index) { for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - b_.CreateStore(operand_index[i], selected_index_address_slot); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + Store(operand_index[i], selected_index_address_slot); } }; llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &b_); - b_.CreateStore(operand_data, selected_value_address); + Store(operand_data, selected_value_address); save_operand_index(operand_index); - b_.CreateStore(b_.getInt1(true), initialized_flag_address); + Store(b_.getInt1(true), initialized_flag_address); // If the initialized_flag is true, call the `select` function to potentially // update the selected value and index with the currently visiting operand. SetToFirstInsertPoint(if_initialized.true_block, &b_); llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &b_); - llvm::Value* operand_element = b_.CreateLoad(operand_address); + llvm::Value* operand_element = Load(operand_address); llvm::Value* result = EmitThreadLocalCall( *select_and_scatter->select(), - {b_.CreateLoad(selected_value_address), operand_element}, - "select_function"); + {Load(selected_value_address), operand_element}, "select_function"); // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. - llvm::Value* cond = b_.CreateICmpNE( + llvm::Value* cond = ICmpNE( result, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); llvm_ir::LlvmIfData if_select_lhs = llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); SetToFirstInsertPoint(if_select_lhs.false_block, &b_); - b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address); + Store(Load(operand_address), selected_value_address); save_operand_index(operand_index); // After iterating over the window elements, scatter the source element to @@ -754,8 +750,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm_ir::IrArray::Index selected_index(source_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + selected_index.push_back(Load(selected_index_address_slot)); } llvm_ir::IrArray source_array(GetIrArrayFor(source)); llvm::Value* source_value = @@ -837,7 +833,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( lhs_llvm_type, "convolution_sum_address", &b_, MinimumAlignmentForPrimitiveType(lhs_element_type)); llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type); - b_.CreateStore(constant_zero, sum_address); + Store(constant_zero, sum_address); llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_); std::vector kernel_spatial(num_spatial_dims); @@ -846,7 +842,7 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( loops .AddLoop( 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)), - tensorflow::strings::StrCat("k", i)) + absl::StrCat("k", i)) ->GetIndVarValue(); } llvm::Value* input_feature = @@ -864,11 +860,11 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( llvm::Value* kernel_index, const WindowDimension& window_dim) { llvm::Value* strided_index = - b_.CreateNSWMul(output_index, b_.getInt64(window_dim.stride())); - llvm::Value* dilated_kernel_index = b_.CreateNSWMul( - kernel_index, b_.getInt64(window_dim.window_dilation())); - return b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, dilated_kernel_index), - b_.getInt64(window_dim.padding_low())); + NSWMul(output_index, b_.getInt64(window_dim.stride())); + llvm::Value* dilated_kernel_index = + NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation())); + return NSWSub(NSWAdd(strided_index, dilated_kernel_index), + b_.getInt64(window_dim.padding_low())); }; std::vector input_spatial(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { @@ -885,9 +881,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( // Also need to check that the input coordinates are not in one of the // holes created by base dilation. const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) { - llvm::Value* remainder = - b_.CreateSRem(input_index, b_.getInt64(base_dilation)); - return b_.CreateICmpEQ(remainder, b_.getInt64(0)); + llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation)); + return ICmpEQ(remainder, b_.getInt64(0)); }; llvm::Value* in_bounds_condition = b_.getInt1(true); @@ -895,17 +890,17 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound( lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), window.dimensions(i).base_dilation())); - llvm::Value* dim_in_bound = b_.CreateICmpULT(input_spatial[i], input_bound); + llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound); llvm::Value* dim_not_in_hole = not_in_hole(input_spatial[i], window.dimensions(i).base_dilation()); - llvm::Value* dim_ok = b_.CreateAnd(dim_in_bound, dim_not_in_hole); - in_bounds_condition = b_.CreateAnd(in_bounds_condition, dim_ok); + llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole); + in_bounds_condition = And(in_bounds_condition, dim_ok); } // Now we need to map the dilated base coordinates back to the actual // data indices on the lhs. const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) { - return b_.CreateSDiv(input_index, b_.getInt64(base_dilation)); + return SDiv(input_index, b_.getInt64(base_dilation)); }; for (int i = 0; i < num_spatial_dims; ++i) { input_spatial[i] = @@ -930,8 +925,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( for (int i = 0; i < num_spatial_dims; ++i) { kernel_index[dnums.kernel_spatial_dimensions(i)] = window.dimensions(i).window_reversal() - ? b_.CreateNSWSub(b_.getInt64(window.dimensions(i).size() - 1), - kernel_spatial[i]) + ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1), + kernel_spatial[i]) : kernel_spatial[i]; } @@ -940,13 +935,13 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForConvolution( llvm_ir::IrArray input_array(GetIrArrayFor(lhs)); llvm::Value* product = - b_.CreateFMul(input_array.EmitReadArrayElement(input_index, &b_), - kernel_array.EmitReadArrayElement(kernel_index, &b_)); - llvm::Value* sum = b_.CreateFAdd(b_.CreateLoad(sum_address), product); - b_.CreateStore(sum, sum_address); + FMul(input_array.EmitReadArrayElement(input_index, &b_), + kernel_array.EmitReadArrayElement(kernel_index, &b_)); + llvm::Value* sum = FAdd(Load(sum_address), product); + Store(sum, sum_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(sum_address); + return Load(sum_address); } Status IrEmitter::HandleConvolution(HloInstruction* convolution) { @@ -1072,34 +1067,32 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { conv_func->setCallingConv(llvm::CallingConv::C); conv_func->setDoesNotThrow(); conv_func->setOnlyAccessesArgMemory(); - b_.CreateCall( - conv_func, - { - GetExecutableRunOptionsArgument(), - b_.CreateBitCast(GetEmittedValueFor(convolution), ir_ptr_type), - b_.CreateBitCast(lhs_address, ir_ptr_type), - b_.CreateBitCast(rhs_address, ir_ptr_type), - b_.getInt64(input_batch), - b_.getInt64(input_rows), - b_.getInt64(input_cols), - b_.getInt64(input_channels), - b_.getInt64(kernel_rows), - b_.getInt64(kernel_cols), - b_.getInt64(kernel_channels), - b_.getInt64(kernel_filters), - b_.getInt64(output_rows), - b_.getInt64(output_cols), - b_.getInt64(row_stride), - b_.getInt64(col_stride), - b_.getInt64(padding_top), - b_.getInt64(padding_bottom), - b_.getInt64(padding_left), - b_.getInt64(padding_right), - b_.getInt64(lhs_row_dilation), - b_.getInt64(lhs_col_dilation), - b_.getInt64(rhs_row_dilation), - b_.getInt64(rhs_col_dilation), - }); + Call(conv_func, { + GetExecutableRunOptionsArgument(), + BitCast(GetEmittedValueFor(convolution), ir_ptr_type), + BitCast(lhs_address, ir_ptr_type), + BitCast(rhs_address, ir_ptr_type), + b_.getInt64(input_batch), + b_.getInt64(input_rows), + b_.getInt64(input_cols), + b_.getInt64(input_channels), + b_.getInt64(kernel_rows), + b_.getInt64(kernel_cols), + b_.getInt64(kernel_channels), + b_.getInt64(kernel_filters), + b_.getInt64(output_rows), + b_.getInt64(output_cols), + b_.getInt64(row_stride), + b_.getInt64(col_stride), + b_.getInt64(padding_top), + b_.getInt64(padding_bottom), + b_.getInt64(padding_left), + b_.getInt64(padding_right), + b_.getInt64(lhs_row_dilation), + b_.getInt64(lhs_col_dilation), + b_.getInt64(rhs_row_dilation), + b_.getInt64(rhs_col_dilation), + }); return Status::OK(); } @@ -1159,15 +1152,14 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { fft_func->setDoesNotThrow(); fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); const int fft_rank = fft_length.size(); - b_.CreateCall( - fft_func, - {GetExecutableRunOptionsArgument(), - b_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type), - b_.CreateBitCast(operand_address, int8_ptr_type), - b_.getInt32(fft->fft_type()), b_.getInt32(fft_rank), - b_.getInt64(input_batch), b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), - b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), - b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); + Call(fft_func, + {GetExecutableRunOptionsArgument(), + BitCast(GetEmittedValueFor(fft), int8_ptr_type), + BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()), + b_.getInt32(fft_rank), b_.getInt64(input_batch), + b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), + b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), + b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); return Status::OK(); } @@ -1206,8 +1198,8 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape)); // TODO(b/63762267): Be more aggressive about specifying alignment. - b_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, - /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); + MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, + /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); } llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_); return Status::OK(); @@ -1466,19 +1458,19 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( accumulator_shard_type, "accumulator", &b_, 0)); } - llvm::Value* init_value_ssa = b_.CreateLoad(GetEmittedValueFor(init_value)); + llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value)); for (llvm::Value* accumulator_shard : accumulator) { llvm::Value* initial_value; auto shard_type = accumulator_shard->getType()->getPointerElementType(); if (auto vector_type = llvm::dyn_cast(shard_type)) { initial_value = - b_.CreateVectorSplat(vector_type->getNumElements(), init_value_ssa); + VectorSplat(vector_type->getNumElements(), init_value_ssa); } else { initial_value = init_value_ssa; } - b_.CreateAlignedStore(initial_value, accumulator_shard, element_alignment); + AlignedStore(initial_value, accumulator_shard, element_alignment); } llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"), @@ -1500,24 +1492,24 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( } CHECK(output_index.end() == it); - llvm::Value* input_address = b_.CreateBitCast( + llvm::Value* input_address = BitCast( arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy()); for (int i = 0; i < accumulator.size(); i++) { auto input_address_typed = - b_.CreateBitCast(input_address, accumulator[i]->getType()); + BitCast(input_address, accumulator[i]->getType()); auto current_accumulator_value = - b_.CreateAlignedLoad(accumulator[i], element_alignment); - auto addend = b_.CreateAlignedLoad(input_address_typed, element_alignment); + AlignedLoad(accumulator[i], element_alignment); + auto addend = AlignedLoad(input_address_typed, element_alignment); arg_array.AnnotateLoadStoreInstructionWithMetadata(addend); auto reduced_result = reduction_generator(&b_, current_accumulator_value, addend); - b_.CreateAlignedStore(reduced_result, accumulator[i], element_alignment); + AlignedStore(reduced_result, accumulator[i], element_alignment); if (i != (accumulator.size() - 1)) { - input_address = b_.CreateConstInBoundsGEP1_32(reduced_result->getType(), - input_address_typed, 1); + input_address = ConstInBoundsGEP1_32(reduced_result->getType(), + input_address_typed, 1); } } @@ -1526,8 +1518,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( ShardedVector result_ssa; result_ssa.reserve(accumulator.size()); for (auto accumulator_shard : accumulator) { - result_ssa.push_back( - b_.CreateAlignedLoad(accumulator_shard, element_alignment)); + result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment)); } return result_ssa; } @@ -1536,18 +1527,18 @@ void IrEmitter::EmitShardedVectorStore( llvm::Value* store_address, const std::vector& value_to_store, const int alignment, const llvm_ir::IrArray& containing_array) { for (int i = 0; i < value_to_store.size(); i++) { - auto store_address_typed = b_.CreateBitCast( - store_address, - llvm::PointerType::getUnqual(value_to_store[i]->getType())); + auto store_address_typed = + BitCast(store_address, + llvm::PointerType::getUnqual(value_to_store[i]->getType())); - auto store_instruction = b_.CreateAlignedStore( - value_to_store[i], store_address_typed, alignment); + auto store_instruction = + AlignedStore(value_to_store[i], store_address_typed, alignment); containing_array.AnnotateLoadStoreInstructionWithMetadata( store_instruction); if (i != (value_to_store.size() - 1)) { - store_address = b_.CreateConstInBoundsGEP1_32( - value_to_store[i]->getType(), store_address_typed, 1); + store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(), + store_address_typed, 1); } } } @@ -1620,9 +1611,8 @@ StatusOr IrEmitter::EmitVectorizedReduce( int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); int64 start_index = 0; int64 end_index = reduce->shape().dimensions(dimension); - std::unique_ptr loop = - loop_nest.AddLoop(start_index, end_index, - tensorflow::strings::Printf("dim.%lld", dimension)); + std::unique_ptr loop = loop_nest.AddLoop( + start_index, end_index, absl::StrFormat("dim.%d", dimension)); array_index[dimension] = loop->GetIndVarValue(); } @@ -1641,9 +1631,9 @@ StatusOr IrEmitter::EmitVectorizedReduce( int64 start_index = 0; int64 end_index = (innermost_dimension_size / vectorization_factor) * vectorization_factor; - std::unique_ptr loop = loop_nest.AddLoop( - start_index, end_index, vectorization_factor, - tensorflow::strings::Printf("dim.%lld", innermost_dimension)); + std::unique_ptr loop = + loop_nest.AddLoop(start_index, end_index, vectorization_factor, + absl::StrFormat("dim.%d", innermost_dimension)); array_index[innermost_dimension] = loop->GetIndVarValue(); SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_); @@ -1713,8 +1703,8 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator", &b_, MinimumAlignmentForPrimitiveType(accumulator_type)); llvm::Value* init_value_addr = GetEmittedValueFor(init_value); - llvm::Value* load_init_value = b_.CreateLoad(init_value_addr); - b_.CreateStore(load_init_value, accumulator_addr); + llvm::Value* load_init_value = Load(init_value_addr); + Store(load_init_value, accumulator_addr); // The enclosing loops go over all the target elements. Now we have to compute // the actual target element. For this, we build a new loop nest to iterate @@ -1747,12 +1737,12 @@ StatusOr IrEmitter::EmitTargetElementLoopBodyForReduce( // Apply the reduction function to the loaded value. llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_); llvm::Value* result = EmitThreadLocalCall( - *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element}, + *reduce->to_apply(), {Load(accumulator_addr), input_element}, "reduce_function"); - b_.CreateStore(result, accumulator_addr); + Store(result, accumulator_addr); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(accumulator_addr); + return Load(accumulator_addr); } Status IrEmitter::HandleReduce(HloInstruction* reduce) { @@ -1990,7 +1980,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { [this, pad](const llvm_ir::IrArray::Index& target_index) { const HloInstruction* padding_value = pad->operand(1); llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value); - return b_.CreateLoad(padding_value_addr); + return Load(padding_value_addr); })); // Create a loop to iterate over the operand elements and update the output @@ -2012,10 +2002,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { const PaddingConfig& padding_config = pad->padding_config(); llvm_ir::IrArray::Index output_index(operand_index.GetType()); for (size_t i = 0; i < operand_index.size(); ++i) { - llvm::Value* offset = b_.CreateMul( - operand_index[i], - b_.getInt64(padding_config.dimensions(i).interior_padding() + 1)); - llvm::Value* index = b_.CreateAdd( + llvm::Value* offset = + Mul(operand_index[i], + b_.getInt64(padding_config.dimensions(i).interior_padding() + 1)); + llvm::Value* index = Add( offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low())); output_index.push_back(index); } @@ -2118,7 +2108,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { gtl::ArraySlice operands(custom_call->operands()); - tensorflow::StringPiece custom_call_target(custom_call->custom_call_target()); + absl::string_view custom_call_target(custom_call->custom_call_target()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( @@ -2126,10 +2116,10 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { for (size_t i = 0; i < operands.size(); ++i) { const HloInstruction* operand = operands[i]; llvm::Value* operand_as_i8ptr = - b_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type); + PointerCast(GetEmittedValueFor(operand), i8_ptr_type); llvm::Value* slot_in_operands_alloca = - b_.CreateInBoundsGEP(operands_alloca, {b_.getInt64(i)}); - b_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca); + InBoundsGEP(operands_alloca, {b_.getInt64(i)}); + Store(operand_as_i8ptr, slot_in_operands_alloca); } auto* custom_call_ir_function = llvm::cast(module_->getOrInsertFunction( @@ -2141,9 +2131,9 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); auto* output_address_arg = - b_.CreatePointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); + PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); - b_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca}); + Call(custom_call_ir_function, {output_address_arg, operands_alloca}); return Status::OK(); } @@ -2170,8 +2160,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { return InternalError( "instruction %s %s does not share slice with " "instruction %s %s", - a->ToString().c_str(), slice_a.ToString().c_str(), - b->ToString().c_str(), slice_b.ToString().c_str()); + a->ToString(), slice_a.ToString(), b->ToString(), + slice_b.ToString()); } return Status::OK(); }; @@ -2202,15 +2192,14 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "header")), compute_function_->function()); - b_.CreateBr(header_bb); + Br(header_bb); b_.SetInsertPoint(header_bb); // Calls the condition function to determine whether to proceed with the // body. It must return a bool, so use the scalar call form. EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond")); - llvm::Value* while_predicate = b_.CreateICmpNE( - b_.CreateLoad( - GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), + llvm::Value* while_predicate = ICmpNE( + Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. @@ -2219,7 +2208,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { compute_function_->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); - b_.CreateCondBr(while_predicate, body_bb, exit_bb); + CondBr(while_predicate, body_bb, exit_bb); // Calls the body function from the body block. b_.SetInsertPoint(body_bb); @@ -2228,7 +2217,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body")); // Finishes with a branch back to the header. - b_.CreateBr(header_bb); + Br(header_bb); // Adds the exit block to the function and sets the insert point there. compute_function_->function()->getBasicBlockList().push_back(exit_bb); @@ -2275,7 +2264,6 @@ StatusOr IrEmitter::EmitFastConcatenate( output_min2maj.end()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); - llvm::Type* i8_type = b_.getInt8Ty(); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate)); llvm_ir::IrArray target_array = GetIrArrayFor(concatenate); @@ -2298,9 +2286,9 @@ StatusOr IrEmitter::EmitFastConcatenate( // Contiguous subregions from each operand to the concatenate contribute to a // contiguous subregion in the target buffer starting at target_region_begin. llvm::Value* target_region_begin = - b_.CreateBitCast(target_array.EmitArrayElementAddress( - outer_dims_index, &b_, "target_region"), - i8_ptr_type); + BitCast(target_array.EmitArrayElementAddress(outer_dims_index, &b_, + "target_region"), + i8_ptr_type); int64 byte_offset_into_target_region = 0; int64 inner_dims_product = @@ -2314,13 +2302,12 @@ StatusOr IrEmitter::EmitFastConcatenate( for (HloInstruction* operand : operands) { const Shape& input_shape = operand->shape(); llvm_ir::IrArray source_array = GetIrArrayFor(operand); - llvm::Value* copy_source_address = b_.CreateBitCast( + llvm::Value* copy_source_address = BitCast( source_array.EmitArrayElementAddress(outer_dims_index, &b_, "src_addr"), i8_ptr_type); llvm::Value* copy_target_address = - b_.CreateGEP(i8_type, target_region_begin, - b_.getInt64(byte_offset_into_target_region)); + GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region)); EmitTransferElements( copy_target_address, copy_source_address, @@ -2352,15 +2339,15 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); if (element_count == 1) { - auto* load_instruction = b_.CreateAlignedLoad( - b_.CreateBitCast(source, primitive_ptr_type), element_alignment); + auto* load_instruction = + AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment); source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction); - auto* store_instruction = b_.CreateAlignedStore( - load_instruction, b_.CreateBitCast(target, primitive_ptr_type), - element_alignment); + auto* store_instruction = + AlignedStore(load_instruction, BitCast(target, primitive_ptr_type), + element_alignment); target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { - auto* memcpy_instruction = b_.CreateMemCpy( + auto* memcpy_instruction = MemCpy( target, /*DstAlign=*/element_alignment, source, /*SrcAlign=*/element_alignment, element_count * primitive_type_size); @@ -2422,9 +2409,9 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { // cond_result = true_computation(true_operand) // else // cond_result = false_computation(false_operand) - llvm::LoadInst* pred_value = b_.CreateLoad( - GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); - llvm::Value* pred_cond = b_.CreateICmpNE( + llvm::LoadInst* pred_value = + Load(GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = ICmpNE( pred_value, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); @@ -2450,11 +2437,6 @@ Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) { return Status::OK(); } -Status IrEmitter::HandleIota(HloInstruction* iota) { - // TODO(b/64798317): implement iota on CPU. - return Unimplemented("Iota is not implemented on CPU."); -} - Status IrEmitter::HandleRng(HloInstruction* rng) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : rng->operands()) { @@ -2511,8 +2493,8 @@ llvm::Value* IrEmitter::GetProfileCounterCommon( int64 prof_counter_idx = it->second; string counter_name = IrName("prof_counter", hlo.name()); - return b_.CreateGEP(GetProfileCountersArgument(), - b_.getInt64(prof_counter_idx), AsStringRef(counter_name)); + return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx), + AsStringRef(counter_name)); } void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b, @@ -2666,8 +2648,7 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( llvm::Value* params = compute_function_->parameters_arg(); llvm::Value* param_address_offset = llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); - llvm::LoadInst* param_address_untyped = - b_.CreateLoad(param_address_offset); + llvm::LoadInst* param_address_untyped = Load(param_address_offset); if (!ShapeUtil::IsOpaque(target_shape)) { AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); @@ -2687,17 +2668,15 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( auto buf_it = thread_local_buffers_.find(key); if (buf_it == thread_local_buffers_.end()) { llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry( - IrShapeType(shape), - tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_, - MinimumAlignmentForShape(target_shape)); + IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()), + &b_, MinimumAlignmentForShape(target_shape)); auto it_inserted_pair = thread_local_buffers_.insert({key, buffer}); CHECK(it_inserted_pair.second); buf_it = it_inserted_pair.first; } return buf_it->second; }(); - return b_.CreateBitCast(tempbuf_address, - IrShapeType(target_shape)->getPointerTo()); + return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo()); } llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( @@ -2705,7 +2684,7 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( GetTempBuffersArgument(), slice.index(), &b_); - llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr); + llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr); if (hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { tempbuf_address_base->setMetadata( @@ -2719,10 +2698,10 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( if (slice.offset() > 0) { // Adjust the address to account for the slice offset. tempbuf_address_untyped = - b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); + InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); } - return b_.CreateBitCast(tempbuf_address_untyped, - IrShapeType(target_shape)->getPointerTo()); + return BitCast(tempbuf_address_untyped, + IrShapeType(target_shape)->getPointerTo()); } llvm::Value* IrEmitter::EmitTempBufferPointer( @@ -2753,7 +2732,7 @@ Status IrEmitter::EmitTargetElementLoop( } Status IrEmitter::EmitTargetElementLoop( - HloInstruction* target_op, tensorflow::StringPiece desc, + HloInstruction* target_op, absl::string_view desc, const llvm_ir::ElementGenerator& element_generator) { VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString(); @@ -2808,8 +2787,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, llvm::Value* destination_value = GetEmittedValueFor(&destination); int64 source_size = ByteSizeOf(source.shape()); // TODO(b/63762267): Be more aggressive about specifying alignment. - b_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value, - /*SrcAlign=*/1, source_size); + MemCpy(destination_value, /*DstAlign=*/1, source_value, + /*SrcAlign=*/1, source_size); return Status::OK(); } @@ -2827,8 +2806,8 @@ Status IrEmitter::ElementTypesSameAndSupported( if (std::find(supported_types.begin(), supported_types.end(), primitive_type) == supported_types.end()) { return Unimplemented("unsupported operand type %s in op %s", - PrimitiveType_Name(primitive_type).c_str(), - HloOpcodeString(instruction.opcode()).c_str()); + PrimitiveType_Name(primitive_type), + HloOpcodeString(instruction.opcode())); } return Status::OK(); } @@ -2848,7 +2827,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { llvm::Value* IrEmitter::EmitThreadLocalCall( const HloComputation& callee, tensorflow::gtl::ArraySlice parameters, - tensorflow::StringPiece name) { + absl::string_view name) { const Shape& return_shape = callee.root_instruction()->shape(); // Lifting this restriction to allow "small" arrays should be easy. Allowing @@ -2863,38 +2842,37 @@ llvm::Value* IrEmitter::EmitThreadLocalCall( CHECK(!parameter->getType()->isPointerTy()); llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry( parameter->getType(), "arg_addr", &b_); - b_.CreateStore(parameter, parameter_addr); + Store(parameter, parameter_addr); parameter_addrs.push_back(parameter_addr); } llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(return_type, module_), - tensorflow::strings::StrCat(name, "_retval_addr"), &b_, + absl::StrCat(name, "_retval_addr"), &b_, MinimumAlignmentForPrimitiveType(return_type)); - b_.CreateCall( - FindOrDie(emitted_functions_, &callee), - GetArrayFunctionCallArguments( - parameter_addrs, &b_, name, - /*return_value_buffer=*/return_value_buffer, - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/ - llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), - /*profile_counters_arg=*/GetProfileCountersArgument())); + Call(FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + parameter_addrs, &b_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), + /*profile_counters_arg=*/GetProfileCountersArgument())); - return b_.CreateLoad(return_value_buffer); + return Load(return_value_buffer); } void IrEmitter::EmitGlobalCall(const HloComputation& callee, - tensorflow::StringPiece name) { - b_.CreateCall(FindOrDie(emitted_functions_, &callee), - GetArrayFunctionCallArguments( - /*parameter_addresses=*/{}, &b_, name, - /*return_value_buffer=*/ - llvm::Constant::getNullValue(b_.getInt8PtrTy()), - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), - /*profile_counters_arg=*/GetProfileCountersArgument())); + absl::string_view name) { + Call(FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + /*parameter_addresses=*/{}, &b_, name, + /*return_value_buffer=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()), + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); } llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index c9a1dab62dcbcd926baa82737d24efa03fd326e9..f98891246b0c281514a0249fff5d654bdf8e31ea 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -39,12 +40,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" @@ -55,7 +56,8 @@ namespace cpu { // This class is the top-level API for the XLA HLO --> LLVM IR compiler. It // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR // functions. -class IrEmitter : public DfsHloVisitorWithDefault { +class IrEmitter : public DfsHloVisitorWithDefault, + public IrBuilderMixin { public: // Create a new LLVM IR emitter. // @@ -100,6 +102,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::IRBuilder<>* b() { return &b_; } + // builder() is for IrBuilderMixin. + llvm::IRBuilder<>* builder() { return &b_; } + // Emit an LLVM global variable for every constant buffer allocation. Status EmitConstantGlobals(); @@ -107,7 +112,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* EmitElementalMap( const HloMapInstruction& map_instr, tensorflow::gtl::ArraySlice elemental_operands, - tensorflow::StringPiece name); + absl::string_view name); protected: // @@ -152,7 +157,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleConditional(HloInstruction* conditional) override; Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* gen_token) override; - Status HandleIota(HloInstruction* iota) override; Status HandleRng(HloInstruction* rng) override; Status FinishVisit(HloInstruction* root) override; @@ -239,7 +243,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { // function that a map operation applies. StatusOr EmitFunction( HloComputation* function, // The function to emit. - tensorflow::StringPiece + absl::string_view function_name_suffix); // Used for LLVM IR register names. // Emits a call to a thread local function (e.g. to the computation nested @@ -251,14 +255,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* EmitThreadLocalCall( const HloComputation& callee, tensorflow::gtl::ArraySlice parameters, - tensorflow::StringPiece name); + absl::string_view name); // Emits a call to a "global" function (e.g. to the computation nested within // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to // the parameters and return values for these computations so there is no need // to explicitly pass parameters or return results. - void EmitGlobalCall(const HloComputation& callee, - tensorflow::StringPiece name); + void EmitGlobalCall(const HloComputation& callee, absl::string_view name); // Returns the buffer to which a global call to `callee` would have written // its result. @@ -285,7 +288,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* target_op, const llvm_ir::ElementGenerator& element_generator); Status EmitTargetElementLoop( - HloInstruction* target_op, tensorflow::StringPiece desc, + HloInstruction* target_op, absl::string_view desc, const llvm_ir::ElementGenerator& element_generator); // Emits a memcpy from the source instruction's result value to the diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 2db4d000f5b149969c88fb4325ca28aa11dc3708..784045313dfa2d44da64c6b50be80258c5e8466a 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_function.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -189,7 +190,7 @@ void IrFunction::Initialize(const string& function_name, llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { CHECK_GT(num_dynamic_loop_bounds_, 0); CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); - string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); + string name = absl::StrCat("dynamic_loop_bound_", offset); return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), b_->getInt64(offset), AsStringRef(name))); } @@ -200,7 +201,7 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { // address buffer). std::vector GetArrayFunctionCallArguments( tensorflow::gtl::ArraySlice parameter_addresses, - llvm::IRBuilder<>* b, tensorflow::StringPiece name, + llvm::IRBuilder<>* b, absl::string_view name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { llvm::Value* parameter_addresses_buffer; @@ -211,13 +212,13 @@ std::vector GetArrayFunctionCallArguments( } else { parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), b); + absl::StrCat(name, "_parameter_addresses"), b); for (size_t i = 0; i < parameter_addresses.size(); ++i) { llvm::Value* parameter_as_i8ptr = b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat( - name, "_parameter_", i, "_address_as_i8ptr"))); + AsStringRef(absl::StrCat(name, "_parameter_", i, + "_address_as_i8ptr"))); llvm::Value* slot_in_param_addresses = b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); @@ -320,8 +321,7 @@ Status EmitCallToParallelForkJoin( /*Linkage=*/llvm::GlobalValue::PrivateLinkage, /*Initializer=*/partitions_array, /*Name=*/ - AsStringRef( - tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); + AsStringRef(absl::StrCat(name, "_parallel_dimension_partitions"))); // Add argument specifying parallel dimension partitions. fork_join_arguments.push_back( diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index a41cbb64cdd9f5b6de5d1eadfbf7e63e1e984801..ee7595f6e9706902a3e6b4b2e7e38c3f022abca3 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -116,7 +116,7 @@ class IrFunction { // Returns an array of compute function call argument ir values. std::vector GetArrayFunctionCallArguments( tensorflow::gtl::ArraySlice parameter_addresses, - llvm::IRBuilder<>* b, tensorflow::StringPiece name, + llvm::IRBuilder<>* b, absl::string_view name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index 8560e4296aa95fe791446abb1b4363b9145f343e..f8441c3e345504616485c6b34b4302acd5cc23a3 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace cpu { @@ -30,8 +30,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( dynamic_loop_bounds_(dynamic_loop_bounds) {} std::vector -ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) { +ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, + llvm::Type* index_type) { CHECK_NE(index_type, nullptr); CHECK(!ShapeUtil::IsTuple(shape_)); @@ -52,15 +52,15 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( llvm::Value* end_index = (*dynamic_loop_bounds_)[bounds_index].second; std::unique_ptr loop = loop_nest.AddLoop( - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), - start_index, end_index); + /*suffix=*/absl::StrFormat("dim.%d", dimension), start_index, + end_index); array_index[dimension] = loop->GetIndVarValue(); } else { // Emit static loop bounds for this dimension. std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + /*suffix=*/absl::StrFormat("dim.%d", dimension)); array_index[dimension] = loop->GetIndVarValue(); } } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index 076c683ca566f2c53992c358903d2aadead290f9..a604e1db222139c239a2a89359a7359463e0def7 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -61,7 +61,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type) override; private: const DynamicLoopBounds* dynamic_loop_bounds_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 286d407ca6e796a184738aee4d14bd5ed7e2f356..b4c0c09ec06bac9b5e228428c072948afdd4a547 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" @@ -217,8 +218,7 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper( // Outline 'instruction' in 'computation' for parallel task assignment. auto* call = module->OutlineExpressionFromComputation( - {instruction}, - tensorflow::strings::StrCat("parallel_", instruction->name()), + {instruction}, absl::StrCat("parallel_", instruction->name()), computation); // Set assigned dimension partitioning to 'instruction'. diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index 8becc8fa23424d7454cc783eb9d853aecb5d053b..a99cd99c14abb66fc426c43656520e01f34a1700 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -73,7 +73,7 @@ class ParallelTaskAssigner : public HloPassInterface { target_machine_features_(*target_machine_features) {} ~ParallelTaskAssigner() override {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "cpu-parallel-task-assigner"; } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index ee272b5f4f49904a9e75a4653b7dc1fdc89434c1..a84ee78b19981e480858320e445de7f5dae27d61 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { @@ -36,7 +35,9 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; ParallelTaskAssignmentTest() - : target_machine_features_([](int64 shape_size) { + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false), + target_machine_features_([](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index f227e4ae139b92e56786e38ef8eef72c9e2cd424..942e2ddd3940fffd5d87518f059beaced3cdc925 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -67,8 +67,8 @@ int main(int argc, char** argv) { /*execution_profile=*/&profile); std::unique_ptr actual = result.ConsumeValueOrDie(); - LOG(INFO) << tensorflow::strings::Printf("computation took %lldns", - profile.compute_time_ns()); + LOG(INFO) << absl::StrFormat("computation took %dns", + profile.compute_time_ns()); LOG(INFO) << actual->ToString(); return 0; diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index b026aef3fec729716234a1f38c4ac4993666aeb5..bf98064647f4c29ba689902da4d737e1922391d3 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -170,15 +170,14 @@ namespace { bool RegisterKnownJITSymbols() { CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global(); -#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ - do { \ - auto* function_address = \ - reinterpret_cast(__xla_cpu_runtime_##base_name); \ - registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ - function_address); \ - CHECK_EQ( \ - tensorflow::StringPiece(xla::cpu::runtime::k##base_name##SymbolName), \ - "__xla_cpu_runtime_" #base_name); \ +#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ + do { \ + auto* function_address = \ + reinterpret_cast(__xla_cpu_runtime_##base_name); \ + registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ + function_address); \ + CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \ + "__xla_cpu_runtime_" #base_name); \ } while (false) REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 4635fa5d74f86eb7f2543d263132d87e6eaa20e0..2384166fd2002a67a8aa785ad5fb341d037ee01f 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -110,6 +110,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -123,6 +124,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index 6fcce42eaa4599eb8a6dacc1bd39eefd39aa5e50..fcd87b36b32915773546c211d7d2c447a69bef49 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index b68ac67574d0b9f20ecc0370cdaed87d4465b225..22721051e54e2cf9590b60333c51d1d028bb28e9 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -129,8 +129,8 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { error_spec_); } -TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { - // Test a chain of fusable ops with a non-fusable op (a reduce) thrown in the +TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { + // Test a chain of fusible ops with a non-fusible op (a reduce) thrown in the // middle. auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index 973aac8766f5aabca15e5173b43480c113c100dd..a434c04a980b9b3cd849792b97a0d9e965ba09f2 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -32,9 +32,9 @@ const char* const kTriple_android_arm = "armv7-none-android"; struct IntrinsicTestSpec { HloOpcode opcode; - tensorflow::StringPiece triple; - tensorflow::StringPiece features; - tensorflow::StringPiece check_lines; + absl::string_view triple; + absl::string_view features; + absl::string_view check_lines; }; // Tests that unary functions get lowered using intrinsic calls. @@ -65,9 +65,8 @@ class CpuUnaryIntrinsicTest features = ""; } - return tensorflow::strings::StrCat(opcode.c_str(), "_On_", triple.c_str(), - features.empty() ? "" : "_With", - features.c_str()); + return absl::StrCat(opcode, "_On_", triple, + (features.empty() ? "" : "_With"), features); } }; diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h index 56b28fd22da1ea6bc19f98e76f0f2ef4044cd3af..c326beb899f9a434d772c0fda032efc9113b6f42 100644 --- a/tensorflow/compiler/xla/service/defuser.h +++ b/tensorflow/compiler/xla/service/defuser.h @@ -29,7 +29,7 @@ class Defuser : public HloPassInterface { public: Defuser() {} ~Defuser() override {} - tensorflow::StringPiece name() const override { return "defuser"; } + absl::string_view name() const override { return "defuser"; } // Run defusion on the given module. Returns whether the module was // changed. diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc index e727ba49cb6321e499b5d50d5f45e7f7f6bb6fef..37d1895d41447ba0219bb57170e61154fdd8bcdd 100644 --- a/tensorflow/compiler/xla/service/defuser_test.cc +++ b/tensorflow/compiler/xla/service/defuser_test.cc @@ -26,6 +26,11 @@ namespace xla { namespace { class DefuserTest : public HloVerifiedTestBase { + public: + DefuserTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: // Returns the number of fusion instructions in the module. int FusionCount() { diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index 48e44714998f61c9bdccaa43719abc533eb83565..ba2a674d9af547ad574ae49e1e87f3afcaf6112a 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -27,9 +27,7 @@ namespace { class ControlDepRemover : public HloPassInterface { public: ControlDepRemover() = default; - tensorflow::StringPiece name() const override { - return "control-dep-remover"; - } + absl::string_view name() const override { return "control-dep-remover"; } StatusOr Run(HloModule* module) override { bool changed = false; diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h index cc1695b7f863805e0b483478639c17cb9061310a..7be70add2f7566376b3179740e411d6341badf7c 100644 --- a/tensorflow/compiler/xla/service/despecializer.h +++ b/tensorflow/compiler/xla/service/despecializer.h @@ -33,7 +33,7 @@ namespace xla { class Despecializer : public HloPassInterface { public: Despecializer(); - tensorflow::StringPiece name() const override { return "despecializer"; } + absl::string_view name() const override { return "despecializer"; } StatusOr Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index e228bb56bce8febcca28ae171f6de90973d020ab..1d0297cfbfc26c562fb36ecd02163c90af4b3003 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -36,9 +36,8 @@ StatusOr StreamExecutorMemoryAllocator::Allocate( se::DeviceMemoryBase result = stream_executor->AllocateArray(size); if (size > 0 && result == nullptr) { return ResourceExhausted( - "Failed to allocate request for %s (%lluB) on device ordinal %d", - tensorflow::strings::HumanReadableNumBytes(size).c_str(), size, - device_ordinal); + "Failed to allocate request for %s (%uB) on device ordinal %d", + tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal); } return OwningDeviceMemory(result, device_ordinal, this); } @@ -61,12 +60,12 @@ StatusOr StreamExecutorMemoryAllocator::GetStreamExecutor( } if (device_ordinal >= stream_executors_.size()) { return InvalidArgument( - "device ordinal value (%d) >= number of devices (%zu)", device_ordinal, + "device ordinal value (%d) >= number of devices (%u)", device_ordinal, stream_executors_.size()); } if (stream_executors_[device_ordinal] == nullptr) { return NotFound("Device %s:%d present but not supported", - platform()->Name().c_str(), device_ordinal); + platform()->Name(), device_ordinal); } return stream_executors_[device_ordinal]; } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc index 2172ae0a29626660e8abd29a789e0baa3831519d..3e7373adc5ab8a60fd18348ce2477175aaaa8fd4 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc @@ -28,14 +28,14 @@ template Status DfsHloVisitorBase::HandleElementwiseUnary( HloInstructionPtr hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s", - HloOpcodeString(hlo->opcode()).c_str()); + HloOpcodeString(hlo->opcode())); } template Status DfsHloVisitorBase::HandleElementwiseBinary( HloInstructionPtr hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s", - HloOpcodeString(hlo->opcode()).c_str()); + HloOpcodeString(hlo->opcode())); } template diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 690b5df514310b0943de2cd69bc889adad58bb3f..f6f8fc5a2ad63af1462b16a9281013b3418b2930 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -19,13 +19,13 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" @@ -107,6 +107,7 @@ class DfsHloVisitorBase { virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; + virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 20c6bafe7c22b02588c034f4532dd38fe10add65..4f620e4c3a3d3c2ecf3fd4a2815b45831faef9e6 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -16,13 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -94,8 +94,11 @@ class DfsHloVisitorWithDefaultBase Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } - Status HandleAllToAll(HloInstructionPtr crs) override { - return DefaultAction(crs); + Status HandleAllToAll(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } + Status HandleCollectivePermute(HloInstructionPtr hlo) override { + return DefaultAction(hlo); } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h index 1959b687f16d6909a3283021c8635b3e65e6e412..fc38e317001695921d20f9bbe5775e61a8eeaa45 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.h +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -29,7 +29,7 @@ class DotDecomposer : public HloPassInterface { DotDecomposer(bool decompose_batch_dot = true) : decompose_batch_dot_(decompose_batch_dot) {} ~DotDecomposer() = default; - tensorflow::StringPiece name() const override { return "dot_decomposer"; } + absl::string_view name() const override { return "dot_decomposer"; } // Run DotDecomposer pass on computations in 'module'. // Returns whether the 'module' was changed. diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 4b19aa5df972001ab1975fac5f88ad02703ff84b..813e93fafa1b67c8abf4ff189642fd3fa8ed6198 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -22,11 +22,14 @@ limitations under the License. // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -39,17 +42,16 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +using absl::StrCat; using llvm_ir::AsStringRef; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; -using tensorflow::strings::StrCat; namespace { @@ -204,7 +206,7 @@ llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, } // namespace StatusOr ElementalIrEmitter::EmitUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { if (op->opcode() == HloOpcode::kCopy) { return operand_value; } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || @@ -218,7 +220,7 @@ StatusOr ElementalIrEmitter::EmitUnaryOp( } StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { switch (op->opcode()) { case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -230,14 +232,14 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } if (to_type == PRED) { return b_->CreateZExt( - b_->CreateICmpNE(operand_value, llvm::ConstantInt::get( - operand_value->getType(), 0)), + ICmpNE(operand_value, + llvm::ConstantInt::get(operand_value->getType(), 0)), llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } if (primitive_util::IsIntegralType(to_type)) { - return b_->CreateIntCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_), - primitive_util::IsSignedIntegralType(from_type)); + return IntCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_), + primitive_util::IsSignedIntegralType(from_type)); } if (primitive_util::IsFloatingPointType(to_type)) { if (to_type == BF16) { @@ -253,19 +255,17 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( primitive_util::ComplexComponentType(to_type), module_); if (primitive_util::IsSignedIntegralType(from_type)) { return EmitComposeComplex( - op, b_->CreateSIToFP(operand_value, to_ir_component_type), - nullptr); + op, SIToFP(operand_value, to_ir_component_type), nullptr); } if (primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED) { return EmitComposeComplex( - op, b_->CreateUIToFP(operand_value, to_ir_component_type), - nullptr); + op, UIToFP(operand_value, to_ir_component_type), nullptr); } } return Unimplemented("conversion from primitive type %s to %s", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str()); + PrimitiveType_Name(from_type), + PrimitiveType_Name(to_type)); } case HloOpcode::kBitcastConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -276,14 +276,13 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return b_->CreateBitCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return BitCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return InvalidArgument( "bitcast conversion from primitive type %s to %s with unequal " "bit-widths (%u versus %u) ", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str(), + PrimitiveType_Name(from_type), PrimitiveType_Name(to_type), primitive_util::BitWidth(from_type), primitive_util::BitWidth(to_type)); } @@ -293,10 +292,8 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( if (is_signed) { auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto zero = llvm::ConstantInt::get(type, 0); - auto cmp = b_->CreateICmpSGE(operand_value, zero); - return b_->CreateSelect(cmp, operand_value, - b_->CreateNeg(operand_value)); + auto cmp = ICmpSGE(operand_value, GetZero(type)); + return Select(cmp, operand_value, Neg(operand_value)); } else { return operand_value; } @@ -308,44 +305,37 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( {operand_value->getType()}, b_); } case HloOpcode::kSign: { - bool is_signed = - primitive_util::IsSignedIntegralType(op->shape().element_type()); + CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type())) + << op->shape().element_type(); auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto zero = llvm::ConstantInt::get(type, 0); - auto cmp = b_->CreateICmpEQ(operand_value, zero); - if (is_signed) { - auto ashr = - b_->CreateAShr(operand_value, type->getIntegerBitWidth() - 1); - return b_->CreateSelect(cmp, zero, b_->CreateOr(ashr, 1)); - } else { - return b_->CreateSelect(cmp, zero, llvm::ConstantInt::get(type, 1)); - } + auto cmp = ICmpEQ(operand_value, GetZero(type)); + auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1); + return Select(cmp, GetZero(type), Or(ashr, 1)); } case HloOpcode::kNegate: - return b_->CreateNeg(operand_value); + return Neg(operand_value); case HloOpcode::kNot: { auto type = op->shape().element_type(); if (type == PRED) { // It is not sufficient to just call CreateNot() here because a PRED // is represented as an i8 and the truth value is stored only in the // bottom bit. - return b_->CreateZExt( - b_->CreateNot(b_->CreateTrunc(operand_value, b_->getInt1Ty())), - llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())), + llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } else if (primitive_util::IsIntegralType(type)) { - return b_->CreateNot(operand_value); + return Not(operand_value); } return Unimplemented("unary op Not is not defined for type '%d'", type); } default: return Unimplemented("unary integer op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { switch (op->opcode()) { case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -362,8 +352,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitComposeComplex( op, - b_->CreateFPCast(operand_value, llvm_ir::PrimitiveTypeToIrType( - to_component_type, module_)), + FPCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), nullptr); } if (from_type == BF16) { @@ -379,26 +369,25 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } if (to_type == PRED) { return b_->CreateZExt( - b_->CreateFCmpUNE( - operand_value, - llvm::ConstantFP::get(operand_value->getType(), 0.0)), + FCmpUNE(operand_value, + llvm::ConstantFP::get(operand_value->getType(), 0.0)), llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } if (primitive_util::IsFloatingPointType(to_type)) { - return b_->CreateFPCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return FPCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsSignedIntegralType(to_type)) { - return b_->CreateFPToSI( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return FPToSI(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsUnsignedIntegralType(to_type)) { - return b_->CreateFPToUI( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return FPToUI(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return Unimplemented("unhandled conversion operation: %s => %s", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str()); + PrimitiveType_Name(from_type), + PrimitiveType_Name(to_type)); } case HloOpcode::kBitcastConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -409,14 +398,13 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return b_->CreateBitCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return BitCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return InvalidArgument( "bitcast conversion from primitive type %s to %s with unequal " "bit-widths (%u versus %u) ", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str(), + PrimitiveType_Name(from_type), PrimitiveType_Name(to_type), primitive_util::BitWidth(from_type), primitive_util::BitWidth(to_type)); } @@ -454,11 +442,10 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( // TODO(b/32151903): Ensure consistent sign behavior for -0.0. auto type = operand_value->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = b_->CreateFCmpOEQ(operand_value, zero); - auto olt = b_->CreateFCmpOLT(operand_value, zero); - return b_->CreateSelect( - oeq, zero, - b_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0), + auto oeq = FCmpOEQ(operand_value, zero); + auto olt = FCmpOLT(operand_value, zero); + return Select(oeq, zero, + Select(olt, llvm::ConstantFP::get(type, -1.0), llvm::ConstantFP::get(type, 1.0))); } case HloOpcode::kIsFinite: { @@ -468,24 +455,24 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( auto abs_value = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::fabs, {operand_value}, {type}, b_); auto infinity = llvm::ConstantFP::getInfinity(type); - auto not_infinite = b_->CreateFCmpONE(abs_value, infinity); + auto not_infinite = FCmpONE(abs_value, infinity); return b_->CreateZExt(not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } case HloOpcode::kNegate: - return b_->CreateFNeg(operand_value); + return FNeg(operand_value); case HloOpcode::kReal: return operand_value; case HloOpcode::kImag: return llvm::ConstantFP::get(operand_value->getType(), 0.0); default: return Unimplemented("unary floating-point op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { PrimitiveType input_type = op->operand(0)->shape().element_type(); PrimitiveType component_type = primitive_util::IsComplexType(input_type) @@ -497,12 +484,11 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); - auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b)); + auto sum_sq = FAdd(FMul(a, a), FMul(b, b)); TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq), - angle); + return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle); } case HloOpcode::kLog1p: { // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) @@ -510,14 +496,12 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); auto one = llvm::ConstantFP::get(llvm_ty, 1.0); - auto a_plus_one = b_->CreateFAdd(a, one); - auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a_plus_one, a_plus_one), - b_->CreateFMul(b, b)); + auto a_plus_one = FAdd(a, one); + auto sum_sq = FAdd(FMul(a_plus_one, a_plus_one), FMul(b, b)); TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq), - angle); + return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle); } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -531,11 +515,9 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( primitive_util::ComplexComponentType(to_type); auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(to_component_type, module_); - return EmitComposeComplex(op, - b_->CreateFPCast(EmitExtractReal(operand_value), - to_ir_component_type), - b_->CreateFPCast(EmitExtractImag(operand_value), - to_ir_component_type)); + return EmitComposeComplex( + op, FPCast(EmitExtractReal(operand_value), to_ir_component_type), + FPCast(EmitExtractImag(operand_value), to_ir_component_type)); } case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) @@ -545,8 +527,7 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); TF_ASSIGN_OR_RETURN( auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); - return EmitComposeComplex(op, b_->CreateFMul(exp_a, cos_b), - b_->CreateFMul(exp_a, sin_b)); + return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b)); } case HloOpcode::kExpm1: { // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i @@ -557,8 +538,8 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( TF_ASSIGN_OR_RETURN( auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0); - auto real_result = b_->CreateFSub(b_->CreateFMul(exp_a, cos_b), one); - auto imag_result = b_->CreateFMul(exp_a, sin_b); + auto real_result = FSub(FMul(exp_a, cos_b), one); + auto imag_result = FMul(exp_a, sin_b); return EmitComposeComplex(op, real_result, imag_result); } case HloOpcode::kCos: { @@ -573,14 +554,13 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); auto type = a->getType(); TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); - auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); - auto half_exp_neg_b = - b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b); TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); - return EmitComposeComplex( - op, b_->CreateFMul(cos_a, b_->CreateFAdd(half_exp_neg_b, half_exp_b)), - b_->CreateFMul(sin_a, b_->CreateFSub(half_exp_neg_b, half_exp_b))); + return EmitComposeComplex(op, + FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)), + FMul(sin_a, FSub(half_exp_neg_b, half_exp_b))); } case HloOpcode::kSin: { // sin(z) = .5i(e^(-iz) - e^(iz)) @@ -596,14 +576,13 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); auto type = a->getType(); TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); - auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); - auto half_exp_neg_b = - b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b); TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); - return EmitComposeComplex( - op, b_->CreateFMul(sin_a, b_->CreateFAdd(half_exp_b, half_exp_neg_b)), - b_->CreateFMul(cos_a, b_->CreateFSub(half_exp_b, half_exp_neg_b))); + return EmitComposeComplex(op, + FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)), + FMul(cos_a, FSub(half_exp_b, half_exp_neg_b))); } case HloOpcode::kTanh: { /* @@ -631,74 +610,63 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a)); TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b)); TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); - auto exp_neg_a = - b_->CreateFDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); - auto exp_2a_minus_exp_neg_2a = b_->CreateFSub( - b_->CreateFMul(exp_a, exp_a), b_->CreateFMul(exp_neg_a, exp_neg_a)); - auto cos_b_sq = b_->CreateFMul(cos_b, cos_b); - auto sin_b_sq = b_->CreateFMul(sin_b, sin_b); - auto real_num = - b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), - b_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); - auto cos_b_sin_b = b_->CreateFMul(cos_b, sin_b); - auto exp_a_plus_exp_neg_a = b_->CreateFAdd(exp_a, exp_neg_a); + auto exp_neg_a = FDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); + auto exp_2a_minus_exp_neg_2a = + FSub(FMul(exp_a, exp_a), FMul(exp_neg_a, exp_neg_a)); + auto cos_b_sq = FMul(cos_b, cos_b); + auto sin_b_sq = FMul(sin_b, sin_b); + auto real_num = FAdd(FMul(cos_b_sq, exp_2a_minus_exp_neg_2a), + FMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); + auto cos_b_sin_b = FMul(cos_b, sin_b); + auto exp_a_plus_exp_neg_a = FAdd(exp_a, exp_neg_a); auto exp_a_plus_exp_neg_a_sq = - b_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); - auto exp_a_minus_exp_neg_a = b_->CreateFSub(exp_a, exp_neg_a); + FMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); + auto exp_a_minus_exp_neg_a = FSub(exp_a, exp_neg_a); auto exp_a_minus_exp_neg_a_sq = - b_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); - auto imag_num = b_->CreateFMul( - cos_b_sin_b, - b_->CreateFSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq)); - auto denom = - b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), - b_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); - return EmitComposeComplex(op, b_->CreateFDiv(real_num, denom), - b_->CreateFDiv(imag_num, denom)); + FMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); + auto imag_num = FMul( + cos_b_sin_b, FSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq)); + auto denom = FAdd(FMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), + FMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); + return EmitComposeComplex(op, FDiv(real_num, denom), + FDiv(imag_num, denom)); } case HloOpcode::kAbs: { - auto sum_sq = - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value), - EmitExtractReal(operand_value)), - b_->CreateFMul(EmitExtractImag(operand_value), - EmitExtractImag(operand_value))); + auto sum_sq = FAdd( + FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)), + FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value))); return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_); } case HloOpcode::kSign: { // Sign(c) = c / |c| - auto sum_sq = - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value), - EmitExtractReal(operand_value)), - b_->CreateFMul(EmitExtractImag(operand_value), - EmitExtractImag(operand_value))); + auto sum_sq = FAdd( + FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)), + FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value))); auto cplx_abs = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_); auto type = cplx_abs->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = b_->CreateFCmpOEQ(cplx_abs, zero); - return b_->CreateSelect( + auto oeq = FCmpOEQ(cplx_abs, zero); + return Select( oeq, EmitComposeComplex(op, zero, zero), - EmitComposeComplex( - op, b_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs), - b_->CreateFDiv(EmitExtractImag(operand_value), cplx_abs))); + EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs), + FDiv(EmitExtractImag(operand_value), cplx_abs))); } case HloOpcode::kNegate: - return EmitComposeComplex(op, - b_->CreateFNeg(EmitExtractReal(operand_value)), - b_->CreateFNeg(EmitExtractImag(operand_value))); + return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)), + FNeg(EmitExtractImag(operand_value))); case HloOpcode::kReal: return EmitExtractReal(operand_value); case HloOpcode::kImag: return EmitExtractImag(operand_value); default: return Unimplemented("unary complex op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { PrimitiveType operand_type = op->operand(0)->shape().element_type(); if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || operand_type == PRED) { @@ -713,21 +681,20 @@ StatusOr ElementalIrEmitter::EmitBinaryOp( } StatusOr ElementalIrEmitter::EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { switch (op->opcode()) { case HloOpcode::kComplex: return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: - return b_->CreateFAdd(lhs_value, rhs_value); + return FAdd(lhs_value, rhs_value); case HloOpcode::kSubtract: - return b_->CreateFSub(lhs_value, rhs_value); + return FSub(lhs_value, rhs_value); case HloOpcode::kMultiply: - return b_->CreateFMul(lhs_value, rhs_value); + return FMul(lhs_value, rhs_value); case HloOpcode::kDivide: - return b_->CreateFDiv(lhs_value, rhs_value); + return FDiv(lhs_value, rhs_value); case HloOpcode::kRemainder: - return b_->CreateFRem(lhs_value, rhs_value); + return FRem(lhs_value, rhs_value); // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered // comparisons always return false when one of the operands is NaN, whereas // unordered comparisons return true. @@ -764,66 +731,52 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value); default: return Unimplemented("binary floating point op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr ElementalIrEmitter::EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { switch (op->opcode()) { case HloOpcode::kAdd: - return EmitComposeComplex(op, - b_->CreateFAdd(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFAdd(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); + return EmitComposeComplex( + op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); case HloOpcode::kSubtract: - return EmitComposeComplex(op, - b_->CreateFSub(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFSub(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); + return EmitComposeComplex( + op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); case HloOpcode::kMultiply: return EmitComposeComplex( op, - b_->CreateFSub(b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))), - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value)))); + FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))), + FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)))); case HloOpcode::kDivide: { // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di)) // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2) auto rhs_sum_sq = - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(rhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(rhs_value), - EmitExtractImag(rhs_value))); + FAdd(FMul(EmitExtractReal(rhs_value), EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(rhs_value), EmitExtractImag(rhs_value))); auto type = rhs_sum_sq->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = b_->CreateFCmpOEQ(rhs_sum_sq, zero); - auto real_inf_or_nan = b_->CreateFDiv(EmitExtractReal(lhs_value), zero); - auto imag_inf_or_nan = b_->CreateFDiv(EmitExtractImag(lhs_value), zero); - return b_->CreateSelect( + auto oeq = FCmpOEQ(rhs_sum_sq, zero); + auto real_inf_or_nan = FDiv(EmitExtractReal(lhs_value), zero); + auto imag_inf_or_nan = FDiv(EmitExtractImag(lhs_value), zero); + return Select( oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan), - EmitComposeComplex( - op, - b_->CreateFDiv( - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))), - rhs_sum_sq), - b_->CreateFDiv( - b_->CreateFSub(b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value))), - rhs_sum_sq))); + EmitComposeComplex(op, + FDiv(FAdd(FMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), + rhs_sum_sq), + FDiv(FSub(FMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)), + FMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value))), + rhs_sum_sq))); } // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered // comparisons always return false when one of the operands is NaN, whereas @@ -833,21 +786,19 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( // unordered comparison. This makes x != y equivalent to !(x == y), and // matches C++'s semantics. case HloOpcode::kEq: - return b_->CreateAnd( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); + return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); case HloOpcode::kNe: - return b_->CreateOr( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); + return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); case HloOpcode::kPower: { // (a+bi)^(c+di) = @@ -859,45 +810,43 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( auto b = EmitExtractImag(lhs_value); auto c = EmitExtractReal(rhs_value); auto d = EmitExtractImag(rhs_value); - auto aa_p_bb = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b)); + auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b)); auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); - auto half_c = b_->CreateFMul(one_half, c); + auto half_c = FMul(one_half, c); TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, EmitPow(component_type, aa_p_bb, half_c)); - auto neg_d = b_->CreateFNeg(d); + auto neg_d = FNeg(d); TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); - auto neg_d_arg_lhs = b_->CreateFMul(neg_d, arg_lhs); + auto neg_d_arg_lhs = FMul(neg_d, arg_lhs); TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, EmitExp(component_type, neg_d_arg_lhs)); - auto coeff = b_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); - auto half_d = b_->CreateFMul(one_half, d); - auto q = b_->CreateFAdd(b_->CreateFMul(c, arg_lhs), - b_->CreateFMul(half_d, ln_aa_p_bb)); + auto half_d = FMul(one_half, d); + auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb)); TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); - return EmitComposeComplex(op, b_->CreateFMul(coeff, cos_q), - b_->CreateFMul(coeff, sin_q)); + return EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)); } default: return Unimplemented("binary complex op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + llvm::Value* rhs_value) { return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_); } llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + llvm::Value* rhs_value) { return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_); } StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, - llvm::Value* x) const { + llvm::Value* x) { if (prim_type != F32) { // TODO(b/34339814): Implement inverse erf for F64. return Unimplemented( @@ -910,9 +859,9 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, auto multiply_add = [&](tensorflow::gtl::ArraySlice coefficients, llvm::Value* w) { llvm::Value* p = getFloat(coefficients.front()); - coefficients.pop_front(); + coefficients.remove_prefix(1); for (float coefficient : coefficients) { - p = b_->CreateFAdd(b_->CreateFMul(p, w), getFloat(coefficient)); + p = FAdd(FMul(p, w), getFloat(coefficient)); } return p; }; @@ -932,25 +881,24 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration( module_, llvm::Intrinsic::log, {b_->getFloatTy()}); - llvm::Value* w = b_->CreateFNeg(b_->CreateCall( - logf_fn, {b_->CreateFMul(b_->CreateFSub(getFloat(1.0f), x), - b_->CreateFAdd(getFloat(1.0f), x))})); + llvm::Value* w = FNeg( + Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))})); llvm::Value* p_addr = llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_); llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_->CreateFCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_); + FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_); // Handle true BB. SetToFirstInsertPoint(if_data.true_block, b_); { - llvm::Value* lw = b_->CreateFSub(w, getFloat(2.5f)); + llvm::Value* lw = FSub(w, getFloat(2.5f)); tensorflow::gtl::ArraySlice lq{ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, -4.39150654e-06f, 0.00021858087f, -0.00125372503f, -0.00417768164f, 0.246640727f, 1.50140941f}; llvm::Value* p = multiply_add(lq, lw); - b_->CreateStore(p, p_addr); + Store(p, p_addr); } // Handle false BB. @@ -959,76 +907,73 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); - llvm::Value* gw = - b_->CreateFSub(b_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f)); + llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f)); tensorflow::gtl::ArraySlice gq{ -0.000200214257f, 0.000100950558f, 0.00134934322f, -0.00367342844f, 0.00573950773f, -0.0076224613f, 0.00943887047f, 1.00167406f, 2.83297682f}; llvm::Value* p = multiply_add(gq, gw); - b_->CreateStore(p, p_addr); + Store(p, p_addr); } SetToFirstInsertPoint(if_data.after_block, b_); - llvm::Value* p = b_->CreateLoad(p_addr); - return b_->CreateFMul(p, x); + llvm::Value* p = Load(p_addr); + return FMul(p, x); } -StatusOr ElementalIrEmitter::EmitErfcInv( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type, + llvm::Value* value) { // Compute erfcinv(value) by calculating erfinv(1.0 - value). auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); - return EmitErfInv(prim_type, b_->CreateFSub(one, value)); + return EmitErfInv(prim_type, FSub(one, value)); } StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { auto x = value; auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); auto negative_half = llvm::ConstantFP::get(type, -0.5); // When x is large, the naive evaluation of ln(x + 1) is more // accurate than the Taylor series. - TF_ASSIGN_OR_RETURN(auto for_large_x, - EmitLog(prim_type, b_->CreateFAdd(x, one))); + TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one))); // The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + …. - auto for_small_x = - b_->CreateFMul(b_->CreateFAdd(b_->CreateFMul(negative_half, x), one), x); + auto for_small_x = FMul(FAdd(FMul(negative_half, x), one), x); const auto kAntilogarithmIsSmallThreshold = 1e-4; auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); - auto x_is_small = b_->CreateFCmpOLT( + auto x_is_small = FCmpOLT( abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold)); - return b_->CreateSelect(x_is_small, for_small_x, for_large_x); + return Select(x_is_small, for_small_x, for_large_x); } StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitCos(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value}, {value->getType()}, b_); } StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { auto x = value; auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); @@ -1036,40 +981,40 @@ StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, // When the exponent is large, the naive evaluation of e^(x) - 1 is more // accurate than the Taylor series. TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value)); - auto for_large_x = b_->CreateFSub(exp_x, one); + auto for_large_x = FSub(exp_x, one); // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + …. // We want exp(x)-1 which is x + x^2/2 + x^3/6 + …. - auto x_squared = b_->CreateFAdd(x, x); - auto x_squared_over_two = b_->CreateFMul(x_squared, half); - auto for_small_x = b_->CreateFAdd(x, x_squared_over_two); + auto x_squared = FAdd(x, x); + auto x_squared_over_two = FMul(x_squared, half); + auto for_small_x = FAdd(x, x_squared_over_two); const auto kExponentIsSmallThreshold = 1e-5; auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); - auto x_is_small = b_->CreateFCmpOLT( - abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); - return b_->CreateSelect(x_is_small, for_small_x, for_large_x); + auto x_is_small = + FCmpOLT(abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); + return Select(x_is_small, for_small_x, for_large_x); } StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const { + llvm::Value* rhs) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs}, {lhs->getType()}, b_); } StatusOr ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const { + llvm::Value* rhs) { return Unimplemented("atan2"); } StatusOr ElementalIrEmitter::EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return Unimplemented("tanh"); } StatusOr ElementalIrEmitter::EmitReducePrecision( - const HloInstruction* hlo, llvm::Value* x) const { + const HloInstruction* hlo, llvm::Value* x) { if (hlo->operand(0)->shape().element_type() != F32) { return Unimplemented("reduce-precision only implemented for F32"); } @@ -1100,23 +1045,103 @@ static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b, return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value); } +llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) { + return llvm::ConstantInt::get(llvm::cast(type), 1); +} + +llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) { + return llvm::ConstantInt::get(llvm::cast(type), 0); +} + +llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) { + auto* integer_type = llvm::cast(type); + return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue( + integer_type->getBitWidth())); +} + +llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) { + auto* integer_type = llvm::cast(type); + return llvm::ConstantInt::get( + integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth())); +} + +llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) { + return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0)); +} + +llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs, + llvm::Value* rhs) { + return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())), + ICmpEQ(rhs, GetMinusOne(rhs->getType()))); +} + +llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs, + llvm::Value* rhs, + bool is_signed) { + // Integer division overflow behavior: + // + // X / 0 == -1 + // INT_SMIN /s -1 = INT_SMIN + + if (!is_signed) { + llvm::Value* udiv_is_unsafe = IsZero(rhs); + llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_div = UDiv(lhs, safe_rhs); + return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div); + } + + llvm::Value* has_zero_divisor = IsZero(rhs); + llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs); + llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor); + llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_div = SDiv(lhs, safe_rhs); + + return Select( + has_zero_divisor, GetMinusOne(lhs->getType()), + Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div)); +} + +llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs, + llvm::Value* rhs, + bool is_signed) { + // Integer remainder overflow behavior: + // + // X % 0 == X + // INT_SMIN %s -1 = 0 + + if (!is_signed) { + llvm::Value* urem_is_unsafe = IsZero(rhs); + llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_rem = URem(lhs, safe_rhs); + return Select(urem_is_unsafe, lhs, safe_rem); + } + + llvm::Value* has_zero_divisor = IsZero(rhs); + llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs); + llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor); + llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_rem = SRem(lhs, safe_rhs); + + return Select( + has_zero_divisor, lhs, + Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem)); +} + StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const { + bool is_signed) { switch (op->opcode()) { // TODO(jingyue): add the "nsw" attribute for signed types. case HloOpcode::kAdd: - return b_->CreateAdd(lhs_value, rhs_value); + return Add(lhs_value, rhs_value); case HloOpcode::kSubtract: - return b_->CreateSub(lhs_value, rhs_value); + return Sub(lhs_value, rhs_value); case HloOpcode::kMultiply: - return b_->CreateMul(lhs_value, rhs_value); + return Mul(lhs_value, rhs_value); case HloOpcode::kDivide: - return is_signed ? b_->CreateSDiv(lhs_value, rhs_value) - : b_->CreateUDiv(lhs_value, rhs_value); + return EmitIntegerDivide(lhs_value, rhs_value, is_signed); case HloOpcode::kRemainder: - return is_signed ? b_->CreateSRem(lhs_value, rhs_value) - : b_->CreateURem(lhs_value, rhs_value); + return EmitIntegerRemainder(lhs_value, rhs_value, is_signed); case HloOpcode::kEq: return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value, rhs_value, b_); @@ -1144,11 +1169,11 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( case HloOpcode::kMaximum: return EmitIntegralMax(lhs_value, rhs_value, is_signed); case HloOpcode::kAnd: - return b_->CreateAnd(lhs_value, rhs_value); + return And(lhs_value, rhs_value); case HloOpcode::kOr: - return b_->CreateOr(lhs_value, rhs_value); + return Or(lhs_value, rhs_value); case HloOpcode::kXor: - return b_->CreateXor(lhs_value, rhs_value); + return Xor(lhs_value, rhs_value); // Shifting out bits >= the number of bits in the type being shifted // produces a poison value in LLVM which is basically "deferred undefined @@ -1157,43 +1182,43 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( // UB. case HloOpcode::kShiftRightArithmetic: return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, - b_->CreateAShr(lhs_value, rhs_value), + AShr(lhs_value, rhs_value), /*saturate_to_sign_bit=*/true); case HloOpcode::kShiftLeft: return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, - b_->CreateShl(lhs_value, rhs_value), + Shl(lhs_value, rhs_value), /*saturate_to_sign_bit=*/false); case HloOpcode::kShiftRightLogical: return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, - b_->CreateLShr(lhs_value, rhs_value), + LShr(lhs_value, rhs_value), /*saturate_to_sign_bit=*/false); default: return Unimplemented("binary integer op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const { - return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE - : llvm::ICmpInst::ICMP_UGE, - lhs_value, rhs_value), - lhs_value, rhs_value); + bool is_signed) { + return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE + : llvm::ICmpInst::ICMP_UGE, + lhs_value, rhs_value), + lhs_value, rhs_value); } llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const { - return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE - : llvm::ICmpInst::ICMP_ULE, - lhs_value, rhs_value), - lhs_value, rhs_value); + bool is_signed) { + return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE + : llvm::ICmpInst::ICMP_ULE, + lhs_value, rhs_value), + lhs_value, rhs_value); } llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, - int64 operand_no) const { + int64 operand_no) { CHECK(hlo.IsElementwise()) << "HLO " << hlo.ToString() << " is not elementwise."; @@ -1234,7 +1259,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( StatusOr ElementalIrEmitter::ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const { + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) { TF_ASSIGN_OR_RETURN(llvm::Value * a_or_mean, operand_to_generator.at(hlo->operand(0))(index)); TF_ASSIGN_OR_RETURN(llvm::Value * b_or_sigma, @@ -1252,17 +1277,17 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( // Perform the division using the float type with the same number of bits // as the raw value to avoid overflow. if (raw_value_size_in_bits == 32) { - elem_value = b_->CreateUIToFP(elem_value, b_->getFloatTy()); - elem_value = b_->CreateFDiv( - elem_value, llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); + elem_value = UIToFP(elem_value, b_->getFloatTy()); + elem_value = FDiv(elem_value, + llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); } else { - elem_value = b_->CreateUIToFP(elem_value, b_->getDoubleTy()); - elem_value = b_->CreateFDiv( + elem_value = UIToFP(elem_value, b_->getDoubleTy()); + elem_value = FDiv( elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64))); } if (elem_ir_ty != elem_value->getType()) { - elem_value = b_->CreateFPTrunc(elem_value, elem_ir_ty); + elem_value = FPTrunc(elem_value, elem_ir_ty); } } @@ -1270,9 +1295,7 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( switch (hlo->random_distribution()) { case RNG_UNIFORM: { if (elem_ir_ty->isFloatingPointTy()) { - return b_->CreateFAdd( - b_->CreateFMul(b_->CreateFSub(b_or_sigma, a_or_mean), elem_value), - a_or_mean); + return FAdd(FMul(FSub(b_or_sigma, a_or_mean), elem_value), a_or_mean); } else { // To generate a uniform random value in [a, b) from a raw random sample // in range [0, 2^N), we let range = b - a and return @@ -1285,22 +1308,21 @@ StatusOr ElementalIrEmitter::ConvertValueForDistribution( // the same cost as if the whole warp were to re-sample. So an // efficient re-sampling implementation on GPU would need to do // nontrivial work to share entropy between threads in the warp. - auto range = b_->CreateSub(b_or_sigma, a_or_mean); - return b_->CreateAdd(a_or_mean, b_->CreateURem(elem_value, range)); + auto range = Sub(b_or_sigma, a_or_mean); + return Add(a_or_mean, URem(elem_value, range)); } } case RNG_NORMAL: { TF_ASSIGN_OR_RETURN( llvm::Value * r, - EmitErfcInv(elem_prim_ty, - b_->CreateFMul(llvm::ConstantFP::get(elem_ir_ty, 2.0), - elem_value))); - return b_->CreateFAdd(b_->CreateFMul(r, b_or_sigma), a_or_mean); + EmitErfcInv(elem_prim_ty, FMul(llvm::ConstantFP::get(elem_ir_ty, 2.0), + elem_value))); + return FAdd(FMul(r, b_or_sigma), a_or_mean); } default: return InvalidArgument( "unhandled distribution %s", - RandomDistribution_Name(hlo->random_distribution()).c_str()); + RandomDistribution_Name(hlo->random_distribution())); } } @@ -1415,8 +1437,7 @@ std::array CalculateSampleValues( // Precondition: the RNG instruction is not fused. llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( const HloInstruction* hlo, - const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) - const { + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) { VLOG(3) << "Using philox RNG algorithm"; CHECK(!hlo->IsFused()); // A random number generated by the per module random number generator. @@ -1439,7 +1460,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( // Load the global state variable for the Philox RNG algorithm. llvm::GlobalVariable* rng_state_ptr = llvm_ir::GetOrCreateVariableForPhiloxRngState(module_, b_); - llvm::Value* rng_state = b_->CreateLoad(rng_state_ptr, "rng_state_value"); + llvm::Value* rng_state = Load(rng_state_ptr, "rng_state_value"); // Build and return the elemental IR generator to generate a random value for // the element corresponding to the current thread. @@ -1465,8 +1486,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( // element within the sample. llvm::Value* elems_per_sample_value = llvm::ConstantInt::get(index_ty, elems_per_sample); - llvm::Value* sample_idx = b_->CreateUDiv(elem_idx, elems_per_sample_value); - llvm::Value* elem_offset = b_->CreateURem(elem_idx, elems_per_sample_value); + llvm::Value* sample_idx = UDiv(elem_idx, elems_per_sample_value); + llvm::Value* elem_offset = URem(elem_idx, elems_per_sample_value); std::array counter_values = CalculateSampleValues( sample_idx, hlo_random_value, global_random_number, rng_state, b_); @@ -1474,18 +1495,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( // Store the four counter_values into the sample_address alloca so we can // load the elem_offset'th one below. for (int idx = 0; idx < 4; ++idx) { - b_->CreateStore(counter_values[idx], - b_->CreateInBoundsGEP(sample_address, b_->getInt32(idx))); + Store(counter_values[idx], + InBoundsGEP(sample_address, b_->getInt32(idx))); } llvm::Type* int64_ty = b_->getInt64Ty(); CHECK(elems_per_sample == 2 || elems_per_sample == 4); llvm::Type* raw_value_ty = elems_per_sample == 2 ? int64_ty : int32_ty; // Retrieve the raw value for the current element from the current sample. - llvm::Value* raw_elem_value = b_->CreateLoad( - b_->CreateInBoundsGEP( - b_->CreatePointerCast(sample_address, raw_value_ty->getPointerTo()), - elem_offset), + llvm::Value* raw_elem_value = Load( + InBoundsGEP(PointerCast(sample_address, raw_value_ty->getPointerTo()), + elem_offset), "raw_elem_value"); return ConvertValueForDistribution(hlo, operand_to_generator, index, @@ -1496,7 +1516,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( StatusOr ElementalIrEmitter::EmitElementalSelect( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, operand_to_generator.at(hlo->operand(0))( ElementwiseSourceIndex(index, *hlo, 0))); @@ -1506,14 +1526,14 @@ StatusOr ElementalIrEmitter::EmitElementalSelect( TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, operand_to_generator.at(hlo->operand(2))( ElementwiseSourceIndex(index, *hlo, 2))); - return b_->CreateSelect(b_->CreateTrunc(pred_value, b_->getInt1Ty()), - on_true_value, on_false_value); + return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value, + on_false_value); } StatusOr ElementalIrEmitter::EmitElementalClamp( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { TF_ASSIGN_OR_RETURN(llvm::Value * min_value, operand_to_generator.at(hlo->operand(0))( ElementwiseSourceIndex(index, *hlo, 0))); @@ -1532,14 +1552,14 @@ StatusOr ElementalIrEmitter::EmitElementalClamp( max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed); } else { return Unimplemented("Clamp unimplemented for %s", - PrimitiveType_Name(prim_type).c_str()); + PrimitiveType_Name(prim_type)); } } StatusOr ElementalIrEmitter::EmitElementalConcatenate( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& target_index) const { + const llvm_ir::IrArray::Index& target_index) { const int64 concat_dim = hlo->dimensions(0); auto source_index = target_index; @@ -1561,9 +1581,9 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( } llvm_ir::SetToFirstInsertPoint(exit_block, b_); - llvm::PHINode* output = b_->CreatePHI( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), - hlo->operands().size()); + llvm::PHINode* output = + PHI(llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + hlo->operands().size()); auto prior_insert_point = b_->GetInsertPoint(); b_->SetInsertPoint(init_block); @@ -1578,9 +1598,8 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( auto concat_dim_size = llvm::ConstantInt::get(source_index[concat_dim]->getType(), operand->shape().dimensions(concat_dim)); - b_->CreateCondBr( - b_->CreateICmpULT(source_index[concat_dim], concat_dim_size), - true_block, false_block); + CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), true_block, + false_block); // Create the terminator of the true block before calling operand // generators, because they require non-degenerate basic blocks. @@ -1593,11 +1612,10 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( // Subtract the size of the concat dimension of the current operand // from the source index. b_->SetInsertPoint(false_block); - source_index[concat_dim] = - b_->CreateSub(source_index[concat_dim], concat_dim_size); + source_index[concat_dim] = Sub(source_index[concat_dim], concat_dim_size); } - b_->CreateUnreachable(); + Unreachable(); b_->SetInsertPoint(exit_block, prior_insert_point); return output; } @@ -1605,7 +1623,7 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { // Emit IR to read dynamic start indices from hlo->operand(1). const HloInstruction* input_hlo = hlo->operand(0); const int64 rank = ShapeUtil::Rank(input_hlo->shape()); @@ -1622,7 +1640,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( // Clamp the start index so that the sliced portion fits in the operand: // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) - start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); + start_index_value = SExtOrTrunc(start_index_value, index_type); int64 largest_valid_start_index = input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i); CHECK_GE(largest_valid_start_index, 0); @@ -1642,7 +1660,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( for (int64 i = 0; i < rank; ++i) { // Emit IR which computes: // input_index = start_index + offset_index - input_index[i] = b_->CreateAdd(slice_start_index[i], index[i]); + input_index[i] = Add(slice_start_index[i], index[i]); } return operand_to_generator.at(input_hlo)(input_index); } @@ -1650,7 +1668,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( StatusOr ElementalIrEmitter::EmitElementalGather( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { const Shape& operand_shape = hlo->operand(0)->shape(); const Shape& indices_shape = hlo->operand(1)->shape(); const Shape& output_shape = hlo->shape(); @@ -1699,7 +1717,7 @@ StatusOr ElementalIrEmitter::EmitElementalGather( auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) { llvm::Value* gather_dim_component_extended = - b_->CreateSExtOrTrunc(index_component, index_type); + SExtOrTrunc(index_component, index_type); int64 operand_dim = dim_numbers.start_index_map(dim); int64 output_dim = operand_to_output_dim[operand_dim]; // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim. @@ -1723,8 +1741,8 @@ StatusOr ElementalIrEmitter::EmitElementalGather( gather_dim_component_extended, is_signed), is_signed); - operand_index[operand_dim] = b_->CreateAdd( - operand_index[operand_dim], gather_dim_component_extended_inbound); + operand_index[operand_dim] = + Add(operand_index[operand_dim], gather_dim_component_extended_inbound); }; if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { @@ -1748,7 +1766,7 @@ StatusOr ElementalIrEmitter::EmitElementalGather( StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { const HloInstruction* input_hlo = hlo->operand(0); const HloInstruction* update_hlo = hlo->operand(1); const HloInstruction* start_hlo = hlo->operand(2); @@ -1771,7 +1789,7 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) - start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); + start_index_value = SExtOrTrunc(start_index_value, index_type); llvm::Value* update_dim_size = index_typed_const(update_hlo->shape().dimensions(i)); int64 largest_valid_start_index = @@ -1787,14 +1805,14 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( start_index_value->setName( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); slice_start_index[i] = start_index_value; - slice_limit_index[i] = b_->CreateAdd(slice_start_index[i], update_dim_size); - - slice_intersection = b_->CreateAnd( - slice_intersection, b_->CreateICmpSGE(index[i], slice_start_index[i]), - "slice_intersection"); - slice_intersection = b_->CreateAnd( - slice_intersection, b_->CreateICmpSLT(index[i], slice_limit_index[i]), - "slice_intersection"); + slice_limit_index[i] = Add(slice_start_index[i], update_dim_size); + + slice_intersection = + And(slice_intersection, ICmpSGE(index[i], slice_start_index[i]), + "slice_intersection"); + slice_intersection = + And(slice_intersection, ICmpSLT(index[i], slice_limit_index[i]), + "slice_intersection"); } // Emit: @@ -1811,26 +1829,26 @@ StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Compute update index for intersection case. llvm_ir::IrArray::Index update_index(index.GetType(), rank); for (int64 i = 0; i < rank; ++i) { - update_index[i] = b_->CreateSub(index[i], slice_start_index[i]); + update_index[i] = Sub(index[i], slice_start_index[i]); } TF_ASSIGN_OR_RETURN(llvm::Value * true_value, operand_to_generator.at(update_hlo)(update_index)); - b_->CreateStore(true_value, ret_value_addr); + Store(true_value, ret_value_addr); // Handle false BB (return data from 'input') SetToFirstInsertPoint(if_data.false_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * false_value, operand_to_generator.at(input_hlo)(index)); - b_->CreateStore(false_value, ret_value_addr); + Store(false_value, ret_value_addr); SetToFirstInsertPoint(if_data.after_block, b_); - return b_->CreateLoad(ret_value_addr); + return Load(ret_value_addr); } StatusOr ElementalIrEmitter::EmitElementalPad( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& padded_index) const { + const llvm_ir::IrArray::Index& padded_index) { auto index = padded_index; llvm::Value* in_bounds = b_->getTrue(); for (size_t i = 0; i < index.size(); ++i) { @@ -1838,26 +1856,22 @@ StatusOr ElementalIrEmitter::EmitElementalPad( return llvm::ConstantInt::get(index[i]->getType(), n); }; const auto& pad_dim = hlo->padding_config().dimensions(i); - index[i] = - b_->CreateSub(index[i], index_typed_const(pad_dim.edge_padding_low())); - in_bounds = b_->CreateAnd(in_bounds, - b_->CreateICmpSGE(index[i], index_typed_const(0)), - "in_bounds"); - in_bounds = b_->CreateAnd( + index[i] = Sub(index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = + And(in_bounds, ICmpSGE(index[i], index_typed_const(0)), "in_bounds"); + in_bounds = And( in_bounds, - b_->CreateICmpEQ( + ICmpEQ( index_typed_const(0), - b_->CreateURem(index[i], - index_typed_const(pad_dim.interior_padding() + 1))), - "in_bounds"); - index[i] = b_->CreateSDiv( - index[i], index_typed_const(pad_dim.interior_padding() + 1)); - in_bounds = b_->CreateAnd( - in_bounds, - b_->CreateICmpSLT( - index[i], - index_typed_const(hlo->operand(0)->shape().dimensions(i))), + URem(index[i], index_typed_const(pad_dim.interior_padding() + 1))), "in_bounds"); + index[i] = + SDiv(index[i], index_typed_const(pad_dim.interior_padding() + 1)); + in_bounds = + And(in_bounds, + ICmpSLT(index[i], + index_typed_const(hlo->operand(0)->shape().dimensions(i))), + "in_bounds"); } // if (in_bounds) { @@ -1873,26 +1887,26 @@ StatusOr ElementalIrEmitter::EmitElementalPad( SetToFirstInsertPoint(if_data.true_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, operand_to_generator.at(hlo->operand(0))(index)); - b_->CreateStore(operand_value, ret_value_addr); + Store(operand_value, ret_value_addr); SetToFirstInsertPoint(if_data.false_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, operand_to_generator.at(hlo->operand(1))( IrArray::Index(index.GetType()))); - b_->CreateStore(padding_value, ret_value_addr); + Store(padding_value, ret_value_addr); SetToFirstInsertPoint(if_data.after_block, b_); // Don't create phi(operand_value, padding_value) here, because invoking // operand_to_generator may create new basic blocks, making the parent // of operand_value or padding_value no longer a predecessor of // if_data.after_block. - return b_->CreateLoad(ret_value_addr); + return Load(ret_value_addr); } StatusOr ElementalIrEmitter::EmitElementalDot( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& dot_result_index) const { + const llvm_ir::IrArray::Index& dot_result_index) { auto lhs_generator = operand_to_generator.at(hlo->operand(0)); auto rhs_generator = operand_to_generator.at(hlo->operand(1)); @@ -1920,8 +1934,7 @@ StatusOr ElementalIrEmitter::EmitElementalDot( llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_); - b_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm), - accumulator_alloca); + Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca); SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_); @@ -1943,42 +1956,37 @@ StatusOr ElementalIrEmitter::EmitElementalDot( } rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue()); - llvm::Value* current_accumulator = b_->CreateLoad(accumulator_alloca); + llvm::Value* current_accumulator = Load(accumulator_alloca); TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); llvm::Value* next_accumulator; if (primitive_util::IsComplexType(primitive_type)) { - llvm::Value* product_real = b_->CreateFSub( - b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); - llvm::Value* product_imag = b_->CreateFAdd( - b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))); - next_accumulator = b_->CreateInsertValue( + llvm::Value* product_real = + FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); + llvm::Value* product_imag = + FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))); + next_accumulator = InsertValue( current_accumulator, - b_->CreateFAdd(EmitExtractReal(current_accumulator), product_real), - {0}); - next_accumulator = b_->CreateInsertValue( + FAdd(EmitExtractReal(current_accumulator), product_real), {0}); + next_accumulator = InsertValue( next_accumulator, - b_->CreateFAdd(EmitExtractImag(current_accumulator), product_imag), - {1}); + FAdd(EmitExtractImag(current_accumulator), product_imag), {1}); } else if (primitive_util::IsFloatingPointType(primitive_type)) { - next_accumulator = b_->CreateFAdd(current_accumulator, - b_->CreateFMul(lhs_value, rhs_value)); + next_accumulator = FAdd(current_accumulator, FMul(lhs_value, rhs_value)); } else { - next_accumulator = - b_->CreateAdd(current_accumulator, b_->CreateMul(lhs_value, rhs_value)); + next_accumulator = Add(current_accumulator, Mul(lhs_value, rhs_value)); } - b_->CreateStore(next_accumulator, accumulator_alloca); + Store(next_accumulator, accumulator_alloca); SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_); - return b_->CreateLoad(accumulator_alloca); + return Load(accumulator_alloca); } llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, - const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) - const { + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) { switch (hlo->opcode()) { case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: @@ -2072,10 +2080,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); auto source_index = target_index; for (int64 dim : hlo->dimensions()) { - source_index[dim] = b_->CreateSub( - llvm::ConstantInt::get(target_index[dim]->getType(), - hlo->shape().dimensions(dim) - 1), - target_index[dim]); + source_index[dim] = + Sub(llvm::ConstantInt::get(target_index[dim]->getType(), + hlo->shape().dimensions(dim) - 1), + target_index[dim]); } return operand_to_generator.at(operand)(source_index); }; @@ -2089,6 +2097,50 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(), hlo->dimensions(), b_)); }; + case HloOpcode::kIota: + return [this, hlo]( + const IrArray::Index& target_index) -> StatusOr { + auto* iota = Cast(hlo); + PrimitiveType element_type = iota->shape().element_type(); + IrArray::Index elem_index = + ShapeUtil::Rank(iota->shape()) > 1 + ? target_index.SourceIndexOfBroadcast( + iota->shape(), + ShapeUtil::MakeShapeWithDescendingLayout( + element_type, + {iota->shape().dimensions(iota->iota_dimension())}), + {iota->iota_dimension()}, b_) + : target_index; + llvm::Value* elem_index_linear = elem_index.linear(); + if (elem_index_linear == nullptr) { + std::vector iota_bound = { + iota->shape().dimensions(iota->iota_dimension())}; + elem_index_linear = elem_index.Linearize(iota_bound, b_); + } + if (ShapeUtil::ElementIsIntegral(iota->shape())) { + return b_->CreateIntCast( + elem_index_linear, + llvm_ir::PrimitiveTypeToIrType(element_type, module_), + /*isSigned=*/false); + } else { + TF_RET_CHECK(ShapeUtil::ElementIsFloating(iota->shape())) + << element_type; + llvm::Type* float_ir_type; + if (element_type == BF16) { + float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_); + } else { + float_ir_type = + llvm_ir::PrimitiveTypeToIrType(element_type, module_); + } + llvm::Value* float_val = + b_->CreateUIToFP(elem_index_linear, float_ir_type); + if (element_type == BF16) { + return EmitF32ToBF16(float_val, b_); + } else { + return float_val; + } + } + }; case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { @@ -2154,28 +2206,28 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( default: return [hlo](const IrArray::Index& index) { return Unimplemented("Unhandled opcode for elemental IR emission: %s", - HloOpcodeString(hlo->opcode()).c_str()); + HloOpcodeString(hlo->opcode())); }; } } -llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const { - return b_->CreateExtractValue(value, {0}); +llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) { + return ExtractValue(value, {0}); } -llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const { - return b_->CreateExtractValue(value, {1}); +llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) { + return ExtractValue(value, {1}); } llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, llvm::Value* real, - llvm::Value* imag) const { + llvm::Value* imag) { auto cplx_type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto complex = b_->CreateInsertValue( - llvm::ConstantAggregateZero::get(cplx_type), real, {0}); + auto complex = + InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0}); if (imag != nullptr) { - complex = b_->CreateInsertValue(complex, imag, {1}); + complex = InsertValue(complex, imag, {1}); } return complex; } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 1598a4dd85632cfa9835a81a21eddff3e57bfa1f..d3e2acaabd4f602171def70ccd3d4fd5adce0d0d 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -23,12 +23,13 @@ limitations under the License. #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { -class ElementalIrEmitter { +class ElementalIrEmitter : public IrBuilderMixin { public: using HloToElementGeneratorMap = std::unordered_map; @@ -40,100 +41,114 @@ class ElementalIrEmitter { virtual ~ElementalIrEmitter() = default; virtual StatusOr EmitUnaryOp(const HloInstruction* op, - llvm::Value* operand_value) const; + llvm::Value* operand_value); virtual StatusOr EmitBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); // Returns a function to generate an element of the output of `hlo`, given a // map of functions to generate elements of its operands. virtual llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const; + const HloToElementGeneratorMap& operand_to_generator); - llvm::IRBuilder<>* b() const { return b_; } - llvm::Module* module() const { return module_; } + llvm::IRBuilder<>* b() { return b_; } + + // builder() is for IrBuilderMixin. + llvm::IRBuilder<>* builder() { return b_; } + + llvm::Module* module() { return module_; } protected: - virtual StatusOr EmitIntegerUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const; + virtual StatusOr EmitIntegerUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); + + virtual StatusOr EmitFloatUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); - virtual StatusOr EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const; + virtual StatusOr EmitComplexUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); - virtual StatusOr EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const; + llvm::Value* IsZero(llvm::Value* v); + llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* GetZero(llvm::Type* type); + llvm::Value* GetOne(llvm::Type* type); + llvm::Value* GetIntSMin(llvm::Type* type); + llvm::Value* GetMinusOne(llvm::Type* type); + + llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs, + bool is_signed); + llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs, + bool is_signed); virtual StatusOr EmitIntegerBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const; + bool is_signed); - virtual StatusOr EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + virtual StatusOr EmitFloatBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); - virtual StatusOr EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + virtual StatusOr EmitComplexBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const; + bool is_signed); llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const; + bool is_signed); virtual StatusOr EmitErfInv(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitErfcInv(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitAtan2(PrimitiveType prim_type, - llvm::Value* lhs, - llvm::Value* rhs) const; + llvm::Value* lhs, llvm::Value* rhs); virtual StatusOr EmitLog(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitSin(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitCos(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitExp(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitPow(PrimitiveType prim_type, - llvm::Value* lhs, - llvm::Value* rhs) const; + llvm::Value* lhs, llvm::Value* rhs); virtual StatusOr EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, - llvm::Value* x) const; + llvm::Value* x); - virtual llvm::Value* EmitExtractReal(llvm::Value* value) const; - virtual llvm::Value* EmitExtractImag(llvm::Value* value) const; + virtual llvm::Value* EmitExtractReal(llvm::Value* value); + virtual llvm::Value* EmitExtractImag(llvm::Value* value); // Composes a complex struct. imag may be nullptr for simple cast operations. llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, - llvm::Value* imag) const; + llvm::Value* imag); // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and // the target array index, computes the source array index of its @@ -142,50 +157,50 @@ class ElementalIrEmitter { // Precondition: `hlo` is an elementwise op. llvm_ir::IrArray::Index ElementwiseSourceIndex( const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, - int64 operand_no) const; + int64 operand_no); // Identifier of the thread unique among all threads on the device - virtual llvm::Value* EmitThreadId() const { return b_->getIntN(128, 0); } + virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); } StatusOr EmitElementalSelect( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalClamp( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalConcatenate( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& target_index) const; + const llvm_ir::IrArray::Index& target_index); StatusOr EmitElementalDynamicSlice( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalGather( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalDynamicUpdateSlice( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr EmitElementalPad( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& padded_index) const; + const llvm_ir::IrArray::Index& padded_index); StatusOr EmitElementalDot( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& dot_result_index) const; + const llvm_ir::IrArray::Index& dot_result_index); llvm::IRBuilder<>* const b_; @@ -200,13 +215,13 @@ class ElementalIrEmitter { // random number generation algorithm. llvm_ir::ElementGenerator MakePhiloxRngElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const; + const HloToElementGeneratorMap& operand_to_generator); // Converts the raw value generated by a random number generation algorithm // to the distribution requested by the RNG HloInstruction. StatusOr ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const; + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 1c9f396b68fa20a03986d81d642d1726b26cd0dc..78edf918a4de633be31bd69e93fee940e539e392 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/proto_serialization.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" using tensorflow::gtl::ArraySlice; @@ -155,9 +155,9 @@ Status Executable::DumpHloSnapshot() { const string& directory_path = module_config().debug_options().xla_dump_executions_to(); const auto& module = hlo_snapshot_->hlo().hlo_module(); - string filename = tensorflow::strings::Printf( - "computation_%lld__%s__execution_%lld", module.id(), - module.entry_computation_name().c_str(), ++execution_count_); + string filename = + absl::StrFormat("computation_%d__%s__execution_%d", module.id(), + module.entry_computation_name(), ++execution_count_); return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_); } diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 70a78c8a2b6f3cf360ca2ac7255f8dc35235125e..997db7c058af6da8ecff399769b85b803e2e5785 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -66,7 +66,7 @@ Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { - return NotFound("no execution record for execution handle: %lld", + return NotFound("no execution record for execution handle: %d", handle.handle()); } handle_to_execution_.erase(handle.handle()); @@ -78,7 +78,7 @@ StatusOr ExecutionTracker::Resolve( tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { - return NotFound("no execution record for execution handle: %lld", + return NotFound("no execution record for execution handle: %d", handle.handle()); } return it->second.get(); diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h index d3efab3614912e4b0c2c8aa3b80277c326382ed0..3cccec9862e0f92df478006939552099868121b9 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.h +++ b/tensorflow/compiler/xla/service/flatten_call_graph.h @@ -28,7 +28,7 @@ namespace xla { // points-to analysis (see b/36865746 for details). class FlattenCallGraph : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "flatten-call-graph"; } + absl::string_view name() const override { return "flatten-call-graph"; } // Duplicates computations called from multiple call- or while-nodes to // flatten the call graph. diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index d889fd8e88ed4008749c116314e9a0c54e6fa63d..3f1a881372174bd775efc17631b3287956fef66a 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -323,7 +323,7 @@ StatusOr GatherExpander::ExpandGather( return Unimplemented( "Gather operations with more than 2147483647 gather indices are not " "supported. This error occurred for %s.", - gather_instr->ToString().c_str()); + gather_instr->ToString()); } TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index c1fc8574da99fff223c7dbb570b4533f76905b9a..7bd9ea598417a931d2df507d472c6a60be05e0bc 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -25,7 +25,7 @@ namespace xla { // nevertheless have a minimum level of support. class GatherExpander : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "gather_expander"; } + absl::string_view name() const override { return "gather_expander"; } StatusOr Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index fbef487ac8095ef1c4143bf3f46cbe85a1343422..d6e943634814fd7fd96494d7a2d9e01f685885ff 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -57,6 +57,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -110,6 +111,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -129,6 +131,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -174,6 +177,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:kernel_tiling", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", @@ -186,6 +190,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@llvm//:core", "@llvm//:support", @@ -231,6 +236,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:math_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:support", ], @@ -347,6 +353,8 @@ cc_library( "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", ], ) @@ -384,6 +392,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", ], ) @@ -402,6 +412,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -496,6 +507,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -527,6 +539,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -687,6 +700,7 @@ cc_library( "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@llvm//:core", ], @@ -775,6 +789,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_absl//absl/strings", ], ) @@ -809,6 +824,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -888,9 +904,8 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_runner", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index e208ad61e331ecac12fe128359da7585a2a3a7b4..86af83b6b975e3fda9f6dd0b62866ced9e8f1b5f 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -62,7 +62,7 @@ StatusOr> BufferAllocations::Builder::Build( if (reinterpret_cast(address.opaque()) % expected_alignment != 0) { return InternalError( - "Address of registered buffer %lld must be a multiple of %llx, but " + "Address of registered buffer %d must be a multiple of %x, but " "was %p", i, kEntryParameterAlignBytes, address.opaque()); } @@ -83,7 +83,7 @@ StatusOr> BufferAllocations::Builder::Build( 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " - "multiple of %llx, but was %p", + "multiple of %x, but was %p", kXlaAllocatedBufferAlignBytes, buffer.opaque()); } // We do manual memory management within BufferAllocations. Be sure not diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 6a285a6b989b29428fc15fd6aef29110577c226e..13c83c9199fb1bbd8b00dbd601afcb677f92bbee 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include +#include "absl/strings/str_replace.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace gpu { @@ -74,9 +74,8 @@ ENTRY MaxDifference { %error = f32[SIZE] divide(%sub_abs, %denominator) ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32 })"; - auto size_string = std::to_string(num_elements); - return tensorflow::str_util::StringReplace( - kF16CompHloText, "SIZE", {size_string.data(), size_string.size()}, true); + return absl::StrReplaceAll(kF16CompHloText, + {{"SIZE", absl::StrCat(num_elements)}}); } StatusOr F16BufferComparator::Create( @@ -125,7 +124,7 @@ StatusOr F16BufferComparator::Create( StatusOr F16BufferComparator::CompareEqualImpl( se::DeviceMemory test_buffer) { if (ref_buffer_.root_buffer().size() != test_buffer.size()) { - return InternalError("Mismatched buffer size: %lld vs %lld", + return InternalError("Mismatched buffer size: %d vs %d", ref_buffer_.root_buffer().size(), test_buffer.size()); } diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 8b0426aa27fa3fbc7225dda81cef17e543f1cf28..9ed523998bf07567133fdac0e40b12b8ce4ea3b0 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -59,7 +59,7 @@ Status ConditionalThunk::ExecuteOnStream( Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to retrieve predicate value on stream %p: %s.", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } // Execute the true or the false computation depending on the value of the diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 7833a4077e6c6ee4960665f37fb01a35530fd302..eea31f3de1029f8ddfeedf67f006e638b7a7d683 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -17,12 +17,11 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h index e09cde9abf85454c7a020566cd8c2671ae12ffc3..6e2e330edd4beabe0b395f05b80d57612d63f110 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h @@ -54,9 +54,7 @@ namespace gpu { // BatchNormRewriter. class CudnnBatchNormRewriter : public HloPassInterface { public: - tensorflow::StringPiece name() const override { - return "cudnn_batchnorm_rewriter"; - } + absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index 7b172812c36bb141787ef3a9285d6f7ce13e343b..bc3c6f72f6799f84169748465d62c3f2a306d5fc 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -17,12 +17,11 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 5a8fc76e85db02d08ea0fb24472b9d6645060971..dbdf8e7a0e959ea05e98a006464b66cfb2fa9f58 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" @@ -21,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" namespace xla { @@ -59,8 +60,8 @@ StatusOr> ScratchAllocator::AllocateBytes( if (byte_size > GetMemoryLimitInBytes(stream)) { return se::port::Status( se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size, GetMemoryLimitInBytes(stream))); } @@ -128,14 +129,14 @@ std::vector GetAlgorithms(CudnnConvKind kind, string AlgorithmToString(const AlgorithmDesc& algo) { if (algo.tensor_ops_enabled()) { - return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); + return absl::StrCat(algo.algo_id(), "+TC"); } - return tensorflow::strings::StrCat(algo.algo_id()); + return absl::StrCat(algo.algo_id()); } string NumBytesToString(int64 bytes) { - return tensorflow::strings::StrCat( - tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)"); + return absl::StrCat(tensorflow::strings::HumanReadableNumBytes(bytes), " (", + bytes, "B)"); } // Acquires a process-global lock on the device pointed to by the given @@ -361,7 +362,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( return InternalError( "All algorithms tried for convolution %s failed. Falling back to " "default algorithm.", - instr->ToString().c_str()); + instr->ToString()); } StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index 472de2ff0f8b0253ca380db94d461046fb3c2fb6..f76d273e8c641dfbdbba38eb161ab8a00a19e1f8 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -39,7 +39,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { Compiler* compiler) : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "cudnn-convolution-algorithm-picker"; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h index 0c0578d88840fed1d77f7456c9acef27dec380f5..fbe7e9849458e9d52be15b3f5610479ab68ffa4c 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h @@ -26,7 +26,7 @@ namespace gpu { // backwards-input convolutions into CustomCall HLOs that call into cuDNN. class CudnnConvolutionRewriter : public HloPassInterface { public: - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "cudnn-convolution-rewriter"; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 7b0d9e53d60dda620714b3443b627405e562b353..07b96fbd3f008143d322f9228e3700458d65a1b6 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -56,7 +57,7 @@ class ScratchBufAllocator : public se::ScratchAllocator { "Can't allocate twice from a ScratchBufAllocator."); } if (byte_size > scratch_.size()) { - return se::port::InternalError(tensorflow::strings::StrCat( + return se::port::InternalError(absl::StrCat( "Can't allocate ", byte_size, " bytes from a ScratchBufAllocator of size ", scratch_.size())); } @@ -196,8 +197,8 @@ Status RunCudnnConvolution( if (!stream->ok()) { return InternalError( - "Unable to launch convolution with type %s and algorithm (%lld, %lld)", - CudnnConvKindToString(kind).c_str(), algorithm.algorithm().algo_id(), + "Unable to launch convolution with type %s and algorithm (%d, %d)", + CudnnConvKindToString(kind), algorithm.algorithm().algo_id(), algorithm.algorithm_no_scratch().algo_id()); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 9b6de115ad7e7f87e431f839c1690858f4bce3fd..57a3a43a6fa08e958ed041e2e00c630195781881 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -23,6 +23,8 @@ limitations under the License. #include "tensorflow/core/platform/types.h" // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" @@ -43,16 +45,14 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace gpu { +using absl::StrAppend; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; -using tensorflow::strings::StrAppend; namespace { // Returns whether operand is a floating-point literal with the given value. @@ -77,7 +77,7 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( const string& callee_name, tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const { + PrimitiveType output_type) { // The libdevice math functions differentiate between "double" and "float" by // appending an 'f' to the function's name. libdevice doesn't have f16 math // functions, so we convert the operands to f32 before calling the function @@ -94,7 +94,7 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( for (int64 i = 0; i < operands.size(); ++i) { if (input_types[i] == F16) { converted_operands[i] = - b_->CreateFPCast(converted_operands[i], b_->getFloatTy()); + FPCast(converted_operands[i], b_->getFloatTy()); converted_input_types[i] = F32; } } @@ -107,13 +107,13 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( break; default: return Unimplemented("Bad type for libdevice math call: %s", - PrimitiveType_Name(output_type).c_str()); + PrimitiveType_Name(output_type)); } llvm::Value* result = EmitMathCall(munged_callee, converted_operands, converted_input_types, output_type) .ValueOrDie(); if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); + result = FPCast(result, b_->getHalfTy()); } return result; } @@ -122,7 +122,7 @@ StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( const string& callee_name, tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const { + PrimitiveType output_type) { // llvm intrinsics differentiate between half/float/double functions via // the suffixes ".f16", ".f32" and ".f64". string munged_callee = callee_name; @@ -138,7 +138,7 @@ StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( break; default: return Unimplemented("Bad type for llvm intrinsic math call: %s", - PrimitiveType_Name(output_type).c_str()); + PrimitiveType_Name(output_type)); } return EmitMathCall(munged_callee, operands, input_types, output_type); } @@ -147,13 +147,13 @@ StatusOr GpuElementalIrEmitter::EmitMathCall( const string& callee_name, tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const { + PrimitiveType output_type) { // Binary math functions transform are of type [T] -> T. for (PrimitiveType input_type : input_types) { if (output_type != input_type) { return Unimplemented("Input type ≠ output type: %s ≠ %s", - PrimitiveType_Name(input_type).c_str(), - PrimitiveType_Name(output_type).c_str()); + PrimitiveType_Name(input_type), + PrimitiveType_Name(output_type)); } } @@ -163,8 +163,7 @@ StatusOr GpuElementalIrEmitter::EmitMathCall( } StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); @@ -183,8 +182,7 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( } StatusOr GpuElementalIrEmitter::EmitPowerOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { CHECK_EQ(op->opcode(), HloOpcode::kPower); PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); @@ -218,7 +216,7 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( // TODO(jlebar): Does this happen with fastmath disabled? If not, should // we force-enable it? TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt()); - return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); + return FDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); } VLOG(10) << "emitting pow as regular call to pow(): " << op->ToString(); @@ -227,55 +225,56 @@ StatusOr GpuElementalIrEmitter::EmitPowerOp( } StatusOr GpuElementalIrEmitter::EmitErfcInv( - PrimitiveType prim_type, llvm::Value* value) const { + PrimitiveType prim_type, llvm::Value* value) { return EmitLibdeviceMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitLog( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitLog1p( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_log1p", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitSin( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitCos( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_cos", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitExp( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitExp(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitExpm1( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_expm1", {value}, {prim_type}, prim_type); } StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const { + llvm::Value* rhs) { return EmitLibdeviceMathCall("__nv_pow", {lhs, rhs}, {prim_type, prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitAtan2( - PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { +StatusOr GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) { return EmitLibdeviceMathCall("__nv_atan2", {lhs, rhs}, {prim_type, prim_type}, prim_type); } -StatusOr GpuElementalIrEmitter::EmitTanh( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, + llvm::Value* value) { // Emit a fast approximation of tanh instead of calling __nv_tanh. // __nv_tanh is particularly bad because it contains branches, thus // preventing LLVM's load-store vectorizer from working its magic across a @@ -285,9 +284,9 @@ StatusOr GpuElementalIrEmitter::EmitTanh( // Upcast F16 to F32 if necessary. llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType(); - llvm::Value* input = b_->CreateFPCast(value, type); + llvm::Value* input = FPCast(value, type); llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); - return b_->CreateFPCast(fast_tanh, value->getType()); + return FPCast(fast_tanh, value->getType()); } llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( @@ -295,7 +294,7 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice input_types, PrimitiveType output_type, - tensorflow::gtl::ArraySlice attributes) const { + tensorflow::gtl::ArraySlice attributes) { std::vector ir_input_types; for (PrimitiveType input_type : input_types) { ir_input_types.push_back( @@ -315,29 +314,28 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( callee->addFnAttr(attribute); } - return b_->CreateCall(callee, llvm_ir::AsArrayRef(operands)); + return Call(callee, llvm_ir::AsArrayRef(operands)); } -llvm::Value* GpuElementalIrEmitter::EmitThreadId() const { - llvm::Value* block_id = b_->CreateIntCast( - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, - {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "block.id"); - llvm::Value* thread_id_in_block = b_->CreateIntCast( - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, - {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); - llvm::Value* threads_per_block = b_->CreateIntCast( - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, - {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); - return b_->CreateNSWAdd(b_->CreateNSWMul(block_id, threads_per_block), - thread_id_in_block); +llvm::Value* GpuElementalIrEmitter::EmitThreadId() { + llvm::Value* block_id = + IntCast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "block.id"); + llvm::Value* thread_id_in_block = + IntCast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); + llvm::Value* threads_per_block = + IntCast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); + return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block); } llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const { + const HloToElementGeneratorMap& operand_to_generator) { switch (hlo->opcode()) { case HloOpcode::kMap: return [=, &operand_to_generator]( @@ -383,7 +381,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN(llvm::Value * init_value, operand_to_generator.at(hlo->operand(1))( IrArray::Index(index.GetType()))); - b_->CreateStore(init_value, accum_ptr); + Store(init_value, accum_ptr); } llvm::Type* index_type = index.GetType(); @@ -405,22 +403,21 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( IrArray::Index input_index(index_type, index.size()); llvm::Value* in_bounds = b_->getInt1(true); for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* stridden_index = b_->CreateNSWMul( + llvm::Value* stridden_index = NSWMul( index[i], index_typed_const(window.dimensions(i).stride())); - input_index[i] = b_->CreateNSWSub( - b_->CreateNSWAdd(stridden_index, window_index[i]), - index_typed_const(window.dimensions(i).padding_low())); + input_index[i] = + NSWSub(NSWAdd(stridden_index, window_index[i]), + index_typed_const(window.dimensions(i).padding_low())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise // we are in the pad and so can skip the computation. This // comparison is equivalent to the unsigned comparison // input_index[i] < bound, as a negative value wraps to a large // positive value. - in_bounds = b_->CreateAnd( - in_bounds, - b_->CreateICmpULT( - input_index[i], - index_typed_const(operand->shape().dimensions(i)))); + in_bounds = + And(in_bounds, + ICmpULT(input_index[i], + index_typed_const(operand->shape().dimensions(i)))); } llvm_ir::LlvmIfData if_data = @@ -432,12 +429,11 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( operand_to_generator.at(operand)(input_index)); TF_ASSIGN_OR_RETURN( llvm::Value * accum_value, - compute_nested_(*hlo->to_apply(), - {b_->CreateLoad(accum_ptr), input_value})); - b_->CreateStore(accum_value, accum_ptr); + compute_nested_(*hlo->to_apply(), {Load(accum_ptr), input_value})); + Store(accum_value, accum_ptr); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_); - return b_->CreateLoad(accum_ptr); + return Load(accum_ptr); }; case HloOpcode::kReduce: // TODO(b/112040122): This should be supported. diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 84454d31bb820a3de6ef3364bd205b8115bd95c0..91942785d286d7ff9f9e7001c788315c77362ea4 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -48,50 +48,50 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const override; + const HloToElementGeneratorMap& operand_to_generator) override; protected: - StatusOr EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const override; + StatusOr EmitFloatBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value) override; StatusOr EmitErfcInv(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitLog(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitSin(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitCos(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitExp(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const override; + llvm::Value* rhs) override; StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const override; + llvm::Value* rhs) override; StatusOr EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; - llvm::Value* EmitThreadId() const override; + llvm::Value* EmitThreadId() override; private: // Emits IR for op, which must have opcode kPower. StatusOr EmitPowerOp(const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); // Emits IR to call a device function named "callee_name" on the given // operand. Returns the IR value that represents the return value. @@ -100,7 +100,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice input_type, PrimitiveType output_type, - tensorflow::gtl::ArraySlice attributes) const; + tensorflow::gtl::ArraySlice attributes); // Emits IR to call an LLVM intrinsic of type [T] -> T. Adjusts // callee_name according to T. Returns the IR value that represents the @@ -109,7 +109,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const string& callee_name, tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const; + PrimitiveType output_type); // Emits IR to call a libdevice function of type [T] -> T. Adjusts // callee_name according to T. Returns the IR value that represents the @@ -118,7 +118,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const string& callee_name, tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const; + PrimitiveType output_type); // Emits IR to call a function of type [T] -> T. Does not munge callee_name. // Returns the IR value that represents the return value of the function. @@ -126,7 +126,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const string& callee_name, tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice input_types, - PrimitiveType output_type) const; + PrimitiveType output_type); const HloModuleConfig& hlo_module_config_; NestedComputer compute_nested_; diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index 0cdddf8bcfd4e849b311bf810eda471d79dbf106..11549cdac53c58cf006b3e4e1a8338c96e772889 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -17,11 +17,11 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -43,8 +43,8 @@ StatusOr> FftScratchAllocator::AllocateBytes( if (byte_size > GetMemoryLimitInBytes(stream)) { return se::port::Status( se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size, GetMemoryLimitInBytes(stream))); } @@ -213,7 +213,7 @@ Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, return Status::OK(); } return InternalError("Unable to launch fft for thunk %p with type %s", this, - FftTypeToString(fft_type_).c_str()); + FftTypeToString(fft_type_)); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 9b86e5315bf51e88cca569499fe9acbe17998e48..1bd88233e183af89268865e2a80155b2d7f638b6 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace gpu { @@ -289,11 +289,10 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { << " flops_to_bytes_ratio: " << CalculateFlopsToBytesRatio(fusion) << " merged_to_current_bytes_ratio: " << merged_to_current_bytes_ratio << " into users { " - << tensorflow::str_util::Join(users, ", ", - [](string* out, HloInstruction* user) { - tensorflow::strings::StrAppend( - out, user->name()); - }) + << absl::StrJoin(users, ", ", + [](string* out, HloInstruction* user) { + absl::StrAppend(out, user->name()); + }) << " }"; // Remove 'fusion' instruction. CHECK_EQ(0, fusion->user_count()); diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h index 4c523a66de977cd32423b25f0d165c4f4ba51c4a..7e3f5775b8d97f43a0bba201d24f34c2d337fabb 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h @@ -34,7 +34,7 @@ namespace gpu { // class FusionMerger : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "fusion merger"; } + absl::string_view name() const override { return "fusion merger"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 74282c568c09921dbeec2e9cce79b6c73b6ea592..9c4a4903667ea1a6c99ce9e912c9d0497b8e389f 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -186,7 +186,7 @@ StatusOr DoGemmAutotune( } return InternalError( - "Unable to autotune cuBLAS gemm on stream %p; none of the %zu algorithms " + "Unable to autotune cuBLAS gemm on stream %p; none of the %u algorithms " "ran successfully", stream, algorithms.size()); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h index 0c6f9b511f3aac5f62182273b827adcd068cd633..8ffae18fe820aa01701731ee56a83aeacf0eab0d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h @@ -27,7 +27,7 @@ namespace gpu { // inserting kCopy instructions. class GpuCopyInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } + absl::string_view name() const override { return "copy-insertion"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 88be63e2679dcb145a1d7c1d3e18206c9e62a9c3..71a02e70df7383a84eb577c4bb2b061651d18a35 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -160,7 +160,7 @@ Status GpuExecutable::ExecuteThunks( if (!block_status.ok()) { return InternalError( "Failed to complete all kernels launched on stream %p: %s", - main_stream, block_status.error_message().c_str()); + main_stream, block_status.error_message()); } } @@ -260,10 +260,9 @@ StatusOr GpuExecutable::ExecuteOnStream( if (buffer.is_null() && buffer.size() > 0) { return FailedPrecondition( "Cannot run XLA computation because pointer to (sub-)buffer at " - "index %s of parameter %lld was null. All pointers to " - "(sub-)buffers must not be null, unless the (sub-)buffer has zero " - "elements.", - allocation.param_shape_index().ToString().c_str(), param_no); + "index %s of parameter %d was null. All pointers to (sub-)buffers " + "must not be null, unless the (sub-)buffer has zero elements.", + allocation.param_shape_index().ToString(), param_no); } buffer_allocations_builder.RegisterBuffer(i, buffer); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 09a1d9c12b05c8ecba0619b84dbe139a9dd955db..627a05e2401e9f07f764988637e87773780ab1f2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc index 4944c41f7d8dc7a78a3cd094aee4d7087c74857e..4268fb2c7a813b3b53e4cd48746028a7b369f28e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc @@ -34,9 +34,8 @@ StatusOr GpuHloSupportChecker::Run(HloModule* module) { return xla::Unimplemented( "GPU backend does not support HLO instruction %s with shape " "containing a sparse layout: %s", - instruction->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(instruction->shape()) - .c_str()); + instruction->ToString(), + ShapeUtil::HumanStringWithLayout(instruction->shape())); } return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h index d63e213d2b1efab4bcff75541cc5ab33d7a07976..bbb3340760c8330bd6570f33382f004315c6d0bd 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h @@ -28,9 +28,7 @@ class GpuHloSupportChecker : public HloPassInterface { GpuHloSupportChecker() = default; ~GpuHloSupportChecker() override = default; - tensorflow::StringPiece name() const override { - return "gpu_hlo_support_checker"; - } + absl::string_view name() const override { return "gpu_hlo_support_checker"; } // Note: always returns false (no instructions are ever modified by this // pass). diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 286547ebae2f1a4b8d783a06d13b4dd96052b952..fbc8ddf599570b90e93eb463a1fd6c275b73711c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -119,7 +120,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { for (const Shape& input_shape : AllLayoutsOf(shape)) { for (const Shape& result_shape : AllLayoutsOf(shape)) { - SCOPED_TRACE(tensorflow::strings::StrCat( + SCOPED_TRACE(absl::StrCat( "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); @@ -192,7 +193,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { // Enumerate all combinations of shapes. for (const Shape& input_shape : AllLayoutsOf(shape)) { for (const Shape& result_shape : AllLayoutsOf(shape)) { - SCOPED_TRACE(tensorflow::strings::StrCat( + SCOPED_TRACE(absl::StrCat( "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); @@ -265,7 +266,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { for (const Shape& input_shape : AllLayoutsOf(shape)) { for (const Shape& result_shape : AllLayoutsOf(shape)) { for (int constrained_param_no : {0, 4}) { - SCOPED_TRACE(tensorflow::strings::StrCat( + SCOPED_TRACE(absl::StrCat( "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 44303724bb5cda4f392c8d17d60c114286b6b7e2..f3c274429242d5c989146d14ea523b5910408cff 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -84,7 +84,7 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } infeed_manager->EnqueueDestination(std::move(buffers)); @@ -97,7 +97,7 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( StatusOr GpuTransferManager::TransferBufferToInfeedInternal( se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits::max()) { - return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); + return InvalidArgument("Infeed shape is too large: needs %d bytes", size); } if (size == 0) { diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index d4a96cd5b353436ea4d1d6db3810b3e777449cd4..bb147c8d9828cebb7b710041234ece4b54d7ed11 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -266,7 +267,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) { params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( - i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); + i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 8c11cd05419289d82b033c936bb60884f45cb636..0e205b9c028dee91b422bd9f18a1c128d54e15f8 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" @@ -24,16 +25,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace gpu { -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; void HloToIrBindings::EmitBasePointersForHlos( tensorflow::gtl::ArraySlice io_hlos, diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index fee6d2af3bfd4976f5845edf592e8310b55a3feb..8c3a026740851767855beae59d6a3c92f7a0d6bd 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -96,7 +96,7 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } VLOG(2) << "Infeeding to GPU complete"; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 0f2c83aeb2633a007559d8caac78ea2d233539ed..0bcaaee2b75a80063e1a1a66fcdd7325d3e2f616 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -26,7 +26,7 @@ namespace gpu { namespace { -bool IsFusile(const HloInstruction& hlo) { +bool IsFusible(const HloInstruction& hlo) { // Don't fuse get-tuple-element on GPU: We can, but it's slower than not // fusing. We never generate kernels for unfused GTEs. Instead, if an // unfused GTE is an input to a kernel (including a fusion kernel), we @@ -245,7 +245,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return true; } - if (!IsFusile(*producer) || !IsFusile(*consumer) || + if (!IsFusible(*producer) || !IsFusible(*consumer) || !InstructionFusion::ShouldFuse(consumer, operand_index)) { return false; } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 8d0522bd8fd6659e64d18c52807df8dc7fc2f3b8..f53dfaee3dec9902d2881122c36509079e0393c5 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -365,7 +365,7 @@ static StatusOr FindHloInstruction( } return NotFound( "Computation '%s' does not contain an instruction with op code '%s'.", - computation.name().c_str(), HloOpcodeString(op).c_str()); + computation.name(), HloOpcodeString(op)); } TEST_F(InstructionFusionTest, MultiOutputFusion) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index c349063c71f000435a05306101ad724505f2d197..f544bcc91976233eff19d97037be989ea0855b86 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -215,7 +215,7 @@ bool IsReductionToVector(const HloInstruction& reduce) { // This emits a device-side call to // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls -llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, +llvm::Value* EmitPrintf(absl::string_view fmt, tensorflow::gtl::ArraySlice arguments, llvm::IRBuilder<>* builder) { std::vector argument_types; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 5d23a3d01842c7b4ff405171cd49c96a19f7e5b0..a35e250101c0743018b76fffb82e9db591c33de3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -126,7 +126,7 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo); bool IsReductionToVector(const HloInstruction& reduce); // Emits call to "vprintf" with given format and arguments. -llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, +llvm::Value* EmitPrintf(absl::string_view fmt, tensorflow::gtl::ArraySlice arguments, llvm::IRBuilder<>* builder); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 7111b53944770c9dbfcd0611f67b18900bcf1ffb..bdf6aadde675ec6fca28efce32f962238dd3d459 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -156,7 +156,7 @@ Status IrEmitter::EmitCallToNestedComputation( std::vector arguments(operands.begin(), operands.end()); arguments.push_back(output); arguments.push_back(bindings_.GetTempBufferBase()); - b_.CreateCall(emitted_function, arguments); + Call(emitted_function, arguments); return Status::OK(); } @@ -178,7 +178,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( computation.root_instruction()->shape().element_type(); bool is_atomic_integral = element_type == S32 || element_type == U32 || element_type == S64 || element_type == U64; - llvm::Value* source = b_.CreateLoad(source_address, "source"); + llvm::Value* source = Load(source_address, "source"); if (root_opcode == HloOpcode::kAdd) { // NVPTX supports atomicAdd on F32 and integer types. if (element_type == F32) { @@ -190,8 +190,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( } if (is_atomic_integral) { // integral + integral - b_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } } @@ -202,8 +202,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Max : llvm::AtomicRMWInst::UMax; - b_.CreateAtomicRMW(opcode, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + AtomicRMW(opcode, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } @@ -212,8 +212,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Min : llvm::AtomicRMWInst::UMin; - b_.CreateAtomicRMW(opcode, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + AtomicRMW(opcode, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } @@ -292,10 +292,10 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, // cas_old_output_address and cas_new_output_address point to the scratch // memory where we store the old and new values for the repeated atomicCAS // operations. - llvm::Value* cas_old_output_address = b_.CreateAlloca( - atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); - llvm::Value* cas_new_output_address = b_.CreateAlloca( - atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); + llvm::Value* cas_old_output_address = + Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); + llvm::Value* cas_new_output_address = + Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); // Emit preparation code to the preheader. llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock(); @@ -309,29 +309,26 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, CHECK_EQ((element_size % sizeof(char)), 0); llvm::Type* address_int_type = module_->getDataLayout().getIntPtrType(output_address_type); - atomic_memory_address = b_.CreatePtrToInt(output_address, address_int_type); + atomic_memory_address = PtrToInt(output_address, address_int_type); llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3); - llvm::Value* offset = b_.CreateAnd(atomic_memory_address, mask); + llvm::Value* offset = And(atomic_memory_address, mask); mask = llvm::ConstantInt::get(address_int_type, -4); - atomic_memory_address = b_.CreateAnd(atomic_memory_address, mask); + atomic_memory_address = And(atomic_memory_address, mask); atomic_memory_address = - b_.CreateIntToPtr(atomic_memory_address, atomic_address_type); - binop_output_address = b_.CreateAdd( - b_.CreatePtrToInt(cas_new_output_address, address_int_type), offset); + IntToPtr(atomic_memory_address, atomic_address_type); binop_output_address = - b_.CreateIntToPtr(binop_output_address, element_address_type); + Add(PtrToInt(cas_new_output_address, address_int_type), offset); + binop_output_address = IntToPtr(binop_output_address, element_address_type); } else { - atomic_memory_address = - b_.CreateBitCast(output_address, atomic_address_type); + atomic_memory_address = BitCast(output_address, atomic_address_type); binop_output_address = - b_.CreateBitCast(cas_new_output_address, element_address_type); + BitCast(cas_new_output_address, element_address_type); } // Use the value from the memory that atomicCAS operates on to initialize // cas_old_output. - llvm::Value* cas_old_output = - b_.CreateLoad(atomic_memory_address, "cas_old_output"); - b_.CreateStore(cas_old_output, cas_old_output_address); + llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output"); + Store(cas_old_output, cas_old_output_address); llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock( b_.GetInsertPoint(), "atomic_op_loop_exit"); @@ -344,32 +341,29 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, // Emit the body of the loop that repeatedly invokes atomicCAS. // // Use cas_old_output to initialize cas_new_output. - cas_old_output = b_.CreateLoad(cas_old_output_address, "cas_old_output"); - b_.CreateStore(cas_old_output, cas_new_output_address); + cas_old_output = Load(cas_old_output_address, "cas_old_output"); + Store(cas_old_output, cas_new_output_address); // Emits code to calculate new_output = operation(old_output, source); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( computation, {binop_output_address, source_address}, binop_output_address)); - llvm::Value* cas_new_output = - b_.CreateLoad(cas_new_output_address, "cas_new_output"); + llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output"); // Emit code to perform the atomicCAS operation // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output, // cas_new_output); - llvm::Value* ret_value = b_.CreateAtomicCmpXchg( - atomic_memory_address, cas_old_output, cas_new_output, - llvm::AtomicOrdering::SequentiallyConsistent, - llvm::AtomicOrdering::SequentiallyConsistent); + llvm::Value* ret_value = + AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output, + llvm::AtomicOrdering::SequentiallyConsistent, + llvm::AtomicOrdering::SequentiallyConsistent); // Extract the memory value returned from atomicCAS and store it as // cas_old_output. - b_.CreateStore(b_.CreateExtractValue(ret_value, 0, "cas_old_output"), - cas_old_output_address); + Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address); // Extract the success bit returned from atomicCAS and generate a // conditional branch on the success bit. - b_.CreateCondBr(b_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, - loop_body_bb); + CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb); // Set the insertion point to the exit basic block so that the caller of // this method can continue emitting code to the right place. @@ -384,8 +378,8 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation( // TODO(b/30258929): We only accept binary computations so far. return Unimplemented( "We only support atomic functions with exactly two parameters, but " - "computation %s has %lld.", - computation.name().c_str(), computation.num_parameters()); + "computation %s has %d.", + computation.name(), computation.num_parameters()); } if (MaybeEmitDirectAtomicOperation(computation, output_address, @@ -472,10 +466,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { if (ShapeUtil::ElementIsComplex(lhs_shape)) { auto value = MultiplyComplex(lhs_value, rhs_value, &b_); result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); - result = b_.CreateInsertValue(result, value.first, {0}); - result = b_.CreateInsertValue(result, value.second, {1}); + result = InsertValue(result, value.first, {0}); + result = InsertValue(result, value.second, {1}); } else { - result = b_.CreateFMul(lhs_value, rhs_value); + result = FMul(lhs_value, rhs_value); } target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_); return Status::OK(); @@ -559,21 +553,21 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt()); llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_); llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_); - llvm::Value* accum = b_.CreateLoad(accum_address); + llvm::Value* accum = Load(accum_address); llvm::Value* updated_accum; if (ShapeUtil::ElementIsComplex(lhs_shape)) { auto value = MultiplyComplex(lhs_element, rhs_element, &b_); llvm::Value* accum_real = Real(accum, &b_); - llvm::Value* real_sum = b_.CreateFAdd(accum_real, value.first); - updated_accum = b_.CreateInsertValue(accum, real_sum, {0}); + llvm::Value* real_sum = FAdd(accum_real, value.first); + updated_accum = InsertValue(accum, real_sum, {0}); llvm::Value* accum_imag = Imag(accum, &b_); - llvm::Value* imag_sum = b_.CreateFAdd(accum_imag, value.second); - updated_accum = b_.CreateInsertValue(updated_accum, imag_sum, {1}); + llvm::Value* imag_sum = FAdd(accum_imag, value.second); + updated_accum = InsertValue(updated_accum, imag_sum, {1}); } else { - llvm::Value* product = b_.CreateFMul(lhs_element, rhs_element); - updated_accum = b_.CreateFAdd(accum, product); + llvm::Value* product = FMul(lhs_element, rhs_element); + updated_accum = FAdd(accum, product); } - b_.CreateStore(updated_accum, accum_address); + Store(updated_accum, accum_address); // After the reduction loop exits, store the accumulator into the target // address. The index into the target address is the concatenation of the rhs @@ -595,7 +589,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_); target_array.EmitWriteArrayElement( target_index, - b_.CreateLoad(accum_address), // The value written to the target array. + Load(accum_address), // The value written to the target array. &b_); // Set the IR builder insert point to the exit basic block of the outer most @@ -646,10 +640,9 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { [=](const llvm_ir::IrArray::Index& index) -> StatusOr { // Initialize an accumulator with init_value. llvm::AllocaInst* accumulator_addr = - b_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType( + Alloca(llvm_ir::PrimitiveTypeToIrType( reduce->shape().element_type(), module_)); - b_.CreateStore(b_.CreateLoad(GetBasePointer(*init_value)), - accumulator_addr); + Store(Load(GetBasePointer(*init_value)), accumulator_addr); // The enclosing loops go over all the target elements. Now we have to // compute the actual target element. For this, we build a new loop nest @@ -686,7 +679,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { *function, {accumulator_addr, input_address}, accumulator_addr)); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(accumulator_addr); + return Load(accumulator_addr); }); } @@ -753,11 +746,6 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { "to a cudnn CustomCall using CudnnBatchNormRewriter."); } -Status IrEmitter::HandleIota(HloInstruction*) { - // TODO(b/64798317): implement iota on GPU. - return Unimplemented("Iota is not implemented on GPU."); -} - StatusOr IrEmitter::ComputeNestedElement( const HloComputation& computation, tensorflow::gtl::ArraySlice parameter_elements) { @@ -769,11 +757,11 @@ StatusOr IrEmitter::ComputeNestedElement( for (llvm::Value* parameter_element : parameter_elements) { parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry( parameter_element->getType(), "parameter_buffer", &b_)); - b_.CreateStore(parameter_element, parameter_buffers.back()); + Store(parameter_element, parameter_buffers.back()); } TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers, return_buffer)); - return b_.CreateLoad(return_buffer); + return Load(return_buffer); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 561c6838798aa92ce2c96b3c45d5ba42fe6edef3..3673b9f58d6cd1e7b88015746b14b737c00d3722 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" @@ -35,12 +36,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" @@ -64,7 +65,8 @@ namespace gpu { // IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is // not a subclass of gpu::IrEmitter, and in fact is better understood as an IR // generator generator. See comments on that class. -class IrEmitter : public DfsHloVisitorWithDefault { +class IrEmitter : public DfsHloVisitorWithDefault, + public IrBuilderMixin { public: IrEmitter(const IrEmitter&) = delete; IrEmitter& operator=(const IrEmitter&) = delete; @@ -95,10 +97,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleBatchNormInference(HloInstruction* batch_norm) override; Status HandleBatchNormTraining(HloInstruction* batch_norm) override; Status HandleBatchNormGrad(HloInstruction* batch_norm) override; - Status HandleIota(HloInstruction* iota) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } + llvm::IRBuilder<>* builder() { return &b_; } + protected: // Constructs an IrEmitter with the given IrEmitter context. // ir_emitter_context is owned by the caller and should outlive the IrEmitter diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index bda298620230225414b701b831024b163e5bf108..c0c8ae181a0eb3d5f38a8b233002f03d1a7a49cf 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" @@ -90,10 +91,10 @@ namespace { using absl::InlinedVector; using absl::nullopt; using absl::optional; +using absl::StrCat; using llvm_ir::IrArray; using llvm_ir::IrName; using tensorflow::gtl::ArraySlice; -using tensorflow::strings::StrCat; // If a dimensions is smaller than this, untiled transposition may be more // efficient. @@ -728,7 +729,7 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( "extra_output_element_address"); TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, extra_output_gens[i].first(index)); - b_.CreateStore(extra_output_ir_value, extra_output_address); + Store(extra_output_ir_value, extra_output_address); } return Status::OK(); } @@ -801,8 +802,7 @@ Status IrEmitterUnnested::EmitReductionToScalar( // // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize), // // // // and threads_per_block is a multiple of warpSize. - // reduce_kernel<<>>(); - // + // reduce_kernel // auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { const int num_reduces = reducers.size(); llvm::Type* element_ir_type = @@ -810,17 +810,17 @@ Status IrEmitterUnnested::EmitReductionToScalar( std::vector partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { llvm::Value* partial_reduction_result_address = - b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, init_value_gens[i](IrArray::Index(index_ty))); - b_.CreateStore(init_ir_value, partial_reduction_result_address); + Store(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } llvm::Value* x_in_tiles = tile_index[0]; - x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); + x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); // Emit an inner for-loop that reduces the elements in the tile. auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { @@ -832,15 +832,14 @@ Status IrEmitterUnnested::EmitReductionToScalar( // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &b_); - llvm::Value* x = b_.CreateNSWAdd( - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)), - tile_element_loop->GetIndVarValue()); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileSize)), + tile_element_loop->GetIndVarValue()); // Unless we know the tile is entirely in bounds, we have to emit a // x-in-bounds check before reading from the input. if (!tile_in_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", - &b_); + ICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. @@ -849,11 +848,11 @@ Status IrEmitterUnnested::EmitReductionToScalar( IrArray::Index input_index( /*linear=*/x, input_shape, &b_); - llvm::Value* input_address = b_.CreateAlloca(element_ir_type); + llvm::Value* input_address = Alloca(element_ir_type); for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - b_.CreateStore(input_ir_value, input_address); + Store(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -864,14 +863,14 @@ Status IrEmitterUnnested::EmitReductionToScalar( // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's // immediately beyond the tile. - llvm::Value* x_end = b_.CreateNSWAdd( - index_typed_constant(kTileSize), - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize))); + llvm::Value* x_end = + NSWAdd(index_typed_constant(kTileSize), + NSWMul(x_in_tiles, index_typed_constant(kTileSize))); // The tile is entirely in bound if all_threads_in_bounds or // x_end <= num_elems. llvm::Value* tile_in_bounds = - b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(num_elems)), - b_.getInt1(all_threads_in_bounds)); + Or(ICmpULE(x_end, index_typed_constant(num_elems)), + b_.getInt1(all_threads_in_bounds)); llvm_ir::LlvmIfData if_tile_in_bounds_data = llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_); llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_); @@ -892,20 +891,18 @@ Status IrEmitterUnnested::EmitReductionToScalar( for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; shuffle_distance /= 2) { llvm::Value* result_from_other_lane = - b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); + Alloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = b_.CreateLoad( - b_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); + llvm::Value* partial_reduction_result = + Load(BitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) << "Requires block size a multiple of the warp size, otherwise we " "will read undefined elements."; - b_.CreateStore( - EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - b_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + Store(EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -920,10 +917,9 @@ Status IrEmitterUnnested::EmitReductionToScalar( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm::Value* lane_id = - b_.CreateURem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); + URem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", - &b_); + ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { @@ -1043,12 +1039,12 @@ Status IrEmitterUnnested::EmitColumnReduction( for (int i = 0; i != num_reduces; ++i) { for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { llvm::Value* partial_reduction_result_address = - b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + - llvm::Twine(i * kTileWidth + x_offset)); + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + + llvm::Twine(i * kTileWidth + x_offset)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, init_value_gens[i](IrArray::Index(index_ty))); - b_.CreateStore(init_ir_value, partial_reduction_result_address); + Store(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } @@ -1059,8 +1055,8 @@ Status IrEmitterUnnested::EmitColumnReduction( llvm::Value* y_in_tiles = tile_index[0]; llvm::Value* x_in_tiles = tile_index[1]; - y_in_tiles = b_.CreateZExtOrTrunc(y_in_tiles, index_ty); - x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); + y_in_tiles = ZExtOrTrunc(y_in_tiles, index_ty); + x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); auto emit_tile_element_loop = [=](bool tile_in_y_bounds, bool tile_in_x_bounds) -> Status { @@ -1072,34 +1068,32 @@ Status IrEmitterUnnested::EmitColumnReduction( // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &b_); - llvm::Value* y = b_.CreateNSWAdd( - b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)), - tile_element_loop->GetIndVarValue()); + llvm::Value* y = + NSWAdd(NSWMul(y_in_tiles, index_typed_constant(kTileHeight)), + tile_element_loop->GetIndVarValue()); // Unless we know that y is in bounds, we have to emit a check before // reading from the input. if (!tile_in_y_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(y, index_typed_constant(height)), "y_in_bounds", - &b_); + ICmpULT(y, index_typed_constant(height)), "y_in_bounds", &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = b_.CreateNSWAdd( - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); // Unless we know that x is in bounds, we have to emit a check before // reading from the input. if (!tile_in_x_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(x, index_typed_constant(width)), "x_in_bounds", - &b_); + ICmpULT(x, index_typed_constant(width)), "x_in_bounds", &b_); llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } - llvm::Value* input_address = b_.CreateAlloca(element_ir_type); + llvm::Value* input_address = Alloca(element_ir_type); // {y,x} is an index to input_matrix_shape [height,width]. We need to // convert that to an index to input_shape (the shape of the operand of // "reduce"). This conversion is composed of a transposition from @@ -1126,7 +1120,7 @@ Status IrEmitterUnnested::EmitColumnReduction( for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - b_.CreateStore(input_ir_value, input_address); + Store(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i * kTileWidth + x_offset], @@ -1141,20 +1135,20 @@ Status IrEmitterUnnested::EmitColumnReduction( // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location // that's immediately beyond the tile. - llvm::Value* y_end = b_.CreateNSWAdd( - index_typed_constant(kTileHeight), - b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight))); + llvm::Value* y_end = + NSWAdd(index_typed_constant(kTileHeight), + NSWMul(y_in_tiles, index_typed_constant(kTileHeight))); // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location // that's immediately beyond the tile. - llvm::Value* x_end = b_.CreateNSWAdd( - index_typed_constant(kTileWidth), - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth))); + llvm::Value* x_end = + NSWAdd(index_typed_constant(kTileWidth), + NSWMul(x_in_tiles, index_typed_constant(kTileWidth))); llvm::Value* tile_in_y_bounds = - b_.CreateOr(b_.CreateICmpULE(y_end, index_typed_constant(height)), - b_.getInt1(height % kTileHeight == 0)); + Or(ICmpULE(y_end, index_typed_constant(height)), + b_.getInt1(height % kTileHeight == 0)); llvm::Value* tile_in_x_bounds = - b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(width)), - b_.getInt1(width % kTileWidth == 0)); + Or(ICmpULE(x_end, index_typed_constant(width)), + b_.getInt1(width % kTileWidth == 0)); // The tile is in y bounds if "height" is a multiple of kTileHeight or // y_end <= height. llvm_ir::LlvmIfData if_tile_in_y_bounds_data = @@ -1188,9 +1182,9 @@ Status IrEmitterUnnested::EmitColumnReduction( reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; for (int i = 0; i != num_reduces; ++i) { for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = b_.CreateNSWAdd( - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); llvm::Value* output_address = GetIrArray(*output, *output, reduce_output_shapes[i]) .EmitArrayElementAddress( @@ -1379,11 +1373,11 @@ Status IrEmitterUnnested::EmitRowReduction( std::vector partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { llvm::Value* partial_reduction_result_address = - b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, init_value_gens[i](IrArray::Index(index_ty))); - b_.CreateStore(init_ir_value, partial_reduction_result_address); + Store(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } @@ -1392,22 +1386,20 @@ Status IrEmitterUnnested::EmitRowReduction( llvm::Value* y = tile_index[1]; llvm::Value* x_tile = tile_index[2]; - x_tile = b_.CreateZExtOrTrunc(x_tile, index_ty); + x_tile = ZExtOrTrunc(x_tile, index_ty); llvm::Value* warp_id = - b_.CreateUDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); + UDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); llvm::Value* lane_id = - b_.CreateURem(x_tile, index_typed_constant(kWarpSize), "lane_id"); + URem(x_tile, index_typed_constant(kWarpSize), "lane_id"); // The x-location of the last element in this z-x-tile. // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); - llvm::Value* last_x = b_.CreateNSWAdd( + llvm::Value* last_x = NSWAdd( lane_id, - b_.CreateNSWMul( - index_typed_constant(kWarpSize), - b_.CreateNSWAdd( - index_typed_constant(x_tile_size - 1), - b_.CreateNSWMul(warp_id, index_typed_constant(x_tile_size))))); + NSWMul(index_typed_constant(kWarpSize), + NSWAdd(index_typed_constant(x_tile_size - 1), + NSWMul(warp_id, index_typed_constant(x_tile_size))))); KernelSupportLibrary ksl( &b_, @@ -1419,9 +1411,8 @@ Status IrEmitterUnnested::EmitRowReduction( auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, int64 x_tile_loop_bound) -> Status { auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { - llvm::Value* z = b_.CreateNSWAdd( - z_indvar, - b_.CreateNSWMul(index_typed_constant(z_tile_size), z_tile)); + llvm::Value* z = + NSWAdd(z_indvar, NSWMul(index_typed_constant(z_tile_size), z_tile)); TF_RETURN_IF_ERROR(ksl.For( "x_tile", /*start=*/index_typed_constant(0), @@ -1429,22 +1420,20 @@ Status IrEmitterUnnested::EmitRowReduction( /*step=*/1, [&](llvm::Value* x_indvar) -> Status { // x = lane_id + // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); - llvm::Value* x = b_.CreateNSWAdd( + llvm::Value* x = NSWAdd( lane_id, - b_.CreateNSWMul( - index_typed_constant(kWarpSize), - b_.CreateNSWAdd( - x_indvar, b_.CreateNSWMul( - warp_id, llvm::ConstantInt::get( - index_ty, x_tile_size))))); + NSWMul(index_typed_constant(kWarpSize), + NSWAdd(x_indvar, + NSWMul(warp_id, llvm::ConstantInt::get( + index_ty, x_tile_size))))); // Unless we know the x-tile is entirely in bounds, we have to // emit a x-in-bounds check before reading from the input. if (!x_tile_in_bounds) { llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(x, index_typed_constant(width)), - "x_in_bounds", &b_); + ICmpULT(x, index_typed_constant(width)), "x_in_bounds", + &b_); // Points b_ to the then-block. llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, &b_); @@ -1452,7 +1441,7 @@ Status IrEmitterUnnested::EmitRowReduction( // Emit code that reads the input element and accumulates it // to the partial reduction result. - llvm::Value* input_address = b_.CreateAlloca(element_ir_type); + llvm::Value* input_address = Alloca(element_ir_type); { // {z,y,x} is an index to input_3d_tensor_shape // [depth,height,width]. We need to convert that to an index @@ -1483,7 +1472,7 @@ Status IrEmitterUnnested::EmitRowReduction( for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - b_.CreateStore(input_ir_value, input_address); + Store(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -1503,8 +1492,8 @@ Status IrEmitterUnnested::EmitRowReduction( }; llvm::Value* tile_in_bounds = - b_.CreateOr(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), - b_.CreateICmpULT(last_x, index_typed_constant(width))); + Or(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), + ICmpULT(last_x, index_typed_constant(width))); TF_RETURN_IF_ERROR( ksl.If(tile_in_bounds, @@ -1532,20 +1521,18 @@ Status IrEmitterUnnested::EmitRowReduction( for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { llvm::Value* result_from_other_lane = - b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); + Alloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = b_.CreateLoad( - b_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); + llvm::Value* partial_reduction_result = + Load(BitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) << "Requires block size a multiple of the warp size, otherwise we " "will read undefined elements."; - b_.CreateStore( - EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - b_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + Store(EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -1560,8 +1547,7 @@ Status IrEmitterUnnested::EmitRowReduction( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", - &b_); + ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { llvm::Value* output_address = @@ -1845,7 +1831,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( &b_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( b_.getInt1Ty(), "initialized_flag_address", &b_); - b_.CreateStore(b_.getInt1(false), initialized_flag_address); + Store(b_.getInt1(false), initialized_flag_address); // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_, @@ -1866,15 +1852,15 @@ Status IrEmitterUnnested::HandleSelectAndScatter( IrArray::Index operand_index(index_type, source_index.size()); llvm::Value* in_bounds_condition = b_.getInt1(true); for (int64 i = 0; i < rank; ++i) { - llvm::Value* strided_index = b_.CreateNSWMul( + llvm::Value* strided_index = NSWMul( source_index[i], index_typed_constant(window.dimensions(i).stride())); - operand_index[i] = b_.CreateNSWSub( - b_.CreateNSWAdd(strided_index, window_index[i]), - index_typed_constant(window.dimensions(i).padding_low())); - llvm::Value* index_condition = b_.CreateICmpULT( + operand_index[i] = + NSWSub(NSWAdd(strided_index, window_index[i]), + index_typed_constant(window.dimensions(i).padding_low())); + llvm::Value* index_condition = ICmpULT( operand_index[i], index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i))); - in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); + in_bounds_condition = And(in_bounds_condition, index_condition); } CHECK(in_bounds_condition != nullptr); @@ -1884,7 +1870,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_); llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( - b_.CreateLoad(initialized_flag_address), "initialized", &b_); + Load(initialized_flag_address), "initialized", &b_); // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. @@ -1892,16 +1878,16 @@ Status IrEmitterUnnested::HandleSelectAndScatter( const auto save_operand_index = [&](const IrArray::Index& operand_index) { for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - b_.CreateStore(operand_index[i], selected_index_address_slot); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + Store(operand_index[i], selected_index_address_slot); } }; IrArray operand_array = GetIrArray(*operand, *select_and_scatter); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &b_); - b_.CreateStore(operand_data, selected_value_address); + Store(operand_data, selected_value_address); save_operand_index(operand_index); - b_.CreateStore(b_.getInt1(true), initialized_flag_address); + Store(b_.getInt1(true), initialized_flag_address); // If the initialized_flag is true, call the `select` function to // potentially update the selected value and index with the currently @@ -1917,11 +1903,11 @@ Status IrEmitterUnnested::HandleSelectAndScatter( TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *select_and_scatter->select(), {selected_value_address, operand_address}, select_return_buffer)); - llvm::Value* result = b_.CreateLoad(select_return_buffer); + llvm::Value* result = Load(select_return_buffer); // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. - llvm::Value* cond = b_.CreateICmpNE( + llvm::Value* cond = ICmpNE( result, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType( PRED, ir_emitter_context_->llvm_module()), @@ -1930,7 +1916,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( llvm_ir::LlvmIfData if_select_lhs = llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_); - b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address); + Store(Load(operand_address), selected_value_address); save_operand_index(operand_index); // After iterating over the window elements, scatter the source element to @@ -1942,8 +1928,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( IrArray::Index selected_index(operand_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + selected_index.push_back(Load(selected_index_address_slot)); } llvm::Value* source_value_address = GetIrArray(*source, *select_and_scatter) @@ -2367,8 +2353,8 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( *slice.allocation()))); CHECK_NE(loc, nullptr); } else { - loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), - {b_.getInt64(slice.offset())}); + loc = InBoundsGEP(kernel_args.at(slice.allocation()), + {b_.getInt64(slice.offset())}); } // If gte_index is nonempty, we have to dereference `loc` to get to the @@ -2376,8 +2362,8 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm::Type* int8_double_pointer = llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0); for (int64 idx : gte_index) { - loc = b_.CreateBitCast(loc, int8_double_pointer); - loc = b_.CreateLoad(b_.CreateInBoundsGEP(loc, {b_.getInt64(idx)})); + loc = BitCast(loc, int8_double_pointer); + loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)})); } bindings_.BindHloToIrValue(*instr, loc, index); @@ -2674,8 +2660,7 @@ Status CheckHloBuffersShareAllocation( if (slice_a != slice_b) { return InternalError( "instruction %s %s does not share allocation with instruction %s %s", - a->ToString().c_str(), slice_a.ToString().c_str(), - b->ToString().c_str(), slice_b.ToString().c_str()); + a->ToString(), slice_a.ToString(), b->ToString(), slice_b.ToString()); } return Status::OK(); } @@ -3155,9 +3140,8 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( const IrArray::Index output_tile_origin = [&] { IrArray::Index index = output_tile_index; for (int i = 1; i < 3; ++i) { - index[i] = - b_.CreateMul(output_tile_index[i], index_typed_constant(kTileSize), - "tile_origin." + std::to_string(i)); + index[i] = Mul(output_tile_index[i], index_typed_constant(kTileSize), + "tile_origin." + std::to_string(i)); } return index; }(); @@ -3170,12 +3154,12 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( std::vector output_tile_bounds(3); for (int i = 1; i < 3; ++i) { // Only last row or column may not have full size. - output_tile_bounds[i] = b_.CreateSelect( - b_.CreateICmpEQ(output_tile_index[i], - index_typed_constant(output_dims_in_tiles[i] - 1)), - index_typed_constant(reduced_output_dims[i] - - (output_dims_in_tiles[i] - 1) * kTileSize), - index_typed_constant(kTileSize), "kTileSize"); + output_tile_bounds[i] = + Select(ICmpEQ(output_tile_index[i], + index_typed_constant(output_dims_in_tiles[i] - 1)), + index_typed_constant(reduced_output_dims[i] - + (output_dims_in_tiles[i] - 1) * kTileSize), + index_typed_constant(kTileSize), "kTileSize"); } KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); @@ -3193,7 +3177,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( // Adds `addend` to the given `dim` of `index`. auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { - index[dim] = b_.CreateAdd(index[dim], addend); + index[dim] = Add(index[dim], addend); return index; }; const IrArray::Index input_index = @@ -3209,10 +3193,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( llvm::Value* shmem_buffer = param_shmem_buffers[id]; // TODO(jlebar): Add AA metadata to this store. Tile buffers are // global variables, so LLVM can't infer much about it. - b_.CreateStore( - input_in_logical_shape.EmitReadArrayElement(index, &b_, - "input_element"), - b_.CreateGEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); + Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, + "input_element"), + GEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); } }); @@ -3233,9 +3216,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( output_index, "output", output_tile_bounds[2], output_tile_bounds[1], [&](const IrArray::Index& index, llvm::Value* y_loc) { // TODO(jlebar): Add AA metadata to this load. - llvm::Instruction* load_from_shmem_buffer = b_.CreateLoad( - b_.CreateGEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), - "output_element"); + llvm::Instruction* load_from_shmem_buffer = + Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), + "output_element"); output_in_reduced_shape_arrays[0].EmitWriteArrayElement( index, load_from_shmem_buffer, &b_); }); @@ -3263,7 +3246,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( output_in_reduced_shape_arrays.size()); for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) { output_in_reduced_shape_arrays[i].EmitWriteArrayElement( - index, b_.CreateExtractValue(output_value, i), &b_); + index, ExtractValue(output_value, i), &b_); } } else { output_in_reduced_shape_arrays[0].EmitWriteArrayElement( @@ -3345,7 +3328,7 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { // if there's a Right Choice. // // This is only sound if tiled transposes are the only place where we use - // shared memory in fusions. If in the future other fusile ops use shared + // shared memory in fusions. If in the future other fusible ops use shared // memory, we'll have to adjust this heuristic. constexpr int kMinBlocksPerCore = 3; constexpr int64 kShmemPerCore = 48 * 1024; diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 6305396635eae7bb3fcda1d4675fb3b5f7d60553..3259eaa2a26d2b8ec8744323d90a0c6a31d5133e 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -16,11 +16,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -41,8 +41,8 @@ Status KernelThunk::Initialize(const GpuExecutable& executable, tensorflow::mutex_lock lock(mutex_); if (!loader_spec_) { loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); - tensorflow::StringPiece ptx = executable.ptx(); - // Convert tensorflow::StringPiece to se::port::StringPiece because + absl::string_view ptx = executable.ptx(); + // Convert absl::string_view to se::port::StringPiece because // StreamExecutor uses the latter. loader_spec_->AddCudaPtxInMemory( se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); @@ -63,7 +63,7 @@ Status KernelThunk::Initialize(const GpuExecutable& executable, if (kernel_cache_.end() == it) { it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; if (!executor->GetKernel(*loader_spec_, &it->second)) { - return InternalError("Unable to load kernel %s", kernel_name_.c_str()); + return InternalError("Unable to load kernel %s", kernel_name_); } } @@ -107,7 +107,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, stream, se::ThreadDim(launch_dimensions.threads_per_block()), se::BlockDim(launch_dimensions.block_count()), *kernel, *kernel_args)) { - return InternalError("Unable to launch kernel %s", kernel_name_.c_str()); + return InternalError("Unable to launch kernel %s", kernel_name_); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index 6bd9c58f83063554d57aea5e2289907be701a2c1..698d2d51cc81a6c87f6578f1f35cdb47cf6bb4f2 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -35,6 +35,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@llvm//:amdgpu_code_gen", "@llvm//:analysis", "@llvm//:bit_reader", diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc index 12a8a59488bfdd6ce55f762926cd63ba56bf9d7f..85bc58cb445627695a46171db64cd8a1f10e0fc8 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc @@ -15,14 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "llvm/IR/Module.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -86,10 +86,11 @@ void IrDumpingPassManager::run(llvm::Module &module) { const llvm::PassInfo *PI = llvm::PassRegistry::getPassRegistry()->getPassInfo(P->getPassID()); const string basename = ReplaceFilenameExtension( - tensorflow::io::Basename(input_filename_), - tensorflow::strings::Printf( + absl::string_view(tensorflow::io::Basename(input_filename_)), + absl::StrFormat( "pass-%02d.before.%s.ll", i, - (PI == nullptr ? "unknown" : PI->getPassArgument().data()))); + absl::string_view(PI == nullptr ? "unknown" + : PI->getPassArgument().data()))); llvm::legacy::PassManager::add( new DumpIrPass(tensorflow::io::JoinPath(output_dir_, basename))); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index cce6e4814174c022f40b9aa199335a85ffaa6ed7..8751e3a9c2a4c8da46d3ecd8437629450d4a2ba2 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" @@ -54,10 +56,7 @@ limitations under the License. #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/Scalar.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/tracing.h" @@ -107,8 +106,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, << ", " << compute_capability.second << ") ." << "Defaulting to libdevice for compute_" << libdevice_version; } - return tensorflow::strings::StrCat("libdevice.compute_", libdevice_version, - ".10.bc"); + return absl::StrCat("libdevice.compute_", libdevice_version, ".10.bc"); } // Gets the GPU name as it's known to LLVM for a given compute capability. If @@ -138,15 +136,16 @@ static string GetSmName(std::pair compute_capability) { << "Defaulting to telling LLVM that we're compiling for sm_" << sm_version; } - return tensorflow::strings::StrCat("sm_", sm_version); + return absl::StrCat("sm_", sm_version); } // Convenience function for producing a name of a temporary compilation product // from the input filename. string MakeNameForTempProduct(const std::string& input_filename, - tensorflow::StringPiece extension) { - return ReplaceFilenameExtension( - tensorflow::io::Basename(llvm_ir::AsString(input_filename)), extension); + absl::string_view extension) { + return ReplaceFilenameExtension(absl::string_view(tensorflow::io::Basename( + llvm_ir::AsString(input_filename))), + extension); } // Initializes LLVM passes. Uses the PassRegistry mechanism. @@ -167,7 +166,7 @@ void InitializePasses(llvm::PassRegistry* pass_registry) { // Returns the TargetMachine, given a triple. std::unique_ptr GetTargetMachine( - llvm::Triple triple, tensorflow::StringPiece cpu_name, + llvm::Triple triple, absl::string_view cpu_name, const HloModuleConfig& hlo_module_config) { std::string error; const llvm::Target* target = TargetRegistry::lookupTarget("", triple, error); @@ -243,9 +242,9 @@ void AddOptimizationPasses(unsigned opt_level, unsigned size_level, } // Emits the given module to a bit code file. -void EmitBitcodeToFile(const Module& module, tensorflow::StringPiece filename) { +void EmitBitcodeToFile(const Module& module, absl::string_view filename) { std::error_code error_code; - llvm::ToolOutputFile outfile(filename.ToString().c_str(), error_code, + llvm::ToolOutputFile outfile(string(filename).c_str(), error_code, llvm::sys::fs::F_None); if (error_code) { LOG(FATAL) << "opening bitcode file for writing: " << error_code.message(); @@ -266,8 +265,9 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { // get creative to add a suffix. string module_id(llvm_ir::AsString(module->getModuleIdentifier())); IrDumpingPassManager codegen_passes( - ReplaceFilenameExtension(tensorflow::io::Basename(module_id), - "-nvptx.dummy"), + ReplaceFilenameExtension( + absl::string_view(tensorflow::io::Basename(module_id)), + "-nvptx.dummy"), "", false); codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); @@ -332,8 +332,8 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module, return !GV.hasName() || (GVS.count(GV.getName()) == 0); }); })) { - return tensorflow::errors::Internal(tensorflow::strings::StrCat( - "Error linking libdevice from ", libdevice_path)); + return tensorflow::errors::Internal( + absl::StrCat("Error linking libdevice from ", libdevice_path)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h index 54e0e140dea1c3a8b21ffde2950c4bc9b703b71c..9654175bfafbb2521743e7894188abe5b5a15217 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h @@ -20,11 +20,11 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc index 9ef9bc3a50fc76f83f05e19163ab339f2da6ef3c..3b2c3591d95ee5a319c82336e9b500d14f88734f 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc @@ -17,13 +17,13 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Support/SourceMgr.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace { @@ -52,14 +52,13 @@ std::unique_ptr LoadIRModule(const string& filename, return module; } -string ReplaceFilenameExtension(tensorflow::StringPiece filename, - tensorflow::StringPiece new_extension) { +string ReplaceFilenameExtension(absl::string_view filename, + absl::string_view new_extension) { auto pos = filename.rfind('.'); - tensorflow::StringPiece stem = - pos == tensorflow::StringPiece::npos - ? filename - : tensorflow::StringPiece(filename.data(), pos); - return tensorflow::strings::StrCat(stem, ".", new_extension); + absl::string_view stem = pos == absl::string_view::npos + ? filename + : absl::string_view(filename.data(), pos); + return absl::StrCat(stem, ".", new_extension); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h index a6daeca95a6da66cb31b82805a6896f57cb80354..60f4926849cd3e8ad144f657f9feb3c3e1ea25e2 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace llvm { class LLVMContext; @@ -41,8 +41,8 @@ std::unique_ptr LoadIRModule(const string& filename, // // For example: // ReplaceFilenameExtension("/foo/baz.txt", "cc") --> "/foo/baz.cc" -string ReplaceFilenameExtension(tensorflow::StringPiece filename, - tensorflow::StringPiece new_extension); +string ReplaceFilenameExtension(absl::string_view filename, + absl::string_view new_extension); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 5575f6c0c6be1d13555d11a863165fd2290947ce..7a43f0be5481721d13370ce1cf795eb9e55cd39b 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -49,7 +49,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, // If possible, we want to pick a reduce operand of the fusion root, // because it has the most constraints. for (const auto* inst : fused_expression_root->operands()) { - if (inst->opcode() == HloOpcode::kReduce) { + if (IsReductionToVector(*inst)) { return inst; } } @@ -64,7 +64,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, auto get_element_shape = [&](const HloInstruction* element_instr) { // Special handling of kReduce instructions -- the fusion // applies to the first operand. - if (element_instr->opcode() == HloOpcode::kReduce) { + if (IsReductionToVector(*element_instr)) { return element_instr->operand(0)->shape(); } return element_instr->shape(); @@ -141,10 +141,15 @@ bool ReduceFriendlyInputLayouts(HloInstruction* instr) { } // namespace bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { - // We can fuse reduces and loop fusions. - return IsInputFusibleReduction(instr) || - (instr->opcode() == HloOpcode::kFusion && - instr->fusion_kind() == HloInstruction::FusionKind::kLoop); + // We can fuse reduces and loop fusions. Elementwise instructions can be fused + // with any other instruction. + // TODO(b/112957171): This should use the same isFusible logic as + // instruction_fusion. + return instr->IsFusible() && + (IsInputFusibleReduction(instr) || + (instr->opcode() == HloOpcode::kFusion && + instr->fusion_kind() == HloInstruction::FusionKind::kLoop) || + instr->IsElementwise()); } int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, @@ -178,28 +183,16 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, // merge into bigger loop fusions and input (reduce) fusions become fusions // with multiple reduce outputs. We could fuse reduce and loop fusions // together too (the result being an input fusion) if we find cases where this - // improves things. + // improves things. Also disable fusing standalone input-fusible reduces into + // loop fusions. CHECK(instr1->opcode() == HloOpcode::kFusion); if ((instr2->opcode() == HloOpcode::kFusion && instr1->fusion_kind() != instr2->fusion_kind()) || - (instr2->opcode() != HloOpcode::kFusion && + (IsReductionToVector(*instr2) && instr1->fusion_kind() == HloInstruction::FusionKind::kLoop)) { return false; } - // Multi-output loop fusions must have equal output shapes to be lowered. - if (instr1->fusion_kind() == HloInstruction::FusionKind::kLoop) { - Shape shape1 = instr1->IsMultiOutputFusion() - ? instr1->shape().tuple_shapes(0) - : instr1->shape(); - Shape shape2 = instr2->IsMultiOutputFusion() - ? instr2->shape().tuple_shapes(0) - : instr2->shape(); - if (!ShapeUtil::Equal(shape1, shape2)) { - return false; - } - } - // Do this check last, as it may be expensive. return !GpuInstructionFusion::FusionWouldBeTooLarge(instr1, instr2); } @@ -211,7 +204,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { tensorflow::gtl::FlatSet to_fuse; // Keep a list of the instructions to fuse after making all the fusion // decisions. We first aggressively add instructions to potential_fusion_list, - // then filter out instructions that will be no longer fusable because of + // then filter out instructions that will be no longer fusible because of // reachability change. This avoids recalculating reachability on a large set // of instructions. std::vector> @@ -227,7 +220,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { continue; } if (!IsInputFusibleReduction(consumer)) { - VLOG(3) << consumer->name() << " is not an input-fusable reduction."; + VLOG(3) << consumer->name() << " is not an input-fusible reduction."; continue; } VLOG(3) << consumer->name() @@ -236,8 +229,8 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { auto consumer_operands = consumer->operands(); for (size_t i = 0; i < consumer_operands.size(); ++i) { HloInstruction* producer = consumer_operands[i]; - if (!producer->IsFusable()) { - VLOG(3) << producer->name() << " is not fusable."; + if (!producer->IsFusible()) { + VLOG(3) << producer->name() << " is not fusible."; continue; } const bool is_loop_fusion = @@ -277,7 +270,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { } } - // Filter out pairs that will be no longer fusable because of reachability + // Filter out pairs that will be no longer fusible because of reachability // change. for (auto& fusion_pair : potential_fusion_list) { HloInstruction* producer = fusion_pair.first; diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h index 67ca5d49eee8508e93284b134f8410eb3a89f9ce..f0b4d67ab8463a39161f71908746cad9e2a8670a 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -22,7 +22,7 @@ namespace xla { namespace gpu { // Multi-output fusion of sibling and producer-consumer instructions for the -// Jellyfish backend. +// GPU backend. class GpuMultiOutputFusion : public MultiOutputFusion { public: GpuMultiOutputFusion(); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 072f885bc13ad9d8e3dd909c35fe221c688e38df..c822c94f1b102e02be4a13a35892a2c181702383 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -15,19 +15,19 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/str_util.h" - -namespace op = xla::testing::opcode_matchers; namespace xla { namespace gpu { +namespace op = xla::testing::opcode_matchers; + using MultiOutputFusionTest = HloTestBase; const char kModulePrefix[] = R"( @@ -47,7 +47,7 @@ const char kModulePrefix[] = R"( TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { // Fusion with reduce instruction root and a sibling reduce instruction // sharing the same input param. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation { p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) @@ -74,7 +74,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[6400]{0} parameter(1) mul = f32[6400]{0} multiply(p1.1, p1.1) @@ -101,7 +101,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[10,10]{1,0} parameter(1) mul = f32[10,10]{1,0} multiply(p1.1, p1.1) @@ -130,7 +130,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) { // Two sibling fusions with reduce instruction roots sharing the same input // param. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) @@ -165,7 +165,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) { // Multi-output fusion with two reduce instructions root and a sibling reduce // instruction sharing the same input param. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) { const.1 = f32[] constant(1) p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0) @@ -198,7 +198,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) { // Verify that if we already have a multi-output fusion that we prefer to pick // a reduce op from its operands for checking shape compatibility. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[10,10]{1,0} parameter(1) mul = f32[10,10]{1,0} multiply(p1.1, p1.1) @@ -228,7 +228,7 @@ TEST_F(MultiOutputFusionTest, } TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[6400]{0} parameter(0) ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) @@ -256,6 +256,50 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { op::Tuple(op::Multiply(), op::Divide())); } +TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) { + // Fusing a reduce into a loop fusion would require changing the fusion kind. + // That's not supported yet. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(0) + reduce = f32[] reduce(p0, const.2), dimensions={0}, to_apply=scalar_add_computation + ROOT root = (f32[6400]{0}, f32[]) tuple(fusion.1, reduce) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(1) + div = f32[6400]{0} divide(p0, const.2) + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Multiply(), op::Divide())); +} + TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) { auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( fused_computation_1 { @@ -341,7 +385,7 @@ TEST_F(MultiOutputFusionTest, } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( ENTRY reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -361,7 +405,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_add { p0.1 = f32[2,2,2]{2,1,0} parameter(0) p1.1 = f32[2,2,2]{2,1,0} parameter(1) @@ -388,7 +432,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_select { p1.1 = f32[2,2,2]{2,1,0} parameter(1) c0 = f32[] constant(0) @@ -429,7 +473,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_element_wise { p0.1 = f32[2,2,2]{2,1,0} parameter(0) p1.1 = f32[2,2,2]{2,1,0} parameter(1) @@ -456,7 +500,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { TEST_F(MultiOutputFusionTest, ProducerConsumerFusionFp16LoopFusionAndReduceFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_select { p1.1 = f16[2,2,2]{2,1,0} parameter(1) c0 = f16[] constant(0) @@ -497,7 +541,7 @@ TEST_F(MultiOutputFusionTest, TEST_F(MultiOutputFusionTest, ProducerConsumerFusionReduceUnfriendlyLoopFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( mixed_input_layouts_computation { p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0) p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 5868c1a42e6986c82648c9a7b2935d8e9100f968..695feadb11ce9a3baf0c6732a9f6df61a4fcd308 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMContext.h" @@ -85,7 +87,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cuda_libdevice_path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -140,7 +141,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, Compiler* compiler) { { HloPassPipeline pipeline("optimization"); - pipeline.AddInvariantChecker(); + pipeline.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), @@ -156,7 +158,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { auto& pass = pipeline.AddPass>("simplification"); - pass.AddInvariantChecker(); + pass.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls // where possible. Not every batchnorm op can be implemented as a call to @@ -203,7 +206,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // Convert convolutions into CustomCalls to cudnn, then canonicalize them // (PadInsertion). HloPassPipeline pipeline("conv_canonicalization"); - pipeline.AddInvariantChecker(); + pipeline.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); // TODO(b/31709653): Directly use the grouped convolution support of Cudnn. pipeline.AddPass(); pipeline.AddPass(); @@ -218,9 +222,22 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, } { - HloPassPipeline pipeline("layout_assignment"); + // Run layout assignment in a separate pipeline from + // "post-layout-assignment" because we want everything after layout + // assignment to have a layout-sensitive invariant-checker, but + // HloPassPipeline also runs its invariant checker before any passes are + // run, meaning, the pipeline that contains layout assignment cannot contain + // a layout-sensitive verifier! + HloPassPipeline pipeline("layout assignment"); pipeline.AddPass( hlo_module->mutable_entry_computation_layout(), stream_exec); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + + { + HloPassPipeline pipeline("post-layout_assignment"); + pipeline.AddInvariantChecker(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -266,17 +283,20 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassFix fusion("fusion"); - fusion.AddInvariantChecker(); + fusion.AddInvariantChecker(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); fusion.AddPass(); fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); + fusion.AddPass(); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); - reduce_pipeline.AddInvariantChecker(); + reduce_pipeline.AddInvariantChecker( + /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -302,7 +322,8 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker(); + pipeline.AddInvariantChecker(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -352,9 +373,9 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { string vmaj_str, vmin_str, vdot_str; if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str, &vmin_str, &vdot_str) || - !tensorflow::strings::safe_strto64(vmaj_str, &vmaj) || - !tensorflow::strings::safe_strto64(vmin_str, &vmin) || - !tensorflow::strings::safe_strto64(vdot_str, &vdot)) { + !absl::SimpleAtoi(vmaj_str, &vmaj) || + !absl::SimpleAtoi(vmin_str, &vmin) || + !absl::SimpleAtoi(vdot_str, &vdot)) { LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path << " --version:\n" << out; @@ -466,7 +487,7 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, tensorflow::SubProcess ptxas_info_dumper; std::vector ptxas_args = { ptxas_path, ptx_path, "-o", cubin_path, - tensorflow::strings::StrCat("-arch=sm_", cc_major, cc_minor)}; + absl::StrCat("-arch=sm_", cc_major, cc_minor)}; if (VLOG_IS_ON(2)) { ptxas_args.push_back("-v"); } @@ -674,7 +695,7 @@ StatusOr> NVPTXCompiler::RunBackend( // Write PTX to IR dump directory, if IR dumping was requested. if (!ir_dump_directory.empty()) { const string ptx_outfile = tensorflow::io::JoinPath( - ir_dump_directory, tensorflow::strings::StrCat(module->name(), ".ptx")); + ir_dump_directory, absl::StrCat(module->name(), ".ptx")); auto status = [&] { auto* env = tensorflow::Env::Default(); TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory)); diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index b99d998c4d7df514c024b1f8d643d08c72059d0e..e0f3e84a4cb25792cf10d38fc529f3e638acf8e4 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -96,7 +96,7 @@ Status OutfeedThunk::ExecuteOnStream( Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } VLOG(2) << "Outfeeding from GPU complete"; diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h index 192359f026bfb2f1d5436713e4a30725fa0ad6ba..11dc56a64fda74cab12024e5f2c6fa2f63c9167d 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h @@ -32,9 +32,7 @@ namespace gpu { // TODO(jlebar): Also pad dots. class PadForTensorCores : public HloPassInterface { public: - tensorflow::StringPiece name() const override { - return "pad for tensor cores"; - } + absl::string_view name() const override { return "pad for tensor cores"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc index 99e7580b826fc5cd6d98a037a5eb064552952e18..104af48c82ab1be9792eff11406af8d2a439e954 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc @@ -29,7 +29,12 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -using PadForTensorCoresTest = HloVerifiedTestBase; +class PadForTensorCoresTest : public HloVerifiedTestBase { + public: + PadForTensorCoresTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { ParseAndVerifyModule(R"( diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h index 67e51509e4c717951c83c7e41943af1de762dee0..a622e894ed9c0d1534262e6b72a5f4ea7b7821ad 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h @@ -26,7 +26,7 @@ namespace gpu { // padding, so that they can be lowered to cuDNN convolution. class PadInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "pad insertion"; } + absl::string_view name() const override { return "pad insertion"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 3838fee674566196e10ddd98462c1a1aa7835e1a..ca57cacb983bd2492a36dc462c09b357abb7ec37 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -57,8 +57,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( unroll_factor_(unroll_factor) {} std::vector -ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) { +ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, + llvm::Type* index_type) { // Emit the following code in LLVM IR: // linear_index = blockIdx.x * blockDim.x + threadIdx.x; // if (linear_index < num_elements) { diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index b82a23419df08cafdc69b6d2f14528484b95dc73..cc7da2e73b681bb351e722cc3fb39f7746f45568 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -58,7 +58,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type) override; private: // The thread and block dimension to parallelize the loop on. diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index c927c5ee1666b6198d96750ff372ac83813a9df9..cf9f102d31305da15dabaf6247f23c5ca9a9e054 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/bits.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -34,9 +34,8 @@ namespace gpu { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims) { - out << tensorflow::strings::Printf("[block: %lld, thread: %lld]", - launch_dims.block_count(), - launch_dims.threads_per_block()); + out << absl::StrFormat("[block: %d, thread: %d]", launch_dims.block_count(), + launch_dims.threads_per_block()); return out; } @@ -91,9 +90,9 @@ LaunchDimensions CalculateLaunchDimensions( } int64 block_count = CeilOfRatio(num_elements, threads_per_block); - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << absl::StrFormat( "Initialized the block count to ceil(# of elements / threads per " - "block) = ceil(%lld/%lld) = %lld", + "block) = ceil(%d/%d) = %d", num_elements, threads_per_block, block_count); return LaunchDimensions(block_count, threads_per_block); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 3f75d8b55959495017f1b08d61bd6e7b44bed27f..091aca23e54bf0585b91e7a05c0837d8a0a2b764 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -16,13 +16,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace gpu { @@ -98,7 +98,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( - i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); + i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc index 0e84ec7e621fcd1778725dc2743d7a70fb01c47a..79e77d4c4d649020cf52ac25c220c3f90e8469b9 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -39,8 +39,7 @@ void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr hlo_module, const string& pattern) { std::unique_ptr executable = std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie()); - string ptx_str = - std::string(static_cast(executable.get())->ptx()); + string ptx_str(static_cast(executable.get())->ptx()); StatusOr filecheck_result = RunFileCheck(ptx_str, pattern); ASSERT_TRUE(filecheck_result.ok()); EXPECT_TRUE(filecheck_result.ValueOrDie()); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index cca35316f0c472d2a17c466f8cd1af7f22575a8b..15d1e269cc22b88f5269175084f20600f165011c 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -27,13 +27,22 @@ namespace { class GpuKernelTilingTest : public GpuCodegenTest { protected: - GpuKernelTilingTest() { + GpuKernelTilingTest() {} + + // Most tests in this file want to skip layout assignment, but a few need it + // enabled. + HloModuleConfig ConfigWithLayoutAssignment() { + return GetModuleConfigForTest(); + } + + HloModuleConfig ConfigWithoutLayoutAssignment() { + HloModuleConfig config; auto debug_options = HloTestBase::GetDebugOptionsForTest(); - config_.set_debug_options(debug_options); // Disable layout_assignment to use the preassigned layouts. - debug_options.add_xla_disable_hlo_passes("layout_assignment"); + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + config.set_debug_options(debug_options); + return config; } - HloModuleConfig config_; }; TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { @@ -46,7 +55,13 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + // + // We must enable layout assignment in order for this test to work correctly. + // AlgebraicSimplifier removes copy1; it's added back by layout assignment, + // which respects the module's entry computation layout. But if we don't run + // layout assignment...well, nobody else adds the copy back. + auto hlo_module = + ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy @@ -68,8 +83,11 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) { ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0) })"; - // Check that a call to llvm.nvvm.barrier0 is not generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + // Check that a call to llvm.nvvm.barrier0 is not generated. As in + // UnnestedTransposeWithProperDimensionsTiled, we must run layout assignment + // here. + auto hlo_module = + ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy @@ -95,7 +113,8 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion @@ -128,7 +147,8 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion @@ -162,7 +182,8 @@ TEST_F(GpuKernelTilingTest, })"; // Check that a call to llvm.nvvm.barrier0 is not generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc index 962293630683fcbbce3941f622061a2ff0f02dda..0f2d5568cafc9db0f5f067437fdd5e2e775ad2c8 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -138,6 +138,9 @@ TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) { HloModuleConfig config; auto debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_max_kernel_unroll_factor(2); + // Disable layout assignment for this test. Layout assignment does not expect + // fusions to be present, and so it does the wrong thing. + debug_options.add_xla_disable_hlo_passes("layout-assignment"); config.set_debug_options(debug_options); const char *const kMultiOutputFusionModule = R"( diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index bdb062837c5ba4b588ea0d535a786f33fe4f4015..141f3219387940a08ef22cbcc0be0971a14c2cd6 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -144,16 +144,15 @@ const std::list& ThunkSchedule::DependsOn( string ThunkSchedule::ToString() const { string result = "Total order:\n"; for (Thunk* thunk : thunk_total_order_) { - tensorflow::strings::StrAppend(&result, "\t", - thunk->hlo_instruction()->ToString(), "\n"); + absl::StrAppend(&result, "\t", thunk->hlo_instruction()->ToString(), "\n"); } - tensorflow::strings::StrAppend(&result, "Dependencies:\n"); + absl::StrAppend(&result, "Dependencies:\n"); for (const auto& entry : depends_on_) { const Thunk* dependent = entry.first; for (const Thunk* dependency : entry.second) { - tensorflow::strings::StrAppend( - &result, "\t", dependent->hlo_instruction()->name(), " depends on ", - dependency->hlo_instruction()->name(), "\n"); + absl::StrAppend(&result, "\t", dependent->hlo_instruction()->name(), + " depends on ", dependency->hlo_instruction()->name(), + "\n"); } } return result; diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 828fc2884bd7d58333d86c35a537f06467cf6e4a..c4754fe378960834e1157b0ff25c03c0fc4754c7 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -70,7 +70,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, if (!block_status.ok()) { return InternalError( "Failed to complete all kernels launched on stream %p: %s", stream, - block_status.error_message().c_str()); + block_status.error_message()); } if (!condition_result) { diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index c5f3906356d821e059d2b1213c9083c4408a4d1c..40183de96ee363996e6b0b883a78e7a8b5d13ab2 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -118,7 +118,8 @@ class WhileTransformerTest : public HloTestBase { } void RunCopyInsertionPass() { - HloVerifier verifier; + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); TF_ASSERT_OK(verifier.Run(module_.get()).status()); CopyInsertion copy_insertion; TF_ASSERT_OK(copy_insertion.Run(module_.get()).status()); diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index 31431f115f8ffd72df65638a2b00e63b3c433a7e..a2be89511babc23ebcd5cb40abee2a95d16dc451 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/types.h" @@ -43,8 +43,7 @@ namespace { // Adds a computation to the given HLO module which adds a scalar constant to // its parameter and returns the result. HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) { - auto builder = - HloComputation::Builder(tensorflow::strings::StrCat("add_", addend)); + auto builder = HloComputation::Builder(absl::StrCat("add_", addend)); auto x_value = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "x_value")); auto half = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 821c599863839865c77a778ba569c56609fea0de..58b7af93ebfce74951c0f2d65ab226fc94d62e4b 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 52 +// Next ID: 53 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -173,6 +173,9 @@ message HloInstructionProto { // Precision configuration for the instruction. Has backend-specific meaning. xla.PrecisionConfigProto precision_config = 51; + + // Collective permute field. + repeated SourceTarget source_target_pairs = 52; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 0ca489846e7137a9ffa341e63c8a289ed4af2043..0986da65cbd3d550ecfa01212364518aba651d86 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -28,15 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; // Data structure used to construct the alias analysis. Thrown away after alias // analysis is complete. This data structure keeps track of which sets of @@ -414,7 +412,7 @@ Status HloAliasAnalysis::Verify() const { } string HloAliasAnalysis::ToString() const { - string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n"); + string out = absl::StrCat("HloAliasAnalysis, module ", module_->name(), "\n"); StrAppend(&out, " Buffers at each position:\n"); for (const HloComputation* computation : module_->computations()) { for (const HloInstruction* instruction : computation->instructions()) { @@ -537,10 +535,10 @@ bool HloAliasAnalysis::HasLiveRangeInterference( if (ordering.MayInterfere(*values[i - 1], *values[i], dataflow_analysis())) { VLOG(1) << "In buffer " << buffer.id() << " containing values:\n " - << Join(values, ", ", - [](string* out, const HloValue* value) { - StrAppend(out, value->ToShortString()); - }) + << absl::StrJoin(values, ", ", + [](string* out, const HloValue* value) { + StrAppend(out, value->ToShortString()); + }) << "\nValue " << values[i - 1]->ToShortString() << " may interfere with value " << values[i]->ToShortString(); diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc index e16413f361fb0216792b47c3c67ef3c1357c2221..6c11a073b74c61e44dfe81a32261ae78ae7b46fb 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.cc +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -27,15 +29,10 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrCat; - bool HloBuffer::operator==(const HloBuffer& other) const { bool equal = id() == other.id(); if (equal) { @@ -59,10 +56,11 @@ std::vector HloBuffer::ComputePositions() const { } string HloBuffer::ToString() const { - return StrCat("HloBuffer ", id_, ", values: ", - Join(values_, ", ", [](string* result, const HloValue* value) { - result->append(value->ToShortString()); - })); + return absl::StrCat( + "HloBuffer ", id_, ", values: ", + absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) { + result->append(value->ToShortString()); + })); } std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 70b18ff35676fbe43d264d186a675c1436f60317..c2d0673f4918116e9bfa9e92702344b24555391b 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -25,6 +25,9 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -37,13 +40,11 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::strings::StrCat; +using absl::StrCat; std::unique_ptr HloComputation::Builder::Build( HloInstruction* root_instruction) { @@ -136,7 +137,7 @@ string RenameFusionParameter(const string& original_name, int64 new_param_no) { } string after_param = original_name.substr(index + param_underscore.size()); int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { + if (absl::SimpleAtoi(after_param, &numeric_suffix)) { return StrCat(original_name.substr(0, index + param_underscore.size()), new_param_no); } @@ -318,12 +319,12 @@ void ComputeComputationPostOrder( } } -enum State { kVisiting, kVisited }; +} // namespace -void ComputeInstructionPostOrder( - std::map> channel_dependency_map, +void HloComputation::ComputeInstructionPostOrder( + const HloComputation::ChannelDependencyMap& channel_dependency_map, std::vector* post_order, HloInstruction* root, - tensorflow::gtl::FlatMap* visited) { + tensorflow::gtl::FlatMap* visited) const { std::vector dfs_stack; dfs_stack.push_back(root); while (!dfs_stack.empty()) { @@ -361,20 +362,22 @@ void ComputeInstructionPostOrder( // dependencies. switch (current->opcode()) { case HloOpcode::kRecvDone: { - const auto& dependencies = - channel_dependency_map[current->channel_id()]; - for (HloInstruction* op : dependencies) { - dfs_stack.emplace_back(op); + auto it = channel_dependency_map.find(current->channel_id()); + if (it != channel_dependency_map.end()) { + for (HloInstruction* op : it->second) { + dfs_stack.emplace_back(op); + } } break; } case HloOpcode::kCrossReplicaSum: { auto all_reduce_id = current->all_reduce_id(); if (all_reduce_id) { - const auto& dependencies = - channel_dependency_map[all_reduce_id.value()]; - for (HloInstruction* op : dependencies) { - dfs_stack.emplace_back(op); + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + for (HloInstruction* op : it->second) { + dfs_stack.emplace_back(op); + } } } break; @@ -385,11 +388,9 @@ void ComputeInstructionPostOrder( } } -} // namespace - -std::map> +HloComputation::ChannelDependencyMap HloComputation::ComputeChannelDependencies() const { - std::map> channel_dependency_map; + ChannelDependencyMap channel_dependency_map; for (const auto& instruction : instructions_) { switch (instruction->opcode()) { case HloOpcode::kSend: { @@ -420,7 +421,7 @@ std::vector HloComputation::MakeInstructionPostOrder() const { std::vector post_order; post_order.reserve(instruction_count()); std::vector trace_instructions; - tensorflow::gtl::FlatMap visited; + tensorflow::gtl::FlatMap visited; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -624,16 +625,15 @@ StatusOr HloComputation::DeepCopyInstruction( if (instruction->parent() != this) { return FailedPrecondition( "Can't deep copy instruction %s: instruction is not in computation %s", - instruction->name().c_str(), name().c_str()); + instruction->name(), name()); } if (indices_to_copy != nullptr && !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) { return FailedPrecondition( "Can't deep copy instruction %s: given shape tree of indices to copy " "has incompatible shapes: %s vs. %s", - instruction->name().c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - ShapeUtil::HumanString(indices_to_copy->shape()).c_str()); + instruction->name(), ShapeUtil::HumanString(instruction->shape()), + ShapeUtil::HumanString(indices_to_copy->shape())); } ShapeIndex index; @@ -663,7 +663,7 @@ StatusOr HloComputation::DeepCopyInstructionWithCustomCopier( if (instruction->parent() != this) { return FailedPrecondition( "Can't deep copy instruction %s: instruction is not in computation %s", - instruction->name().c_str(), name().c_str()); + instruction->name(), name()); } ShapeIndex index; return DeepCopyHelper(instruction, &index, copy_leaf); @@ -682,6 +682,9 @@ ProgramShape HloComputation::ComputeProgramShape() const { } bool HloComputation::operator==(const HloComputation& other) const { + if (this == &other) { + return true; + } std::set> visited; std::function eq = [&visited, &eq](const HloInstruction* a, const HloInstruction* b) { @@ -743,16 +746,19 @@ std::unique_ptr HloComputation::ComputeReachability() switch (hlo->opcode()) { case HloOpcode::kRecvDone: { - const auto& dependencies = channel_dependency_map[hlo->channel_id()]; - absl::c_copy(dependencies, std::back_inserter(inputs)); + auto it = channel_dependency_map.find(hlo->channel_id()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } break; } case HloOpcode::kCrossReplicaSum: { auto all_reduce_id = hlo->all_reduce_id(); if (all_reduce_id) { - const auto& dependencies = - channel_dependency_map[all_reduce_id.value()]; - absl::c_copy(dependencies, std::back_inserter(inputs)); + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } } break; } @@ -802,11 +808,10 @@ std::vector HloComputation::CollectUnreachableRoots() const { } } VLOG(3) << "Unreachable roots:" - << tensorflow::str_util::Join( - unreachable_roots, "\n\t", - [](string* out, const HloInstruction* hlo) { - tensorflow::strings::StrAppend(out, hlo->ToString()); - }); + << absl::StrJoin(unreachable_roots, "\n\t", + [](string* out, const HloInstruction* hlo) { + absl::StrAppend(out, hlo->ToString()); + }); return unreachable_roots; } @@ -977,8 +982,7 @@ void HloComputation::UniquifyName(NameUniquer* name_uniquer) { name_ = name_uniquer->GetUniqueName(name_); } -HloInstruction* HloComputation::GetInstructionWithName( - tensorflow::StringPiece name) { +HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) { auto instructions_in_computation = instructions(); auto it = absl::c_find_if( instructions_in_computation, diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index faa33f0f90e8b070982347820a994818f21f93a8..59016624f764d985f2dc3816600466ea66aade77 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -367,7 +367,7 @@ class HloComputation { // Returns the instruction in this computation that has name `name`. Returns // null if there is no such computation. - HloInstruction* GetInstructionWithName(tensorflow::StringPiece name); + HloInstruction* GetInstructionWithName(absl::string_view name); int64 unique_id() const { return unique_id_; } @@ -403,8 +403,15 @@ class HloComputation { // instructions. For send&recv pairs it means the send instruction and for // cross-replica-sum the union of the dependencies for all participating // instructions. - std::map> ComputeChannelDependencies() - const; + using ChannelDependencyMap = + tensorflow::gtl::FlatMap>; + ChannelDependencyMap ComputeChannelDependencies() const; + + enum VisitState { kVisiting, kVisited }; + void ComputeInstructionPostOrder( + const HloComputation::ChannelDependencyMap& channel_dependency_map, + std::vector* post_order, HloInstruction* root, + tensorflow::gtl::FlatMap* visited) const; string name_; int64 unique_id_; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h index 331480bd029727fa15476cb9ced2e7b7afd170f3..4557983a9c0b0006cc2189c96a88478d469475c1 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.h +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h @@ -25,7 +25,7 @@ namespace xla { // computation on constants. class HloConstantFolding : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "constant_folding"; } + absl::string_view name() const override { return "constant_folding"; } // Run constant folding operations on the given module. Returns whether the // module was changed (constant expressions folded). diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 3e68f59bd9a0439ed04794bcd964d1734abc4bbc..0e12a1ee03497b2ff0afd48509ae1f10c05e5f60 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -540,15 +540,10 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { } Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) { - // TODO(b/110096724): Compute correct cost here. - double flops = 0.0; - ShapeUtil::ForEachSubshape(hlo->shape(), - [&](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsArray(subshape)) { - flops += ShapeUtil::ElementsIn(subshape); - } - }); - current_properties_[kFlopsKey] = flops; + return Status::OK(); +} + +Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 1bf1c4a315655e78e10a8a66b571347357cc23e9..c6a2007904a4c550f520d4725cd67796686e4b88 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -72,6 +72,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleFft(const HloInstruction* fft) override; Status HandleCrossReplicaSum(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; + Status HandleCollectivePermute(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; Status HandleRng(const HloInstruction* random) override; diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index c4e27dc558ecb2a3a0acfd036de73506ce7631fa..131846794d9cfa9268cc7a96ad045bba6161e05c 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -16,14 +16,15 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" namespace xla { +using absl::StrCat; using tensorflow::gtl::ArraySlice; -using tensorflow::strings::StrCat; StatusOr MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs) { @@ -336,8 +337,8 @@ StatusOr BroadcastZeros( StatusOr> CreateComputationWithSignature( ArraySlice domain, const Shape& range, - tensorflow::StringPiece name) { - HloComputation::Builder b{std::string(name)}; + absl::string_view name) { + HloComputation::Builder b{string(name)}; int64 param_idx = 0; for (const Shape* param_shape : domain) { b.AddInstruction(HloInstruction::CreateParameter( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 5ff8946fb098b57ae563a8ade47e8323f807a369..1bc6d09b4502c88d0d4e4e207075d64714190611 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -177,7 +177,7 @@ StatusOr BroadcastZeros( // a value of type `range`. StatusOr> CreateComputationWithSignature( tensorflow::gtl::ArraySlice domain, const Shape& range, - tensorflow::StringPiece name); + absl::string_view name); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 06484f4012fc091f70df7bc8ec231ce3fcf89669..cb367adf5ef29111838dd6ee1b770394eef1301c 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/hash/hash.h" namespace xla { @@ -103,6 +104,9 @@ int64 CseHash(const HloInstruction* instruction) { for (auto operand : instruction->operands()) { hash = tensorflow::Hash64Combine(hash, operand->unique_id()); } + if (instruction->opcode() == HloOpcode::kConstant) { + hash = tensorflow::Hash64Combine(hash, instruction->literal().Hash()); + } return hash; } diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h index 5e2b348bdda2b31556fb692e24d2bad2e4173ef5..a28c03599a8765da708f37b986010713654647cb 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.h +++ b/tensorflow/compiler/xla/service/hlo_cse.h @@ -34,7 +34,7 @@ class HloCSE : public HloPassInterface { : is_layout_sensitive_(is_layout_sensitive), only_fusion_computations_(only_fusion_computations) {} ~HloCSE() override = default; - tensorflow::StringPiece name() const override { return "cse"; } + absl::string_view name() const override { return "cse"; } // Run CSE on the given module. Returns whether the module was changed (common // subexpressions were found and eliminated). diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 01840a56e2114eb3d478425f12aa3c7c7063bdf2..3376d170e64a71c0fa6b659e1d5ed195ac9eaba3 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -30,8 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -79,8 +78,8 @@ bool MultiDynamicSliceUseShareSameIndices( } // namespace -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; HloDataflowAnalysis::HloDataflowAnalysis( const HloModule& module, bool ssa_form, bool bitcast_defines_value, @@ -838,7 +837,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { return Unimplemented( "Computation %s is called in both a parallel (eg, kMap) and " "sequential (eg, kCall) context", - computation->name().c_str()); + computation->name()); } if (call_graph_node.caller_callsites().empty() || call_graph_node.context() == CallContext::kParallel) { @@ -977,28 +976,22 @@ Status HloDataflowAnalysis::Verify() const { bool HloDataflowAnalysis::DoesNotUseOperandBuffer( const HloInstruction* operand, const ShapeIndex& index, const HloInstruction* user) const { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - // Iterate through all users of all uses of the fusion parameter value. - // Return false if any uses are detected, returns true otherwise. - const HloValue& value = GetValueDefinedAt(fusion_param, index); - return value.uses().empty(); - } else { - // Return false if no value at 'operand' and 'index' is used at 'user'. - for (const HloValue* value : GetValueSet(operand, index).values()) { - for (const HloUse& use : value->uses()) { - if (use.instruction == user) { - return false; + // Return false if no value at 'operand' and 'index' is used at 'user'. + for (const HloValue* value : GetValueSet(operand, index).values()) { + for (const HloUse& use : value->uses()) { + if (use.instruction == user) { + if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + HloInstruction* fusion_param = + user->fused_parameter(use.operand_number); + const HloValue& value = + GetValueDefinedAt(fusion_param, use.operand_index); + return value.uses().empty(); } + return false; } } } - return true; } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index f4abc7a7c7dcfb223067fe946bec0c5ef32f206b..a1678d4943c7c722df38c4dc93e284d614279217 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -138,7 +138,8 @@ class HloDataflowAnalysis { // Returns true if 'user' cannot possibly use the buffer at 'index' in // 'operand'. Returns false otherwise. // - // REQUIRES: 'operand' is an operand of 'user'. + // 'operand' does not have to be an operand of 'user'. This can be the case + // with indirect uses. bool DoesNotUseOperandBuffer(const HloInstruction* operand, const ShapeIndex& index, const HloInstruction* user) const; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 4755c4a0cf8d268b1c47e596a14605eb2c60b36c..d1a96c10f88e3c05e21a6db4eccb46683cd64c4a 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1963,6 +1963,54 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); } +// Similar to FusedDynamicUpdateSlice above, but tests indirect uses of the +// parameter tuple. +TEST_F(DoesNotUseOperandBufferTest, IndirectUses) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto t0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 0)); + auto t1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 1)); + // Swap the tuple elements. + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({t1, t0})); + + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); + // The same holds for the parameter tuple, except that the tuple elements are + // swapped in 'tuple'. + EXPECT_TRUE( + dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {1}, fusion)); + EXPECT_FALSE( + dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {0}, fusion)); +} + class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {}; TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h index 4e244494d6f98c48f4376bd762f116b9a9c2084d..1fe69b1395753a612499e6e87bfc22f8ac8e767b 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.h +++ b/tensorflow/compiler/xla/service/hlo_dce.h @@ -36,7 +36,7 @@ namespace xla { class HloDCE : public HloPassInterface { public: ~HloDCE() override {} - tensorflow::StringPiece name() const override { return "dce"; } + absl::string_view name() const override { return "dce"; } // Run the pass on the given module. Returns whether the module was changed // (instructions were removed). diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc index af904647f82576c8a9b3d9dd3fbe110bc712df0c..72185698c9bdcbf2bebed7ee82bc4ed082ce6a14 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -31,31 +31,10 @@ class HloDomainIsolator::RunContext { StatusOr Run(); private: - // Inserts a kDomain instruction between operand and instruction in case - // the attribute (ie, sharding) values change between root and instruction. - // Returns the newly inserted kDomain instruction, or nullptr if no kDomain - // instruction was necessary. - StatusOr CreateDomain(HloInstruction* instruction, - HloInstruction* root, - HloInstruction* operand); - HloModule* module_; HloDomainIsolator* isolator_; }; -StatusOr HloDomainIsolator::RunContext::CreateDomain( - HloInstruction* instruction, HloInstruction* root, - HloInstruction* operand) { - HloInstruction* domain = nullptr; - std::unique_ptr domain_instruction = - isolator_->creator_(instruction, root, operand); - if (domain_instruction != nullptr) { - domain = operand->parent()->AddInstruction(std::move(domain_instruction)); - TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain)); - } - return domain; -} - StatusOr HloDomainIsolator::RunContext::Run() { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator"); @@ -76,10 +55,11 @@ StatusOr HloDomainIsolator::RunContext::Run() { root = root->mutable_operand(0); } // Check whether a kDomain is necessary between instruction and operand. - TF_ASSIGN_OR_RETURN(HloInstruction * domain, - CreateDomain(instruction, root, operand)); + HloInstruction* domain = + isolator_->creator_(instruction, root, operand); if (domain != nullptr) { VLOG(4) << "New domain: " << domain->ToString(); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain)); ++added_domains; } } diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h index bb1537766c6bd20ebd6a8adf15442bcbd9b29250..d36631fc2f16902ed8f1f89f903027081f9b3801 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -38,12 +38,12 @@ class HloDomainIsolator : public HloPassInterface { // instruction differes from the attribute of the root (the second // HloInstruction argument). // Returns nullptr in case no domain separation is necessary. - using DomainCreator = std::function( + using DomainCreator = std::function; explicit HloDomainIsolator(DomainCreator creator); - tensorflow::StringPiece name() const override { return "domain_isolator"; } + absl::string_view name() const override { return "domain_isolator"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index edf0073f3091ef4da7ced3f13b56961a7db4b430..8b2846e0c277b3e7cffd578d988d0a09c13833ed 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -72,6 +72,11 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { } Status HloDomainMap::Populate(HloComputation* computation) { + InstructionOrderMap instructions_post_order; + int64 count = 0; + for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { + instructions_post_order.insert(std::make_pair(instruction, count++)); + } for (HloInstruction* instruction : computation->instructions()) { if (IsDomainInstruction(instruction)) { // If this is a kDomain of the kind we are currently processing, check @@ -85,7 +90,7 @@ Status HloDomainMap::Populate(HloComputation* computation) { continue; } TF_ASSIGN_OR_RETURN(std::unique_ptr domain, - CreateDomain(instruction)); + CreateDomain(instruction, instructions_post_order)); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } return Status::OK(); @@ -143,10 +148,12 @@ Status HloDomainMap::ExpandDomain(HloInstruction* instruction, } StatusOr> HloDomainMap::CreateDomain( - HloInstruction* instruction) const { + HloInstruction* instruction, + const InstructionOrderMap& instructions_order) const { auto domain = absl::make_unique(); TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get())); - domain->instructions = MakeNonDomainInstructions(domain->reach_set); + domain->instructions = + MakeNonDomainInstructions(domain->reach_set, instructions_order); return std::move(domain); } @@ -168,7 +175,8 @@ bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { /* static */ std::vector HloDomainMap::MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet& instruction_set) { + const tensorflow::gtl::FlatSet& instruction_set, + const InstructionOrderMap& instructions_order) { std::vector instructions; instructions.reserve(instruction_set.size()); for (HloInstruction* instruction : instruction_set) { @@ -176,9 +184,10 @@ HloDomainMap::MakeNonDomainInstructions( instructions.push_back(instruction); } } + // sort instructions according to instructions_order std::sort(instructions.begin(), instructions.end(), - [](HloInstruction* a, HloInstruction* b) { - return a->unique_id() < b->unique_id(); + [&instructions_order](HloInstruction* a, HloInstruction* b) { + return instructions_order.at(a) < instructions_order.at(b); }); return instructions; } diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index 1ca71597253eecfb45ae8f384240033a57045277..633109249a91eec3d7b4cbe5b423b73f980217c9 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -70,6 +70,11 @@ class HloDomainMap { int64 GetDomainId(HloInstruction* instruction) const; private: + // Map used for representing instruction ordering, i.e. + // order_map[a] < order_map[b] means a must be ordered before b. + using InstructionOrderMap = + tensorflow::gtl::FlatMap; + HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} // Check if the kDomain instruction is facing (via its operand link) another @@ -95,12 +100,14 @@ class HloDomainMap { // Creates a domain data structure using the ExpandDomain() API. StatusOr> CreateDomain( - HloInstruction* instruction) const; + HloInstruction* instruction, + const InstructionOrderMap& instructions_order) const; // Out of an instruction set, returns a vector of all the ones which are not // a kDomain kind. static std::vector MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet& instruction_set); + const tensorflow::gtl::FlatSet& instruction_set, + const InstructionOrderMap& instructions_order); string domain_kind_; std::vector> instruction_domains_; diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index f855f2a1fc944fcc11c9afed278bef4af87813da..6c142ee47421049e8a25dfb80a6297e02fe782f1 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -20,10 +20,10 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -44,7 +44,10 @@ class DomainMetadata { // two domains of different kind intersect each other. tensorflow::gtl::FlatSet reach_set; - // The same instructions in reach_set, but purged from kDomain instructions. + // The same instructions in reach_set, but purged from kDomain instructions + // and ordered according to their computation graph post-order, i.e. + // if instructions[pos_a] depends on instructions[pos_b], then pos_a > + // pos_b. std::vector instructions; // If we consider a graph edge as an arrow oriented from the operand to the @@ -63,7 +66,7 @@ class DomainMetadata { // Returns the metadata type. A unique identifier which describes the real // metadata type. - virtual tensorflow::StringPiece Kind() const = 0; + virtual absl::string_view Kind() const = 0; // Compares the metadata object with another one and returns true if the // two matches. diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h index c859e05f02e54d601804b641094ecdd11bbe1aed..97bc8ef604092acc849b55b09af8a24bf775529e 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.h +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -35,13 +35,13 @@ class HloDomainRemover : public HloPassInterface { // instructions in it with the same attributes (ie, sharding), a normalizer // function is tasked at applying attribute normalization on the instructions // within such domain. - HloDomainRemover(tensorflow::StringPiece kind, + HloDomainRemover(absl::string_view kind, std::function normalizer) - : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {} + : kind_(kind), normalizer_(std::move(normalizer)) {} - tensorflow::StringPiece name() const override { return "domain_remover"; } + absl::string_view name() const override { return "domain_remover"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 2654929bf056a87e9e8330f94f6168f7b7354a5b..c8e0a9e289ea15a9b60334e31eec1dc8cb093245 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -29,6 +29,11 @@ namespace xla { namespace { class HloDomainTest : public HloVerifiedTestBase { + public: + HloDomainTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: bool FindUserViaDomainPath(HloInstruction* instruction, HloInstruction* operand) const { @@ -46,9 +51,8 @@ class HloDomainTest : public HloVerifiedTestBase { // Checks whether there is a kDomain instruction in the edge between the // instruction and the operand. - bool HasDomainEdge(HloModule* module, - tensorflow::StringPiece instruction_name, - tensorflow::StringPiece operand_name) { + bool HasDomainEdge(HloModule* module, absl::string_view instruction_name, + absl::string_view operand_name) { HloInstruction* instruction = FindInstruction(module, instruction_name); HloInstruction* operand = FindInstruction(module, operand_name); CHECK_NE(instruction, nullptr); @@ -66,7 +70,7 @@ class HloDomainTest : public HloVerifiedTestBase { return false; } - StatusOr ParseModule(tensorflow::StringPiece hlo_string) { + StatusOr ParseModule(absl::string_view hlo_string) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); ParseAndVerifyModule(hlo_string, config); @@ -84,7 +88,7 @@ class OpNameMetadata : public DomainMetadata { return absl::make_unique(opname_); } - tensorflow::StringPiece Kind() const override { return KindName(); } + absl::string_view Kind() const override { return KindName(); } bool Matches(const DomainMetadata& other) const override { const OpNameMetadata* other_ptr = @@ -98,16 +102,16 @@ class OpNameMetadata : public DomainMetadata { string ToString() const override { return opname_; } - static tensorflow::StringPiece KindName() { return "opname"; } + static absl::string_view KindName() { return "opname"; } private: string opname_; }; // Creator function for OpNameMetadata domains. -std::unique_ptr OpNameDomainCreator(HloInstruction* instruction, - HloInstruction* root, - HloInstruction* operand) { +HloInstruction* OpNameDomainCreator(HloInstruction* instruction, + HloInstruction* root, + HloInstruction* operand) { if (instruction->metadata().op_name() == root->metadata().op_name()) { return nullptr; } @@ -115,9 +119,9 @@ std::unique_ptr OpNameDomainCreator(HloInstruction* instruction, absl::make_unique(root->metadata().op_name()); std::unique_ptr user_side_metadata = absl::make_unique(instruction->metadata().op_name()); - return HloInstruction::CreateDomain(operand->shape(), operand, - std::move(operand_side_metadata), - std::move(user_side_metadata)); + return operand->parent()->AddInstruction(HloInstruction::CreateDomain( + operand->shape(), operand, std::move(operand_side_metadata), + std::move(user_side_metadata))); } Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain, @@ -144,7 +148,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -186,7 +190,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(!isolator_changed); } @@ -213,7 +217,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -250,7 +254,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_FALSE(isolator_changed); } @@ -304,7 +308,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator sharding_isolator(CreateShardingDomain); + HloDomainIsolator sharding_isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, sharding_isolator.Run(module)); EXPECT_TRUE(sharding_isolator_changed); @@ -346,7 +350,8 @@ ENTRY entry { token = token[] after-all() infeed = ((f32[4], f32[4]), token[]) infeed(token), sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}} - infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0 + infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0, + sharding={{maximal device=1}, {maximal device=0}} gte0 = f32[4] get-tuple-element(infeed.data), index=0 gte1 = f32[4] get-tuple-element(infeed.data), index=1 copy0 = f32[4] copy(gte0) @@ -358,7 +363,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -380,11 +385,8 @@ ENTRY entry { // \ / // TUPLE // | - HloInstruction* infeed = FindInstruction(module, "infeed"); - ASSERT_NE(infeed, nullptr); - HloInstruction* infeed_data = - infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0)); + HloInstruction* infeed_data = FindInstruction(module, "infeed.data"); + ASSERT_NE(infeed_data, nullptr); auto infeed_data_users = infeed_data->users(); HloInstruction* new_gte0 = infeed_data->parent()->AddInstruction( @@ -447,7 +449,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -492,6 +494,7 @@ TEST_F(HloDomainTest, DumpParseNullSharding) { ASSERT_TRUE(ParseModule(hlo_string).status().ok()); } +// Tuple inputs are domain instructions. TEST_F(HloDomainTest, DomainTuple) { const char* const hlo_string = R"( HloModule Module @@ -499,14 +502,15 @@ HloModule Module ENTRY entry { p0 = f32[4] parameter(0), sharding={maximal device=0} cst = u32[] constant(0), sharding={maximal device=1} - tpl = (u32[], f32[4]) tuple(cst, p0), sharding={{maximal device=1}, {maximal device=0}} + tpl = (u32[], f32[4]) tuple(cst, p0), + sharding={{maximal device=1}, {maximal device=0}} ROOT gte = f32[4] get-tuple-element(tpl), index=1, sharding={maximal device=0} } )"; TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -584,5 +588,109 @@ ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) { EXPECT_FALSE(HasDomainEdge(module, "d", "c")); } +// Emulate instructions inserted at top and bottom within nested tuple domain. +TEST_F(HloDomainTest, DomainTupleTopBottomInsert) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = f32[4] parameter(0), sharding={maximal device=1} + p1 = (f32[5], f32[6]) parameter(1), + sharding={{maximal device=1}, {maximal device=0}} + tuple.0 = (f32[4], (f32[5], f32[6])) tuple(p0, p1), + sharding={{maximal device=1}, {maximal device=1}, {maximal device=0}} + ROOT res = (f32[5], f32[6]) get-tuple-element(tuple.0), index=1, + sharding={{maximal device=1}, {maximal device=0}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + + HloDomainIsolator isolator(ShardingDomainCreator{}); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_TRUE(isolator_changed); + + // Clear sharding of tuple.0 instruction, in order to test domain sharding + // application. + auto tuple0 = FindInstruction(module, "tuple.0"); + tuple0->clear_sharding(); + + // Insert the following instructons above and below tuple.0, to emulate other + // passes effects: + // COPY.0 + // \ / + // TUPLE.0 + // / \ + // COPY.1 \ + // / \ + // GTE.0 GTE.1 + // | | + // | COPY.2 + // \ / + // \ / + // TUPLE.1 + // | + auto tuple0_users = tuple0->users(); + auto computation = tuple0->parent(); + HloInstruction* copy0 = computation->AddInstruction( + HloInstruction::CreateUnary(tuple0->operand(1)->shape(), HloOpcode::kCopy, + tuple0->mutable_operand(1))); + TF_EXPECT_OK(tuple0->ReplaceOperandWith(1, copy0)); + + HloInstruction* copy1 = computation->AddInstruction( + HloInstruction::CreateUnary(tuple0->shape(), HloOpcode::kCopy, tuple0)); + HloInstruction* gte0 = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(copy1->shape(), 0), copy1, 0)); + HloInstruction* gte1 = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(tuple0->shape(), 1), tuple0, 1)); + HloInstruction* copy2 = computation->AddInstruction( + HloInstruction::CreateUnary(gte1->shape(), HloOpcode::kCopy, gte1)); + HloInstruction* tuple1 = + computation->AddInstruction(HloInstruction::CreateTuple({gte0, copy2})); + + for (HloInstruction* user : tuple0_users) { + TF_EXPECT_OK(tuple0->ReplaceUseWith(user, tuple1)); + } + + HloDomainRemover remover(ShardingMetadata::KindName(), + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + EXPECT_TRUE(remover_changed); + + EXPECT_TRUE(tuple0->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(tuple0->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + tuple0->sharding()); + + EXPECT_TRUE(copy0->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(copy0->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + copy0->sharding()); + + // copy1 has partial information only from gte.0, so in the end it gets no + // sharding at all. During propagation it does propagate the information from + // gte.0 though, enabling Tuple.0 to be fully sharded. + EXPECT_FALSE(copy1->has_sharding()); + + EXPECT_TRUE(gte0->has_sharding()); + EXPECT_EQ(HloSharding::AssignDevice(1), gte0->sharding()); + + EXPECT_TRUE(gte1->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(gte1->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + gte1->sharding()); + + EXPECT_TRUE(copy2->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(copy2->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + copy2->sharding()); + + EXPECT_TRUE(tuple1->has_sharding()); + EXPECT_EQ(tuple0->sharding(), tuple1->sharding()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc index 751fc677e2d955fd3d9f8970f7c0370a22c054bf..dc514ae3e5c6907f6398805d171e69ee8635d08e 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc @@ -52,7 +52,7 @@ Status HloDomainVerifier::RunContext::PopulateDomainKinds() { TF_RET_CHECK(instruction->user_side_metadata().Kind() == instruction->operand_side_metadata().Kind()) << instruction->ToString(); - kinds.insert(instruction->user_side_metadata().Kind().ToString()); + kinds.insert(string(instruction->user_side_metadata().Kind())); } } } diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h index 8e53cf97f8ba9a88140a909ad20c1a938aec8c1f..81d6d69a8c59da2fc77cb2bab808602cd964fdaf 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h @@ -33,7 +33,7 @@ class HloDomainVerifier : public HloPassInterface { public: HloDomainVerifier(std::vector kinds) : kinds_(std::move(kinds)) {} - tensorflow::StringPiece name() const override { return "domain_verifier"; } + absl::string_view name() const override { return "domain_verifier"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h index 2b109225d0b192e5c9e4f6d841377ffad8078dc2..44ded2c2faf7c38d1e2f2aae577ddc07089bbb6a 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h @@ -32,9 +32,7 @@ class HloElementTypeConverter : public HloPassInterface { HloElementTypeConverter(PrimitiveType eliminate_type, PrimitiveType replace_with_type); - tensorflow::StringPiece name() const override { - return "element_type_converter"; - } + absl::string_view name() const override { return "element_type_converter"; } // Returns the pass on the module and returns whether the module was modified. StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index fb900494919fc66d2f7010f2b0137720586dcf78..71f91fde93904cbd4ef157e0bc7098b81a53907f 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -435,7 +435,7 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { if (!ShapeUtil::ElementIsFloating(operand->shape())) { return InvalidArgument( "expected element type in shape to be float for IsFinite op, got: %s", - PrimitiveType_Name(operand->shape().element_type()).c_str()); + PrimitiveType_Name(operand->shape().element_type())); } switch (operand->shape().element_type()) { @@ -476,9 +476,9 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s vs %s", - ShapeUtil::HumanString(compare->shape()).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str()); + ShapeUtil::HumanString(compare->shape()), + ShapeUtil::HumanString(lhs->shape()), + ShapeUtil::HumanString(rhs->shape())); } TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type()); @@ -1105,8 +1105,8 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { HloEvaluator loop_body_evaluator(max_loop_iterations_); while (keep_going) { if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { - return InvalidArgument("Loop %s exceeded loop iteration limit (%lld).", - while_hlo->name().c_str(), max_loop_iterations_); + return InvalidArgument("Loop %s exceeded loop iteration limit (%d).", + while_hlo->name(), max_loop_iterations_); } TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate( *cond_comp, {lcv.get()})); @@ -1262,7 +1262,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape()); if (sort_dim != rank - 1) { return Unimplemented( - "Trying to support along dimension %lld, which is not the last " + "Trying to support along dimension %d, which is not the last " "dimension", sort_dim); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 7588916de5068416410daf1a71a0bbad56f3ef0b..0ea708955237a92c2b9f9d8bac1e5e6b4185ca49 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -222,8 +222,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(operand->shape()).c_str()); + ShapeUtil::HumanString(shape), + ShapeUtil::HumanString(operand->shape())); } auto result = absl::make_unique(shape); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 4b8e6260ac837fa88a64126aaf83998b060d7efc..c3af15c6a88e42d0339fddcccd7dae7c6b62fb52 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -52,7 +52,10 @@ static std::array use_bf16_params{true, false}; class HloEvaluatorTest : public ::testing::WithParamInterface, public HloVerifiedTestBase { protected: - HloEvaluatorTest() : use_bfloat16_(GetParam()) { + HloEvaluatorTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false), + use_bfloat16_(GetParam()) { evaluator_ = absl::make_unique(); } @@ -1216,7 +1219,12 @@ TEST_P(HloEvaluatorTest, EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } -class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; +class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase { + public: + HloEvaluatorPreciseReduceTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // Tests that Reduce doesn't lose precision when adding many numbers (because // it accumulates its result in a double). diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 2da2cc2d71ed94315cfc15a737155b65f9e8f7ad..f682e69ee93b874c614376cc69c425a7f58de259 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -21,7 +21,9 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/core/lib/core/casts.h" @@ -143,7 +145,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo_instruction) override { return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", - HloOpcodeString(hlo_instruction->opcode()).c_str()); + HloOpcodeString(hlo_instruction->opcode())); } // TODO(b/35950897): many of the stl functions used in the handlers are not @@ -2493,11 +2495,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same::value || std::is_same::value || std::is_same::value>::type* = nullptr> - Status HandleIota(HloInstruction* iota) { - auto result = absl::make_unique(iota->shape()); - auto data = result->data(); + Status HandleIota(HloInstruction* instruction) { + auto* iota = Cast(instruction); + std::vector data(iota->shape().dimensions(iota->iota_dimension())); std::iota(data.begin(), data.end(), 0); - parent_->evaluated_[iota] = std::move(result); + auto result = LiteralUtil::CreateR1(data); + + if (ShapeUtil::Rank(iota->shape()) > 1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[iota], + result->Broadcast(iota->shape(), {iota->iota_dimension()})); + } else { + TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); + parent_->evaluated_[iota] = std::move(result); + } + return Status::OK(); } template shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str()); + ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), + ShapeUtil::HumanString(rhs->shape())); } const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); @@ -2690,10 +2701,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str(), - ShapeUtil::HumanString(ehs->shape()).c_str()); + ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), + ShapeUtil::HumanString(rhs->shape()), + ShapeUtil::HumanString(ehs->shape())); } const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index eba80c0f199f6224f4b46ac19af482c713585154..460ae2b5eca78659f86df1227e6a0a4e57508611 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -14,15 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { -using tensorflow::strings::StrCat; +using absl::StrCat; using ::testing::AllOf; using ::testing::ContainsRegex; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index f8ade39e8cd27aa87a9bc530cc08ae1a9aff65e2..3041d94fa9f55b1acffc1295d07e48c967322865 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -26,6 +26,11 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -40,47 +45,23 @@ limitations under the License. #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" -using ::absl::nullopt; -using ::absl::optional; -using ::tensorflow::Env; -using ::tensorflow::WriteStringToFile; -using ::tensorflow::io::JoinPath; -using ::tensorflow::str_util::Join; -using ::tensorflow::str_util::StringReplace; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace xla { namespace hlo_graph_dumper { namespace { -// Helpers for Printf and Appendf. -template -struct PrintfConvert { - const T& operator()(const T& t) const { return t; } -}; -template <> -struct PrintfConvert { - const char* operator()(const string& s) const { return s.c_str(); } -}; - -// Like tensorflow::strings::Printf/Appendf, but you don't need to call c_str() -// on strings. -template -string Printf(const char* fmt, const Ts&... ts) { - return tensorflow::strings::Printf(fmt, PrintfConvert()(ts)...); -} -template -void Appendf(string* s, const char* fmt, const Ts&... ts) { - tensorflow::strings::Appendf(s, fmt, PrintfConvert()(ts)...); -} +using absl::nullopt; +using absl::optional; +using absl::StrAppend; +using absl::StrCat; +using absl::StrFormat; +using absl::StrJoin; +using tensorflow::Env; +using tensorflow::WriteStringToFile; +using tensorflow::io::JoinPath; // Used to indicate how we should treat a given HLOInstruction in the graph. // should we treat it like normal, hide it, and so on? @@ -209,17 +190,15 @@ NodeColors NodeColorsForScheme(ColorScheme color) { string NodeColorAttributes(ColorScheme color) { NodeColors node_colors = NodeColorsForScheme(color); - return Printf( - R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", - node_colors.style, node_colors.font_color, node_colors.stroke_color, - node_colors.fill_color); + return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", + node_colors.style, node_colors.font_color, + node_colors.stroke_color, node_colors.fill_color); } // Replaces <> with <>, so that this string is safe(er) for use in a // graphviz HTML-like string. -string HtmlLikeStringSanitize(tensorflow::StringPiece s) { - return StringReplace(StringReplace(s, "<", "<", /*replace_all=*/true), ">", - ">", /*replace_all=*/true); +string HtmlLikeStringSanitize(absl::string_view s) { + return absl::StrReplaceAll(s, {{"<", "<"}, {">", ">"}}); } // Tries to generates a human-readable one-word description of the given @@ -322,11 +301,11 @@ optional MatchTrivialComputation(const HloComputation* computation) { // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax). class HloDotDumper { public: - HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, + HloDotDumper(const HloComputation* computation, absl::string_view label, const DebugOptions& debug_options, bool show_backend_config, const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), - label_(std::string(label)), + label_(label), debug_options_(debug_options), show_backend_config_(show_backend_config), profile_(profile), @@ -448,7 +427,7 @@ string HloDotDumper::Dump() { } string HloDotDumper::Header() { - const char* fmt = R"(digraph G { + constexpr char fmt[] = R"(digraph G { rankdir = TB; compound = true; label = <%s>; @@ -457,7 +436,7 @@ labelloc = t; tooltip = " "; // DOT graphs accept a stylesheet as a URI. So naturally, an inline // stylesheet is a data URI! -stylesheet=" +stylesheet=< data:text/css, @import url(https://fonts.googleapis.com/css?family=Roboto:400,700); svg text { @@ -466,7 +445,7 @@ stylesheet=" } %s -" +> )"; @@ -481,8 +460,8 @@ stylesheet=" } if (profile_ != nullptr) { auto cycles = profile_->total_cycles_executed(*computation_); - Appendf(&graph_label, "
total cycles = %lld (%s)", cycles, - tensorflow::strings::HumanReadableNum(cycles)); + absl::StrAppendFormat(&graph_label, "
total cycles = %d (%s)", cycles, + tensorflow::strings::HumanReadableNum(cycles)); } // Create CSS rules that say, when you hover over the given node or cluster, @@ -509,14 +488,14 @@ stylesheet=" // One could imagine other ways of writing this CSS rule that involve // less duplication, but this way seems to be relatively performant. edge_css_rules.push_back( - Printf(" #%s%d:hover ~ #edge%lld text { fill: %s; }\n" - " #%s%d:hover ~ #edge%lld path { " - "stroke: %s; stroke-width: .2em; }\n" - " #%s%d:hover ~ #edge%lld polygon { " - "fill: %s; stroke: %s; stroke-width: .2em; }\n", - elem_type, elem_id, edge_id, color, // - elem_type, elem_id, edge_id, color, // - elem_type, elem_id, edge_id, color, color)); + StrFormat(" #%s%d:hover ~ #edge%d text { fill: %s; }\n" + " #%s%d:hover ~ #edge%d path { " + "stroke: %s; stroke-width: .2em; }\n" + " #%s%d:hover ~ #edge%d polygon { " + "fill: %s; stroke: %s; stroke-width: .2em; }\n", + elem_type, elem_id, edge_id, color, // + elem_type, elem_id, edge_id, color, // + elem_type, elem_id, edge_id, color, color)); }; // The "to_node" value may be a NULL, indicating that this points to the @@ -559,10 +538,10 @@ stylesheet=" } } - return Printf(fmt, graph_label, Join(edge_css_rules, "\n")); + return StrFormat(fmt, graph_label, StrJoin(edge_css_rules, "\n")); } -string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); } +string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); } bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) { CHECK_EQ(instr->opcode(), HloOpcode::kFusion); @@ -600,9 +579,9 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() << " as " << next_edge_id_; edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); - const char* edge_fmt = + constexpr char edge_fmt[] = R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; - edges_.push_back(Printf( + edges_.push_back(StrFormat( edge_fmt, InstructionId(from), InstructionId(parent_instr), SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); } @@ -619,9 +598,10 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, string subcomp_label, style; if (parent_instr->opcode() == HloOpcode::kFusion) { - subcomp_label = Printf("Fused expression for %s
%s", - HtmlLikeStringSanitize(parent_instr->name()), - HtmlLikeStringSanitize(parent_instr->ToCategory())); + subcomp_label = + StrFormat("Fused expression for %s
%s", + HtmlLikeStringSanitize(parent_instr->name()), + HtmlLikeStringSanitize(parent_instr->ToCategory())); string extra_info = GetInstructionNodeExtraInfo(parent_instr); if (!extra_info.empty()) { StrAppend(&subcomp_label, "
", extra_info); @@ -647,18 +627,18 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, strokecolor = highlight ? "#b71c1c" : "#c2c2c2"; } style = - Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")", - fillcolor, strokecolor); + StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")", + fillcolor, strokecolor); } else { - subcomp_label = Printf("Subcomputation for %s
%s", - HtmlLikeStringSanitize(parent_instr->name()), - HtmlLikeStringSanitize(subcomp->name())); + subcomp_label = StrFormat("Subcomputation for %s
%s", + HtmlLikeStringSanitize(parent_instr->name()), + HtmlLikeStringSanitize(subcomp->name())); style = "style=rounded; color=black;"; } string comp_body = DumpComputation(subcomp); - const char* computation_fmt = R"(subgraph %s { + constexpr char computation_fmt[] = R"(subgraph %s { %s label = <%s>; labelloc = t; @@ -667,7 +647,7 @@ tooltip = " "; } // %s )"; - return Printf(computation_fmt, id, style, subcomp_label, comp_body, id); + return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id); } string HloDotDumper::DumpComputation(const HloComputation* comp) { @@ -718,11 +698,11 @@ string HloDotDumper::DumpRootTag() { VLOG(2) << "Adding edge from " << from->name() << " to root tag as " << next_edge_id_; edge_ids_.insert({{from, to}, next_edge_id_++}); - edges_.push_back(Printf(R"(%s -> %s [tooltip=" "];)", from_id, to_id)); + edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id)); - return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" - "\n", - to_id, node_body, node_shape, NodeColorAttributes(color)); + return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" + "\n", + to_id, node_body, node_shape, NodeColorAttributes(color)); } static const HloConstantInstruction* TryGetFusionParameterConstant( @@ -817,10 +797,10 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } } - return Printf(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" - "\n", - InstructionId(instr), node_body, node_shape, node_metadata, - NodeColorAttributes(color)); + return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" + "\n", + InstructionId(instr), node_body, node_shape, node_metadata, + NodeColorAttributes(color)); } string HloDotDumper::GetInstructionNodeInlinedOperands( @@ -833,7 +813,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which // is just noise. if (ShapeUtil::IsZeroElementArray(shape)) { - return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape())); + return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape())); } // Print the literal value of constants with <= K elements. @@ -848,19 +828,19 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // collected from profiling tools. Those constants may not have a valid // literal. if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { - return Printf("%s (%s)", constant->literal().ToString(), - ShapeUtil::HumanString(constant->shape())); + return StrFormat("%s (%s)", constant->literal().ToString(), + ShapeUtil::HumanString(constant->shape())); } // Otherwise, print e.g. "%constant.42 (s32[100])". string constant_name; - if (tensorflow::str_util::StartsWith(constant->name(), "constant")) { + if (absl::StartsWith(constant->name(), "constant")) { constant_name = constant->name(); } else { constant_name = StrCat("constant ", constant->name()); } - return Printf("%s %s", constant_name, - ShapeUtil::HumanString(constant->shape())); + return StrFormat("%s %s", constant_name, + ShapeUtil::HumanString(constant->shape())); }; std::vector lines; @@ -881,7 +861,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( TryGetFusionParameterConstant(operand)) { operand_str = stringify_constant(constant); } else { - operand_str = Printf("Parameter %lld", operand->parameter_number()); + operand_str = StrFormat("Parameter %d", operand->parameter_number()); } } else { operand_str = operand->name(); @@ -890,13 +870,13 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( if (operand_str) { if (instr->operand_count() > 1) { - lines.push_back(Printf("operand %lld = %s", i, *operand_str)); + lines.push_back(StrFormat("operand %d = %s", i, *operand_str)); } else { - lines.push_back(Printf("operand = %s", *operand_str)); + lines.push_back(StrFormat("operand = %s", *operand_str)); } } } - return Join(lines, "
"); + return StrJoin(lines, "
"); } ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { @@ -1049,6 +1029,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kGray; case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kRecv: @@ -1079,14 +1060,13 @@ string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { // If we have a parameter, put the param number in the name. if (instr->opcode() == HloOpcode::kParameter) { - return Printf("Parameter %lld", instr->parameter_number()); + return StrFormat("Parameter %d", instr->parameter_number()); } // The HLO instruction name contains usually the opcode, e.g. "%add.42" is // an add instruction. In this case we render just the name. - if (tensorflow::str_util::StartsWith(instr->name(), - HloOpcodeString(instr->opcode()))) { - return Printf("%s", HtmlLikeStringSanitize(instr->name())); + if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) { + return StrFormat("%s", HtmlLikeStringSanitize(instr->name())); } string extended_opcode = StrCat(HloOpcodeString(instr->opcode()), @@ -1094,8 +1074,8 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { ? "" : StrCat(":", xla::ToString(instr->fusion_kind()))); // If the name does not contain the opcode, render both. - return Printf("%s
%s", HtmlLikeStringSanitize(extended_opcode), - HtmlLikeStringSanitize(instr->name())); + return StrFormat("%s
%s", HtmlLikeStringSanitize(extended_opcode), + HtmlLikeStringSanitize(instr->name())); } string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { @@ -1104,16 +1084,16 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name())); } if (!instr->metadata().op_type().empty()) { - lines.push_back(Printf( + lines.push_back(StrFormat( "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type()))); } if (!instr->metadata().source_file().empty() && instr->metadata().source_line() != 0) { - lines.push_back(Printf("op_type: %s", instr->metadata().source_file(), - instr->metadata().source_line())); + lines.push_back(StrFormat("op_type: %s:%d", instr->metadata().source_file(), + instr->metadata().source_line())); } - return Join(lines, "
"); + return StrJoin(lines, "
"); } string HloDotDumper::GetInstructionNodeBackendConfig( @@ -1160,13 +1140,12 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { constexpr int kMaxShapeLen = 64; if (instr_shape.length() > kMaxShapeLen) { instr_shape = StrCat( - tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3), - "..."); + absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "..."); } lines.push_back(instr_shape); } if (debug_options_.xla_hlo_graph_addresses()) { - lines.push_back(Printf("[%p]", instr)); + lines.push_back(StrFormat("[%p]", instr)); } if (profile_ != nullptr) { double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr); @@ -1174,11 +1153,11 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { profile_->total_cycles_executed(*instr->parent()); if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { lines.push_back( - Printf("%% of cycles executed=%.2f", - 100 * hlo_cycles_executed / total_cycles_executed)); + StrFormat("%% of cycles executed=%.2f", + 100 * hlo_cycles_executed / total_cycles_executed)); } } - return Join(lines, "
"); + return StrJoin(lines, "
"); } // Gets the total number of array elements in the given shape. For tuples, this @@ -1210,7 +1189,8 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { string edge_label; if (instr->operand_count() > 1 && !control_edge) { - edge_label = Printf(R"( headlabel="%lld", labeldistance=2)", operand_num); + edge_label = + StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num); } else if (control_edge) { edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\""; } @@ -1220,10 +1200,11 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { // means. bool is_big_array = TotalElementsInShape(from->shape()) >= 4096; - const char* kEdgeFmt = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; - edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to), - (is_big_array ? "normal" : "empty"), from->name(), - to->name(), edge_label)); + constexpr char kEdgeFmt[] = + R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; + edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to), + (is_big_array ? "normal" : "empty"), + from->name(), to->name(), edge_label)); }; // Add edges from instr's operands to instr. Parameters within fusion @@ -1264,14 +1245,14 @@ string HloDotDumper::GetInstructionTrivialComputationStr( continue; } if (instr->called_computations().size() == 1) { - lines.push_back(Printf("Subcomputation: %s", - HtmlLikeStringSanitize(*computation_type))); + lines.push_back(StrFormat("Subcomputation: %s", + HtmlLikeStringSanitize(*computation_type))); } else { - lines.push_back(Printf("Subcomputation %lld: %s", i, - HtmlLikeStringSanitize(*computation_type))); + lines.push_back(StrFormat("Subcomputation %d: %s", i, + HtmlLikeStringSanitize(*computation_type))); } } - return Join(lines, "
"); + return StrJoin(lines, "
"); } const HloInstruction* HloDotDumper::GetNodeForEdge( diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 1d7a062c55696de9db4b187efd86bce191279083..064c53252c0ac4d4e7b93169ad7cbee4807cb963 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -23,12 +24,11 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { -using ::tensorflow::strings::StrCat; +using absl::StrCat; using ::testing::HasSubstr; string TestName() { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a211167519cd5f8c531aa6289ce312ee35a6d44b..ed4e15991052cba0707ca02c32abf652e41de623 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -24,6 +24,11 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/strings/ascii.h" +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -41,17 +46,15 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using tensorflow::str_util::CEscape; -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::CEscape; +using absl::StrAppend; +using absl::StrCat; +using absl::StrJoin; /* static */ StatusOr> HloInstruction::CreateFromProto( @@ -314,8 +317,18 @@ StatusOr> HloInstruction::CreateFromProto( proto.shape(), all_operands(), /*replica_groups=*/ std::vector(proto.replica_groups().begin(), - proto.replica_groups().end()), - /*barrier=*/proto.cross_replica_sum_barrier()); + proto.replica_groups().end())); + break; + } + case HloOpcode::kCollectivePermute: { + std::vector> source_target_pairs( + proto.source_target_pairs_size()); + for (int i = 0; i < source_target_pairs.size(); i++) { + source_target_pairs[i].first = proto.source_target_pairs(i).source(); + source_target_pairs[i].second = proto.source_target_pairs(i).target(); + } + instruction = CreateCollectivePermute(proto.shape(), operands(0), + source_target_pairs); break; } case HloOpcode::kConvolution: @@ -415,6 +428,12 @@ StatusOr> HloInstruction::CreateFromProto( computations(0), *scatter_dimension_numbers); break; } + case HloOpcode::kIota: + TF_RET_CHECK(proto.dimensions_size() <= 1) + << "Iota instruction should have at most 1 dimension but sees " + << proto.dimensions_size(); + instruction = CreateIota(proto.shape(), proto.dimensions(0)); + break; default: { instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -477,8 +496,8 @@ StatusOr> HloInstruction::CreateFromProto( } /* static */ std::unique_ptr HloInstruction::CreateIota( - const Shape& shape) { - return absl::WrapUnique(new HloInstruction(HloOpcode::kIota, shape)); + const Shape& shape, int64 iota_dimension) { + return absl::make_unique(shape, iota_dimension); } /* static */ std::unique_ptr @@ -665,8 +684,7 @@ HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction::CreateCrossReplicaSum( const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* reduce_computation, - const std::vector& replica_groups, - tensorflow::StringPiece barrier, + const std::vector& replica_groups, absl::string_view barrier, const absl::optional& all_reduce_id) { return absl::make_unique( shape, operands, reduce_computation, replica_groups, barrier, @@ -675,10 +693,17 @@ HloInstruction::CreateCrossReplicaSum( /* static */ std::unique_ptr HloInstruction::CreateAllToAll( const Shape& shape, tensorflow::gtl::ArraySlice operands, - const std::vector& replica_groups, - tensorflow::StringPiece barrier) { + const std::vector& replica_groups) { return absl::make_unique(shape, operands, - replica_groups, barrier); + replica_groups); +} + +/* static */ std::unique_ptr +HloInstruction::CreateCollectivePermute( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs) { + return absl::make_unique( + shape, operand, source_target_pairs); } /* static */ std::unique_ptr HloInstruction::CreateInfeed( @@ -690,7 +715,7 @@ HloInstruction::CreateCrossReplicaSum( /* static */ std::unique_ptr HloInstruction::CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, - HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) { + HloInstruction* token_operand, absl::string_view outfeed_config) { return absl::make_unique( outfeed_shape, operand, token_operand, outfeed_config); } @@ -1068,7 +1093,7 @@ bool HloInstruction::HasSideEffect() const { /* static */ std::unique_ptr HloInstruction::CreateCustomCall( const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target) { + absl::string_view custom_call_target) { return absl::make_unique(shape, operands, custom_call_target); } @@ -1154,6 +1179,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kReducePrecision: case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kConvolution: @@ -1347,7 +1373,7 @@ std::unique_ptr HloInstruction::Clone( // If names ends with .suffix[0-9]+ then replace with a suffix with the // numeric value incremented. int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) { clone->name_ = StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1); } else { @@ -1622,6 +1648,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kOutfeed: case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: @@ -1819,7 +1846,7 @@ void HloInstruction::set_false_computation(HloComputation* false_computation) { string HloInstruction::SignatureString() const { string operands = - Join(operands_, ", ", [](string* out, HloInstruction* operand) { + StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) { StrAppend(out, ShapeUtil::HumanString(operand->shape())); }); return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); @@ -1966,7 +1993,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( slice.size() > kMaxOperandsToShowIfCompact) { slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); } - operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { + operands = StrJoin(slice, ", ", [&](string* out, HloInstruction* operand) { // If operand is already been deleted, put `null` to the string output. if (operand == nullptr) { StrAppend(out, "null "); @@ -1986,7 +2013,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( } else if (!options.compact_operands()) { str.push_back(PrintName(operand->name(), options)); } - StrAppend(out, Join(str, " ")); + StrAppend(out, StrJoin(str, " ")); }); const int64 remaining = operands_.size() - slice.size(); if (slice.size() != operands_.size()) { @@ -2033,11 +2060,11 @@ std::vector HloInstruction::ExtraAttributesToString( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { extra.push_back(StrCat( - "calls=", Join(called_computations(), ", ", - [&](string* out, const HloComputation* computation) { - StrAppend(out, - PrintName(computation->name(), options)); - }))); + "calls=", + StrJoin(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, PrintName(computation->name(), options)); + }))); } } else if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kFullBodies) { @@ -2070,12 +2097,12 @@ std::vector HloInstruction::ExtraAttributesToString( break; default: if (!called_computations().empty()) { - extra.push_back( - StrCat("calls=\n", - Join(called_computations(), ", ", - [&](string* out, const HloComputation* computation) { - StrAppend(out, computation->ToString(new_options)); - }))); + extra.push_back(StrCat( + "calls=\n", + StrJoin(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, computation->ToString(new_options)); + }))); } break; } @@ -2086,11 +2113,11 @@ std::vector HloInstruction::ExtraAttributesToString( } if (!control_predecessors_.empty()) { extra.push_back(StrCat("control-predecessors={", - Join(control_predecessors_, ", ", - [&](string* out, HloInstruction* pre) { - StrAppend(out, - PrintName(pre->name(), options)); - }), + StrJoin(control_predecessors_, ", ", + [&](string* out, HloInstruction* pre) { + StrAppend(out, + PrintName(pre->name(), options)); + }), "}")); } if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { @@ -2104,10 +2131,10 @@ std::vector HloInstruction::ExtraAttributesToString( string HloInstruction::ToShortString() const { return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(", - Join(operands_, ", ", - [](string* out, HloInstruction* operand) { - StrAppend(out, "%", operand->name()); - }), + StrJoin(operands_, ", ", + [](string* out, HloInstruction* operand) { + StrAppend(out, "%", operand->name()); + }), ")"); } @@ -2168,7 +2195,7 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } -bool HloInstruction::IsFusable() const { +bool HloInstruction::IsFusible() const { // Instructions which are traced should not be fused. if (tracing()) { return false; @@ -2274,6 +2301,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleCrossReplicaSum(this); case HloOpcode::kAllToAll: return visitor->HandleAllToAll(this); + case HloOpcode::kCollectivePermute: + return visitor->HandleCollectivePermute(this); case HloOpcode::kTuple: return visitor->HandleTuple(this); case HloOpcode::kMap: @@ -2380,7 +2409,7 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return InternalError( "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - " "please file a bug for XLA.", - HloOpcodeString(opcode_).c_str()); + HloOpcodeString(opcode_)); } // Explicit instantiations. @@ -2463,7 +2492,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) { return FailedPrecondition( "A cycle is detected while visiting instruction %s", - current_node->ToString().c_str()); + current_node->ToString()); } } @@ -2472,7 +2501,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) { return FailedPrecondition( "A cycle is detected while visiting instruction %s", - current_node->ToString().c_str()); + current_node->ToString()); } } } @@ -2788,7 +2817,7 @@ StatusOr StringToFusionKind( if (kind_name == "kCustom") { return HloInstruction::FusionKind::kCustom; } - return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str()); + return InvalidArgument("Unknown fusion kind: %s", kind_name); } string PaddingConfigToString(const PaddingConfig& padding) { @@ -2797,7 +2826,7 @@ string PaddingConfigToString(const PaddingConfig& padding) { [](const PaddingConfig::PaddingConfigDimension& dim) { return dim.interior_padding() != 0; }); - return Join( + return StrJoin( padding.dimensions(), "x", [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) { StrAppend( @@ -2821,16 +2850,15 @@ string OpMetadataToString(const OpMetadata& metadata) { if (metadata.source_line() != 0) { result.push_back(StrCat("source_line=", metadata.source_line())); } - return Join(result, " "); + return StrJoin(result, " "); } string RandomDistributionToString(const RandomDistribution& distribution) { - return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); + return absl::AsciiStrToLower(RandomDistribution_Name(distribution)); } string PrecisionToString(const PrecisionConfigProto::Precision& precision) { - return tensorflow::str_util::Lowercase( - PrecisionConfigProto::Precision_Name(precision)); + return absl::AsciiStrToLower(PrecisionConfigProto::Precision_Name(precision)); } string ConvolutionDimensionNumbersToString( @@ -2858,8 +2886,8 @@ string ConvolutionDimensionNumbersToString( output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i); } - return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->", - Join(output_dims, "")); + return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->", + StrJoin(output_dims, "")); } string HloInstruction::DotDimensionNumbersToString() const { @@ -2870,19 +2898,21 @@ string HloInstruction::DotDimensionNumbersToString() const { const DotDimensionNumbers& dnums = *dot_dimension_numbers_; if (!dnums.lhs_batch_dimensions().empty()) { result.push_back(StrCat("lhs_batch_dims={", - Join(dnums.lhs_batch_dimensions(), ","), "}")); + StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); } result.push_back(StrCat("lhs_contracting_dims={", - Join(dnums.lhs_contracting_dimensions(), ","), "}")); + StrJoin(dnums.lhs_contracting_dimensions(), ","), + "}")); if (!dnums.rhs_batch_dimensions().empty()) { result.push_back(StrCat("rhs_batch_dims={", - Join(dnums.rhs_batch_dimensions(), ","), "}")); + StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); } result.push_back(StrCat("rhs_contracting_dims={", - Join(dnums.rhs_contracting_dimensions(), ","), "}")); + StrJoin(dnums.rhs_contracting_dimensions(), ","), + "}")); - return Join(result, ", "); + return StrJoin(result, ", "); } StatusOr StringToRandomDistribution(const string& name) { @@ -2896,7 +2926,7 @@ StatusOr StringToRandomDistribution(const string& name) { } return map; }(); - auto found = map->find(tensorflow::str_util::Lowercase(name)); + auto found = map->find(absl::AsciiStrToLower(name)); if (found == map->end()) { return InvalidArgument("Unknown distribution"); } @@ -2909,15 +2939,14 @@ string HloInstruction::PrecisionConfigToString() const { } return StrCat( "operand_precision={", - Join(precision_config_.operand_precision(), ",", - [](string* out, int32 precision) { - CHECK(PrecisionConfigProto::Precision_IsValid(precision)) - << precision; - StrAppend( - out, - PrecisionToString( - static_cast(precision))); - }), + StrJoin(precision_config_.operand_precision(), ",", + [](string* out, int32 precision) { + CHECK(PrecisionConfigProto::Precision_IsValid(precision)) + << precision; + StrAppend(out, PrecisionToString( + static_cast( + precision))); + }), "}"); } @@ -2934,7 +2963,7 @@ StatusOr StringToPrecision( } return map; }(); - auto found = map->find(tensorflow::str_util::Lowercase(name)); + auto found = map->find(absl::AsciiStrToLower(name)); if (found == map->end()) { return InvalidArgument("Unknown distribution"); } @@ -3185,25 +3214,20 @@ const string& HloInstruction::outfeed_config() const { } const std::vector& HloInstruction::replica_groups() const { - if (opcode() == HloOpcode::kCrossReplicaSum) { - return Cast(this)->replica_groups(); - } - return Cast(this)->replica_groups(); + return Cast(this)->replica_groups(); +} + +const std::vector>& +HloInstruction::source_target_pairs() const { + return Cast(this)->source_target_pairs(); } string HloInstruction::cross_replica_sum_barrier() const { - if (opcode() == HloOpcode::kCrossReplicaSum) { - return Cast(this)->cross_replica_sum_barrier(); - } - return Cast(this)->cross_replica_sum_barrier(); + return Cast(this)->cross_replica_sum_barrier(); } void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { - if (opcode() == HloOpcode::kCrossReplicaSum) { - return Cast(this)->set_cross_replica_sum_barrier( - barrier); - } - return Cast(this)->set_cross_replica_sum_barrier( + return Cast(this)->set_cross_replica_sum_barrier( barrier); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index fdd34544eb90022dece5ad7da5436cd757f07f1b..4a424cebc070accdac8e334410d005031775c28f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -34,6 +34,8 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -47,7 +49,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/iterator_range.h" @@ -222,7 +223,7 @@ class CanonicalNameMap { return iter->second; } - string new_name = tensorflow::strings::StrCat("tmp_", index++); + string new_name = absl::StrCat("tmp_", index++); canonical_name_map[old_name] = new_name; return new_name; } @@ -349,7 +350,8 @@ class HloInstruction { std::unique_ptr literal); // Creates an Iota instruction. - static std::unique_ptr CreateIota(const Shape& shape); + static std::unique_ptr CreateIota(const Shape& shape, + int64 iota_dimension); // Creates a get tuple element instruction. static std::unique_ptr CreateGetTupleElement( @@ -450,8 +452,7 @@ class HloInstruction { const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* reduce_computation, const std::vector& replica_groups, - tensorflow::StringPiece barrier, - const absl::optional& all_reduce_id); + absl::string_view barrier, const absl::optional& all_reduce_id); // This op handles the communication of an Alltoall operation. On each core, // the operands are N ops in the same shape, where N is the number of cores @@ -466,12 +467,18 @@ class HloInstruction { // within replica 1, 2, 3, and in the gather phase, the received blocks will // be concatenated in the order of 1, 2, 3; another Alltoall will be applied // within replica 4, 5, 0, and the concatenation order is 4, 5, 0. - // - // TODO(b/110096724): This is NOT YET ready to use. static std::unique_ptr CreateAllToAll( const Shape& shape, tensorflow::gtl::ArraySlice operands, - const std::vector& replica_groups, - tensorflow::StringPiece barrier); + const std::vector& replica_groups); + + // Creates a communitation instructions that permutes data cross replicas. + // Data is sent/received according to the (source_replica_id, + // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a + // target_replica_id in any pair, the output on that replica is a tensor + // conssits of 0(s) in `shape`. + static std::unique_ptr CreateCollectivePermute( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -496,7 +503,7 @@ class HloInstruction { // which is a TOKEN. static std::unique_ptr CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, - HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); + HloInstruction* token_operand, absl::string_view outfeed_config); // Creates an asynchronous send instruction with the given channel id, which // initiates sending the operand data to a unique receive instruction in @@ -709,7 +716,7 @@ class HloInstruction { // to the given operands. "shape" is the resultant shape. static std::unique_ptr CreateCustomCall( const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target); + absl::string_view custom_call_target); // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. @@ -1032,7 +1039,7 @@ class HloInstruction { // Returns true if this instruction can be legally fused into a fusion // instruction. - bool IsFusable() const; + bool IsFusible() const; // Returns the sharding applied to this operator. // REQUIRES: has_sharding() is true. @@ -1040,6 +1047,8 @@ class HloInstruction { CHECK(has_sharding()); return *sharding_; } + std::shared_ptr sharding_ptr() const { return sharding_; } + // Returns the sharding applied to this operator, or default_ if none exists. const HloSharding& sharding_or_default(const HloSharding& default_) const { return sharding_ ? *sharding_ : default_; @@ -1054,7 +1063,10 @@ class HloInstruction { // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { - sharding_ = absl::make_unique(sharding); + sharding_ = std::make_shared(sharding); + } + void set_sharding(std::shared_ptr sharding) { + sharding_ = std::move(sharding); } void set_single_sharding(const HloSharding& sharding); // Sets a sharding that assigns the current instruction to device. @@ -1090,19 +1102,6 @@ class HloInstruction { // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; - // TODO(b/80249101): Remove these methods once HLO scheduling and copy - // insertion are integrated, and we don't need to run a separate pass - // of copy elision anymore. - bool CopyElisionAllowed() const { - CHECK_EQ(HloOpcode::kCopy, opcode_); - return copy_elision_allowed_; - } - - void SetCopyElisionAllowed(bool value) { - CHECK_EQ(HloOpcode::kCopy, opcode_); - copy_elision_allowed_ = value; - } - // Returns data on the dimension numbers used for a dot operation. const DotDimensionNumbers& dot_dimension_numbers() const { CHECK(dot_dimension_numbers_ != nullptr); @@ -1440,9 +1439,12 @@ class HloInstruction { // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const; - // Delegates to HloAllToAllInstruction::replica_groups. + // Delegates to HloCollectiveInstruction::replica_groups. const std::vector& replica_groups() const; + // Delegates to HloCollectivePermuteInstruction::source_target_pairs. + const std::vector>& source_target_pairs() const; + // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. string cross_replica_sum_barrier() const; void set_cross_replica_sum_barrier(const string& barrier); @@ -1655,7 +1657,10 @@ class HloInstruction { bool copy_elision_allowed_ = true; // The sharding, if one exists. - std::unique_ptr sharding_; + // Uses std::shared_ptr to allow reuse of the same sharding object between + // HloInstructions and other components as HloSharding can be very large for + // many element tuples. + std::shared_ptr sharding_; // Fields used by the kDomain instruction. std::unique_ptr operand_side_metadata_; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 504b13043f86f152cc83b0b961bf2e8fa3ad2afb..8b0b90dfb32336821a059ed2239599a6307583b2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -53,7 +53,7 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { public: Status DefaultAction(HloInstruction* hlo_instruction) override { return Unimplemented("not implemented %s", - HloOpcodeString(hlo_instruction->opcode()).c_str()); + HloOpcodeString(hlo_instruction->opcode())); } Status HandleParameter(HloInstruction* parameter) override { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 36fac4a266c6fa525fe8640804e1b51fd19b582a..ffc74cfeddb9880d1119642ac3f6c1bc2ebecfcd 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -19,6 +19,10 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -29,10 +33,10 @@ limitations under the License. namespace xla { namespace { -using ::tensorflow::str_util::CEscape; -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::CEscape; +using absl::StrAppend; +using absl::StrCat; +using absl::StrJoin; bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, const HloInstruction* operand) { @@ -160,7 +164,7 @@ HloInstructionProto HloFftInstruction::ToProto() const { std::vector HloFftInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { return {StrCat("fft_type=", FftType_Name(fft_type())), - StrCat("fft_length={", Join(fft_length(), ","), "}")}; + StrCat("fft_length={", StrJoin(fft_length(), ","), "}")}; } bool HloFftInstruction::IdenticalSlowPath( @@ -297,43 +301,75 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( Cast(new_operands[0]), is_host_transfer()); } -HloAllReduceInstruction::HloAllReduceInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloComputation* reduce_computation, - const std::vector& replica_groups, - tensorflow::StringPiece barrier, const absl::optional& all_reduce_id) - : HloInstruction(HloOpcode::kCrossReplicaSum, shape), - replica_groups_(replica_groups), - cross_replica_sum_barrier_(barrier.begin(), barrier.end()), - all_reduce_id_(all_reduce_id) { +HloCollectiveInstruction::HloCollectiveInstruction( + HloOpcode opcode, const Shape& shape, + tensorflow::gtl::ArraySlice operands, + const std::vector& replica_groups) + : HloInstruction(opcode, shape), replica_groups_(replica_groups) { for (auto operand : operands) { AppendOperand(operand); } - AppendComputation(reduce_computation); } -HloInstructionProto HloAllReduceInstruction::ToProto() const { +HloInstructionProto HloCollectiveInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); *proto.mutable_replica_groups() = {replica_groups_.begin(), replica_groups_.end()}; - // Proto3 is so sad. - if (all_reduce_id_) { - proto.set_all_reduce_id(*all_reduce_id_); - } - proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); return proto; } -std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( +std::vector HloCollectiveInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& /*options*/) const { std::vector result; std::vector replica_group_str; for (const ReplicaGroup& group : replica_groups()) { replica_group_str.push_back( - StrCat("{", Join(group.replica_ids(), ","), "}")); + StrCat("{", StrJoin(group.replica_ids(), ","), "}")); } result.push_back( - StrCat("replica_groups={", Join(replica_group_str, ","), "}")); + StrCat("replica_groups={", StrJoin(replica_group_str, ","), "}")); + return result; +} + +bool HloCollectiveInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + /*eq_computations*/) const { + const auto& casted_other = + static_cast(other); + return ContainersEqual(replica_groups(), casted_other.replica_groups(), + [](const ReplicaGroup& a, const ReplicaGroup& b) { + return ContainersEqual(a.replica_ids(), + b.replica_ids()); + }); +} + +HloAllReduceInstruction::HloAllReduceInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + const std::vector& replica_groups, absl::string_view barrier, + const absl::optional& all_reduce_id) + : HloCollectiveInstruction(HloOpcode::kCrossReplicaSum, shape, operands, + replica_groups), + cross_replica_sum_barrier_(barrier), + all_reduce_id_(all_reduce_id) { + AppendComputation(reduce_computation); +} + +HloInstructionProto HloAllReduceInstruction::ToProto() const { + HloInstructionProto proto = HloCollectiveInstruction::ToProto(); + // Proto3 is so sad. + if (all_reduce_id_) { + proto.set_all_reduce_id(*all_reduce_id_); + } + proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); + return proto; +} + +std::vector HloAllReduceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector result = + HloCollectiveInstruction::ExtraAttributesToStringImpl(options); if (!cross_replica_sum_barrier().empty()) { result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); } @@ -348,11 +384,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath( const std::function& eq_computations) const { const auto& casted_other = static_cast(other); - return ContainersEqual(replica_groups(), casted_other.replica_groups(), - [](const ReplicaGroup& a, const ReplicaGroup& b) { - return ContainersEqual(a.replica_ids(), - b.replica_ids()); - }) && + return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && eq_computations(to_apply(), casted_other.to_apply()) && cross_replica_sum_barrier() == casted_other.cross_replica_sum_barrier() && @@ -371,63 +403,69 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( HloAllToAllInstruction::HloAllToAllInstruction( const Shape& shape, tensorflow::gtl::ArraySlice operands, - const std::vector& replica_groups, - tensorflow::StringPiece barrier) - : HloInstruction(HloOpcode::kAllToAll, shape), - replica_groups_(replica_groups), - cross_replica_sum_barrier_(barrier.begin(), barrier.end()) { - for (auto operand : operands) { - AppendOperand(operand); - } -} - -bool HloAllToAllInstruction::IdenticalSlowPath( - const HloInstruction& other, - const std::function& - eq_computations) const { - const auto& casted_other = static_cast(other); - return ContainersEqual(replica_groups(), casted_other.replica_groups(), - [](const ReplicaGroup& a, const ReplicaGroup& b) { - return ContainersEqual(a.replica_ids(), - b.replica_ids()); - }) && - cross_replica_sum_barrier() == - casted_other.cross_replica_sum_barrier(); -} + const std::vector& replica_groups) + : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands, + replica_groups) {} std::unique_ptr HloAllToAllInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, HloCloneContext* /*context*/) const { - return absl::make_unique( - shape, new_operands, replica_groups(), cross_replica_sum_barrier()); + return absl::make_unique(shape, new_operands, + replica_groups()); } -std::vector HloAllToAllInstruction::ExtraAttributesToStringImpl( - const HloPrintOptions& options) const { - std::vector result; - std::vector replica_group_str; - for (const ReplicaGroup& group : replica_groups()) { - replica_group_str.push_back( - StrCat("{", Join(group.replica_ids(), ","), "}")); - } - result.push_back( - StrCat("replica_groups={", Join(replica_group_str, ","), "}")); +HloCollectivePermuteInstruction::HloCollectivePermuteInstruction( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs) + : HloInstruction(HloOpcode::kCollectivePermute, shape), + source_target_pairs_(source_target_pairs) { + AppendOperand(operand); +} - if (!cross_replica_sum_barrier().empty()) { - result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); +HloInstructionProto HloCollectivePermuteInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (const auto& pair : source_target_pairs()) { + auto* proto_pair = proto.add_source_target_pairs(); + proto_pair->set_source(pair.first); + proto_pair->set_target(pair.second); } + return proto; +} +std::vector +HloCollectivePermuteInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector result; + std::vector strs; + for (const auto& pair : source_target_pairs()) { + strs.push_back(StrCat("{", pair.first, ",", pair.second, "}")); + } + result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}")); return result; } -HloInstructionProto HloAllToAllInstruction::ToProto() const { - HloInstructionProto proto = HloInstruction::ToProto(); - *proto.mutable_replica_groups() = {replica_groups_.begin(), - replica_groups_.end()}; - proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); - return proto; +bool HloCollectivePermuteInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + /*eq_computations*/) const { + const auto& casted_other = + static_cast(other); + return ContainersEqual( + source_target_pairs(), casted_other.source_target_pairs(), + [](const std::pair& a, const std::pair& b) { + return a == b; + }); +} + +std::unique_ptr +HloCollectivePermuteInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* /*context*/) const { + return absl::make_unique( + shape, new_operands[0], source_target_pairs()); } HloReverseInstruction::HloReverseInstruction( @@ -448,7 +486,7 @@ HloInstructionProto HloReverseInstruction::ToProto() const { std::vector HloReverseInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloReverseInstruction::IdenticalSlowPath( @@ -487,7 +525,7 @@ HloInstructionProto HloConcatenateInstruction::ToProto() const { std::vector HloConcatenateInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloConcatenateInstruction::IdenticalSlowPath( @@ -530,7 +568,7 @@ HloInstructionProto HloReduceInstruction::ToProto() const { std::vector HloReduceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloReduceInstruction::IdenticalSlowPath( @@ -573,7 +611,7 @@ HloInstructionProto HloSortInstruction::ToProto() const { std::vector HloSortInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloSortInstruction::IdenticalSlowPath( @@ -606,7 +644,7 @@ HloTransposeInstruction::HloTransposeInstruction( Permute(dimensions, shape.dimensions()).begin())) << "shape: " << ShapeUtil::HumanString(shape) << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << Join(dimensions, ", ") << "}"; + << ", dimensions: {" << StrJoin(dimensions, ", ") << "}"; AppendOperand(operand); } @@ -627,7 +665,7 @@ HloInstructionProto HloTransposeInstruction::ToProto() const { std::vector HloTransposeInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloTransposeInstruction::IdenticalSlowPath( @@ -666,7 +704,7 @@ HloInstructionProto HloBroadcastInstruction::ToProto() const { std::vector HloBroadcastInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloBroadcastInstruction::IdenticalSlowPath( @@ -727,7 +765,7 @@ bool HloMapInstruction::IsElementwiseImpl( std::vector HloMapInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloMapInstruction::IdenticalSlowPath( @@ -785,7 +823,7 @@ std::vector HloSliceInstruction::ExtraAttributesToStringImpl( bounds.push_back( StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]")); } - return {StrCat("slice={", Join(bounds, ", "), "}")}; + return {StrCat("slice={", StrJoin(bounds, ", "), "}")}; } bool HloSliceInstruction::IdenticalSlowPath( @@ -871,7 +909,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( // lines. Compact this into one line by stripping out white space. string tmp = literal().ToString(); std::replace(tmp.begin(), tmp.end(), '\n', ' '); - std::vector v = tensorflow::str_util::Split(tmp, ' '); + std::vector v = absl::StrSplit(tmp, ' '); bool first = true; // Concatenate elements in "v" with spaces separating them, but ignoring // empty entries. @@ -1166,7 +1204,7 @@ HloInstruction* HloFusionInstruction::FuseInstructionInternal( HloInstruction* HloFusionInstruction::CloneAndFuseInternal( HloInstruction* instruction_to_fuse, bool add_output) { - CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); + CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString(); VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); HloInstruction* clone = nullptr; if (called_computations().empty()) { @@ -1572,12 +1610,13 @@ std::unique_ptr HloInfeedInstruction::CloneWithNewOperandsImpl( infeed_shape(), new_operands[0], infeed_config()); } -HloOutfeedInstruction::HloOutfeedInstruction( - const Shape& outfeed_shape, HloInstruction* operand, - HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) +HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape, + HloInstruction* operand, + HloInstruction* token_operand, + absl::string_view outfeed_config) : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), outfeed_shape_(outfeed_shape), - outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { + outfeed_config_(outfeed_config) { CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) << "Outfeed shape " << outfeed_shape << " must be compatible with operand shape " << operand->shape(); @@ -1785,7 +1824,7 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( HloCustomCallInstruction::HloCustomCallInstruction( const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target) + absl::string_view custom_call_target) : HloInstruction(HloOpcode::kCustomCall, shape), custom_call_target_(custom_call_target.begin(), custom_call_target.end()) { @@ -1921,8 +1960,8 @@ HloInstructionProto HloDynamicSliceInstruction::ToProto() const { std::vector HloDynamicSliceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return { - StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")}; + return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","), + "}")}; } bool HloDynamicSliceInstruction::IdenticalSlowPath( @@ -1958,17 +1997,17 @@ string HloGatherInstruction::GatherDimensionNumbersToString() const { CHECK(gather_dimension_numbers_ != nullptr); string offset_dims = StrCat("offset_dims={", - Join(gather_dimension_numbers_->offset_dims(), ","), "}"); - string collapsed_slice_dims = - StrCat("collapsed_slice_dims={", - Join(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}"); + StrJoin(gather_dimension_numbers_->offset_dims(), ","), "}"); + string collapsed_slice_dims = StrCat( + "collapsed_slice_dims={", + StrJoin(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}"); string start_index_map = StrCat("start_index_map={", - Join(gather_dimension_numbers_->start_index_map(), ","), "}"); + StrJoin(gather_dimension_numbers_->start_index_map(), ","), "}"); string index_vector_dim = StrCat( "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); - return Join>( + return StrJoin>( {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim}, ", "); } @@ -2005,7 +2044,7 @@ HloInstructionProto HloGatherInstruction::ToProto() const { std::vector HloGatherInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { return {GatherDimensionNumbersToString(), - StrCat("slice_sizes={", Join(gather_slice_sizes(), ","), "}")}; + StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")}; } bool HloGatherInstruction::IdenticalSlowPath( @@ -2044,20 +2083,20 @@ HloScatterInstruction::HloScatterInstruction( } string HloScatterInstruction::ScatterDimensionNumbersToString() const { - string update_window_dims = - StrCat("update_window_dims={", - Join(scatter_dimension_numbers().update_window_dims(), ","), "}"); + string update_window_dims = StrCat( + "update_window_dims={", + StrJoin(scatter_dimension_numbers().update_window_dims(), ","), "}"); string inserted_window_dims = StrCat( "inserted_window_dims={", - Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); + StrJoin(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); string scatter_dims_to_operand_dims = StrCat( "scatter_dims_to_operand_dims={", - Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), + StrJoin(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), "}"); string index_vector_dim = StrCat( "index_vector_dim=", scatter_dimension_numbers().index_vector_dim()); - return Join>( + return StrJoin>( {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, index_vector_dim}, ", "); @@ -2116,4 +2155,34 @@ std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( scatter_dimension_numbers()); } +HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension) + : HloInstruction(HloOpcode::kIota, shape), + iota_dimension_(iota_dimension) {} + +HloInstructionProto HloIotaInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.add_dimensions(iota_dimension()); + return proto; +} + +std::vector HloIotaInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("iota_dimension=", iota_dimension())}; +} + +bool HloIotaInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return iota_dimension() == casted_other.iota_dimension(); +} + +std::unique_ptr HloIotaInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return absl::make_unique(shape, iota_dimension()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 0a6a0c62339018341fd009b474ed6b721d74110b..ee6e337b6a4ccc769a5389c5ce657337cbbd32fb 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -218,18 +218,37 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction { HloCloneContext* context) const override; }; -class HloAllReduceInstruction : public HloInstruction { +class HloCollectiveInstruction : public HloInstruction { + public: + const std::vector& replica_groups() const { + return replica_groups_; + } + + protected: + explicit HloCollectiveInstruction( + HloOpcode opcode, const Shape& shape, + tensorflow::gtl::ArraySlice operands, + const std::vector& replica_groups); + + HloInstructionProto ToProto() const override; + + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + std::vector replica_groups_; +}; + +class HloAllReduceInstruction : public HloCollectiveInstruction { public: explicit HloAllReduceInstruction( const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* reduce_computation, const std::vector& replica_groups, - tensorflow::StringPiece barrier, - const absl::optional& all_reduce_id); - - const std::vector& replica_groups() const { - return replica_groups_; - } + absl::string_view barrier, const absl::optional& all_reduce_id); // Returns the barrier config used for the CrossReplicaSum implementation of // each backend. @@ -259,9 +278,6 @@ class HloAllReduceInstruction : public HloInstruction { tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const override; - // The replica ids of each subgroup for CrossReplicaSum op. - std::vector replica_groups_; - // The string representation of the barrier config used for CrossReplicaSum. string cross_replica_sum_barrier_; @@ -271,25 +287,31 @@ class HloAllReduceInstruction : public HloInstruction { absl::optional all_reduce_id_; }; -class HloAllToAllInstruction : public HloInstruction { +class HloAllToAllInstruction : public HloCollectiveInstruction { public: explicit HloAllToAllInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice operand, - const std::vector& replica_groups, - tensorflow::StringPiece barrier); + const Shape& shape, tensorflow::gtl::ArraySlice operands, + const std::vector& replica_groups); - const std::vector& replica_groups() const { - return replica_groups_; - } + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; - // TODO(b/110096724): rename this. - void set_cross_replica_sum_barrier(string barrier) { - cross_replica_sum_barrier_ = barrier; - } - string cross_replica_sum_barrier() const { - return cross_replica_sum_barrier_; +class HloCollectivePermuteInstruction : public HloInstruction { + public: + explicit HloCollectivePermuteInstruction( + const Shape& shape, HloInstruction* operand, + const std::vector>& source_target_pairs); + + const std::vector>& source_target_pairs() const { + return source_target_pairs_; } + // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; private: @@ -306,10 +328,7 @@ class HloAllToAllInstruction : public HloInstruction { tensorflow::gtl::ArraySlice new_operands, HloCloneContext* context) const override; - std::vector replica_groups_; - - // The string representation of the barrier config. - string cross_replica_sum_barrier_; + const std::vector> source_target_pairs_; }; class HloReverseInstruction : public HloInstruction { @@ -918,7 +937,7 @@ class HloOutfeedInstruction : public HloInstruction { explicit HloOutfeedInstruction(const Shape& outfeed_shape, HloInstruction* operand, HloInstruction* token_operand, - tensorflow::StringPiece outfeed_config); + absl::string_view outfeed_config); // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_)); @@ -1071,7 +1090,7 @@ class HloCustomCallInstruction : public HloInstruction { public: explicit HloCustomCallInstruction( const Shape& shape, tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target); + absl::string_view custom_call_target); const Window& window() const override { CHECK(window_ != nullptr); return *window_; @@ -1260,6 +1279,30 @@ class HloScatterInstruction : public HloInstruction { std::unique_ptr scatter_dimension_numbers_; }; +class HloIotaInstruction : public HloInstruction { + public: + explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension); + // Returns the dimension sizes or numbers associated with this instruction. + int64 iota_dimension() const { return iota_dimension_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + const int64 iota_dimension_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 2e01b090beebe9280d0931c5d0c6a56f728b9eff..8350285e67554bd8d2f619884c346c696e33caf5 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -17,20 +17,20 @@ limitations under the License. #include +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" namespace xla { - -using ::tensorflow::StringPiece; - namespace { +using absl::string_view; + constexpr int kEOF = -1; constexpr int kError = -2; @@ -66,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const { return ptr < buf_.end() && ptr >= buf_.begin(); } -tensorflow::StringPiece HloLexer::StringPieceFromPointers( - const char* begin, const char* end) const { +absl::string_view HloLexer::StringPieceFromPointers(const char* begin, + const char* end) const { CHECK(begin <= end); CHECK(begin == buf_.end() || CanDereference(begin)); CHECK(end == buf_.end() || CanDereference(end)); - return tensorflow::StringPiece(begin, end - begin); + return absl::string_view(begin, end - begin); } tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( @@ -235,7 +235,7 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kAttributeName; } - tensorflow::StringPiece identifier = + absl::string_view identifier = StringPieceFromPointers(token_start_, current_ptr_); // See if this is a keyword. @@ -269,7 +269,7 @@ TokKind HloLexer::LexIdentifier() { } } - str_val_ = std::string(identifier); + str_val_ = string(identifier); return TokKind::kIdent; } @@ -306,8 +306,7 @@ TokKind HloLexer::LexNumberOrPattern() { R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"}; if (RE2::Consume(&consumable, *float_pattern)) { current_ptr_ = consumable.begin(); - tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(), - &decimal_val_); + CHECK(absl::SimpleAtod(string(token_start_, current_ptr_), &decimal_val_)); return TokKind::kDecimal; } @@ -339,7 +338,7 @@ TokKind HloLexer::LexNumberOrPattern() { if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); auto slice = StringPieceFromPointers(token_start_, current_ptr_); - if (tensorflow::strings::safe_strto64(slice, &int64_val_)) { + if (absl::SimpleAtoi(slice, &int64_val_)) { return TokKind::kInt; } LOG(ERROR) << "Failed to parse int literal: " << slice; @@ -375,24 +374,24 @@ std::pair HloLexer::GetLineAndColumn(LocTy location) const { line_no_cache_.last_query = ptr; line_no_cache_.line_no_of_query = line_no; size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); - if (line_offset == tensorflow::StringPiece::npos) { + if (line_offset == absl::string_view::npos) { line_offset = 0; } return {line_no, ptr - start - line_offset}; } -tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const { +absl::string_view HloLexer::GetLine(LocTy loc) const { if (!CanDereference(loc)) { return "LINE OUT OF RANGE"; } size_t line_start = StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); - const char* start = line_start == tensorflow::StringPiece::npos + const char* start = line_start == absl::string_view::npos ? buf_.begin() : buf_.begin() + line_start + 1; size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); const char* end = - line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end; + line_end == absl::string_view::npos ? buf_.end() : loc + line_end; return StringPieceFromPointers(start, end); } @@ -404,10 +403,14 @@ TokKind HloLexer::LexString() { static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); - tensorflow::StringPiece raw = + absl::string_view raw = StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); string error; - if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) { + // TODO(b/113077997): Change to absl::CUnescape once it works properly with + // copy-on-write std::string implementations. + if (!tensorflow::str_util::CUnescape( // non-absl ok + tensorflow::StringPiece(raw.data(), raw.size()), // non-absl ok + &str_val_, &error)) { LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error; return TokKind::kError; } diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index f9ecd9ccb91c19ff0801ee55a1aa4da3696e97ab..3e2f8bcd52f9043f161197756a2060b28dded1d9 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -18,10 +18,10 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/types.h" @@ -34,7 +34,7 @@ namespace xla { // it directly. class HloLexer { public: - explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { + explicit HloLexer(absl::string_view buf) : buf_(buf) { current_ptr_ = buf_.begin(); } @@ -77,7 +77,7 @@ class HloLexer { std::pair GetLineAndColumn(LocTy location) const; // Returns the whole line given the location. - tensorflow::StringPiece GetLine(LocTy loc) const; + absl::string_view GetLine(LocTy loc) const; private: // Returns the current character. If it's neither the end of input buffer nor @@ -89,8 +89,8 @@ class HloLexer { // Creates StringPiece with the given begin and end. Exits if the begin > end, // or it's out of the range of the current buffer. - tensorflow::StringPiece StringPieceFromPointers(const char* begin, - const char* end) const; + absl::string_view StringPieceFromPointers(const char* begin, + const char* end) const; tensorflow::RegexpStringPiece RegexpStringPieceFromPointers( const char* begin, const char* end) const; @@ -107,7 +107,7 @@ class HloLexer { TokKind LexNumberOrPattern(); TokKind LexString(); - const tensorflow::StringPiece buf_; + const absl::string_view buf_; const char* current_ptr_; // Information about the current token. diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc index 18f17b75aede734b4971a07347f31ba45db9dc96..3a1dd471c626ae9497cfcca62c30736bcdbb2b38 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -29,17 +30,14 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { using Worklist = std::deque; using Workset = std::unordered_set; -namespace { - void AddToWorklist(const HloInstruction* instruction, Worklist* worklist, Workset* workset) { if (workset->count(instruction) == 0) { diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 7e4b8834357d39099f76450b849d6b5624e4e3b4..5269cad94d35be3dd1c009588bbe422ff1533364 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -15,15 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace testing { -using ::tensorflow::str_util::Join; - bool HloMatcher::MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const { @@ -210,8 +208,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain( dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) { *listener << instruction->ToString() << " has wrong lhs_contracting_dimensions (got {" - << Join(dim_nums.lhs_contracting_dimensions(), ",") << "} want {" - << lhs_contracting_dim_ << "})"; + << absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",") + << "} want {" << lhs_contracting_dim_ << "})"; return false; } @@ -219,8 +217,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain( dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) { *listener << instruction->ToString() << " has wrong rhs_contracting_dimensions (got {" - << Join(dim_nums.rhs_contracting_dimensions(), ",") << "} want {" - << rhs_contracting_dim_ << "})"; + << absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",") + << "} want {" << rhs_contracting_dim_ << "})"; return false; } diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 0a442e77f0b0aedea807e0991d4f30ead83a1a6b..9ace0d76e0c98420b085f30c0f0042a33b6e7583 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -306,7 +306,7 @@ inline ::testing::Matcher Shape( return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape)); } inline ::testing::Matcher Shape( - tensorflow::StringPiece shape) { + absl::string_view shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher( ShapeUtil::ParseShapeString(shape).ValueOrDie())); } @@ -316,7 +316,7 @@ inline ::testing::Matcher ShapeWithLayout( new ::xla::testing::HloShapeAndLayoutMatcher(shape)); } inline ::testing::Matcher ShapeWithLayout( - tensorflow::StringPiece shape) { + absl::string_view shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher( ShapeUtil::ParseShapeString(shape).ValueOrDie())); } @@ -329,7 +329,7 @@ inline ::testing::Matcher Sharding( } // Matcher for Sharding from sharding string inline ::testing::Matcher Sharding( - tensorflow::StringPiece sharding) { + absl::string_view sharding) { return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher( ParseSharding(sharding).ValueOrDie())); } diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index d60b76d63f8fb0b3b775e743beaec58316fa3740..78167335c8efeb3de4b475bba562a8f0150a3aa6 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -24,11 +24,11 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -410,7 +410,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( string error_message = "The subcomputation to outline has multiple outputs:\n"; for (HloInstruction* output : outputs) { - tensorflow::strings::StrAppend(&error_message, output->ToString(), "\n"); + absl::StrAppend(&error_message, output->ToString(), "\n"); } LOG(FATAL) << error_message; } @@ -536,8 +536,7 @@ uint64 HloModule::RandomNew64() const { return rng_(); } -HloComputation* HloModule::GetComputationWithName( - tensorflow::StringPiece name) { +HloComputation* HloModule::GetComputationWithName(absl::string_view name) { auto computations_in_module = computations(); auto it = absl::c_find_if( computations_in_module, diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index d2e726a0db63f622cd5092d56b4f746232d04aad..cf129b835db56c21245c7e98d7e7876c1e507132 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" @@ -142,7 +142,7 @@ class HloModule { // Returns the computation in this module that has the name `name`. Returns // null if there is no such computation. - HloComputation* GetComputationWithName(tensorflow::StringPiece name); + HloComputation* GetComputationWithName(absl::string_view name); // Gets the number of computations in this module. int64 computation_count() const { return computations_.size(); } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index f9708283eb4becd67a76ff30103001c81c2c703a..9bfa3a5f45c8e810f9ea7d6bdcd72b90254d15b9 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { -using tensorflow::strings::StrAppend; +using absl::StrAppend; HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape, bool ignore_layouts) @@ -39,15 +39,14 @@ void HloModuleConfig::SetDefaultComputationLayout( } string HloModuleConfig::compilation_cache_key() const { - string key = - tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled()); + string key = absl::StrCat("profiling=", hlo_profiling_enabled()); StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : entry_computation_layout_->parameter_layouts()) { params.push_back(param_layout.shape().DebugString()); } - StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ", + StrAppend(&key, absl::StrJoin(params, ", "), ") => ", entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h index 29024085c1038961ef2b3721de1ce0e8a55ccf45..12ca2340a6ccaa50780e81168c755c1fec3aa1be 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.h +++ b/tensorflow/compiler/xla/service/hlo_module_dce.h @@ -31,7 +31,7 @@ namespace xla { class HloModuleDCE : public HloPassInterface { public: ~HloModuleDCE() override {} - tensorflow::StringPiece name() const override { return "hlo-module-dce"; } + absl::string_view name() const override { return "hlo-module-dce"; } // Run the pass on the given module. Returns whether the module was changed // (instructions were removed). diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index f52a37bc7426ea6f1cf8754d9ee8db98b1493f15..9c01862a4b7024826c3f701b795819abe945d07f 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -131,6 +132,14 @@ Status HloModuleGroupMetadata::Build() { if (VLOG_IS_ON(4)) { DumpCollectedStats(); } + + for (HloModule* module : modules_) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(module)); + points_to_analyses_[module] = std::move(points_to_analysis); + } + return Status::OK(); } @@ -163,7 +172,7 @@ Status HloModuleGroupMetadata::VerifyCompanionSets() const { ss << " " << hlo->name() << std::endl; } ss << "has multiple instructions on the same device"; - return FailedPrecondition("%s", ss.str().c_str()); + return FailedPrecondition("%s", ss.str()); } } } @@ -411,16 +420,16 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, Status HloModuleGroupMetadata::VerifyChannelInstructions() { for (const Channel& channel : channels_) { if (channel.send == nullptr) { - return FailedPrecondition("missing send for id : %lld", channel.id); + return FailedPrecondition("missing send for id : %d", channel.id); } if (channel.recv == nullptr) { - return FailedPrecondition("missing recv for id : %lld", channel.id); + return FailedPrecondition("missing recv for id : %d", channel.id); } if (channel.send_done == nullptr) { - return FailedPrecondition("missing send-done for id : %lld", channel.id); + return FailedPrecondition("missing send-done for id : %d", channel.id); } if (channel.recv_done == nullptr) { - return FailedPrecondition("missing recv-done for id : %lld", channel.id); + return FailedPrecondition("missing recv-done for id : %d", channel.id); } } @@ -436,33 +445,33 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { auto send_done_device = GetInstructionDevice(*channel.send_done); if (!send_device) { return FailedPrecondition("send instruction must have a device: %s", - channel.send->ToString().c_str()); + channel.send->ToString()); } if (!send_done_device) { return FailedPrecondition("send_done instruction must have a device: %s", - channel.send_done->ToString().c_str()); + channel.send_done->ToString()); } if (*send_device != *send_done_device) { return FailedPrecondition( - "send and send-done (channel=%lld) must be on the same device: %lld " - "vs. %lld", + "send and send-done (channel=%d) must be on the same device: %d " + "vs. %d", channel.id, *send_device, *send_done_device); } auto recv_device = GetInstructionDevice(*channel.recv); auto recv_done_device = GetInstructionDevice(*channel.recv_done); if (!recv_done_device) { return FailedPrecondition("recv_done instruction must have a device: %s", - channel.recv_done->ToString().c_str()); + channel.recv_done->ToString()); } if (*recv_device != *recv_done_device) { return FailedPrecondition( - "recv and recv-done (channel=%lld) must be on the same device: %lld " - "vs. %lld", + "recv and recv-done (channel=%d) must be on the same device: %d " + "vs. %d", channel.id, *recv_device, *recv_done_device); } if (*send_device == *recv_device) { return FailedPrecondition( - "send and recv (channel=%lld) must be on different devices: %lld", + "send and recv (channel=%d) must be on different devices: %d", channel.id, *send_device); } } @@ -483,7 +492,7 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { !CheckCompanionPathsCompatibility( path, GetCompanionsPath(channel.recv_done))) { return FailedPrecondition( - "Nest companion paths do not match for channel %lld", channel.id); + "Nest companion paths do not match for channel %d", channel.id); } } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index dead6d9c2090c2f296788bbb97dbd7edc4ce4392..768b0c7eb3695715de5cef7dad1ed5a110561605 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" @@ -197,6 +198,10 @@ class HloModuleGroupMetadata { // Returns the maximum channel id or all_reduce_id used in the module group. int64 max_channel_id() const { return max_channel_id_; } + TuplePointsToAnalysis* points_to_analysis(HloModule* module) const { + return points_to_analyses_.at(module).get(); + } + private: Status Build(); @@ -271,6 +276,9 @@ class HloModuleGroupMetadata { // The modules that this metadata was built from. const std::vector& modules_; + + tensorflow::gtl::FlatMap> + points_to_analyses_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 1a4da388e4ac4f4d0b303309aebfec9d75b3ebdd..d70328c8a3db60488a631a82bf27a14fd01e6dba 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -270,8 +270,8 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( string cyclic_instructions; for (const auto& state : *visit_state) { if (state.second == VisitState::kVisiting) { - tensorflow::strings::StrAppend(&cyclic_instructions, - state.first->ToString(), "\n"); + absl::StrAppend(&cyclic_instructions, state.first->ToString(), + "\n"); } } // TODO(b/64305524): Improve the error message to print out the @@ -282,7 +282,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( "following nodes. Note that the order of the nodes is arbitrary " "and that the list may include nodes that are not part of the " "cycle.\n%s", - predecessor->ToString().c_str(), cyclic_instructions.c_str()); + predecessor->ToString(), cyclic_instructions); } stack.push(predecessor); } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index d1eaf357855205f1e9867e86f3042b96b6beff97..2d4e38589fe4693e73c46d6c82e51cb0a8388f85 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -39,7 +39,7 @@ StatusOr StringToHloOpcode(const string& opcode_name) { }); auto it = opcode_map->find(opcode_name); if (it == opcode_map->end()) { - return InvalidArgument("Unknown opcode: %s", opcode_name.c_str()); + return InvalidArgument("Unknown opcode: %s", opcode_name); } return it->second; } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index b8f2a21ff9df6460303610cf64c98d1b96836171..e6bfb8025d4bfeba1d334d1f946e33841a2da092 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -58,6 +58,7 @@ namespace xla { V(kCall, "call", kHloOpcodeIsVariadic) \ V(kCeil, "ceil") \ V(kClamp, "clamp") \ + V(kCollectivePermute, "collective-permute") \ V(kClz, "count-leading-zeros") \ V(kComplex, "complex") \ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 6c1e015f77a62c3e3ff7ffa5ce9dea735f46e10a..0581d5c40425d332d89cc92ca6c6b0b10dd8fcf1 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -25,8 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -254,6 +254,10 @@ bool HloOrdering::LiveRangeStrictlyBefore( } // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { + if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(), + use.instruction)) { + continue; + } if (!UseIsBeforeValueDefinition(use, b, dataflow)) { VLOG(4) << "use of " << a << " (" << use << ") not before " << b << " is defined"; @@ -302,22 +306,20 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const { std::vector pieces; pieces.push_back(name); for (auto* computation : module_->MakeNonfusionComputations()) { - pieces.push_back(tensorflow::strings::Printf("computation %s:", - computation->name().c_str())); + pieces.push_back(absl::StrFormat("computation %s:", computation->name())); const auto all = computation->MakeInstructionPostOrder(); for (auto instruction : all) { - pieces.push_back(tensorflow::strings::Printf( - " %s predecessors:", instruction->name().c_str())); + pieces.push_back( + absl::StrFormat(" %s predecessors:", instruction->name())); for (auto predecessor : all) { if (predecessors_.at(computation) ->IsReachable(predecessor, instruction)) { - pieces.push_back( - tensorflow::strings::Printf(" %s", predecessor->name().c_str())); + pieces.push_back(absl::StrFormat(" %s", predecessor->name())); } } } } - return tensorflow::str_util::Join(pieces, "\n"); + return absl::StrJoin(pieces, "\n"); } DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) @@ -368,8 +370,8 @@ string SequentialHloOrdering::ToString() const { std::vector pieces; pieces.push_back("SequentialHloOrdering"); for (auto* computation : module_->computations()) { - pieces.push_back(tensorflow::strings::Printf("computation %s order:", - computation->name().c_str())); + pieces.push_back( + absl::StrFormat("computation %s order:", computation->name())); // Gather all instructions in the module sequence for this computation and // sort them by their position. std::vector instructions; @@ -384,11 +386,10 @@ string SequentialHloOrdering::ToString() const { return order_position_.at(a) < order_position_.at(b); }); for (auto instruction : instructions) { - pieces.push_back( - tensorflow::strings::Printf(" %s", instruction->name().c_str())); + pieces.push_back(absl::StrFormat(" %s", instruction->name())); } } - return tensorflow::str_util::Join(pieces, "\n"); + return absl::StrJoin(pieces, "\n"); } std::ostream& operator<<( diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index beef96476c611cead41c936e55bcda07c7d48a6a..eae4508b24b98ec4e93d221aaa2dd3a6c221aaba 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -17,6 +17,10 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" @@ -26,22 +30,17 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace { -using ::absl::nullopt; -using ::absl::optional; -using ::tensorflow::StringPiece; -using ::tensorflow::str_util::Join; -using ::tensorflow::str_util::Split; -using ::tensorflow::str_util::SplitAndParseAsInts; -using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::nullopt; +using absl::optional; +using absl::StrAppend; +using absl::StrCat; +using absl::StrFormat; +using absl::StrJoin; const double kF16max = 65504; @@ -50,7 +49,7 @@ class HloParser { public: using LocTy = HloLexer::LocTy; - explicit HloParser(StringPiece str, const HloModuleConfig& config) + explicit HloParser(absl::string_view str, const HloModuleConfig& config) : lexer_(str), config_(config) {} // Runs the parser. Returns false if an error occurred. @@ -60,7 +59,7 @@ class HloParser { std::unique_ptr ConsumeHloModule() { return std::move(module_); } // Returns the error information. - string GetError() const { return Join(error_, "\n"); } + string GetError() const { return StrJoin(error_, "\n"); } // Stand alone parsing utils for various aggregate data types. StatusOr ParseShardingOnly(); @@ -253,8 +252,8 @@ class HloParser { bool CanBeParamListToShape(); // Logs the current parsing line and the given message. Always returns false. - bool TokenError(StringPiece msg); - bool Error(LocTy loc, StringPiece msg); + bool TokenError(absl::string_view msg); + bool Error(LocTy loc, absl::string_view msg); // If the current token is 'kind', eats it (i.e. lexes the next token) and // returns true. @@ -293,6 +292,17 @@ class HloParser { missing_instruction_hook_; }; +bool SplitToInt64s(absl::string_view s, char delim, std::vector* out) { + for (const auto& split : absl::StrSplit(s, delim)) { + int64 val; + if (!absl::SimpleAtoi(split, &val)) { + return false; + } + out->push_back(val); + } + return true; +} + // Creates replica groups from the provided nested array. groups[i] represents // the replica ids for group 'i'. std::vector CreateReplicaGroups( @@ -307,22 +317,22 @@ std::vector CreateReplicaGroups( return replica_groups; } -bool HloParser::Error(LocTy loc, StringPiece msg) { +bool HloParser::Error(LocTy loc, absl::string_view msg) { auto line_col = lexer_.GetLineAndColumn(loc); const unsigned line = line_col.first; const unsigned col = line_col.second; std::vector error_lines; error_lines.push_back( StrCat("was parsing ", line, ":", col, ": error: ", msg)); - error_lines.push_back(std::string(lexer_.GetLine(loc))); + error_lines.emplace_back(lexer_.GetLine(loc)); error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); - error_.push_back(Join(error_lines, "\n")); + error_.push_back(StrJoin(error_lines, "\n")); VLOG(1) << "Error: " << error_.back(); return false; } -bool HloParser::TokenError(StringPiece msg) { +bool HloParser::TokenError(absl::string_view msg) { return Error(lexer_.GetLoc(), msg); } @@ -552,11 +562,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kIota: { + optional iota_dimension; + attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64, + &iota_dimension}; if (!ParseOperands(&operands, /*expected_size=*/0) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateIota(shape)); + instruction = builder->AddInstruction( + HloInstruction::CreateIota(shape, *iota_dimension)); break; } // Unary ops. @@ -681,7 +695,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional barrier; attrs["replica_groups"] = {/*required=*/false, AttrTy::kBracedInt64ListList, &tmp_groups}; - attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -689,8 +702,29 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (tmp_groups) { replica_groups = CreateReplicaGroups(*tmp_groups); } - instruction = builder->AddInstruction(HloInstruction::CreateAllToAll( - shape, operands, replica_groups, barrier ? *barrier : "")); + instruction = builder->AddInstruction( + HloInstruction::CreateAllToAll(shape, operands, replica_groups)); + break; + } + case HloOpcode::kCollectivePermute: { + optional>> source_targets; + attrs["source_target_pairs"] = { + /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + std::vector> pairs(source_targets->size()); + for (int i = 0; i < pairs.size(); i++) { + if ((*source_targets)[i].size() != 2) { + return TokenError( + "expects 'source_target_pairs=' to be a list of pairs"); + } + pairs[i].first = (*source_targets)[i][0]; + pairs[i].second = (*source_targets)[i][1]; + } + instruction = builder->AddInstruction( + HloInstruction::CreateCollectivePermute(shape, operands[0], pairs)); break; } case HloOpcode::kReshape: { @@ -1577,8 +1611,7 @@ bool HloParser::ParseInstructionNames( } std::pair* instr = FindInstruction(name); if (!instr) { - return TokenError( - Printf("instruction '%s' is not defined", name.c_str())); + return TokenError(StrFormat("instruction '%s' is not defined", name)); } instructions->push_back(instr->first); } while (EatIfPresent(TokKind::kComma)); @@ -1807,10 +1840,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, std::vector elems_seen_until_dim( elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", - Join(elems_seen_until_dim, ",", - [](string* out, const tensorflow::int64& num_elems) { - StrAppend(out, num_elems - 1); - }), + StrJoin(elems_seen_until_dim, ",", + [](string* out, const tensorflow::int64& num_elems) { + StrAppend(out, num_elems - 1); + }), "]"); }; do { @@ -1820,17 +1853,17 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, case TokKind::kLbrace: { nest_level++; if (nest_level > rank) { - return TokenError(Printf( - "expects nested array in rank %lld, but sees larger", rank)); + return TokenError(absl::StrFormat( + "expects nested array in rank %d, but sees larger", rank)); } if (nest_level > 1) { elems_seen_per_dim[nest_level - 2]++; if (elems_seen_per_dim[nest_level - 2] > shape.dimensions(nest_level - 2)) { - return TokenError(Printf( - "expects %lld elements in the %sth element, but sees more", + return TokenError(absl::StrFormat( + "expects %d elements in the %sth element, but sees more", shape.dimensions(nest_level - 2), - get_index_str(nest_level - 2).c_str())); + get_index_str(nest_level - 2))); } } lexer_.Lex(); @@ -1839,9 +1872,9 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, case TokKind::kRbrace: { nest_level--; if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) { - return TokenError(Printf( - "expects %lld elements in the %sth element, but sees %lld", - shape.dimensions(nest_level), get_index_str(nest_level).c_str(), + return TokenError(absl::StrFormat( + "expects %d elements in the %sth element, but sees %d", + shape.dimensions(nest_level), get_index_str(nest_level), elems_seen_per_dim[nest_level])); } elems_seen_per_dim[nest_level] = 0; @@ -1862,15 +1895,15 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, if (rank > 0) { if (nest_level != rank) { return TokenError( - Printf("expects nested array in rank %lld, but sees %lld", rank, - nest_level)); + absl::StrFormat("expects nested array in rank %d, but sees %d", + rank, nest_level)); } elems_seen_per_dim[rank - 1]++; if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) { - return TokenError( - Printf("expects %lld elements on the minor-most dimension, but " - "sees more", - shape.dimensions(rank - 1))); + return TokenError(absl::StrFormat( + "expects %d elements on the minor-most dimension, but " + "sees more", + shape.dimensions(rank - 1))); } } if (lexer_.GetKind() == TokKind::kw_true || @@ -1997,7 +2030,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, return Error( index_loc, StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", Join(index, ", "), "]")); + ": [", StrJoin(index, ", "), "]")); } } if (!ParseToken(TokKind::kColon, @@ -2126,8 +2159,8 @@ bool HloParser::ParseSubAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return Error(loc, Printf("sub-attribute %s is expected but not seen", - attr_it.first.c_str())); + return Error(loc, StrFormat("sub-attribute %s is expected but not seen", + attr_it.first)); } } return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes"); @@ -2147,8 +2180,8 @@ bool HloParser::ParseAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return Error(loc, Printf("attribute %s is expected but not seen", - attr_it.first.c_str())); + return Error(loc, StrFormat("attribute %s is expected but not seen", + attr_it.first)); } } return true; @@ -2164,7 +2197,7 @@ bool HloParser::ParseAttributeHelper( } VLOG(1) << "Parsing attribute " << name; if (!seen_attrs->insert(name).second) { - return Error(loc, Printf("attribute %s already exists", name.c_str())); + return Error(loc, StrFormat("attribute %s already exists", name)); } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { @@ -2174,13 +2207,13 @@ bool HloParser::ParseAttributeHelper( } else { allowed_attrs = StrCat( "Allowed attributes: ", - Join(attrs, ", ", - [&](string* out, const std::pair& kv) { - StrAppend(out, kv.first); - })); + StrJoin(attrs, ", ", + [&](string* out, const std::pair& kv) { + StrAppend(out, kv.first); + })); } - return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(), - allowed_attrs.c_str())); + return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name, + allowed_attrs)); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; @@ -2375,7 +2408,7 @@ bool HloParser::ParseAttributeHelper( } }(); if (!success) { - return Error(loc, Printf("error parsing attribute %s", name.c_str())); + return Error(loc, StrFormat("error parsing attribute %s", name)); } return true; } @@ -2490,20 +2523,24 @@ bool HloParser::ParseConvolutionDimensionNumbers( } string str = lexer_.GetStrVal(); - // The str is expected to have 3 items, lhs, rhs, out, and it must looks like + // The str is expected to have 3 items, lhs, rhs, out, and it must look like // lhs_rhs->out, that is, the first separator is "_" and the second is "->". - // So we replace the "->" with "_" and then split on "_". - str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->", - /*newsub=*/"_", - /*replace_all=*/false); - std::vector lhs_rhs_out = Split(str, "_"); - if (lhs_rhs_out.size() != 3) { + std::vector split1 = absl::StrSplit(str, "_"); + if (split1.size() != 2) { + LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " + << str; + } + std::vector split2 = absl::StrSplit(split1[1], "->"); + if (split2.size() != 2) { LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " << str; } + absl::string_view lhs = split1[0]; + absl::string_view rhs = split2[0]; + absl::string_view out = split2[1]; - const tensorflow::int64 rank = lhs_rhs_out[0].length(); - if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { + const tensorflow::int64 rank = lhs.length(); + if (rank != rhs.length() || rank != out.length()) { return TokenError( "convolution lhs, rhs, and output must have the same rank"); } @@ -2518,8 +2555,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( // lhs { - const string& lhs = lhs_rhs_out[0]; - if (!is_unique(lhs)) { + if (!is_unique(string(lhs))) { return TokenError( StrCat("expects unique lhs dimension numbers, but sees ", lhs)); } @@ -2536,14 +2572,13 @@ bool HloParser::ParseConvolutionDimensionNumbers( dnums->set_input_spatial_dimensions(c - '0', i); } else { return TokenError( - Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1)); + StrFormat("expects [0-%dbf] in lhs dimension numbers", rank - 1)); } } } // rhs { - const string& rhs = lhs_rhs_out[1]; - if (!is_unique(rhs)) { + if (!is_unique(string(rhs))) { return TokenError( StrCat("expects unique rhs dimension numbers, but sees ", rhs)); } @@ -2560,14 +2595,13 @@ bool HloParser::ParseConvolutionDimensionNumbers( dnums->set_kernel_spatial_dimensions(c - '0', i); } else { return TokenError( - Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1)); + StrFormat("expects [0-%dio] in rhs dimension numbers", rank - 1)); } } } // output { - const string& out = lhs_rhs_out[2]; - if (!is_unique(out)) { + if (!is_unique(string(out))) { return TokenError( StrCat("expects unique output dimension numbers, but sees ", out)); } @@ -2583,8 +2617,8 @@ bool HloParser::ParseConvolutionDimensionNumbers( } else if (c < '0' + rank && c >= '0') { dnums->set_output_spatial_dimensions(c - '0', i); } else { - return TokenError( - Printf("expects [0-%lldbf] in output dimension numbers", rank - 1)); + return TokenError(StrFormat( + "expects [0-%dbf] in output dimension numbers", rank - 1)); } } } @@ -2630,9 +2664,10 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { } const auto& range = ranges.back(); if (range.size() != 2 && range.size() != 3) { - return Error(loc, Printf("expects [start:limit:step] or [start:limit], " - "but sees %ld elements.", - range.size())); + return Error(loc, + StrFormat("expects [start:limit:step] or [start:limit], " + "but sees %d elements.", + range.size())); } } while (EatIfPresent(TokKind::kComma)); @@ -2818,14 +2853,13 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { LocTy loc = lexer_.GetLoc(); if (!result->empty()) { - return Error(loc, - Printf("sub-attribute '%s=' already exists", name.c_str())); + return Error(loc, StrFormat("sub-attribute '%s=' already exists", name)); } // 1D if (lexer_.GetKind() == TokKind::kInt) { tensorflow::int64 number; if (!ParseInt64(&number)) { - return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); + return Error(loc, StrFormat("expects sub-attribute '%s=i'", name)); } result->push_back(number); return true; @@ -2833,9 +2867,8 @@ bool HloParser::ParseDxD(const string& name, // 2D or higher. if (lexer_.GetKind() == TokKind::kDxD) { string str = lexer_.GetStrVal(); - if (!SplitAndParseAsInts(str, 'x', result)) { - return Error(loc, - Printf("expects sub-attribute '%s=ixj...'", name.c_str())); + if (!SplitToInt64s(str, 'x', result)) { + return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name)); } lexer_.Lex(); return true; @@ -2853,10 +2886,9 @@ bool HloParser::ParseWindowPad( return TokenError("expects window pad pattern, e.g., '0_0x3_3'"); } string str = lexer_.GetStrVal(); - std::vector padding_str = Split(str, 'x'); - for (int i = 0; i < padding_str.size(); i++) { + for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { std::vector low_high; - if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || + if (!SplitToInt64s(padding_dim_str, '_', &low_high) || low_high.size() != 2) { return Error(loc, "expects padding_low and padding_high separated by '_'"); @@ -2877,10 +2909,9 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { } LocTy loc = lexer_.GetLoc(); string str = lexer_.GetStrVal(); - std::vector padding_str = Split(str, 'x'); - for (const auto& padding_dim_str : padding_str) { + for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { std::vector padding_dim; - if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || + if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { return Error(loc, "expects padding config pattern like 'low_high_interior' or " @@ -2932,9 +2963,8 @@ bool HloParser::ParseOpcode(HloOpcode* result) { string val = lexer_.GetStrVal(); auto status_or_result = StringToHloOpcode(val); if (!status_or_result.ok()) { - return TokenError( - Printf("expects opcode but sees: %s, error: %s", val.c_str(), - status_or_result.status().error_message().c_str())); + return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -2948,7 +2978,7 @@ bool HloParser::ParseFftType(FftType* result) { } string val = lexer_.GetStrVal(); if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) { - return TokenError(Printf("expects fft type but sees: %s", val.c_str())); + return TokenError(StrFormat("expects fft type but sees: %s", val)); } lexer_.Lex(); return true; @@ -2962,9 +2992,9 @@ bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { string val = lexer_.GetStrVal(); auto status_or_result = StringToFusionKind(val); if (!status_or_result.ok()) { - return TokenError( - Printf("expects fusion kind but sees: %s, error: %s", val.c_str(), - status_or_result.status().error_message().c_str())); + return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s", + val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -2980,8 +3010,8 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { auto status_or_result = StringToRandomDistribution(val); if (!status_or_result.ok()) { return TokenError( - Printf("expects random distribution but sees: %s, error: %s", - val.c_str(), status_or_result.status().error_message().c_str())); + StrFormat("expects random distribution but sees: %s, error: %s", val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -2996,9 +3026,9 @@ bool HloParser::ParsePrecision(PrecisionConfigProto::Precision* result) { string val = lexer_.GetStrVal(); auto status_or_result = StringToPrecision(val); if (!status_or_result.ok()) { - return TokenError( - Printf("expects precision but sees: %s, error: %s", val.c_str(), - status_or_result.status().error_message().c_str())); + return TokenError(StrFormat("expects precision but sees: %s, error: %s", + val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -3092,7 +3122,7 @@ StatusOr HloParser::ParseShardingOnly() { lexer_.Lex(); OpSharding op_sharding; if (!ParseSharding(&op_sharding)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument("Syntax error:\nExtra content after sharding"); @@ -3104,7 +3134,7 @@ StatusOr HloParser::ParseWindowOnly() { lexer_.Lex(); Window window; if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument("Syntax error:\nExtra content after window"); @@ -3117,7 +3147,7 @@ HloParser::ParseConvolutionDimensionNumbersOnly() { lexer_.Lex(); ConvolutionDimensionNumbers dnums; if (!ParseConvolutionDimensionNumbers(&dnums)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument( @@ -3155,7 +3185,7 @@ Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder, // Parse the instruction with the registered hook. if (!ParseInstruction(builder, root_name)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } return Status::OK(); } @@ -3163,47 +3193,46 @@ Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder, } // namespace StatusOr> ParseHloString( - tensorflow::StringPiece str, const HloModuleConfig& config) { + absl::string_view str, const HloModuleConfig& config) { HloParser parser(str, config); if (!parser.Run()) { - return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", parser.GetError()); } return parser.ConsumeHloModule(); } -StatusOr> ParseHloString( - tensorflow::StringPiece str) { +StatusOr> ParseHloString(absl::string_view str) { HloModuleConfig config; return ParseHloString(str, config); } StatusOr> ParseHloOpToModule( - tensorflow::StringPiece str, tensorflow::StringPiece name) { + absl::string_view str, absl::string_view name) { HloModuleConfig config; HloParser parser(str, config); - auto builder = absl::make_unique(name.ToString()); + auto builder = absl::make_unique(string(name)); string root_name; TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name)); std::unique_ptr computation = builder->Build(); - auto module = absl::make_unique(name.ToString(), config); + auto module = absl::make_unique(string(name), config); module->AddEntryComputation(std::move(computation)); return std::move(module); } -StatusOr ParseSharding(tensorflow::StringPiece str) { +StatusOr ParseSharding(absl::string_view str) { HloModuleConfig config; HloParser parser(str, config); return parser.ParseShardingOnly(); } -StatusOr ParseWindow(tensorflow::StringPiece str) { +StatusOr ParseWindow(absl::string_view str) { HloModuleConfig config; HloParser parser(str, config); return parser.ParseWindowOnly(); } StatusOr ParseConvolutionDimensionNumbers( - tensorflow::StringPiece str) { + absl::string_view str) { HloModuleConfig config; HloParser parser(str, config); return parser.ParseConvolutionDimensionNumbersOnly(); diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 6c184bfe9ad8a49ee67c4621f1b22b90f1659e8f..0c64b50481bf2e86a2c588fbf2d77226c8428b7c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_lexer.h" @@ -32,32 +33,31 @@ namespace xla { // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with the given config. StatusOr> ParseHloString( - tensorflow::StringPiece str, const HloModuleConfig& config); + absl::string_view str, const HloModuleConfig& config); // Parses the text for a single HLO operation into an HLO module with a function // that runs that operation (with the same parameters) as its entry computation. StatusOr> ParseHloOpToModule( - tensorflow::StringPiece str, tensorflow::StringPiece name = "single_op"); + absl::string_view str, absl::string_view name = "single_op"); // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with default config. -StatusOr> ParseHloString( - tensorflow::StringPiece str); +StatusOr> ParseHloString(absl::string_view str); // Parses the result of HloSharding::ToString(), e.g. "{replicated}". -StatusOr ParseSharding(tensorflow::StringPiece str); +StatusOr ParseSharding(absl::string_view str); // Parses the result of window_util::ToString(const Window&). -StatusOr ParseWindow(tensorflow::StringPiece str); +StatusOr ParseWindow(absl::string_view str); // Parses the result of ConvolutionDimensionNumbersToString(), e.g. // "b0f_0io->b0f". StatusOr ParseConvolutionDimensionNumbers( - tensorflow::StringPiece str); + absl::string_view str); // ParseHloString sharding from str. str is supposed to contain the body of the // sharding, i.e. just the rhs of the "sharding={...}" attribute string. -StatusOr ParseSharding(tensorflow::StringPiece str); +StatusOr ParseSharding(absl::string_view str); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index f52cfadb811ca3b9b549d8e14cb619b0c5b9c02e..ba07ec432e9dddf3f0fc45164c66b2c8403568ff 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -16,20 +16,19 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" -namespace op = ::xla::testing::opcode_matchers; - namespace xla { - namespace { -using ::tensorflow::StringPiece; +namespace op = ::xla::testing::opcode_matchers; +using absl::string_view; struct TestData { string test_name; @@ -1094,7 +1093,19 @@ R"(HloModule AllToAllWithSubgroups ENTRY AllToAllWithSubgroups { input = f32[128,32]{0,1} parameter(0) - ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}, barrier="abc" + ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}} +} + +)" +}, +// collective-permute +{ +"CollectivePermute", +R"(HloModule CollectivePermute + +ENTRY CollectivePermute { + input = f32[128,32]{0,1} parameter(0) + ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} } )" @@ -1105,7 +1116,7 @@ ENTRY AllToAllWithSubgroups { R"(HloModule iota ENTRY Iota { - ROOT iota = f32[100]{0} iota() + ROOT iota = f32[100]{0} iota(), iota_dimension=0 } )" @@ -1128,8 +1139,8 @@ ENTRY Computation { class HloParserTest : public ::testing::Test, public ::testing::WithParamInterface { protected: - static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(tensorflow::str_util::StrContains(s, expected)) + static void ExpectHasSubstr(string_view s, string_view expected) { + EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } @@ -1393,15 +1404,14 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 )"; - ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat( - prefix, ",dim_labels=00_01_10", suffix)) - .status() - .error_message(), - "expects dim labels pattern"); + ExpectHasSubstr( + ParseHloString(absl::StrCat(prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); ExpectHasSubstr( - ParseHloString(tensorflow::strings::StrCat( - prefix, ",dim_labels=010_1100->010", suffix)) + ParseHloString(absl::StrCat(prefix, ",dim_labels=010_1100->010", suffix)) .status() .error_message(), "must have the same rank"); diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h index 0cddf8fb8f7589739d1233fa4974ff703211a137..f1ad0f9b0148cb3d5f938e7f5d220d6cb82ea98d 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_interface.h +++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h @@ -29,7 +29,7 @@ namespace xla { class HloPassInterface { public: virtual ~HloPassInterface() = default; - virtual tensorflow::StringPiece name() const = 0; + virtual absl::string_view name() const = 0; // Run the pass on the given HLO module. Return whether it modified the // module. diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index d8f1ab916b5c5c500c2d8dcd8605be083f95862a..6e4ed0de626688c0d836d6bc9c619245db8d61dd 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -17,22 +17,23 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace xla { - namespace { + +using absl::StrAppend; +using absl::StrCat; + void DumpModuleGraph(const HloModule& module, const string& message) { hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(3) << "HLO " << message << ":"; @@ -48,9 +49,9 @@ void DumpModuleProto(const HloModule& module, const string& dump_to, tensorflow::mutex_lock lock(mu); const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; - const string mod_name = SanitizeFileName(tensorflow::strings::Printf( - "module_%04d.%04lld.%s.after_%s", module.unique_id(), pass_number, - pipeline_name.c_str(), pass_name.c_str())); + const string mod_name = SanitizeFileName( + absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(), + pass_number, pipeline_name, pass_name)); TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module), dump_to, mod_name)); @@ -68,7 +69,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { repeated_field.end()); if (!disabled_passes.empty()) { VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " - << tensorflow::str_util::Join(disabled_passes, ", "); + << absl::StrJoin(disabled_passes, ", "); } auto run_invariant_checkers = [this, @@ -90,7 +91,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { return Status::OK(); }; - string prefix = std::string(name()) + ": pipeline start"; + string prefix = StrCat(name(), ": pipeline start"); bool changed = false; string message; TF_RETURN_IF_ERROR( @@ -98,12 +99,12 @@ StatusOr HloPassPipeline::Run(HloModule* module) { const string xla_dump_per_pass_hlo_proto_to = module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, - std::string(name()), "pipeline_start"); + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()), + "pipeline_start"); } for (auto& pass : passes_) { - if (disabled_passes.count(std::string(pass->name())) > 0) { + if (disabled_passes.count(string(pass->name())) > 0) { VLOG(1) << " Skipping HLO pass " << pass->name() << ", disabled by --xla_disable_hlo_passes"; continue; @@ -120,8 +121,8 @@ StatusOr HloPassPipeline::Run(HloModule* module) { TF_RETURN_IF_ERROR( run_invariant_checkers(StrCat("after running pass: ", pass->name()))); if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, - std::string(name()), std::string(pass->name())); + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()), + string(pass->name())); } changed |= changed_this_pass; diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index 3bb1342aa370c09dc5cd180e6b0abade4a62c91d..1d41a4dac1d8e2f392be0e4e856ead36a5b71d68 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -34,7 +34,7 @@ namespace xla { class HloPassPipeline : public HloPassInterface { public: explicit HloPassPipeline(const string& name) : name_(name) {} - tensorflow::StringPiece name() const override { return name_; } + absl::string_view name() const override { return name_; } // Add a pass to the pipeline. It should be called with the arguments for the // pass constructor: diff --git a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc index b9cca138703c8fa61aadf69dd7304a215a9f4be2..c3cacd7ce6b1ea3ad7cf84e898f274ae12622ac5 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 04e4a293596fe057bf770ec2949fb83ffadce117..569d2e5d2d9b3aea4b79924af7839a03fc8de285 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -21,6 +21,9 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" @@ -38,17 +41,13 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" -using ::tensorflow::strings::HumanReadableNumBytes; - namespace xla { - namespace { +using ::tensorflow::strings::HumanReadableNumBytes; + // Potential optimizations: // . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue // of candidates. @@ -207,11 +206,10 @@ class InstructionList { Item* to_insert, tensorflow::gtl::ArraySlice before_instructions) { VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name() << " before {" - << tensorflow::str_util::Join(before_instructions, ", ", - [](string* out, Item* item) { - tensorflow::strings::StrAppend( - out, item->instruction->name()); - }) + << absl::StrJoin(before_instructions, ", ", + [](string* out, Item* item) { + absl::StrAppend(out, item->instruction->name()); + }) << "}"; // Find the minimal position number of any instruction in @@ -394,10 +392,9 @@ class MemoryUsageTracker { int64 unfinished_user_count; string ToString() const { - return tensorflow::strings::StrCat( - "Buffer ", id, " (defined by ", - defining_instruction->instruction->name(), ", size ", size, - " bytes)"); + return absl::StrCat("Buffer ", id, " (defined by ", + defining_instruction->instruction->name(), ", size ", + size, " bytes)"); } }; @@ -741,29 +738,27 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, } string MemoryUsageTracker::ToString() const { - string output = tensorflow::strings::StrCat("MemoryUsageTracker for ", - computation_->name(), "\n"); - tensorflow::strings::StrAppend( - &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", - memory_usage(), " bytes)"); + string output = + absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n"); + absl::StrAppend(&output, + "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", + memory_usage(), " bytes)"); for (auto* item = instruction_list_.first(); item != nullptr; item = instruction_list_.next(item)) { const HloInstruction* instruction = item->instruction; string inprogress = item == in_progress_item_ ? " in-progress" : ""; string placed = item->placed ? " placed" : ""; - tensorflow::strings::StrAppend(&output, " ", instruction->name(), - inprogress, placed, "\n Defines:\n"); + absl::StrAppend(&output, " ", instruction->name(), inprogress, placed, + "\n Defines:\n"); for (BufferId buffer_id : item->buffers_defined) { const Buffer& buffer = buffers_[buffer_id]; string live = IsCurrentlyLive(buffer_id) ? " live" : ""; - tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live, - ", ", buffer.unfinished_user_count, - " unfinished uses\n"); + absl::StrAppend(&output, " ", buffer.ToString(), live, ", ", + buffer.unfinished_user_count, " unfinished uses\n"); } - tensorflow::strings::StrAppend(&output, " Uses:\n"); + absl::StrAppend(&output, " Uses:\n"); for (BufferId buffer_id : item->buffers_used) { - tensorflow::strings::StrAppend(&output, " ", - buffers_[buffer_id].ToString(), "\n"); + absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n"); } } return output; @@ -781,10 +776,9 @@ bool MemoryUsageTracker::Check() const { CHECK(elements_are_unique(defined_buffers)) << "Instruction " << instruction->name() << " does not have unique defined buffers: " - << tensorflow::str_util::Join( + << absl::StrJoin( defined_buffers, ", ", [this](string* out, BufferId buffer_id) { - tensorflow::strings::StrAppend( - out, buffers_.at(buffer_id).ToString()); + absl::StrAppend(out, buffers_.at(buffer_id).ToString()); }); for (const Buffer& buffer : buffers_) { @@ -804,10 +798,9 @@ bool MemoryUsageTracker::Check() const { CHECK(elements_are_unique(used_buffers)) << "Instruction " << instruction->name() << " does not have unique used buffers: " - << tensorflow::str_util::Join( + << absl::StrJoin( used_buffers, ", ", [this](string* out, BufferId buffer_id) { - tensorflow::strings::StrAppend( - out, buffers_.at(buffer_id).ToString()); + absl::StrAppend(out, buffers_.at(buffer_id).ToString()); }); } for (const Buffer& buffer : buffers_) { @@ -1210,6 +1203,49 @@ StatusOr HloRematerialization::Run( VLOG(1) << "HloRematerialization() with memory limit of " << HumanReadableNumBytes(memory_limit_bytes); + XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); + + // Create initial sequence of HLO instructions. + TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( + *module, + [this](const BufferValue& buffer) { + return size_function_(buffer.shape()); + }, + scheduler_algorithm_)); + if (copy_insertion) { + // We run a separate pass of copy elision here because the sequential + // ordering from the HLO schedule allows for more copies to be eliminated. + // TODO(b/80249101): Instead of a separate copy elision pass, use the + // ordering from the HLO schedule directly for copy insertion. + + // First create a copy of the schedule which contains HloInstruction unique + // ids instead of HloInstruction*. This is necessary for updating the + // schedule below. + // TODO(b/113175018): Remove this when the HLO schedule is self-contained + // and can update itself. + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(*sequence); + + SequentialHloOrdering ordering(module, *sequence); + TF_RETURN_IF_ERROR( + copy_insertion->RemoveUnnecessaryCopies(ordering, module)); + + // RemoveUnnecessaryCopies only considers interference when determining + // whether it is legal to remove a copy. However, copies in the graph may be + // necessary for other reason such as preventing a constant from being live + // out of the graph. So run AddSpecialCaseCopies to re-insert these copies. + // TODO(b/80249101): Break copy insertion into several passes and run each + // one once in the regular HLO pipeline. + TF_RETURN_IF_ERROR(copy_insertion->AddSpecialCaseCopies(module)); + + // The passes above can add and remove copies, update the schedule to + // account for these transformations. Newly added instructions will be + // placed ASAP in the schedule. + TF_RETURN_IF_ERROR(UpdateSchedule(*module, id_sequence, sequence)); + + TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference( + SequentialHloOrdering(module, *sequence), module)); + } TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); @@ -1231,24 +1267,6 @@ StatusOr HloRematerialization::Run( << HumanReadableNumBytes(module_output_size) << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); - XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); - // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( - *module, - [this](const BufferValue& buffer) { - return size_function_(buffer.shape()); - }, - scheduler_algorithm_)); - if (copy_insertion) { - // We run a separate pass of copy elision here because the sequential - // ordering from the HLO schedule allows for more copies to be eliminated. - // TODO(b/80249101): Instead of a separate copy elision pass, use the - // ordering from the HLO schedule directly for copy insertion. - SequentialHloOrdering ordering(module, *sequence); - TF_RETURN_IF_ERROR( - copy_insertion->RemoveUnnecessaryCopies(ordering, module)); - } - // Compute peak memory usage of all computations in the module called in a // sequential context. call_graph_ = CallGraph::Build(module); @@ -1335,12 +1353,11 @@ StatusOr HloRematerialization::Run( XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); if (current_peak_memory > memory_limit_bytes) { - LOG(WARNING) << tensorflow::strings::Printf( - "Can't reduce memory use below %s (%lld bytes) by rematerialization; " - "only reduced to %s (%lld bytes)", - HumanReadableNumBytes(memory_limit_bytes).c_str(), memory_limit_bytes, - HumanReadableNumBytes(current_peak_memory).c_str(), - current_peak_memory); + LOG(WARNING) << absl::StrFormat( + "Can't reduce memory use below %s (%d bytes) by rematerialization; " + "only reduced to %s (%d bytes)", + HumanReadableNumBytes(memory_limit_bytes), memory_limit_bytes, + HumanReadableNumBytes(current_peak_memory), current_peak_memory); } return changed; diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 8f3ae9c62127d8bd79f272f801d9aa9a3043ab6a..7bd8a4a544b21a35f20eeed493f7e0528a7e87dd 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -32,7 +32,7 @@ limitations under the License. namespace xla { /*static*/ StatusOr> -HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, +HloRunner::CreateModuleFromString(const absl::string_view hlo_string, const DebugOptions& debug_options) { HloModuleConfig config; config.set_debug_options(debug_options); diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 65537f07f56e74b7fe2c2f9792af21efc7229573..cfc519063e837cb961c4c4fb1efe611a7fe273ba 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -87,8 +87,7 @@ class HloRunner { // Converts an HloModule from the given hlo textual IR string (in // HloModule::ToString format). static StatusOr> CreateModuleFromString( - const tensorflow::StringPiece hlo_string, - const DebugOptions& debug_options); + const absl::string_view hlo_string, const DebugOptions& debug_options); // Reads the proto file in xla.HloProto format, creates and returns the // HloModule. diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 27cc5361cde2fa021b9489f98217ae5648afc2ad..0fc3b268c059802a3882ad5032a9fe5da28cbf23 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include +#include #include #include @@ -28,16 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" -using ::tensorflow::strings::HumanReadableNumBytes; - namespace xla { - namespace { +using ::tensorflow::strings::HumanReadableNumBytes; + // Class implementing a list scheduler of HLO instructions which produces a // sequence which minimizes memory usage by preferring to schedule the node that // frees bigger buffer and defines smaller outputs. @@ -582,4 +581,187 @@ StatusOr> ScheduleOneComputation( size_function, nullptr, empty_map); } +tensorflow::gtl::FlatMap> +ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence) { + tensorflow::gtl::FlatMap> id_sequence; + for (const auto& computation_sequence : sequence) { + for (const HloInstruction* instruction : computation_sequence.second) { + id_sequence[computation_sequence.first].push_back( + instruction->unique_id()); + } + } + return id_sequence; +} + +Status UpdateSchedule( + const HloModule& module, + const tensorflow::gtl::FlatMap>& + id_sequence, + SequentialHloOrdering::HloModuleSequence* sequence) { + // Map from unique ID to HloInstruction pointer for instructions in the + // module. + tensorflow::gtl::FlatMap id_to_instruction; + // Set of all HloInstructions in the schedule. + tensorflow::gtl::FlatSet ids_in_schedule; + std::vector nonfusion_computations = + module.MakeNonfusionComputations(); + for (const HloComputation* computation : nonfusion_computations) { + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK( + id_to_instruction.insert({instruction->unique_id(), instruction}) + .second); + } + for (int id : id_sequence.at(computation)) { + ids_in_schedule.insert(id); + } + } + + // Map from HloInstruction X to newly added instructions (instruction is in + // module, but not in schedule) which use X. If an instruction is not in the + // map, then it has no users which are newly added instructions. + tensorflow::gtl::FlatMap> + new_instruction_uses; + + // For each newly added instruction, this is the count of the instruction's + // operands that have not yet been scheduled. When this value reaches zero, + // then the instruction may be placed in the schedule. + tensorflow::gtl::FlatMap + unscheduled_operand_count; + // For each computation, this is the set of newly added instructions which + // have no operands. These must be handled specially and are added to the + // beginning of the schedule. + tensorflow::gtl::FlatMap> + new_zero_operand_instructions; + for (const HloComputation* computation : nonfusion_computations) { + new_zero_operand_instructions[computation] = {}; + for (const HloInstruction* instruction : computation->instructions()) { + if (ids_in_schedule.count(instruction->unique_id()) == 0) { + // This is a newly added instruction which is not in the schedule. + for (const HloInstruction* operand : instruction->operands()) { + new_instruction_uses[operand].push_back(instruction); + } + if (instruction->operands().empty()) { + new_zero_operand_instructions[computation].push_back(instruction); + } + unscheduled_operand_count[instruction] = instruction->operand_count(); + } + } + } + + // Update the schedule with the newly added instructions, and remove any + // instructions no longer in the graph. + for (const HloComputation* computation : nonfusion_computations) { + std::vector old_computation_sequence = + std::move(sequence->at(computation)); + sequence->at(computation).clear(); + + // Create a worklist of newly added instructions which are ready to be added + // to the schedule. Initialize worklist with those that have zero operands. + std::queue worklist; + for (const HloInstruction* instruction : + new_zero_operand_instructions.at(computation)) { + worklist.push(instruction); + } + + // Lambda which schedules all instructions on the worklist. + auto schedule_worklist = [&]() { + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop(); + sequence->at(computation).push_back(instruction); + std::vector* new_users = + tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); + if (new_users != nullptr) { + // This just-scheduled instruction has users which are newly added to + // the module. Update the number of unscheduled operands and push the + // newly added instruction to the worklist if it is ready to + // schedule. + for (const HloInstruction* new_user : *new_users) { + unscheduled_operand_count.at(new_user)--; + CHECK_GE(unscheduled_operand_count.at(new_user), 0); + if (unscheduled_operand_count.at(new_user) == 0) { + worklist.push(new_user); + } + } + } + } + }; + + schedule_worklist(); + for (int id : id_sequence.at(computation)) { + auto it = id_to_instruction.find(id); + if (it == id_to_instruction.end()) { + // This instruction in the schedule is no longer in the module. + continue; + } + const HloInstruction* instruction = it->second; + worklist.push(instruction); + schedule_worklist(); + } + } + + TF_RETURN_IF_ERROR(VerifySchedule(module, *sequence)); + return Status::OK(); +} + +Status VerifySchedule( + const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& sequence) { + VLOG(2) << "VerifySchedule()"; + XLA_VLOG_LINES(2, module.ToString()); + VLOG(2) << sequence; + + // Verify the set of computations in the sequence is exactly the set of + // computations in the module. + std::vector nonfusion_computations = + module.MakeNonfusionComputations(); + TF_RET_CHECK(nonfusion_computations.size() == sequence.size()); + tensorflow::gtl::FlatSet computations_in_module( + module.computations().begin(), module.computations().end()); + for (const auto& computation_sequence : sequence) { + TF_RET_CHECK(computations_in_module.count(computation_sequence.first) == 1); + } + + // For each computation verify the set of instructions is the same and that + // each dependency and control edge is honored. + for (const HloComputation* computation : nonfusion_computations) { + tensorflow::gtl::FlatMap instruction_position; + int pos = 0; + for (const HloInstruction* instruction : sequence.at(computation)) { + TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) + << "Instruction " << instruction->name() + << " appears more than once in the schedule"; + pos++; + } + + TF_RET_CHECK(instruction_position.size() == + computation->instruction_count()); + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK(instruction_position.count(instruction) == 1) + << "Instruction " << instruction->name() << " is not in schedule"; + } + + for (const HloInstruction* instruction : computation->instructions()) { + for (const HloInstruction* operand : instruction->operands()) { + TF_RET_CHECK(instruction_position.at(operand) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its operand " << operand->name(); + } + + for (const HloInstruction* pred : instruction->control_predecessors()) { + TF_RET_CHECK(instruction_position.at(pred) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its control predecessor " + << pred->name(); + } + } + } + + return Status::OK(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 2b33ccc8bfb895286bb3747aab0a16cf25e2cfae..d06b8d9a5cdef82380bd68ae0991a3957db80f48 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -85,6 +85,43 @@ StatusOr> ScheduleOneComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function); +// Transforms the given schedule such that it is (again) a valid schedule for +// the module. This is used to update a schedule after the HLO module has been +// transformed in some way. In general, the only transformations to the module +// for which a schedule can be updated is the addition or removal of +// instructions to/from the module. Updating the schedule after new dependencies +// between existing instructions in the module is not supported and may result +// in an error status returned. +// +// Instructions in the module which also exist in the given schedule will remain +// in the same order in the updated schedule. Instructions which exist in the +// module but not in the given schedule will be placed as early as possible in +// the updated schedule. +// +// 'id_sequence' is a mirror of the given schedule 'sequence' but with +// HloInstruction ids rather than HloInstruction pointers. This should be +// constructed using ComputeIdSchedule below after the schedule is constructed +// but before the HLO module is transformed. +Status UpdateSchedule( + const HloModule& module, + const tensorflow::gtl::FlatMap>& + id_sequence, + SequentialHloOrdering::HloModuleSequence* sequence); + +// Constructs a copy of the given schedule but with HloInstruction unique ids +// rather than HloInstruction pointers. This is necessary for updating a +// schedule as HloInstruction points in the schedule may become invalid if +// instructions are removed from the module. Used by UpdateSchedule above.. +// TODO(b/113175018): Remove this function when HLO schedule is its own class. +tensorflow::gtl::FlatMap> +ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence); + +// Verifies that the given schedule is valid for the given module. Specifically, +// the schedule contains exactly the instructions in the module and every +// dependency in the module is satisfied in the schedule. +Status VerifySchedule(const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& sequence); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 639c20ad8e181cfdaa80ccf0311215fc64b52829..930801288a0ea0fa7fd75dd38610430ae7010b5a 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -415,5 +417,251 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { .ValueOrDie()); } +TEST_F(HloSchedulingTest, UpdateScheduleUnchangedModule) { + // Updating the schedule of an unchanged HLO module should not affect the + // schedule at all. + const string module_str = R"( +HloModule UpdateScheduleUnchanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(sequence); + std::vector entry_schedule = sequence.begin()->second; + + EXPECT_EQ(entry_schedule.size(), 6); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(entry_schedule, sequence.begin()->second); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithNewInstructions) { + // Add some additional instructions to a module and verify the schedule can be + // updated. + const string module_str = R"( +HloModule UpdateScheduleWithNewInstructions + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(sequence); + + HloComputation* entry = module->entry_computation(); + const Shape shape = entry->root_instruction()->shape(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kSubtract, constant, entry->root_instruction())); + entry->set_root_instruction(sub); + + auto in_schedule = [&](const HloInstruction* hlo) { + return std::find(sequence.at(entry).begin(), sequence.at(entry).end(), + hlo) != sequence.at(entry).end(); + }; + + EXPECT_EQ(sequence.at(entry).size(), 6); + EXPECT_FALSE(in_schedule(constant)); + EXPECT_FALSE(in_schedule(sub)); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(entry).size(), 8); + EXPECT_TRUE(in_schedule(constant)); + EXPECT_TRUE(in_schedule(sub)); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithAddedAndDeletedInstruction) { + // Add and delete some instructions from a module and verify that the schedule + // can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithAddedAndDeletedInstruction + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(sequence); + + // Set the entry root to some expression containing just a parameter and a + // constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* new_root = entry->AddInstruction( + HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, + constant, entry->parameter_instruction(0))); + entry->set_root_instruction(new_root); + + // DCE should remove everything but the parameters and the newly added code. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(sequence.at(entry).size(), 6); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(entry).size(), 4); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithCompletelyReplacedModule) { + // Completely replace a module with an entirely new set of instructions and + // verify that the schedule can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithCompletelyReplacedModule + +ENTRY main { + a = f32[] constant(42.0) + b = f32[] constant(123.0) + ROOT sum = f32[] add(a, b) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(sequence); + + // Replace the entry computation with the negation of a constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kNegate, constant)); + entry->set_root_instruction(new_root); + + // DCE the old instructions. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(sequence.at(entry).size(), 3); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(entry).size(), 2); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithMultipleComputations) { + // Create changes to more than one computation in an HLO module and verify + // that the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(sequence); + + const HloInstruction* xla_while = + module->entry_computation()->root_instruction()->operand(0); + HloComputation* body = xla_while->while_body(); + HloComputation* cond = xla_while->while_condition(); + + // Negate the root of the cond. + cond->set_root_instruction(cond->AddInstruction( + HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kNot, cond->root_instruction()))); + + // Replace the body with a computation which just passes through its + // parameter. + body->set_root_instruction(body->parameter_instruction(0)); + + // DCE the dead code in the body. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(sequence.at(body).size(), 7); + EXPECT_EQ(sequence.at(cond).size(), 4); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(body).size(), 1); + EXPECT_EQ(sequence.at(cond).size(), 5); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 903fbbec1a0fd2cd696a5aac14521849f5903df2..980dae07ceec20a945f7db5f1377c6f5c08af47a 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -15,13 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrCat; +using absl::StrCat; +using absl::StrJoin; HloSharding HloSharding::AssignDevice(int64 device_id) { return HloSharding(device_id); @@ -71,12 +72,9 @@ HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, const HloSharding& sharding) { CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); CHECK(!sharding.IsTuple()) << sharding.ToString(); - int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape); + int64 leaf_count = RequiredLeaves(tuple_shape); std::vector flattened_list; - flattened_list.reserve(leaf_count); - for (int64 i = 0; i < leaf_count; ++i) { - flattened_list.push_back(sharding); - } + flattened_list.resize(leaf_count, sharding); return HloSharding(flattened_list); } @@ -92,7 +90,7 @@ string HloSharding::ToString() const { for (const HloSharding& element : tuple_elements_) { parts.push_back(element.ToString()); } - return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); + return StrCat("{", absl::StrJoin(parts, ", "), "}"); } if (replicated_) { @@ -101,8 +99,8 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); } else { - return StrCat("{devices=[", Join(tile_assignment_.dimensions(), ","), "]", - Join(tile_assignment_, ","), "}"); + return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), + "]", StrJoin(tile_assignment_, ","), "}"); } } @@ -445,7 +443,7 @@ absl::optional HloSharding::ExtractSingleSharding() const { } for (int64 i = 1; i < tuple_elements_.size(); ++i) { if (tuple_elements_[0] != tuple_elements_[i]) { - return absl::optional(); + return absl::nullopt; } } return tuple_elements_.front(); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 4c64ac60c5f907d3fb6ff35e9faaea28eaab3cb7..be51c3f55b59aa65dbb15210b494a5e795f0cd3e 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -260,9 +260,9 @@ class HloSharding { bool maximal_; bool tuple_; Array tile_assignment_; - // Only non-empty when tuple_ is true, but because empty tuples are allowed - // may also be empty even then. This is a flattened list of all the leaf - // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order). + // Only non-empty when tuple_ is true. If a tuple is empty then one entry is + // present for the root. This is a flattened list of all the leaf shardings in + // a tuple shape, by pre-order walk (ShapeTree iterator order). std::vector tuple_elements_; }; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 6f0353ee5f5cccd45459f1a54436b64774e58ca8..6e9b96488cf6343d641405fbda6744d021dd1855 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -24,6 +24,23 @@ namespace xla { namespace { +// AssignmentKind and kUnassignedDevice are used during tuple domain sharding +// propagation in order to distinguish among three cases: +// kUnassigned: no assignment has occurred +// kAssigned: at least an assignment has occurred +// kConflict: no assignment has occurred because of conflicting propagations, +// which occurs when multiple users of an instruction have different +// shardings. +enum class AssignmentKind { kUnassigned, kAssigned, kConflict }; + +// kUnassignedDevice can only be assigned to tuple leaf shardings to indicate +// absence of sharding information for that particular sub-sharding during +// sharding propagation. It is used to be able to express tuple shardings with +// partial information. At the end of the propagation the sharding of +// tuple-shaped instructions using kUnassignedDevice's is cleared. +// TODO(b/112883246): Centralized enum of reserved devices. +constexpr int64 kUnassignedDevice = -2; + struct PassThrough { PassThrough(HloInstruction* user, HloInstruction* operand) : user(user), operand(operand) {} @@ -118,13 +135,17 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, return Status::OK(); } -std::unique_ptr CloneShardingForDomain( - const HloSharding& sharding) { - auto single_sharding = sharding.ExtractSingleSharding(); +// For tuple shardings if every element have the same sharsing then we want to +// treat them as single element sharsings to insert less domain separation as a +// domain can prevent some optimizations and we want to minimize that from +// happening. +std::shared_ptr CloneShardingForDomain( + std::shared_ptr sharding) { + auto single_sharding = sharding->ExtractSingleSharding(); if (!single_sharding) { - return absl::make_unique(sharding); + return sharding; } - return absl::make_unique(*single_sharding); + return std::make_shared(*single_sharding); } Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, @@ -143,108 +164,174 @@ Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, return Status::OK(); } -// Retrieves the sharding of a tuple shaped instruction in form of a ShapeTree. -// If the instruction has no sharding, a ShapeTree with HloSharding::Replicate() -// sharding will be returned. -ShapeTree GetTupleSharding(HloInstruction* tuple) { - if (tuple->has_sharding()) { - return tuple->sharding().GetAsShapeTree(tuple->shape()); +// Return the ShapeTree of the user argument. The user argument +// is assumed to be a user of the instruction argument. +// If user is a tuple instruction, return the tuple subsharding corresponding to +// the operand matching the instruction argument, because that is the +// subsharding corresponding to instruction. +ShapeTree GetShardingTreeFromUser( + const HloInstruction& instruction, const HloInstruction& user) { + if (user.opcode() == HloOpcode::kTuple) { + return user.sharding() + .GetSubSharding(user.shape(), {user.operand_index(&instruction)}) + .GetAsShapeTree(instruction.shape()); + } + return user.sharding().GetAsShapeTree(user.shape()); +} + +// Assign rhs to lhs. If rhs is unassigned (assigned to kUnassignedDevice) +// then no assignment is made. Therefore kUnassignedDevice is never propagated. +// kConflict is returned if lhs is already assigned and rhs is assigned to a +// different device. +StatusOr AssignLeafSharding(HloSharding* lhs, + const HloSharding& rhs) { + TF_RET_CHECK(!lhs->IsTuple() && !rhs.IsTuple()); + if (rhs.UsesDevice(kUnassignedDevice)) { + return AssignmentKind::kUnassigned; + } + if (lhs->UsesDevice(kUnassignedDevice)) { + *lhs = rhs; + return AssignmentKind::kAssigned; + } + return lhs->UniqueDevice() != rhs.UniqueDevice() + ? AssignmentKind::kConflict + : AssignmentKind::kUnassigned; +} + +// Assigns the whole rhs tree to lhs_tree, starting at lhs_it. +// In case of conflicting assignment AssignmentKind::kConflict is returned. In +// this case lhs_tree is partially assigned, up to the conflicting leaf. It is +// up to the caller to discard the partial assignment in case of conflict. +StatusOr AssignTreeSharding( + ShapeTree* lhs_tree, ShapeTree::iterator lhs_it, + const ShapeTree& rhs_tree) { + AssignmentKind assigned = AssignmentKind::kUnassigned; + auto rhs_it = rhs_tree.begin(); + for (; lhs_it != lhs_tree->end() && rhs_it != rhs_tree.end(); + ++lhs_it, ++rhs_it) { + // TODO(b/112885211): Add ShapeTree::IsLeaf(const ShapeTreeIterator &it) + if (rhs_tree.IsLeaf(rhs_it->first)) { + TF_RET_CHECK(lhs_tree->IsLeaf(lhs_it->first)); + TF_ASSIGN_OR_RETURN(AssignmentKind sub_assigned, + AssignLeafSharding(&lhs_it->second, rhs_it->second)); + if (sub_assigned == AssignmentKind::kConflict) { + // In case of conflict we return conflict to the caller. At this point + // partial assignments to lhs_tree may have been made already. It is up + // to the caller to discard the partial assignment in case of conflict. + return AssignmentKind::kConflict; + } else if (sub_assigned == AssignmentKind::kAssigned) { + assigned = sub_assigned; + } + } } - return ShapeTree(tuple->shape(), HloSharding::Replicate()); + TF_RET_CHECK(rhs_it == rhs_tree.end()); + return assigned; } -// Retrieves the sharding of operand, asked from a user instruction which is -// within domain. If operand is a kDomain, it means that sharding argument is -// the operand sharding, otherwise the operand's own sharding will be returned. -const HloSharding* GetOperandSharding(const HloInstruction* operand, +StatusOr ApplyShardingFromUsers(HloInstruction* instruction, const DomainMetadata::Domain& domain, - const HloSharding& sharding) { - // Here the user of operand is within the domain instruction set, and since it - // is user of operand, we need to look into the enter_domains set. If this is - // not a kDomain within the user domains set, then return the operand - // sharding, if any. - if (operand->opcode() != HloOpcode::kDomain || - domain.enter_domains.count(const_cast(operand)) == 0) { - return operand->has_sharding() ? &operand->sharding() : nullptr; + const HloSharding& domain_sharding) { + if (instruction->users().empty()) { + // No sharding from users, use domain_sharding, after checking + // compatibility. + TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()) && + ShapeUtil::GetLeafCount(instruction->shape()) == + domain_sharding.tuple_elements().size()); + instruction->set_sharding(domain_sharding); + return true; + } + AssignmentKind assigned = AssignmentKind::kUnassigned; + // The sharding_tree leaves are initialized to kUnassignedDevice. Only Tuple + // subshardings can result in a final sharding assignment containing + // kUnassignedDevice leaves, in case some tuple indexes are not used, or are + // used by users that don't have a sharding. + // Non-tuple shardings are either assigned to a real sharding, or are not + // assigned at all. As such they will never get assigned to kUnassignedDevice. + // In any case, kUnassignedDevice is never propagated, from the implementation + // of AssignLeafSharding. + ShapeTree sharding_tree( + instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice)); + for (HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kDomain && + domain.exit_domains.count(const_cast(user)) > 0) { + // If a user is a domain and it is registered in the domain exits, then + // the instruction sharding is taken directly from the domain, and no + // further users need to be visited. + instruction->set_sharding(domain_sharding); + return true; + } + if (!user->has_sharding()) { + continue; + } + AssignmentKind sub_assigned = AssignmentKind::kUnassigned; + ShapeTree user_sharding_tree = + GetShardingTreeFromUser(*instruction, *user); + if (ShapeUtil::IsTuple(instruction->shape())) { + // For tuple-shaped instructions collect individual tuple subshardings + // from the uses, and then combine them into the tuple sharding. + // If the user is a GTE its sharding concerns only the subtree of + // sharding_tree at index user->tuple_index, otherwise the whole + // sharding_tree is affected. + ShapeTree::iterator sharding_tree_begin = + user->opcode() == HloOpcode::kGetTupleElement + ? sharding_tree.find({user->tuple_index()}) + : sharding_tree.begin(); + TF_ASSIGN_OR_RETURN( + sub_assigned, AssignTreeSharding(&sharding_tree, sharding_tree_begin, + user_sharding_tree)); + } else { + // Non-tuple shape: assign common users sharding. + TF_RET_CHECK(user_sharding_tree.leaf_count() == 1) + << "Expected non-tuple user sharding"; + TF_ASSIGN_OR_RETURN( + sub_assigned, + AssignTreeSharding(&sharding_tree, sharding_tree.begin(), + user_sharding_tree)); + } + + if (sub_assigned == AssignmentKind::kConflict) { + // In case of conflict we don't assign any sharding. + return false; + } else if (sub_assigned == AssignmentKind::kAssigned) { + assigned = sub_assigned; + } + } + + if (assigned == AssignmentKind::kAssigned) { + if (ShapeUtil::IsTuple(instruction->shape())) { + instruction->set_sharding(HloSharding::Tuple(sharding_tree)); + } else { + TF_RET_CHECK(sharding_tree.leaf_count() == 1); + instruction->set_sharding(sharding_tree.leaf_begin()->second); + } + return true; } - // At this point operand is a kDomain of the currently processed domain, so we - // can refer to sharding as the domain sharding. - return &sharding; + return false; } // Tries to propagate the sharding information into the instructions that are -// part of the domain, in a post order manner (operand propagate to user). +// part of the domain, in a reverse post order manner (users propoagate to +// instruction). StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, - const HloSharding& sharding) { + const HloSharding& domain_sharding) { int64 assigned = 0; - for (HloInstruction* instruction : domain.instructions) { + // domain.instructions are ordered in a post-order manner. As we do + // user->operand propagation we process instructions in reverse order. In so + // doing we are guaranteed to process all users before their operands. + for (auto it = domain.instructions.rbegin(); it != domain.instructions.rend(); + ++it) { + HloInstruction* instruction = *it; if (instruction->has_sharding()) { continue; } - if (instruction->opcode() == HloOpcode::kGetTupleElement) { - HloInstruction* tuple = instruction->mutable_operand(0); - const HloSharding* tuple_sharding = - GetOperandSharding(tuple, domain, sharding); - if (tuple_sharding != nullptr) { - if (tuple_sharding->IsTuple()) { - HloSharding sub_sharding = tuple_sharding->GetSubSharding( - tuple->shape(), {instruction->tuple_index()}); - VLOG(4) << " " << instruction->name() << " to sharding " - << sub_sharding; - instruction->set_sharding(sub_sharding); - } else { - SetSingleSharding(instruction, *tuple_sharding); - } - ++assigned; - } - } else if (instruction->opcode() == HloOpcode::kTuple) { - int64 tuple_assigned = 0; - ShapeTree shape_tree = GetTupleSharding(instruction); - for (int64 i = 0; i < instruction->operand_count(); ++i) { - const HloSharding* operand_sharding = - GetOperandSharding(instruction->operand(i), domain, sharding); - if (operand_sharding != nullptr) { - HloSharding operand_subsharding = HloSharding::Replicate(); - if (operand_sharding == &sharding) { - operand_subsharding = - sharding.GetSubSharding(instruction->shape(), {i}); - operand_sharding = &operand_subsharding; - } - if (shape_tree.element({i}) != *operand_sharding) { - *shape_tree.mutable_element({i}) = *operand_sharding; - ++tuple_assigned; - } - } - } - if (tuple_assigned > 0) { - HloSharding tuple_sharding = HloSharding::Tuple(shape_tree); - VLOG(4) << " " << instruction->name() << " to sharding " - << tuple_sharding; - instruction->set_sharding(tuple_sharding); - ++assigned; - } - } else { - // If all the operand of the given instruction has the same single device - // assignment, assign that device to this instruction as well. - const HloSharding* common_sharding = nullptr; - for (const HloInstruction* operand : instruction->operands()) { - const HloSharding* operand_sharding = - GetOperandSharding(operand, domain, sharding); - if (operand_sharding != nullptr) { - if (common_sharding != nullptr && - *common_sharding != *operand_sharding) { - common_sharding = nullptr; - break; - } - common_sharding = operand_sharding; - } - } - if (common_sharding != nullptr) { - VLOG(4) << " " << instruction->name() << " to sharding " - << *common_sharding; - instruction->set_sharding(*common_sharding); - ++assigned; - } + // Take the sharding from the users. + TF_ASSIGN_OR_RETURN( + bool instruction_assigned, + ApplyShardingFromUsers(instruction, domain, domain_sharding)); + if (instruction_assigned) { + ++assigned; + VLOG(4) << " " << instruction->name() << " to sharding " + << instruction->sharding(); } } return assigned; @@ -262,84 +349,40 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain, return ApplyDomainSingleSharding(domain, *single_sharding); } VLOG(1) << "Assigning non-trivial sharding " << sharding; - for (;;) { - TF_ASSIGN_OR_RETURN(int64 assigned, - ApplyDomainShardingPass(domain, sharding)); - if (assigned == 0) { - break; - } - } + TF_RETURN_IF_ERROR(ApplyDomainShardingPass(domain, sharding).status()); + int64 unassigned = 0; for (HloInstruction* instruction : domain.instructions) { if (!instruction->has_sharding()) { LOG(WARNING) << "Unassigned instruction: " << instruction->ToString(); ++unassigned; + } else { + // Un-set sharding of tuples whose sub-sgardings are assigned to + // kUnassignedDevice. Indeed in case of doubt it is better to leave the + // entire tuple unassigned, and let the device placer decide for it. + if (instruction->sharding().UsesDevice(kUnassignedDevice)) { + TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape())) + << "Only tuples can have kUnassignedDevice sub shardings"; + instruction->clear_sharding(); + } } } // Should we error out if unassigned > 0? return Status::OK(); } -// Creates a kDomain instruction to be placed between instruction and operand. -// The kDomain instruction will be created only if the sharding differ between -// the instruction and the operand. -std::unique_ptr CreateDomain(HloInstruction* instruction, - HloInstruction* root, - HloInstruction* operand) { - const HloSharding* instruction_sharding = - instruction->has_sharding() ? &instruction->sharding() : nullptr; - const HloSharding* root_sharding = - root->has_sharding() ? &root->sharding() : nullptr; - // No need for domain if they both have no sharding. - if (instruction_sharding == nullptr && root_sharding == nullptr) { - return nullptr; - } - // No need for domain if they match. - if (instruction_sharding != nullptr && root_sharding != nullptr && - ShardingMatches(*instruction_sharding, *root_sharding)) { - return nullptr; - } - std::unique_ptr real_instruction_sharding; - std::unique_ptr real_operand_sharding; - if (instruction_sharding != nullptr) { - real_instruction_sharding = CloneShardingForDomain(*instruction_sharding); - } - if (root_sharding != nullptr) { - real_operand_sharding = CloneShardingForDomain(*root_sharding); - } - VLOG(3) << "Creating domain:"; - VLOG(3) << " Instruction: " << instruction->name(); - VLOG(3) << " Operand: " << operand->name(); - VLOG(3) << " User side sharding: " - << (real_instruction_sharding != nullptr - ? real_instruction_sharding->ToString() - : "None"); - VLOG(3) << " Operand side sharding: " - << (real_operand_sharding != nullptr - ? real_operand_sharding->ToString() - : "None"); - - std::unique_ptr operand_side_metadata = - absl::make_unique(std::move(real_operand_sharding)); - std::unique_ptr user_side_metadata = - absl::make_unique(std::move(real_instruction_sharding)); - return HloInstruction::CreateDomain(operand->shape(), operand, - std::move(operand_side_metadata), - std::move(user_side_metadata)); -} - -StatusOr> ExtractOriginalCommonSharding( +StatusOr> ExtractOriginalCommonSharding( tensorflow::gtl::ArraySlice instructions) { // If we are here, all the instructions being passed had the same sharding // (or no sharding), by the means of the ShardingMatches() API. // As such, no kDomain was inserted, and here we are asked to extract the // original common sharding. // All the instructions passed to this API are part of the same computation. - const HloSharding* sharding = nullptr; + std::shared_ptr sharding; for (HloInstruction* instruction : instructions) { if (instruction->has_sharding()) { if (sharding == nullptr) { - sharding = &instruction->sharding(); + sharding = instruction->sharding_ptr(); } else { TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding())) << "Sharding " << *sharding << " does not match the one in " @@ -348,10 +391,10 @@ StatusOr> ExtractOriginalCommonSharding( } } if (sharding == nullptr) { - return std::unique_ptr(); + return std::shared_ptr(); } VLOG(4) << "Extracted sharding is " << *sharding; - return CloneShardingForDomain(*sharding); + return CloneShardingForDomain(sharding); } } // namespace @@ -405,7 +448,7 @@ Status ShardingMetadata::NormalizeShardingDomain( TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding)); } } else { - TF_ASSIGN_OR_RETURN(std::unique_ptr sharding, + TF_ASSIGN_OR_RETURN(std::shared_ptr sharding, ExtractOriginalCommonSharding(domain.instructions)); if (sharding != nullptr) { VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString(); @@ -417,10 +460,75 @@ Status ShardingMetadata::NormalizeShardingDomain( return Status::OK(); } -std::unique_ptr CreateShardingDomain( - HloInstruction* instruction, HloInstruction* root, - HloInstruction* operand) { - return CreateDomain(instruction, root, operand); +// Creates a kDomain instruction to be placed between instruction and operand. +// The kDomain instruction will be created only if the sharding differ between +// the instruction and the operand. +HloInstruction* ShardingDomainCreator::operator()(HloInstruction* instruction, + HloInstruction* root, + HloInstruction* operand) { + auto instruction_sharding = instruction->sharding_ptr(); + auto root_sharding = root->sharding_ptr(); + // No need for domain if they both have no sharding. + if (instruction_sharding == nullptr && root_sharding == nullptr) { + return nullptr; + } + // No need for domain if they match. + if (instruction_sharding != nullptr && root_sharding != nullptr && + ShardingMatches(*instruction_sharding, *root_sharding)) { + return nullptr; + } + + if (instruction_sharding != nullptr) { + instruction_sharding = CloneShardingForDomain(instruction_sharding); + } + if (root_sharding != nullptr) { + root_sharding = CloneShardingForDomain(root_sharding); + } + + auto it = domain_cse_map_.find({operand, instruction_sharding}); + if (it != domain_cse_map_.end()) { + return it->second; + } + + VLOG(3) << "Creating domain:"; + VLOG(3) << " Instruction: " << instruction->name(); + VLOG(3) << " Operand: " << operand->name(); + VLOG(3) << " User side sharding: " + << (instruction_sharding != nullptr ? instruction_sharding->ToString() + : "None"); + VLOG(3) << " Operand side sharding: " + << (root_sharding != nullptr ? root_sharding->ToString() : "None"); + + HloInstruction* domain = + operand->parent()->AddInstruction(HloInstruction::CreateDomain( + operand->shape(), operand, + absl::make_unique(root_sharding), + absl::make_unique(instruction_sharding))); + domain_cse_map_.emplace(DomainCseMapKey{operand, instruction_sharding}, + domain); + return domain; +} + +bool ShardingDomainCreator::DomainCseMapKey::operator==( + const ShardingDomainCreator::DomainCseMapKey& other) const { + if (instruction != other.instruction) { + return false; + } + if (sharding == nullptr && other.sharding == nullptr) { + return true; + } + if (sharding == nullptr || other.sharding == nullptr) { + return false; + } + return *sharding == *other.sharding; +} + +size_t ShardingDomainCreator::DomainCseMapHasher::operator()( + const ShardingDomainCreator::DomainCseMapKey& key) const { + return tensorflow::Hash64Combine( + std::hash{}(key.instruction), + key.sharding ? key.sharding->Hash() + : static_cast(0x297814aaad196e6dULL)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h index dc258e40949d68be57aad955889d00a567db7e61..7a6b0d9abcbf1f8206654fc66e6dd99f82696556 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -27,12 +27,12 @@ namespace xla { // A DomainMetadata implementation that internally wraps a sharding attribute. class ShardingMetadata : public DomainMetadata { public: - explicit ShardingMetadata(std::unique_ptr sharding) + explicit ShardingMetadata(std::shared_ptr sharding) : sharding_(std::move(sharding)) {} std::unique_ptr Clone() const override; - tensorflow::StringPiece Kind() const override { return KindName(); } + absl::string_view Kind() const override { return KindName(); } bool Matches(const DomainMetadata& other) const override; @@ -40,7 +40,7 @@ class ShardingMetadata : public DomainMetadata { const HloSharding* sharding() const { return sharding_.get(); } - static tensorflow::StringPiece KindName() { return "sharding"; } + static absl::string_view KindName() { return "sharding"; } static StatusOr ToShardingMetadata( const DomainMetadata* metadata); @@ -55,15 +55,33 @@ class ShardingMetadata : public DomainMetadata { const DomainMetadata* metadata); private: - std::unique_ptr sharding_; + std::shared_ptr sharding_; }; -// Given an HLO graph edge between instruction and one of its operands, creates -// a ShardingMetadata based kDomain instruction if the sharding between -// instruction and parent changes. Returns nullptr if there is no need for a -// domain separation. -std::unique_ptr CreateShardingDomain( - HloInstruction* instruction, HloInstruction* root, HloInstruction* operand); +// If the sharding between root and instruction changes then returns a +// ShardingMetadata based kDomain instruction what can be used to separate +// operand and instruction. +// Returns nullptr if there is no need for a domain separation. +class ShardingDomainCreator { + public: + HloInstruction* operator()(HloInstruction* instruction, HloInstruction* root, + HloInstruction* operand); + + private: + // Map from instruction and user sharding to domain users to CSE identical + // domains. + struct DomainCseMapKey { + const HloInstruction* instruction; + std::shared_ptr sharding; + + bool operator==(const DomainCseMapKey& other) const; + }; + struct DomainCseMapHasher { + size_t operator()(const DomainCseMapKey& key) const; + }; + std::unordered_map + domain_cse_map_; +}; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 45fc300fcaf5a301fe11768da77a7c0907919c39..2341f8ada0dba4e5a5f39e991498a2ee44303dbd 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -115,6 +115,13 @@ TEST_F(HloShardingTest, Tile) { } } +// Tests that empty tuple is supported. +TEST_F(HloShardingTest, EmptySingleTuple) { + HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}), + HloSharding::AssignDevice(0)); + EXPECT_TRUE(sharding.ExtractSingleSharding()); +} + TEST_F(HloShardingTest, NestedTuple) { // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6]) Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({ diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h index 2ef38821af632180714911c0ff22731fd559b915..d1cf644f8273e632e2952cca0da749616e9b6233 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h @@ -24,7 +24,7 @@ namespace xla { // one arbitrarily to use and delete the others. class HloSubcomputationUnification : public HloPassInterface { public: - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "subcomputation-unification"; } diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index b78bfa0cdf4db605576fa11e18ce6c654c6a0b6d..487653344976a10e18ba667085525ba1ecbb8612 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -21,28 +23,25 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" - -using ::tensorflow::GraphDef; -using ::tensorflow::NodeDef; -using ::tensorflow::TensorShapeProto; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; -using ::tensorflow::str_util::Join; namespace xla { namespace hlo_graph_dumper { namespace { +using absl::StrAppend; +using absl::StrCat; +using tensorflow::GraphDef; +using tensorflow::NodeDef; +using tensorflow::TensorShapeProto; + string GetOpDefName(const HloInstruction* instruction) { string name = StrCat("hlo-", HloOpcodeString(instruction->opcode())); - tensorflow::str_util::TitlecaseString(&name, "-"); + tensorflow::str_util::TitlecaseString(&name, "-"); // non-absl ok name.erase(std::remove(name.begin(), name.end(), '-'), name.end()); if (instruction->opcode() == HloOpcode::kFusion) { string fusion_name = ToString(instruction->fusion_kind()); - StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1)); + StrAppend(&name, absl::string_view(fusion_name).substr(1)); } return name; } @@ -166,7 +165,9 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); } else { layout_string = StrCat( - "{", Join(LayoutUtil::MinorToMajor(instruction->shape()), ","), "}"); + "{", + absl::StrJoin(LayoutUtil::MinorToMajor(instruction->shape()), ","), + "}"); } attrs["layout"].set_s(layout_string); } diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 14703aaf64bdbfee4e737331dd47d5def95e1d4b..e0c13261772cf7eb9f71cd02182dc3166ba172ed 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -30,16 +32,13 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; const Shape& HloPosition::shape() const { return ShapeUtil::GetSubshape(instruction->shape(), index); @@ -216,10 +215,11 @@ void HloValueSet::SortAndUniquifyValues() { } string HloValueSet::ToString() const { - return StrCat("HloValueSet: ", - Join(values_, ", ", [](string* result, const HloValue* value) { - result->append(value->ToShortString()); - })); + return StrCat( + "HloValueSet: ", + absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) { + result->append(value->ToShortString()); + })); } bool HloValueSet::AssignUnionOf( diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 7acf58e25291ae19ba5790bcdd0a18207419dddf..f1b29c255970b1f0838dc5ad8214192bc536b7e3 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -15,11 +15,13 @@ limitations under the License. #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -115,6 +117,11 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { ShapeInference::InferAllToAllTupleShape(operand_shapes)); } +Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { + return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( + hlo->operand(0)->shape())); +} + Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), @@ -122,39 +129,32 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -namespace { - -Status CheckIsTokenOperand(const HloInstruction* instruction, - int64 operand_no) { +Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no) { const HloInstruction* token = instruction->operand(operand_no); if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) { return InternalError( - "Expected operand %lld to be token-shaped, actual shape is " + "Expected operand %d to be token-shaped, actual shape is " "%s:\n%s", - operand_no, ShapeUtil::HumanString(token->shape()).c_str(), - instruction->ToString().c_str()); + operand_no, StringifyShape(token->shape()), instruction->ToString()); } return Status::OK(); } -Status CheckOperandAndParameter(const HloInstruction* instruction, - int64 operand_number, - const HloComputation* computation, - int64 parameter_number) { +Status ShapeVerifier::CheckOperandAndParameter( + const HloInstruction* instruction, int64 operand_number, + const HloComputation* computation, int64 parameter_number) { const HloInstruction* operand = instruction->operand(operand_number); const HloInstruction* parameter = computation->parameter_instruction(parameter_number); - if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) { + if (!ShapesSame(operand->shape(), parameter->shape())) { return InternalError("Operand %s shape does not match parameter's %s in %s", - operand->ToString().c_str(), - parameter->ToString().c_str(), - instruction->ToString().c_str()); + operand->ToString(), parameter->ToString(), + instruction->ToString()); } return Status::OK(); } -} // namespace - Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { HloInfeedInstruction* infeed = Cast(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); @@ -171,14 +171,12 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { // Outfeed has a separate shape field for the value which is outfed to the // host. The shape of the instruction itself is always a token. - if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), - outfeed->operand(0)->shape())) { + if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { return InternalError( - "Expected outfeed shape to be compatible with operand's shape %s, " + "Expected outfeed shape to be equal to operand's shape %s, " "actual shape is %s:\n%s", - ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), - ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(), - outfeed->ToString().c_str()); + StringifyShape(outfeed->operand(0)->shape()), + StringifyShape(outfeed->outfeed_shape()), outfeed->ToString()); } return CheckShape(outfeed, ShapeUtil::MakeTokenShape()); } @@ -196,7 +194,7 @@ bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0, Status ShapeVerifier::HandleRng(HloInstruction* instruction) { if (instruction->operand_count() != 2) { return InternalError("Expected two operands for Rng instruction: %s", - instruction->ToString().c_str()); + instruction->ToString()); } const Shape& shape_0 = instruction->operand(0)->shape(); @@ -204,14 +202,14 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) { return InternalError( "Expected scalar types for the two operands of Rng instruction: %s", - instruction->ToString().c_str()); + instruction->ToString()); } if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) { return InternalError( "Expected compatible element types for the result and the two operands" " of Rng instruction: %s", - instruction->ToString().c_str()); + instruction->ToString()); } PrimitiveType element_type = shape_0.element_type(); @@ -224,7 +222,7 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { "Element type not supported." " Expected element to be of floating point type, integral type or" " predicate type for RngUniform: %s", - instruction->ToString().c_str()); + instruction->ToString()); } break; @@ -233,13 +231,13 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { return InternalError( "Element type not supported." " Expected element to be FloatingPointType for RngNormal: %s", - instruction->ToString().c_str()); + instruction->ToString()); } break; default: return InternalError( "Invalid Rng distribution %s", - RandomDistribution_Name(instruction->random_distribution()).c_str()); + RandomDistribution_Name(instruction->random_distribution())); } return Status::OK(); @@ -258,8 +256,8 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) { return InternalError( "Expected sort to have to have the same dimensions for the keys and " "the values. Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(), - ShapeUtil::HumanString(sort->operand(1)->shape()).c_str()); + StringifyShape(sort->operand(0)->shape()), + StringifyShape(sort->operand(1)->shape())); } return CheckVariadicShape(sort); } @@ -268,10 +266,18 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) { return CheckShape(constant, constant->literal().shape()); } -Status ShapeVerifier::HandleIota(HloInstruction* iota) { - return ShapeUtil::Rank(iota->shape()) == 1 - ? Status::OK() - : InternalError("Iota only supports arrays of rank 1."); +Status ShapeVerifier::HandleIota(HloInstruction* instruction) { + auto* iota = Cast(instruction); + const int64 rank = ShapeUtil::Rank(iota->shape()); + if (rank == 0) { + return InternalError("Iota does not support scalars."); + } + int64 iota_dimension = iota->iota_dimension(); + if (iota_dimension >= rank) { + return InternalError( + "The iota dimension cannot go beyond the operation rank."); + } + return Status::OK(); } Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { @@ -333,7 +339,18 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { return Status::OK(); } -Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); } +Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { + for (HloInstruction* fused_param : fusion->fused_parameters()) { + int64 param_no = fused_param->parameter_number(); + if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) { + return InternalError( + "Shape mismatch between parameter number %d and its operand in " + "%s.", + param_no, fusion->ToString().c_str()); + } + } + return Status::OK(); +} Status ShapeVerifier::HandleCall(HloInstruction* call) { for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) { @@ -415,12 +432,11 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0)); const Shape& conditional_shape = xla_while->while_condition()->root_instruction()->shape(); - if (!ShapeUtil::Compatible(conditional_shape, - ShapeUtil::MakeShape(PRED, {}))) { + if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) { return InternalError( "Conditional computation shape does not lead to a scalar predicate " "shape: %s", - ShapeUtil::HumanString(conditional_shape).c_str()); + StringifyShape(conditional_shape)); } // The shape of kWhile should match the shape of the body computation it // calls. @@ -551,7 +567,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { return InternalError( "Seen floating point types of different precisions in " "%s, but mixed precision is disallowed.", - instruction->ToString().c_str()); + instruction->ToString()); } return Status::OK(); })); @@ -598,53 +614,51 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } // Check if the output shape matches the expected shape. - bool compatible; + // // We treat BF16 and F32 as compatible types if mixed precision is allowed, // but only when the instruction defines the BF16/F32 buffer. - switch (instruction->opcode()) { - case HloOpcode::kTupleSelect: - // TupleSelect only defines the top-level buffer, which in this case is - // the tuple, so we cannot allow mixed precision. - compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); - break; - case HloOpcode::kGetTupleElement: - case HloOpcode::kTuple: - // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed - // precision is disallowed. - case HloOpcode::kConstant: - case HloOpcode::kBitcast: - case HloOpcode::kBitcastConvert: - case HloOpcode::kCall: - case HloOpcode::kConditional: - case HloOpcode::kConvert: - case HloOpcode::kCustomCall: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: - case HloOpcode::kParameter: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kWhile: - // The above opcodes should match the expected shapes exactly. - compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); - break; - default: - if (allow_mixed_precision_) { - compatible = ShapeUtil::CompatibleIgnoringFpPrecision( - instruction->shape(), inferred_shape); - } else { - compatible = - ShapeUtil::Compatible(instruction->shape(), inferred_shape); - } - } - if (!compatible) { + bool equal = [&] { + switch (instruction->opcode()) { + // The opcodes below can't have implicit layout conversions, nor can they + // implicitly transform f32 -> bf16. Fundamentally these are either + // reinterpreting existing data (e.g. kBitcast) or shuffling data around + // without modifying it (e.g. kGetTupleElement, kTupleSelect). + case HloOpcode::kBitcast: + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConstant: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: + case HloOpcode::kWhile: + return ShapesSame(instruction->shape(), inferred_shape); + + // We allow arbitrary layout and f32->bf16 transformations on all other + // instructions, although this may be made more strict pending discussion + // in b/112709536. + default: + if (allow_mixed_precision_) { + return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(), + inferred_shape); + } else { + return ShapeUtil::Compatible(instruction->shape(), inferred_shape); + } + } + }(); + if (!equal) { return InternalError( - "Expected instruction to have shape compatible with %s, actual " + "Expected instruction to have shape equal to %s, actual " "shape is %s:\n%s", - ShapeUtil::HumanString(inferred_shape).c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - instruction->ToString().c_str()); + StringifyShape(inferred_shape), StringifyShape(instruction->shape()), + instruction->ToString()); } return Status::OK(); } @@ -688,10 +702,10 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { string ComputationsToString( tensorflow::gtl::ArraySlice computations) { - return tensorflow::str_util::Join( - computations, ",", [](string* s, const HloComputation* computation) { - s->append(computation->name()); - }); + return absl::StrJoin(computations, ",", + [](string* s, const HloComputation* computation) { + s->append(computation->name()); + }); } // Verifies various invariants about the structure of the HLO: @@ -709,23 +723,23 @@ Status VerifyHloStructure(HloModule* module) { for (const HloComputation* computation : module->computations()) { if (computation->parent() == nullptr) { return InternalError("Computation %s has a null parent pointer", - computation->name().c_str()); + computation->name()); } if (computation->parent() != module) { return InternalError( "Computation %s parent() does not point to parent module", - computation->name().c_str()); + computation->name()); } for (const HloInstruction* instruction : computation->instructions()) { if (instruction->parent() == nullptr) { return InternalError("Instruction %s has a null parent pointer", - instruction->name().c_str()); + instruction->name()); } if (instruction->parent() != computation) { return InternalError( "Instruction %s parent() does not point to parent computation", - instruction->name().c_str()); + instruction->name()); } } } @@ -742,9 +756,8 @@ Status VerifyHloStructure(HloModule* module) { return InternalError( "Operand %d (%s) of instruction %s is in a different " "computation: %s vs %s", - i, operand->name().c_str(), instruction->name().c_str(), - operand->parent()->name().c_str(), - instruction->parent()->name().c_str()); + i, operand->name(), instruction->name(), + operand->parent()->name(), instruction->parent()->name()); } } } @@ -760,7 +773,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { "Instruction of fused computation does not match expected " "instruction " "%s.", - fusion->ToString().c_str()); + fusion->ToString()); } // Fused root instruction and fused parameters must all be owned by the @@ -774,7 +787,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (fused_root == instruction) { if (root_owned) { return InternalError("Root appears more than once in %s.", - fusion->ToString().c_str()); + fusion->ToString()); } root_owned = true; } @@ -782,7 +795,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { return InternalError("Parameter appears more than once in %s.", - fusion->ToString().c_str()); + fusion->ToString()); } parameter_owned[i] = true; } @@ -790,20 +803,19 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { } if (!root_owned) { return InternalError("Root not found in computation of %s.", - fusion->ToString().c_str()); + fusion->ToString()); } // Make sure all the parameter_owned entries are set for (int i = 0; i < parameter_owned.size(); i++) { if (!parameter_owned[i]) { return InternalError("Parameter %d not found in computation of %s.", i, - fusion->ToString().c_str()); + fusion->ToString()); } } // Fused root must have no users. if (fused_root->user_count() != 0) { - return InternalError("Root of %s may not have users.", - fusion->ToString().c_str()); + return InternalError("Root of %s may not have users.", fusion->ToString()); } // All uses of fused instructions must be in the fusion computation, and @@ -813,54 +825,46 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (instruction != fused_root) { if (instruction->user_count() == 0) { return InternalError("Non-root instruction %s in %s must have users.", - instruction->ToString().c_str(), - fusion->ToString().c_str()); + instruction->ToString(), fusion->ToString()); } for (auto& user : instruction->users()) { if (fused_computation != user->parent()) { return InternalError( "Non-root instruction %s in %s may not have external users.", - instruction->ToString().c_str(), fusion->ToString().c_str()); + instruction->ToString(), fusion->ToString()); } } } } // Fused parameter instructions must be numbered contiguously and match up - // (shapes compatible) with their respective operand. + // (shapes equal) with their respective operand. CHECK_EQ(fusion->operands().size(), fused_parameters.size()); std::vector parameter_numbers(fused_parameters.size(), false); for (auto fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); if (param_no < 0) { - return InternalError("Unexpected negative parameter number %lld in %s.", - param_no, fusion->ToString().c_str()); + return InternalError("Unexpected negative parameter number %d in %s.", + param_no, fusion->ToString()); } if (param_no >= fused_parameters.size()) { return InternalError( - "Unexpected parameter number %lld in %s: higher then number of " + "Unexpected parameter number %d in %s: higher then number of " "parameters %lu.", - param_no, fusion->ToString().c_str(), fused_parameters.size()); + param_no, fusion->ToString(), fused_parameters.size()); } if (parameter_numbers[param_no]) { return InternalError( - "Did not expect parameter number %lld more than once in %s.", - param_no, fusion->ToString().c_str()); + "Did not expect parameter number %d more than once in %s.", param_no, + fusion->ToString()); } parameter_numbers[param_no] = true; - if (!ShapeUtil::Compatible(fused_param->shape(), - fusion->operand(param_no)->shape())) { - return InternalError( - "Shape mismatch between parameter number %lld and its operand in " - "%s.", - param_no, fusion->ToString().c_str()); - } } // Make sure all the parameter_numbers entries were seen. for (int i = 0; i < parameter_numbers.size(); i++) { if (!parameter_numbers[i]) { return InternalError("Did not see parameter number %d in %s.", i, - fusion->ToString().c_str()); + fusion->ToString()); } } @@ -875,18 +879,18 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { auto* while_body = instruction->while_body(); if (while_cond->num_parameters() != 1) { return FailedPrecondition( - "While condition must have exactly 1 parameter; had %lld : %s", - while_cond->num_parameters(), while_cond->ToString().c_str()); + "While condition must have exactly 1 parameter; had %d : %s", + while_cond->num_parameters(), while_cond->ToString()); } if (while_body->num_parameters() != 1) { return FailedPrecondition( - "While body must have exactly 1 parameter; had %lld : %s", - while_body->num_parameters(), while_body->ToString().c_str()); + "While body must have exactly 1 parameter; had %d : %s", + while_body->num_parameters(), while_body->ToString()); } if (instruction->operand_count() != 1) { return FailedPrecondition( - "While loop must have exactly one operand; had %lld : %s", - instruction->operand_count(), instruction->ToString().c_str()); + "While loop must have exactly one operand; had %d : %s", + instruction->operand_count(), instruction->ToString()); } return Status::OK(); } @@ -894,16 +898,14 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) { if (instruction->true_computation()->num_parameters() != 1) { return FailedPrecondition( - "True computation %s of %s must have 1 parameter insted of %lld", - instruction->true_computation()->name().c_str(), - instruction->ToString().c_str(), + "True computation %s of %s must have 1 parameter insted of %d", + instruction->true_computation()->name(), instruction->ToString(), instruction->true_computation()->num_parameters()); } if (instruction->false_computation()->num_parameters() != 1) { return FailedPrecondition( - "False computation %s of %s must have 1 parameter insted of %lld", - instruction->false_computation()->name().c_str(), - instruction->ToString().c_str(), + "False computation %s of %s must have 1 parameter insted of %d", + instruction->false_computation()->name(), instruction->ToString(), instruction->false_computation()->num_parameters()); } return Status::OK(); @@ -916,11 +918,11 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { return FailedPrecondition( "Implicit broadcast is not allowed in HLO." - "Found non-compatible shapes for instruction %s.\n" + "Found different shapes for instruction %s.\n" "output: %s\noperand: %s\n", - HloOpcodeString(instruction->opcode()).c_str(), - ShapeUtil::HumanString(out_shape).c_str(), - ShapeUtil::HumanString(operand_shape).c_str()); + HloOpcodeString(instruction->opcode()), + ShapeUtil::HumanString(out_shape), + ShapeUtil::HumanString(operand_shape)); } } return Status::OK(); @@ -951,7 +953,7 @@ Status VerifyEntryAndExitShapes(const HloModule& module) { if (ShapeContainsToken(param->shape())) { return InternalError( "Entry parameter %d is or contains a token shape: %s", i, - ShapeUtil::HumanString(param->shape()).c_str()); + ShapeUtil::HumanString(param->shape())); } } return Status::OK(); @@ -963,9 +965,9 @@ Status CheckSameChannel(const HloInstruction* instr1, if (instr1->channel_id() != instr2->channel_id()) { return InternalError( "Expected to have the same channel id, actual channel ids are: %s " - "(%lld), %s (%lld)", - instr1->ToString().c_str(), instr1->channel_id(), - instr2->ToString().c_str(), instr2->channel_id()); + "(%d), %s (%d)", + instr1->ToString(), instr1->channel_id(), instr2->ToString(), + instr2->channel_id()); } return Status::OK(); } @@ -986,7 +988,7 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1, "Expected instructions to have the same is-host-transfer property: " "%s, " "%s ", - instr1->ToString().c_str(), instr2->ToString().c_str()); + instr1->ToString(), instr2->ToString()); } return Status::OK(); } @@ -1003,12 +1005,12 @@ Status VerifySendsAndRecvs(const HloModule& module) { host_channels.insert({sendrecv->channel_id(), sendrecv}); if (!it_inserted.second) { return FailedPrecondition( - "Channel %lld is used for multiple host send/recv instructions: " + "Channel %d is used for multiple host send/recv instructions: " "%s " "and " "%s", - sendrecv->channel_id(), sendrecv->ToString().c_str(), - it_inserted.first->second->ToString().c_str()); + sendrecv->channel_id(), sendrecv->ToString(), + it_inserted.first->second->ToString()); } } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 523bf4d70cd335a969a4d46f92408caf470db8a6..42e3027bf14a827bd0a791510c2d9c107d989ab9 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -28,9 +28,9 @@ namespace xla { // TODO(b/26024837): Check output shape for all instruction types. class ShapeVerifier : public DfsHloVisitor { public: - explicit ShapeVerifier() : allow_mixed_precision_(false) {} - explicit ShapeVerifier(bool allow_mixed_precision) - : allow_mixed_precision_(allow_mixed_precision) {} + explicit ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision) + : layout_sensitive_(layout_sensitive), + allow_mixed_precision_(allow_mixed_precision) {} Status HandleElementwiseUnary(HloInstruction* hlo) override; Status HandleElementwiseBinary(HloInstruction* hlo) override; @@ -47,6 +47,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleFft(HloInstruction* fft) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; + Status HandleCollectivePermute(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleInfeed(HloInstruction*) override; Status HandleOutfeed(HloInstruction*) override; @@ -106,13 +107,42 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckVariadicShape(const HloInstruction* instruction); private: - // Return true if the shapes of the two operands have the same element type, - // and the result shape either has the same element type as the operand - // shapes or mixed precision is allowed and the result shape and the operand - // shapes have floating point element types. + // Helpers that switch on layout_sensitive_. + bool ShapesSame(const Shape& a, const Shape& b) { + return layout_sensitive_ ? ShapeUtil::Equal(a, b) + : ShapeUtil::Compatible(a, b); + } + bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) { + return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b) + : ShapeUtil::CompatibleIgnoringFpPrecision(a, b); + } + string StringifyShape(const Shape& s) { + return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s) + : ShapeUtil::HumanString(s); + } + + // Checks that the given operand of the given instruction is of type TOKEN. + Status CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no); + + // Checks that the shape of the given operand of the given instruction matches + // the given parameter of the given computation. + Status CheckOperandAndParameter(const HloInstruction* instruction, + int64 operand_number, + const HloComputation* computation, + int64 parameter_number); + + // Returns true if the shapes of the two operands have the same element type, + // and the result shape either has the same element type as the operand shapes + // or mixed precision is allowed and the result shape and the operand shapes + // have floating point element types. bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1, const Shape& result_shape); + // If the verifier is layout-sensitive, shapes must be equal to what's + // expected. Otherwise, the shapes must simply be compatible. + bool layout_sensitive_; + // Whether the inputs and output of an instruction can contain both F32s and // BF16s. Tuples that include both F32s and BF16s are allowed regardless of // this flag. @@ -125,14 +155,10 @@ class HloVerifier : public HloPassInterface { public: using ShapeVerifierFactory = std::function()>; - // Uses standard shape inference. - explicit HloVerifier() - : shape_verifier_factory_( - [] { return absl::make_unique(false); }) {} - - explicit HloVerifier(bool allow_mixed_precision) - : shape_verifier_factory_([allow_mixed_precision] { - return absl::make_unique(allow_mixed_precision); + explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision) + : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] { + return absl::make_unique(layout_sensitive, + allow_mixed_precision); }) {} // Uses custom shape verification. @@ -140,10 +166,9 @@ class HloVerifier : public HloPassInterface { : shape_verifier_factory_(std::move(shape_verifier_factory)) {} ~HloVerifier() override = default; - tensorflow::StringPiece name() const override { return "verifier"; } + absl::string_view name() const override { return "verifier"; } - // Note: always returns false (no instructions are ever modified by this - // pass). + // Never returns true; no instructions are ever modified by this pass. StatusOr Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index d764964f3c3dc58a54bd0307f8b625076c14f3e5..70b741353d043bbe6bcc6d4bf55e9cf9d0d8d3c3 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -37,13 +37,15 @@ using ::testing::HasSubstr; class HloVerifierTest : public HloTestBase { public: HloVerifierTest() - : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/false) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/false) {} }; class HloVerifierTestAllowMixedPrecision : public HloTestBase { public: HloVerifierTestAllowMixedPrecision() - : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} }; TEST_F(HloVerifierTest, NullInstructionParent) { diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index bb5b40a8a87c5eab5a5b1599581a81bbd064511b..e76b93107c923b41666f6b0a388dda143a8cb50a 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -14,27 +14,27 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/metric_table_report.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { -using tensorflow::strings::Appendf; +using absl::StrAppend; +using absl::StrAppendFormat; +using absl::StrCat; +using absl::StrFormat; using tensorflow::strings::HumanReadableElapsedTime; using tensorflow::strings::HumanReadableNumBytes; -using tensorflow::strings::Printf; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; string HumanReadableProfileBuilder::ToString() const { string s; - Appendf(&s, "Execution profile for %s: (%s @ f_nom)\n", - computation_name_.c_str(), - HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str()); + StrAppendFormat(&s, "Execution profile for %s: (%s @ f_nom)\n", + computation_name_, + HumanReadableElapsedTime(CyclesToSeconds(total_cycles_))); int64 cumulative_cycles = 0; auto print_op = [&](const OpInfo& op, bool is_total = false) { @@ -56,7 +56,7 @@ string HumanReadableProfileBuilder::ToString() const { if (op.bytes_accessed > op.cycles) { bytes_per_cycle = StrCat(HumanReadableNumBytes(bpc), "/cycle"); } else { - bytes_per_cycle = Printf("%.3fB/cycle", bpc); + bytes_per_cycle = StrFormat("%.3fB/cycle", bpc); } } @@ -77,27 +77,24 @@ string HumanReadableProfileBuilder::ToString() const { // columns in the output. cycles_percent_str = "100.% 100Σ"; } else { - cycles_percent_str = - Printf("%5.2f%% %2.0fΣ", cycles_percent, cumulative_cycles_percent); + cycles_percent_str = StrFormat("%5.2f%% %2.0fΣ", cycles_percent, + cumulative_cycles_percent); } double nsecs = op.cycles / clock_rate_ghz_; - Appendf( + StrAppendFormat( &s, - "%15lld cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: " + "%15d cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: " "%16s :: %s\n", - op.cycles, cycles_percent_str.c_str(), CyclesToMicroseconds(op.cycles), + op.cycles, cycles_percent_str, CyclesToMicroseconds(op.cycles), op.optimal_seconds < 0 ? "" - : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), - op.flop_count <= 0 - ? "" - : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), + : StrFormat("(%12.1f optimal)", op.optimal_seconds * 1e6), + op.flop_count <= 0 ? "" : HumanReadableNumFlops(op.flop_count, nsecs), op.transcendental_count <= 0 ? "" - : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs) - .c_str(), - bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str()); + : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs), + bytes_per_sec, bytes_per_cycle, op.name); }; float optimal_seconds_sum = 0.0; diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h index 6f56c3aa82e9d1c942fd67ff7a5948cf2e54370d..925111fa1f1e48650b0089f402d92e431043eabe 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -29,10 +29,10 @@ namespace xla { // computation, suitable for consumption by humans. class HumanReadableProfileBuilder { public: - explicit HumanReadableProfileBuilder(tensorflow::StringPiece computation_name, + explicit HumanReadableProfileBuilder(absl::string_view computation_name, int64 total_cycles, double clock_rate_ghz) - : computation_name_(std::string(computation_name)), + : computation_name_(computation_name), total_cycles_(total_cycles), clock_rate_ghz_(clock_rate_ghz) { CHECK_GE(clock_rate_ghz, 1e-9); @@ -43,15 +43,13 @@ class HumanReadableProfileBuilder { // Adds an operation to the profile. If you don't know the number of // floating-point ops or bytes touched by the op, or if you don't know how // fast it would run optimally, pass -1 for that param. - void AddOp(tensorflow::StringPiece op_name, - tensorflow::StringPiece short_name, - tensorflow::StringPiece category, int64 cycles, int64 flop_count, + void AddOp(absl::string_view op_name, absl::string_view short_name, + absl::string_view category, int64 cycles, int64 flop_count, int64 transcendental_count, int64 bytes_accessed, float optimal_seconds) { - op_infos_.push_back({std::string(op_name), std::string(short_name), - std::string(category), cycles, flop_count, - transcendental_count, bytes_accessed, - optimal_seconds}); + op_infos_.push_back({string(op_name), string(short_name), string(category), + cycles, flop_count, transcendental_count, + bytes_accessed, optimal_seconds}); } // Gets the human-readable profile. diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h index aa325dc8a353c5bfbfded0c2774c66bfcc71c9cb..85bb4a8b2450a48d461f1d84e0609a38a6818d9c 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h @@ -30,7 +30,7 @@ class ImplicitBroadcastRemover : public HloPassInterface { ImplicitBroadcastRemover() {} ~ImplicitBroadcastRemover() override {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "implicit-broadcast-remover"; } diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc index f85d31d5225b8012b68f851b2bfec219d736ba0d..df88587492e256b5a4176971b2f443fda8f43421 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc @@ -26,6 +26,11 @@ namespace xla { namespace { class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase { + public: + ImplicitBroadcastRemoverTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: ImplicitBroadcastRemover remover_; }; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 256c8e5573a940b8a49b9ad9a4d10c5049f5dacc..43ef30d1eb645b5d12c1776f8fef28d00452349c 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -17,12 +17,13 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace gtl = ::tensorflow::gtl; @@ -33,32 +34,30 @@ using UnknownArray = Analysis::UnknownArray; using ConstantArray = Analysis::ConstantArray; using ReshapedArray = Analysis::ReshapedArray; using ScalarIndexedArray = Analysis::ScalarIndexedArray; +using absl::StrJoin; using tensorflow::gtl::ArraySlice; -using tensorflow::str_util::Join; } // namespace string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { switch (root->kind()) { case Array::kUnknown: { auto* unknown_tensor = root->as(); - return tensorflow::strings::StrCat("%", - unknown_tensor->instruction().name()); + return absl::StrCat("%", unknown_tensor->instruction().name()); } case Array::kConstant: { if (print_constants) { string contents = root->as()->literal()->ToString(); - return tensorflow::strings::StrCat( - "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents, - ")"); + return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()), + " ", contents, ")"); } - return tensorflow::strings::StrCat( - "(constant ", ShapeUtil::HumanString(root->shape()), ")"); + return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()), + ")"); } case Array::kReshaped: { ReshapedArray* reshaped_array = root->as(); - return tensorflow::strings::StrCat( + return absl::StrCat( "(reshape ", ToString(reshaped_array->operand(), print_constants), " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")"); } @@ -69,11 +68,11 @@ string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { string name = root->kind() == Array::kScalarIndexedConstant ? "scalar-indexed-const" : "scalar-indexed"; - return tensorflow::strings::StrCat( + return absl::StrCat( "(", name, " ", ToString(indexed_array->source(), print_constants), " ", ToString(indexed_array->indices(), print_constants), " ", indexed_array->source_dim(), "->[", - Join(indexed_array->output_dims(), ","), "])"); + StrJoin(indexed_array->output_dims(), ","), "])"); } } } @@ -379,8 +378,8 @@ std::vector ComputeReshapePassthroughDimPairs( CHECK_NE(candidate_operand_dim, 0) << "result_dim = " << result_dim << ", result_subarray_size = " << result_subarray_size - << ", result_shape = [" << Join(result_shape, ",") << "]" - << ", operand_shape = [" << Join(operand_shape, ",") << "]"; + << ", result_shape = [" << StrJoin(result_shape, ",") << "]" + << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]"; if (candidate_operand_dim != -1 && result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { @@ -396,12 +395,13 @@ std::vector ComputeReshapePassthroughDimPairs( std::vector result_strings; absl::c_transform(result, std::back_inserter(result_strings), [](ReshapePassthroughDimPair value) { - return tensorflow::strings::StrCat( - value.result_dim, "->", value.operand_dim); + return absl::StrCat(value.result_dim, "->", + value.operand_dim); }); - VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to [" - << Join(result_shape, ",") << "] passthrough indices are [" - << Join(result_strings, ",") << "] (legend: `result`->`operand`)"; + VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to [" + << StrJoin(result_shape, ",") << "] passthrough indices are [" + << StrJoin(result_strings, ",") + << "] (legend: `result`->`operand`)"; } DCHECK(absl::c_is_sorted( @@ -443,7 +443,7 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, ArraySlice result_shape, int64 source_passthrough_dim) { VLOG(3) << "FindSourcePositionForPassthroughResultDim([" - << Join(operand_shape, ",") << "], [" << Join(result_shape, ",") + << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",") << "], " << source_passthrough_dim << ")"; int64 indexed_source_subarray_size = @@ -755,9 +755,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( if (source_dim_for_new_scalar_indexed_node == -1) { VLOG(3) << "Could not compute the source dim for the new scalar indexed " "node: scalar_indexed_source_shape = [" - << Join(scalar_indexed_source_shape.dimensions(), ",") + << StrJoin(scalar_indexed_source_shape.dimensions(), ",") << "] and new_scalar_indexed_source_shape = [" - << Join(new_scalar_indexed_source_shape, ",") << "]"; + << StrJoin(new_scalar_indexed_source_shape, ",") << "]"; return nullptr; } @@ -997,8 +997,7 @@ absl::optional GetOnlyNonContractingNonBatchDim( // `contracting_dims` and `batch_dims` are the contracting and batch dimensions // of whatever operand `indexed_array` is to the dot (LHS or RHS). bool CanFoldDotIntoIndexedArray( - tensorflow::StringPiece tag, - Analysis::ScalarIndexedConstantArray* indexed_array, + absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array, ArraySlice contracting_dims, ArraySlice batch_dims) { absl::optional non_contracting_non_batch_dim = GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()), @@ -1135,7 +1134,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForDot( return nullptr; } -tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const { +absl::string_view IndexedArrayAnalysisPrinterPass::name() const { return "indexed-array-analysis-printer-pass"; } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index 675eb31d2666b52e21394a06ff95e7dc7cd1987a..3fa7d749e1984cc5d7249499e304593b5413cfe2 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -371,7 +371,7 @@ class IndexedArrayAnalysis { // unconditionally add to the regular HLO pass pipeline. class IndexedArrayAnalysisPrinterPass : public HloPassInterface { public: - tensorflow::StringPiece name() const override; + absl::string_view name() const override; StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 97052edf7d783491888cad3f57621e4cd6b045bc..c34c32f7d3361efbfca1fdfe5c286a4c03b5dc60 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -22,6 +22,11 @@ limitations under the License. namespace xla { namespace { class IndexedArrayAnalysisTest : public HloVerifiedTestBase { + public: + IndexedArrayAnalysisTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: void AssertArrayForRootExpressionIs(const string& hlo_text, const string& root_expression) { @@ -634,9 +639,9 @@ ENTRY main { AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( (scalar-indexed-const (constant f32[3,4] f32[3,4] { - { 0.761594176, 0.964027584, 0.995054781, 0.999329329 }, - { 0.761594176, 0.995054781, 0.964027584, 0.999329329 }, - { 0.999329329, 0.995054781, 0.964027584, 0.761594176 } + { 0.761594, 0.964028, 0.995055, 0.999329 }, + { 0.761594, 0.995055, 0.964028, 0.999329 }, + { 0.999329, 0.995055, 0.964028, 0.761594 } }) %indices 0->[0]))"); } diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h index a523811f6c141a7dc24b1c88897d82d046aa1a2d..efa8ed3abcc6cd7cd8d31ec2170eae8752988c09 100644 --- a/tensorflow/compiler/xla/service/inliner.h +++ b/tensorflow/compiler/xla/service/inliner.h @@ -27,7 +27,7 @@ namespace xla { class Inliner : public HloPassInterface { public: ~Inliner() override = default; - tensorflow::StringPiece name() const override { return "inline"; } + absl::string_view name() const override { return "inline"; } // Run inlining on the given computation. Returns whether the computation was // changed. diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index be59ce82816c1c30e079449599406705a55400c0..83313c7ec1868677190b0671c411f8c82535f590 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -122,6 +122,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kCustomCall: case HloOpcode::kDivide: case HloOpcode::kDomain: @@ -189,13 +190,13 @@ bool InstructionFusion::CanFuseOnAllPaths( if (consumer == producer) { return true; } - if (!consumer->IsFusable()) { + if (!consumer->IsFusible()) { return false; } for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { auto* consumer_operand = consumer->mutable_operand(i); // If the operand is not on a path to the producer, it doesn't matter - // whether it's fusable. + // whether it's fusible. if (!reachability_->IsReachable(producer, consumer_operand)) { continue; } @@ -205,7 +206,7 @@ bool InstructionFusion::CanFuseOnAllPaths( } // The producer is reachable from consumer_operand which means we need // to be able to fuse consumer_operand into consumer in order for - // producer to be fusable into consumer on all paths. + // producer to be fusible into consumer on all paths. // Perform the recursive step: make sure producer can be fused into // consumer_operand on all paths. if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) { @@ -216,7 +217,7 @@ bool InstructionFusion::CanFuseOnAllPaths( } InstructionFusion::HloInstructionSet -InstructionFusion::ComputeGloballyUnfusable( +InstructionFusion::ComputeGloballyUnfusible( tensorflow::gtl::ArraySlice post_order) { // Forbid fusion of producers that: // a) Need to be duplicated, unless they can be fused into all consumers @@ -270,19 +271,19 @@ InstructionFusion::ComputeGloballyUnfusable( // all of its consumers on all paths. // // That means, that for: - // A --> B (fusable) - // \-> C (non-fusable) + // A --> B (fusible) + // \-> C (non-fusible) // A will be not allowed to be fused into B, as it cannot be fused into C. // // Similarly, for: // A -------------> B // \-> C -> D -/ // If: - // - A is fusable into B and C, and D is fusable into B - // - C is *not* fusable into D + // - A is fusible into B and C, and D is fusible into B + // - C is *not* fusible into D // A will be not allowed to be fused into B, as it cannot be fused via // all paths. - if (producer->IsFusable() && + if (producer->IsFusible() && CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) { continue; } @@ -318,7 +319,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { InsertOrDie(&post_order_index, post_order[i], i); } - HloInstructionSet do_not_duplicate = ComputeGloballyUnfusable(post_order); + HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(post_order); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all @@ -341,7 +342,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { // consistent. post_order_index.erase(instruction); - if (!instruction->IsFusable() && + if (!instruction->IsFusible() && instruction->opcode() != HloOpcode::kFusion) { continue; } @@ -413,7 +414,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); - if (!operand->IsFusable()) { + if (!operand->IsFusible()) { continue; } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index f73ca9adf768ed26f9ec9f162e01b7b160f50daf..9802d4cfc1b2f4b221a4bc2827bfa90ff023b200 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -36,7 +36,7 @@ class InstructionFusion : public HloPassInterface { bool may_duplicate = true) : is_expensive_(is_expensive), may_duplicate_(may_duplicate) {} ~InstructionFusion() override = default; - tensorflow::StringPiece name() const override { return "fusion"; } + absl::string_view name() const override { return "fusion"; } // Run instruction fusion on the given computation. Returns whether the // computation was changed (instructions were fused). @@ -122,7 +122,7 @@ class InstructionFusion : public HloPassInterface { // Computes the set of nodes that we do not want to fuse into any of their // consumers based on a global analysis of the HLO graph. - HloInstructionSet ComputeGloballyUnfusable( + HloInstructionSet ComputeGloballyUnfusible( tensorflow::gtl::ArraySlice post_order); // Used to determine if an HLO is expensive. Expensive operations will not be diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 9e7a15f0330d3f06779c850a4b575f84fe0b9505..da1ad90959dc0ab1a840b3390281ce9d4999651e 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -158,7 +158,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { .ValueOrDie()); } -TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { +TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusible) { HloComputation::Builder builder(TestName()); auto shape = ShapeUtil::MakeShape(F32, {16, 16}); auto param0 = @@ -216,7 +216,7 @@ TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); } -TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { +TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { // Make sure we do not duplicate the add, as we cannot fuse through the rng. // // p0 -> add -------------------------> sub @@ -309,7 +309,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); // A variant of the above that allows the algorithm to put add2 into the set - // of unfusable ops to short-circuit the decision whether add1 should be fused + // of unfusible ops to short-circuit the decision whether add1 should be fused // into sub2. // // /---------------\ diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index e57a9b3672391e11b130b1c16307a80a0a5b5e77..c9b40d3c6195f80a19272a0d98890049d02315b9 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/ptr_util.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status_macros.h" -#include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/platform.h" @@ -77,9 +77,9 @@ XlaInterpreterPlatform::GetUncachedExecutor( if (!init_status.ok()) { return port::Status{ port::error::INTERNAL, - port::Printf( + absl::StrFormat( "failed initializing StreamExecutor for device ordinal %d: %s", - config.ordinal, init_status.ToString().c_str())}; + config.ordinal, init_status.ToString())}; } return std::move(executor); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index c75bffc63d71c8018ad71b035d4e9a0886c0f4a6..5e5c93e3a21b55cb39ce4a0112ea83ba0cd29e88 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -27,6 +27,9 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -49,20 +52,11 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" namespace xla { -// For now moving only one API here, but we should have a single top level -// anonymous namespace, instead of three or four spread all over this file. -namespace { - -} // namespace - std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint) { out << constraint.ToString(); @@ -77,9 +71,8 @@ BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout, } string BufferLayoutConstraint::ToString() const { - return tensorflow::strings::Printf("BufferLayoutConstraint %s: %s", - buffer_->ToString().c_str(), - LayoutUtil::HumanString(layout_).c_str()); + return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(), + LayoutUtil::HumanString(layout_)); } OperandLayoutConstraint::OperandLayoutConstraint( @@ -98,15 +91,14 @@ OperandLayoutConstraint::OperandLayoutConstraint( } string OperandLayoutConstraint::ToString() const { - return tensorflow::strings::Printf( - "OperandLayoutConstraint %s, operand %lld: %s", - instruction_->name().c_str(), operand_no_, - shape_layout_.ToString().c_str()); + return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s", + instruction_->name(), operand_no_, + shape_layout_.ToString()); } string ResultLayoutConstraint::ToString() const { - return tensorflow::strings::Printf("ResultLayoutConstraint: %s", - shape_layout_.ToString().c_str()); + return absl::StrFormat("ResultLayoutConstraint: %s", + shape_layout_.ToString()); } LayoutConstraints::LayoutConstraints( @@ -174,8 +166,7 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, return FailedPrecondition( "Layout of buffer %s cannot be constrained because buffer is not " "array-shaped, has shape: %s", - buffer.ToString().c_str(), - ShapeUtil::HumanString(buffer.shape()).c_str()); + buffer.ToString(), ShapeUtil::HumanString(buffer.shape())); } TF_RETURN_IF_ERROR( LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); @@ -191,9 +182,8 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, return FailedPrecondition( "Buffer %s already has the layout constraint %s, cannot add " "incompatible constraint %s", - buffer.ToString().c_str(), - LayoutUtil::HumanString(curr_constraint.layout()).c_str(), - LayoutUtil::HumanString(layout).c_str()); + buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()), + LayoutUtil::HumanString(layout)); } iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); } else { @@ -227,11 +217,11 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, } if (curr_shape_layout->mandatory()) { return FailedPrecondition( - "Operand %lld of instruction %s already has a layout constraint " + "Operand %d of instruction %s already has a layout constraint " "%s, cannot add incompatible constraint %s", - operand_no, instruction->name().c_str(), - curr_shape_layout->shape_layout().ToString().c_str(), - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + operand_no, instruction->name(), + curr_shape_layout->shape_layout().ToString(), + ShapeUtil::HumanStringWithLayout(shape_with_layout)); } } @@ -240,9 +230,9 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, // layouts beyond this immediate use and is complicated to handle. if (OperandBufferForwarded(instruction, operand_no)) { return FailedPrecondition( - "Cannot constraint layout of operand %lld of instruction %s " + "Cannot constraint layout of operand %d of instruction %s " "because instruction forwards operand's LogicalBuffer(s)", - operand_no, instruction->name().c_str()); + operand_no, instruction->name()); } auto key = std::make_pair(instruction, operand_no); @@ -284,8 +274,8 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout, return FailedPrecondition( "Result of computation %s already has the layout constraint %s, " "cannot add incompatible constraint %s", - computation_->name().c_str(), curr_shape_layout->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + computation_->name(), curr_shape_layout->ToString(), + ShapeUtil::HumanStringWithLayout(shape_with_layout)); } // New constraint matches existing constraint. Nothing to do. return Status::OK(); @@ -307,9 +297,8 @@ Status LayoutConstraints::SetInstructionLayout( if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) { return FailedPrecondition( "Instruction %s of shape %s cannot be assigned incompatible layout %s", - instruction->name().c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + instruction->name(), ShapeUtil::HumanString(instruction->shape()), + ShapeUtil::HumanStringWithLayout(shape_with_layout)); } // Create a BufferLayoutConstraint for each array shape in the output of the @@ -368,31 +357,27 @@ const ShapeLayout* LayoutConstraints::ResultLayout() const { string LayoutConstraints::ToString() const { string output; - tensorflow::strings::StrAppend(&output, "LayoutConstraints for computation ", - computation_->name(), ":\n"); + absl::StrAppend(&output, "LayoutConstraints for computation ", + computation_->name(), ":\n"); for (auto* instruction : computation_->MakeInstructionPostOrder()) { - tensorflow::strings::StrAppend(&output, " ", instruction->ToShortString(), - "\n"); + absl::StrAppend(&output, " ", instruction->ToShortString(), "\n"); for (int64 i = 0; i < instruction->operand_count(); ++i) { if (OperandLayout(instruction, i) != nullptr) { - tensorflow::strings::StrAppend( - &output, " operand (", i, - "): ", OperandLayout(instruction, i)->ToString(), "\n"); + absl::StrAppend(&output, " operand (", i, + "): ", OperandLayout(instruction, i)->ToString(), "\n"); } } for (const LogicalBuffer* buffer : points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { if (BufferLayout(*buffer) != nullptr) { - tensorflow::strings::StrAppend( - &output, " ", buffer->ToString(), " : ", - LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n"); + absl::StrAppend(&output, " ", buffer->ToString(), " : ", + LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n"); } } } if (ResultLayout() != nullptr) { - tensorflow::strings::StrAppend(&output, " => ", ResultLayout()->ToString(), - "\n"); + absl::StrAppend(&output, " => ", ResultLayout()->ToString(), "\n"); } return output; } @@ -763,7 +748,7 @@ Status CheckParameterLayout(HloInstruction* parameter, return InternalError( "parameter instruction %s does not match layout of computation " "shape: %s", - parameter->ToString().c_str(), parameter_layout.ToString().c_str()); + parameter->ToString(), parameter_layout.ToString()); } return Status::OK(); } @@ -774,8 +759,8 @@ Status CheckConstantLayout(HloInstruction* constant) { constant->shape())) { return InternalError( "constant instruction %s does not match the layout of its literal %s", - constant->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(constant->literal().shape()).c_str()); + constant->ToString(), + ShapeUtil::HumanStringWithLayout(constant->literal().shape())); } return Status::OK(); } @@ -908,13 +893,10 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { return InternalError( "Layout of instruction %s at index {%s} does not match " "source LogicalBuffer %s: %s vs %s", - instruction->name().c_str(), - tensorflow::str_util::Join(index, ",").c_str(), - buffer->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(instruction_subshape) - .c_str(), - ShapeUtil::HumanStringWithLayout(buffer->shape()) - .c_str()); + instruction->name(), absl::StrJoin(index, ","), + buffer->ToString(), + ShapeUtil::HumanStringWithLayout(instruction_subshape), + ShapeUtil::HumanStringWithLayout(buffer->shape())); } } } @@ -998,16 +980,17 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( CHECK(ShapeUtil::IsArray(instruction->shape())); CHECK(ShapeUtil::IsArray(operand->shape())); - if (instruction->IsElementwiseOnOperand(operand_no) && - !ShapeUtil::IsScalar(operand->shape()) && + if (!ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == - ShapeUtil::Rank(instruction->shape())) { - // Assign operands the same layout as the instruction, so that + ShapeUtil::Rank(instruction->shape()) && + InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) { + // Propagate the result layout to the operand layout if the instruction + // requires the same layout out for the result and the operand. + // + // For elementwise operations, using the same layout for the operands and + // the result also has the following benefits: // 1) the elementwise operation can reuse its operand's buffer, and // 2) the input and output elements can reuse the same linear index. - // - // TODO(jingyue): Other operations, such as kSlice and kConcat, can benefit - // from assigning the same layout to input and output. return absl::make_unique(output_layout); } @@ -1076,9 +1059,9 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( CHECK(ShapeUtil::IsArray(user->shape()) && ShapeUtil::IsArray(operand->shape())); - if (user->IsElementwiseOnOperand(operand_no) && - !ShapeUtil::IsScalar(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape())) { + if (!ShapeUtil::IsScalar(operand->shape()) && + ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) && + InstructionRequiresInputLayoutEqualToOutputLayout(user)) { // Assign users the same layout as the operand. return absl::make_unique(operand_layout); } @@ -1385,7 +1368,7 @@ StatusOr InferArrayLayout( // This should not happen because we've assigned layouts to all // instructions preceding this one. return InternalError("LogicalBuffer %s does not have a layout", - source_buffer->ToString().c_str()); + source_buffer->ToString()); } if (first_buffer_layout == nullptr) { @@ -1400,9 +1383,8 @@ StatusOr InferArrayLayout( return FailedPrecondition( "Array at index {%s} in instruction %s aliases buffers %s " "and %s which have different layouts", - tensorflow::str_util::Join(index, ",").c_str(), - instruction->name().c_str(), source_buffers[0]->ToString().c_str(), - source_buffer->ToString().c_str()); + absl::StrJoin(index, ","), instruction->name(), + source_buffers[0]->ToString(), source_buffer->ToString()); } } @@ -1570,7 +1552,7 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // present in the IR before layout assignment is a bug. return InternalError( "Unexpected bitcast operation seen during layout assignment: %s.", - instruction->ToString().c_str()); + instruction->ToString()); } if (instruction->opcode() != HloOpcode::kInfeed) { LayoutUtil::ClearLayout(instruction->mutable_shape()); @@ -1822,6 +1804,107 @@ StatusOr LayoutAssignment::Run(HloModule* module) { return true; } +bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( + const HloInstruction* instruction) { + switch (instruction->opcode()) { + case HloOpcode::kAbs: + case HloOpcode::kAdd: + case HloOpcode::kAnd: + case HloOpcode::kAtan2: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCeil: + case HloOpcode::kClamp: + case HloOpcode::kClz: + case HloOpcode::kComplex: + case HloOpcode::kConcatenate: + case HloOpcode::kConditional: + case HloOpcode::kConvert: + case HloOpcode::kCos: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: + case HloOpcode::kCustomCall: + case HloOpcode::kDivide: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kEq: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFft: + case HloOpcode::kFloor: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLe: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kLt: + case HloOpcode::kMap: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kNegate: + case HloOpcode::kNot: + case HloOpcode::kOr: + case HloOpcode::kXor: + case HloOpcode::kPad: + case HloOpcode::kPower: + case HloOpcode::kReal: + case HloOpcode::kReducePrecision: + case HloOpcode::kReduceWindow: + case HloOpcode::kRemainder: + case HloOpcode::kReverse: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kSelect: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSlice: + case HloOpcode::kSort: + case HloOpcode::kSubtract: + case HloOpcode::kTanh: + case HloOpcode::kTupleSelect: + case HloOpcode::kWhile: + return true; + case HloOpcode::kBatchNormGrad: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kCall: + case HloOpcode::kConstant: + case HloOpcode::kConvolution: + case HloOpcode::kCopy: + case HloOpcode::kDomain: + case HloOpcode::kDot: + case HloOpcode::kFusion: + case HloOpcode::kGather: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kIota: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReduce: + case HloOpcode::kReshape: + case HloOpcode::kRng: + case HloOpcode::kScatter: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kAfterAll: + case HloOpcode::kTrace: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + return false; + } +} + Status LayoutAssignment::Init() { computation_layouts_.clear(); *entry_computation_layout_ = saved_entry_computation_layout_; diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index f9e8dbea2f8aa224318adf3cf4b5e493792d3093..cf545031d3c7c66770ea4a2392a2df3b8c24cd38 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -297,12 +297,17 @@ class LayoutAssignment : public HloPassInterface { ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment() override {} - tensorflow::StringPiece name() const override { return "layout-assignment"; } + absl::string_view name() const override { return "layout-assignment"; } // Assign layouts to the given module. Returns whether the module was changed // (any layouts were changed). StatusOr Run(HloModule* module) override; + // Returns true if the instruction requires that operands with the same rank + // as the output have to have the same layout as the output. + virtual bool InstructionRequiresInputLayoutEqualToOutputLayout( + const HloInstruction* instruction); + protected: // These methods, invoked by PropagateConstraints, propagate a layout // constraint to its neighbors (i.e. operands and users) in order to minimize diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index a16fa75e3032cfa4257d9b5608dd176fdb4ddbdb..7505d7a5b35fc437592ce842c79731beada04053 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -59,7 +59,7 @@ class LayoutAssignmentTest : public HloTestBase { EXPECT_IS_OK(layout_assignment.Run(module).status()); } - std::vector LayoutOf(HloModule* module, tensorflow::StringPiece name) { + std::vector LayoutOf(HloModule* module, absl::string_view name) { auto minor_to_major = FindInstruction(module, name)->shape().layout().minor_to_major(); return std::vector(minor_to_major.begin(), minor_to_major.end()); @@ -861,5 +861,115 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); } +TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { + const char* module_str = R"( + HloModule CopySliceOperandToAvoidImplicitLayoutChange + + ENTRY CopySliceOperandToAvoidImplicitLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[4,5]{0,1} parameter(1) + slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]} + ROOT add0 = f32[3,4]{1,0} add(par0,slice0) + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + auto copy = FindInstruction(module.get(), "copy.1"); + auto slice = FindInstruction(module.get(), "slice0"); + EXPECT_EQ(slice->operand(0), copy); + EXPECT_TRUE( + LayoutUtil::Equal(slice->shape().layout(), copy->shape().layout())); +} + +TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { + const char* module_str = R"( + HloModule CopyDSliceOperandToAvoidImplicitLayoutChange + + ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[4,5]{0,1} parameter(1) + par2 = s32[2] parameter(2) + dslice0 = f32[3,4] dynamic-slice(par1, par2), dynamic_slice_sizes={3,4} + ROOT add0 = f32[3,4]{1,0} add(par0,dslice0) + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + auto copy = FindInstruction(module.get(), "copy.1"); + auto dslice = FindInstruction(module.get(), "dslice0"); + EXPECT_EQ(dslice->operand(0), copy); + EXPECT_TRUE( + LayoutUtil::Equal(dslice->shape().layout(), copy->shape().layout())); +} + +TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { + const char* module_str = R"( + HloModule CopyConcatOperandToAvoidImplicitLayoutChange + + ENTRY CopyConcatOperandToAvoidImplicitLayoutChange { + par0 = f32[3,8]{1,0} parameter(0) + par1 = f32[3,5]{0,1} parameter(1) + par2 = f32[3,3]{1,0} parameter(2) + concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2), + dimensions={1} + ROOT add0 = f32[3,8]{1,0} add(par0,concat0) + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + auto copy = FindInstruction(module.get(), "copy.1"); + auto concat = FindInstruction(module.get(), "concat0"); + EXPECT_EQ(concat->operand(0), copy); + EXPECT_TRUE( + LayoutUtil::Equal(concat->shape().layout(), copy->shape().layout())); +} + +TEST_F(LayoutAssignmentTest, + ConvolutionOperandWithImplicitLayoutChangeNotCopied) { + const char* module_str = R"( + HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied + + ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied { + par0 = f32[128,3,230,230]{2,3,1,0} parameter(0) + par1 = f32[7,7,3,64]{3,2,0,1} parameter(1) + ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1), + window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01, + feature_group_count=1 + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + auto copy = FindInstruction(module.get(), "copy.1"); + EXPECT_EQ(copy, nullptr); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 539a9522c173977716a032feec7824e998febae9..be12d7c90ccfc90f0721458d3af600f7ddc823ff 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -38,6 +38,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -69,6 +70,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:support", "@llvm//:target", @@ -89,6 +91,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -104,6 +107,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -121,6 +125,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", "@llvm//:core", ], ) @@ -192,6 +197,7 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", "//tensorflow/compiler/xla/service/gpu:partition_assignment", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@llvm//:core", ], @@ -219,7 +225,7 @@ cc_library( deps = [ ":llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -230,6 +236,7 @@ cc_library( hdrs = ["buffer_assignment_util.h"], deps = [ "//tensorflow/compiler/xla/service:buffer_assignment", + "@com_google_absl//absl/strings", ], ) @@ -242,3 +249,12 @@ cc_library( "@llvm//:core", ], ) + +cc_library( + name = "ir_builder_mixin", + srcs = [], + hdrs = ["ir_builder_mixin.h"], + deps = [ + "@llvm//:core", + ], +) diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index fe9eab93aae95557e3ee27a64c09b78f37ac2348..8d9fa99d82b4e49b653d9f05cc9baa5e3fdcefa6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ +#include "absl/strings/str_cat.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace llvm_ir { diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc index 4eb5d9fb4750927ca189e02f312b2d6be7fdd418..bdce4a171b8a58f617f1d56e6cf6db5354846703 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" +#include "absl/strings/str_cat.h" namespace xla { namespace llvm_ir { @@ -48,7 +49,7 @@ string ConstantBufferAllocationToGlobalName( c = '_'; } } - return tensorflow::strings::StrCat("buffer_for_", instr_name); + return absl::StrCat("buffer_for_", instr_name); } const Literal& LiteralForConstantAllocation( diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 27fbb11e2ede66a1268e7e949634b2c7d29cbc1c..ad350613dd23f4a477c422a6311f1b03bc681574 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -40,7 +40,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( const Shape& update_shape, const ElementGenerator& start_indices_generator, bool is_signed, ElementGenerator update_array_generator, const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions, - tensorflow::StringPiece name, llvm::IRBuilder<>* b) { + absl::string_view name, llvm::IRBuilder<>* b) { const Shape& output_shape = output_array.GetShape(); // Read start indices from start_indices_generator. @@ -101,8 +101,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( Status EmitDynamicUpdateSliceInPlace( tensorflow::gtl::ArraySlice operand_arrays, - const IrArray& output_array, tensorflow::StringPiece name, - llvm::IRBuilder<>* b) { + const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b) { VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name; // No need to use operand_arrays[0], the input array of the diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h index 3502577d236a099e0b721b98217b758696966821..e1631a62ae8486f03a4fe8fcb32f1b49d5dd2339 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -65,8 +65,7 @@ inline bool CanEmitFusedDynamicUpdateSliceInPlace( // modify the input/output buffer without touching any of the other elements. Status EmitDynamicUpdateSliceInPlace( tensorflow::gtl::ArraySlice operand_arrays, - const IrArray& output_array, tensorflow::StringPiece name, - llvm::IRBuilder<>* b); + const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b); // Given a loop-fusion node whose root is a dynamic-update-slice op whose // array-to-be-updated and output share the same buffer slice, emits diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 72ede377e1a505d5e4916915e18827e1a0f3fdf9..6d637cad6df6e8913167329d59c8a589311c32c9 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -98,7 +98,7 @@ Status FusedIrEmitter::HandleGetTupleElement( return Unimplemented( "GetTupleElement fusion currently only supports" " parameter operands, but found operand: %s", - operand->name().c_str()); + operand->name()); } // Emit code to lookup tuple element pointer, and store it in 'gte_values_'. llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement( diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 2b6caee6aa72f426cf85c8c56c3ef500ff8c5d3d..6971220022d9d3fe5caded731977df4dfffd2992 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -342,9 +342,9 @@ llvm::Value* IrArray::Index::Linearize( return logical_linear_index; } -llvm::Value* IrArray::EmitArrayElementAddress( - const IrArray::Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name) const { +llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, + llvm::IRBuilder<>* b, + absl::string_view name) const { if (ShapeUtil::IsScalar(*shape_)) { // Special handling of scalars: a scalar pretends to have the same value for // every index, thus effectively implementing broadcasting of its value @@ -402,7 +402,7 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata( llvm::Value* IrArray::EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name) const { + absl::string_view name) const { llvm::Value* element_address = EmitArrayElementAddress(index, b, name); llvm::LoadInst* load = b->CreateLoad(element_address); AnnotateLoadStoreInstructionWithMetadata(load); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index cbfd2e701235c9a5e65378eab4e1be469b1e9256..e913c109b3ff0e4e7192e501a314aa381a4268b0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -20,12 +20,12 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -241,7 +241,7 @@ class IrArray { // The optional name is useful for debugging when looking at // the emitted LLVM IR. llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name = "") const; + absl::string_view name = "") const; // Attach metadata this IrArray instance knows about to "instruction". void AnnotateLoadStoreInstructionWithMetadata( @@ -255,7 +255,7 @@ class IrArray { // The optional name is useful for debugging when looking at // the emitted LLVM IR. llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name = "") const; + absl::string_view name = "") const; // Emit IR to write the given value to the array element at the given index. void EmitWriteArrayElement(const Index& index, llvm::Value* value, diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h new file mode 100644 index 0000000000000000000000000000000000000000..abc06fb7b4245294df2dc20d25a22ac4fdaeb4cf --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h @@ -0,0 +1,400 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_ + +#include "llvm/IR/IRBuilder.h" + +namespace xla { + +// Mixin class that injects more ergonomic versions of llvm::IRBuilder methods +// into a class. Intended to be used as a CRTP base class, like: +// +// class MyIrEmitter : public IrBuilderMixin { +// llvm::IRBuilder<>* builder() { return builder_; } +// +// void EmitFoo(HloInstruction* foo) { +// Add(Mul(...), FPToUI(...)); +// } +// }; + +template +class IrBuilderMixin { + protected: + template + llvm::Value* Add(Args&&... args) { + return mixin_builder()->CreateAdd(std::forward(args)...); + } + + template + llvm::LoadInst* AlignedLoad(Args&&... args) { + return mixin_builder()->CreateAlignedLoad(std::forward(args)...); + } + + template + llvm::StoreInst* AlignedStore(Args&&... args) { + return mixin_builder()->CreateAlignedStore(std::forward(args)...); + } + + template + llvm::AllocaInst* Alloca(Args&&... args) { + return mixin_builder()->CreateAlloca(std::forward(args)...); + } + + template + llvm::Value* And(Args&&... args) { + return mixin_builder()->CreateAnd(std::forward(args)...); + } + + template + llvm::Value* AtomicCmpXchg(Args&&... args) { + return mixin_builder()->CreateAtomicCmpXchg(std::forward(args)...); + } + + template + llvm::Value* AtomicRMW(Args&&... args) { + return mixin_builder()->CreateAtomicRMW(std::forward(args)...); + } + + template + llvm::Value* BitCast(Args&&... args) { + return mixin_builder()->CreateBitCast(std::forward(args)...); + } + + template + llvm::Value* Br(Args&&... args) { + return mixin_builder()->CreateBr(std::forward(args)...); + } + + llvm::CallInst* Call(llvm::Value* callee, + llvm::ArrayRef args = llvm::None, + const llvm::Twine& name = "", + llvm::MDNode* fp_math_tag = nullptr) { + return mixin_builder()->CreateCall(callee, args, name, fp_math_tag); + } + + template + llvm::BranchInst* CondBr(Args&&... args) { + return mixin_builder()->CreateCondBr(std::forward(args)...); + } + + template + llvm::Value* ConstInBoundsGEP1_32(Args&&... args) { + return mixin_builder()->CreateConstInBoundsGEP1_32( + std::forward(args)...); + } + + template + llvm::Value* FAdd(Args&&... args) { + return mixin_builder()->CreateFAdd(std::forward(args)...); + } + + template + llvm::Value* FMul(Args&&... args) { + return mixin_builder()->CreateFMul(std::forward(args)...); + } + + llvm::Value* GEP(llvm::Value* ptr, llvm::ArrayRef idx_list, + const llvm::Twine& name = "") { + return mixin_builder()->CreateGEP(ptr, idx_list, name); + } + + template + llvm::Value* ICmpEQ(Args&&... args) { + return mixin_builder()->CreateICmpEQ(std::forward(args)...); + } + + template + llvm::Value* ICmpNE(Args&&... args) { + return mixin_builder()->CreateICmpNE(std::forward(args)...); + } + + template + llvm::Value* ICmpULE(Args&&... args) { + return mixin_builder()->CreateICmpULE(std::forward(args)...); + } + + template + llvm::Value* ICmpULT(Args&&... args) { + return mixin_builder()->CreateICmpULT(std::forward(args)...); + } + + llvm::Value* InBoundsGEP(llvm::Value* ptr, + llvm::ArrayRef idx_list, + const llvm::Twine& name = "") { + return mixin_builder()->CreateInBoundsGEP(ptr, idx_list, name); + } + + llvm::Value* ExtractValue(llvm::Value* agg, llvm::ArrayRef idxs, + const llvm::Twine& name = "") { + return mixin_builder()->CreateExtractValue(agg, idxs, name); + } + + llvm::Value* InsertValue(llvm::Value* agg, llvm::Value* val, + llvm::ArrayRef idxs, + const llvm::Twine& name = "") { + return mixin_builder()->CreateInsertValue(agg, val, idxs, name); + } + + template + llvm::Value* IntToPtr(Args&&... args) { + return mixin_builder()->CreateIntToPtr(std::forward(args)...); + } + + template + llvm::LoadInst* Load(Args&&... args) { + return mixin_builder()->CreateLoad(std::forward(args)...); + } + + template + llvm::CallInst* MemCpy(Args&&... args) { + return mixin_builder()->CreateMemCpy(std::forward(args)...); + } + + template + llvm::Value* Mul(Args&&... args) { + return mixin_builder()->CreateMul(std::forward(args)...); + } + + template + llvm::Value* NSWAdd(Args&&... args) { + return mixin_builder()->CreateNSWAdd(std::forward(args)...); + } + + template + llvm::Value* NSWMul(Args&&... args) { + return mixin_builder()->CreateNSWMul(std::forward(args)...); + } + + template + llvm::Value* NSWSub(Args&&... args) { + return mixin_builder()->CreateNSWSub(std::forward(args)...); + } + + template + llvm::Value* Or(Args&&... args) { + return mixin_builder()->CreateOr(std::forward(args)...); + } + + template + llvm::Value* PointerCast(Args&&... args) { + return mixin_builder()->CreatePointerCast(std::forward(args)...); + } + + template + llvm::Value* PtrToInt(Args&&... args) { + return mixin_builder()->CreatePtrToInt(std::forward(args)...); + } + + template + llvm::Value* SDiv(Args&&... args) { + return mixin_builder()->CreateSDiv(std::forward(args)...); + } + + template + llvm::Value* Select(Args&&... args) { + return mixin_builder()->CreateSelect(std::forward(args)...); + } + + template + llvm::Value* SRem(Args&&... args) { + return mixin_builder()->CreateSRem(std::forward(args)...); + } + + template + llvm::StoreInst* Store(Args&&... args) { + return mixin_builder()->CreateStore(std::forward(args)...); + } + + template + llvm::Value* UDiv(Args&&... args) { + return mixin_builder()->CreateUDiv(std::forward(args)...); + } + + template + llvm::Value* URem(Args&&... args) { + return mixin_builder()->CreateURem(std::forward(args)...); + } + + template + llvm::Value* VectorSplat(Args&&... args) { + return mixin_builder()->CreateVectorSplat(std::forward(args)...); + } + + template + llvm::Value* ZExtOrTrunc(Args&&... args) { + return mixin_builder()->CreateZExtOrTrunc(std::forward(args)...); + } + + template + llvm::Value* AShr(Args&&... args) { + return mixin_builder()->CreateAShr(std::forward(args)...); + } + + template + llvm::Value* FCmpOEQ(Args&&... args) { + return mixin_builder()->CreateFCmpOEQ(std::forward(args)...); + } + + template + llvm::Value* FCmpOLT(Args&&... args) { + return mixin_builder()->CreateFCmpOLT(std::forward(args)...); + } + + template + llvm::Value* FCmpONE(Args&&... args) { + return mixin_builder()->CreateFCmpONE(std::forward(args)...); + } + + template + llvm::Value* FCmpUNE(Args&&... args) { + return mixin_builder()->CreateFCmpUNE(std::forward(args)...); + } + + template + llvm::Value* FDiv(Args&&... args) { + return mixin_builder()->CreateFDiv(std::forward(args)...); + } + + template + llvm::Value* FNeg(Args&&... args) { + return mixin_builder()->CreateFNeg(std::forward(args)...); + } + + template + llvm::Value* FPCast(Args&&... args) { + return mixin_builder()->CreateFPCast(std::forward(args)...); + } + + template + llvm::Value* FPToSI(Args&&... args) { + return mixin_builder()->CreateFPToSI(std::forward(args)...); + } + + template + llvm::Value* FPToUI(Args&&... args) { + return mixin_builder()->CreateFPToUI(std::forward(args)...); + } + + template + llvm::Value* FPTrunc(Args&&... args) { + return mixin_builder()->CreateFPTrunc(std::forward(args)...); + } + + template + llvm::Value* FRem(Args&&... args) { + return mixin_builder()->CreateFRem(std::forward(args)...); + } + + template + llvm::Value* FSub(Args&&... args) { + return mixin_builder()->CreateFSub(std::forward(args)...); + } + + template + llvm::Value* ICmpSGE(Args&&... args) { + return mixin_builder()->CreateICmpSGE(std::forward(args)...); + } + + template + llvm::Value* ICmpSLT(Args&&... args) { + return mixin_builder()->CreateICmpSLT(std::forward(args)...); + } + + template + llvm::Value* IntCast(Args&&... args) { + return mixin_builder()->CreateIntCast(std::forward(args)...); + } + + template + llvm::Value* LShr(Args&&... args) { + return mixin_builder()->CreateLShr(std::forward(args)...); + } + + template + llvm::Value* MemSet(Args&&... args) { + return mixin_builder()->CreateMemSet(std::forward(args)...); + } + + template + llvm::Value* Neg(Args&&... args) { + return mixin_builder()->CreateNeg(std::forward(args)...); + } + + template + llvm::Value* Not(Args&&... args) { + return mixin_builder()->CreateNot(std::forward(args)...); + } + + template + llvm::PHINode* PHI(Args&&... args) { + return mixin_builder()->CreatePHI(std::forward(args)...); + } + + template + llvm::Value* RetVoid(Args&&... args) { + return mixin_builder()->CreateRetVoid(std::forward(args)...); + } + + template + llvm::Value* SExtOrTrunc(Args&&... args) { + return mixin_builder()->CreateSExtOrTrunc(std::forward(args)...); + } + + template + llvm::Value* Shl(Args&&... args) { + return mixin_builder()->CreateShl(std::forward(args)...); + } + + template + llvm::Value* SIToFP(Args&&... args) { + return mixin_builder()->CreateSIToFP(std::forward(args)...); + } + + template + llvm::Value* Sub(Args&&... args) { + return mixin_builder()->CreateSub(std::forward(args)...); + } + + template + llvm::Value* Trunc(Args&&... args) { + return mixin_builder()->CreateTrunc(std::forward(args)...); + } + + template + llvm::Value* UIToFP(Args&&... args) { + return mixin_builder()->CreateUIToFP(std::forward(args)...); + } + + template + llvm::Value* Unreachable(Args&&... args) { + return mixin_builder()->CreateUnreachable(std::forward(args)...); + } + + template + llvm::Value* Xor(Args&&... args) { + return mixin_builder()->CreateXor(std::forward(args)...); + } + + private: + llvm::IRBuilder<>* mixin_builder() { + return static_cast(this)->builder(); + } +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index b79567369aa532c4963e3941f6cb9844cd1476dd..bd0139f85b6a5c5dc23dad962263038451921e65 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -19,7 +19,7 @@ limitations under the License. namespace xla { Status KernelSupportLibrary::For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { return If(b_->CreateICmpSLT(start, end), [&]() -> Status { @@ -30,7 +30,7 @@ Status KernelSupportLibrary::For( } Status KernelSupportLibrary::For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function& for_body_generator) { @@ -56,7 +56,7 @@ Status KernelSupportLibrary::For( } Status KernelSupportLibrary::If( - tensorflow::StringPiece name, llvm::Value* condition, + absl::string_view name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, name, b_); @@ -70,7 +70,7 @@ Status KernelSupportLibrary::If( void KernelSupportLibrary::EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, + absl::string_view kernel_name, KernelSupportLibrary::ArgumentVector arguments, const std::function& kernel_body_generator) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index c5354a8c427e503f591ba724eee295d1c51cfc13..b152cf9275c86ece2e049d193c45e07db22a1170 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -18,12 +18,12 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { // A thin wrapper around llvm_loop.h to make code generating structured control @@ -49,13 +49,13 @@ class KernelSupportLibrary { // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } Status For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator); void ForReturnVoid( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { @@ -67,7 +67,7 @@ class KernelSupportLibrary { })); } - Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step, + Status For(absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { @@ -77,7 +77,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, int64 start, int64 end, int64 step, + absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { ForReturnVoid(name, /*start=*/b_->getInt64(start), @@ -99,13 +99,13 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function& for_body_generator); - void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + void ForReturnVoid(absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function& @@ -129,7 +129,7 @@ class KernelSupportLibrary { peel_first_iteration, for_body_generator); } - void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + void ForReturnVoid(absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, bool peel_first_iteration, const std::function& @@ -140,7 +140,7 @@ class KernelSupportLibrary { } Status For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { return For(name, start, end, step, @@ -151,7 +151,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { ForReturnVoid(name, start, end, step, @@ -162,8 +162,7 @@ class KernelSupportLibrary { } Status For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, + absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { return For(name, start, end, llvm::ConstantInt::get(start->getType(), step), /*peel_first_iteration=*/false, @@ -173,8 +172,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, + absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { ForReturnVoid(name, start, end, llvm::ConstantInt::get(start->getType(), step), @@ -182,7 +180,7 @@ class KernelSupportLibrary { } Status For( - tensorflow::StringPiece name, int64 start, int64 end, int64 step, + absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { return For(name, /*start=*/b_->getInt64(start), /*end=*/b_->getInt64(end), @@ -190,7 +188,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, int64 start, int64 end, int64 step, + absl::string_view name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { ForReturnVoid(name, /*start=*/b_->getInt64(start), /*end=*/b_->getInt64(end), @@ -203,7 +201,7 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - Status If(tensorflow::StringPiece name, llvm::Value* condition, + Status If(absl::string_view name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator = []() -> Status { return Status::OK(); }); @@ -222,7 +220,7 @@ class KernelSupportLibrary { IfReturnVoid("", condition, true_block_generator, false_block_generator); } - void IfReturnVoid(tensorflow::StringPiece name, llvm::Value* condition, + void IfReturnVoid(absl::string_view name, llvm::Value* condition, const std::function& true_block_generator, const std::function& false_block_generator = []() { }) { @@ -259,13 +257,13 @@ class KernelSupportLibrary { // Currently we only support at most one nullptr value in `arguments`. static void EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, ArgumentVector arguments, + absl::string_view kernel_name, ArgumentVector arguments, const std::function& kernel_body_generator); // Thin wrappers around the more general EmitAndCallOutlinedKernel above. static void EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1, + absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, const std::function& kernel_body_generator) { @@ -278,7 +276,7 @@ class KernelSupportLibrary { static void EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1, + absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, llvm::Value* arg3, const std::function& kernel_body_generator) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index ba7f94834c7fd04d97cec012537244323308b8ce..9f3329e7f0e0f5a1605d64ba7d4c177a6e45601f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" @@ -25,19 +26,17 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace llvm_ir { -ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, +ForLoop::ForLoop(absl::string_view prefix, absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization) - : prefix_(std::string(prefix)), - suffix_(std::string(suffix)), + : prefix_(prefix), + suffix_(suffix), start_index_(start_index), end_index_(end_index), step_(step), @@ -46,9 +45,9 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, prevent_vectorization_(prevent_vectorization) {} /* static */ std::unique_ptr ForLoop::EmitForLoop( - tensorflow::StringPiece prefix, llvm::Value* start_index, - llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b, - UnrollMode unroll_mode, bool prevent_vectorization) { + absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index, + llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode, + bool prevent_vectorization) { std::unique_ptr loop(new ForLoop(prefix, /*suffix=*/"", start_index, end_index, step, unroll_mode, prevent_vectorization)); @@ -168,16 +167,16 @@ std::vector ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) { return result; } -string ForLoop::GetQualifiedName(tensorflow::StringPiece name) { +string ForLoop::GetQualifiedName(absl::string_view name) { return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_)); } -llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, +llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name, llvm::IRBuilder<>* b) { return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b); } -std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, +std::unique_ptr ForLoopNest::AddLoop(absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode, @@ -186,12 +185,9 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, unroll_mode, prevent_vectorization); } -std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, - llvm::Value* stride, - UnrollMode unroll_mode, - bool prevent_vectorization) { +std::unique_ptr ForLoopNest::AddLoop( + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, + llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); @@ -216,7 +212,7 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, - tensorflow::StringPiece suffix, + absl::string_view suffix, UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); @@ -227,7 +223,7 @@ std::unique_ptr ForLoopNest::AddLoop(int64 start_index, std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, int64 stride, - tensorflow::StringPiece suffix, + absl::string_view suffix, UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); @@ -238,7 +234,7 @@ std::unique_ptr ForLoopNest::AddLoop(int64 start_index, } IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, - tensorflow::StringPiece suffix) { + absl::string_view suffix) { std::vector dimensions(ShapeUtil::Rank(shape)); std::iota(dimensions.begin(), dimensions.end(), 0); return AddLoopsForShapeOnDimensions(shape, dimensions, suffix); @@ -246,14 +242,14 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( const Shape& shape, tensorflow::gtl::ArraySlice dimensions, - tensorflow::StringPiece suffix) { + absl::string_view suffix) { llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size()); for (int64 dimension : dimensions) { std::unique_ptr loop = AddLoop( /*start_index=*/0, /*end_index=*/shape.dimensions(dimension), /*suffix=*/ - llvm_ir::IrName(suffix, tensorflow::strings::StrCat(dimension))); + llvm_ir::IrName(suffix, absl::StrCat(dimension))); index[dimension] = loop->GetIndVarValue(); } return index; @@ -261,7 +257,7 @@ IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( IrArray::Index ForLoopNest::EmitOperandArrayLoopNest( const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, - tensorflow::StringPiece name_suffix) { + absl::string_view name_suffix) { // Prepares the dimension list we will use to emit the loop nest. Outermost // loops are added first. Add loops in major-to-minor order, and skip the // 'dimension_to_skip' dimension. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index a4fed5c8dc55d38d25031252e3960404a5bf84e6..0a406bd90b98979d270e21d03fd70251ae4caac1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -19,15 +19,15 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -78,7 +78,7 @@ class ForLoop { // `unroll_mode` specifies the desired LLVM unrolling behavior for generated // loop. static std::unique_ptr EmitForLoop( - tensorflow::StringPiece prefix, llvm::Value* start_index, + absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -133,19 +133,18 @@ class ForLoop { // Allow ForLoopNest to call this private constructor. friend class ForLoopNest; - ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, + ForLoop(absl::string_view prefix, absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization); // Emit the loop at the insert point of the builder. void Emit(llvm::IRBuilder<>* b); - llvm::BasicBlock* CreateLoopBB(tensorflow::StringPiece name, - llvm::IRBuilder<>* b); + llvm::BasicBlock* CreateLoopBB(absl::string_view name, llvm::IRBuilder<>* b); // Creates a name for an LLVM construct, appending prefix_ and suffix_, if // they are set. - string GetQualifiedName(tensorflow::StringPiece name); + string GetQualifiedName(absl::string_view name); // Return a list of metadata nodes that should be associated with the // llvm::Loop for this `ForLoop`. @@ -182,9 +181,9 @@ class ForLoopNest { SetIndexType(index_ty); } - ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* b, + ForLoopNest(absl::string_view name, llvm::IRBuilder<>* b, llvm::Type* index_ty = nullptr) - : name_(std::string(name)), + : name_(name), outer_loop_preheader_bb_(nullptr), outer_loop_exit_bb_(nullptr), inner_loop_body_bb_(nullptr), @@ -197,14 +196,14 @@ class ForLoopNest { // been added then emit loop inside the body of the last added loop. // unroll_mode is used to emit metadata that controls LLVM unrolling. std::unique_ptr AddLoop( - tensorflow::StringPiece suffix, llvm::Value* start_index, + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop( - tensorflow::StringPiece suffix, llvm::Value* start_index, + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -213,13 +212,13 @@ class ForLoopNest { // end index are constant. std::unique_ptr AddLoop( int64 start_index, int64 end_index, int64 stride, - tensorflow::StringPiece suffix, + absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop( - int64 start_index, int64 end_index, tensorflow::StringPiece suffix, + int64 start_index, int64 end_index, absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -234,8 +233,7 @@ class ForLoopNest { // within the shape. One possible order for that sequence would be: // // (0,0), (0,1), (0,2), (1,0), (1,1), (1,2) - IrArray::Index AddLoopsForShape(const Shape& shape, - tensorflow::StringPiece suffix); + IrArray::Index AddLoopsForShape(const Shape& shape, absl::string_view suffix); // Add a loop for each dimension in "dimensions". "suffix" is the // name suffix of the indvar and basic blocks in this new loop nest. @@ -245,7 +243,7 @@ class ForLoopNest { // dimension that is not in "dimensions". IrArray::Index AddLoopsForShapeOnDimensions( const Shape& shape, tensorflow::gtl::ArraySlice dimensions, - tensorflow::StringPiece suffix); + absl::string_view suffix); // Emits a series of nested loops for iterating over an operand array. Loops // are constructed in major to minor dimension layout order. No loop is @@ -256,7 +254,7 @@ class ForLoopNest { // basic blocks) constructed by this method. IrArray::Index EmitOperandArrayLoopNest(const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, - tensorflow::StringPiece name_suffix); + absl::string_view name_suffix); // Convenience methods which return particular basic blocks of the outermost // or innermost loops. These methods return nullptr if no loops have been diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index e6126881af8b8123e08a4eaa934b52a7fd378ce6..f0db2a3761afd3e887979d307fb3b9a557eea491 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/MDBuilder.h" @@ -34,8 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -61,7 +61,7 @@ string AsString(const std::string& str) { return string(str.data(), str.length()); } -llvm::StringRef AsStringRef(tensorflow::StringPiece str) { +llvm::StringRef AsStringRef(absl::string_view str) { return llvm::StringRef(str.data(), str.size()); } @@ -262,15 +262,17 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b, int alignment) { return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment); } -llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( - llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, - llvm::IRBuilder<>* b, int alignment) { +llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, + llvm::Value* element_count, + absl::string_view name, + llvm::IRBuilder<>* b, + int alignment) { llvm::IRBuilder<>::InsertPoint insert_point = b->saveIP(); llvm::Function* function = b->GetInsertBlock()->getParent(); b->SetInsertPoint(&function->getEntryBlock(), @@ -285,7 +287,7 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( } llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b) { return llvm::BasicBlock::Create( /*Context=*/b->getContext(), @@ -294,27 +296,25 @@ llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, /*InsertBefore*/ insert_before); } -LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, +LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name, llvm::IRBuilder<>* b, bool emit_else) { llvm_ir::LlvmIfData if_data; if_data.if_block = b->GetInsertBlock(); if_data.true_block = - CreateBasicBlock(nullptr, tensorflow::strings::StrCat(name, "-true"), b); + CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b); if_data.false_block = - emit_else ? CreateBasicBlock( - nullptr, tensorflow::strings::StrCat(name, "-false"), b) + emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b) : nullptr; // Add a terminator to the if block, if necessary. if (if_data.if_block->getTerminator() == nullptr) { b->SetInsertPoint(if_data.if_block); - if_data.after_block = CreateBasicBlock( - nullptr, tensorflow::strings::StrCat(name, "-after"), b); + if_data.after_block = + CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b); b->CreateBr(if_data.after_block); } else { if_data.after_block = if_data.if_block->splitBasicBlock( - b->GetInsertPoint(), - AsStringRef(tensorflow::strings::StrCat(name, "-after"))); + b->GetInsertPoint(), AsStringRef(absl::StrCat(name, "-after"))); } // Our basic block should now end with an unconditional branch. Remove it; @@ -413,14 +413,14 @@ string IrName(string a) { return a; } -string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b) { +string IrName(absl::string_view a, absl::string_view b) { if (!a.empty() && !b.empty()) { - return IrName(tensorflow::strings::StrCat(a, ".", b)); + return IrName(absl::StrCat(a, ".", b)); } - return IrName(tensorflow::strings::StrCat(a, b)); + return IrName(absl::StrCat(a, b)); } -string IrName(const HloInstruction* a, tensorflow::StringPiece b) { +string IrName(const HloInstruction* a, absl::string_view b) { return IrName(a->name(), b); } @@ -556,7 +556,7 @@ std::map MergeMetadata( return result; } -static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) { +static string GetProcessUniqueIrFileName(absl::string_view prefix) { static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); static NameUniquer* uniquer = new NameUniquer(/*separator=*/"-"); @@ -584,18 +584,16 @@ Status DumpIRToDirectory(const string& directory_name, // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously // dumped from the same process in such cases. string unique_and_safe_file_name = GetProcessUniqueIrFileName( - tensorflow::strings::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", - optimized ? "with" : "no", "-opt")); + absl::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", + optimized ? "with" : "no", "-opt")); string ir_file_name = tensorflow::io::JoinPath( - directory_name, - tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll")); + directory_name, absl::StrCat(unique_and_safe_file_name, ".ll")); // For some models the embedded constants can be huge, so also dump the module // with the constants stripped to get IR that is easier to manipulate. string ir_no_constant_initializers_file_name = tensorflow::io::JoinPath( - directory_name, - tensorflow::strings::StrCat(unique_and_safe_file_name, "-noconst.ll")); + directory_name, absl::StrCat(unique_and_safe_file_name, "-noconst.ll")); TF_RETURN_IF_ERROR(CreateAndWriteStringToFile( directory_name, ir_file_name, DumpModuleToString(llvm_module))); @@ -607,8 +605,7 @@ Status DumpIRToDirectory(const string& directory_name, llvm::Function* CreateFunction(llvm::FunctionType* function_type, llvm::GlobalValue::LinkageTypes linkage, bool enable_fast_math, bool optimize_for_size, - tensorflow::StringPiece name, - llvm::Module* module) { + absl::string_view name, llvm::Module* module) { llvm::Function* function = llvm::Function::Create(function_type, linkage, AsStringRef(name), module); function->setCallingConv(llvm::CallingConv::C); @@ -638,7 +635,7 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { fake_argv_storage.push_back(""); for (const auto& it : options) { // Skip options the XLA backend itself consumes. - if (!tensorflow::str_util::StartsWith(it.first, "xla_")) { + if (!absl::StartsWith(it.first, "xla_")) { if (it.second.empty()) { fake_argv_storage.push_back(it.first); } else { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 09583985342033d486d50910b6f5ca732a9a3756..dde50e19d1c77491fb843710ea765ecb2e8af932 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" @@ -47,11 +47,11 @@ namespace llvm_ir { // Convert a std::string (used by LLVM's interfaces) to string. string AsString(const std::string& str); -// Convert a tensorflow::StringPiece to a llvm::StringRef. Note: both -// tensorflow::StringPiece and llvm::StringRef are non-owning pointers into a +// Convert a absl::string_view to a llvm::StringRef. Note: both +// absl::string_view and llvm::StringRef are non-owning pointers into a // string in memory. This method is used to feed strings to LLVM // & Clang APIs that expect llvm::StringRef. -llvm::StringRef AsStringRef(tensorflow::StringPiece str); +llvm::StringRef AsStringRef(absl::string_view str); template llvm::ArrayRef AsArrayRef(const std::vector& vec) { @@ -88,8 +88,8 @@ string DumpModuleToString(const llvm::Module& module); // - removing all '%'s. // string IrName(string a); -string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b); -string IrName(const HloInstruction* a, tensorflow::StringPiece b = ""); +string IrName(absl::string_view a, absl::string_view b); +string IrName(const HloInstruction* a, absl::string_view b = ""); // Removes special characters from a function name. // @@ -164,21 +164,23 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, // This can be useful to avoid e.g. executing an alloca every time // through a loop. llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b, int alignment = 0); // As EmitAllocaAtFunctionEntry, but allocates element_count entries // instead of a single element. -llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( - llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, - llvm::IRBuilder<>* b, int alignment = 0); +llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, + llvm::Value* element_count, + absl::string_view name, + llvm::IRBuilder<>* b, + int alignment = 0); // Creates a basic block with the same context and function as for the // builder. Inserts at the end of the function if insert_before is // null. llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b); // Struct with data on a conditional branch in a diamond shape created @@ -210,7 +212,7 @@ struct LlvmIfData { // Currently the insertion point of the builder must be a well-formed // block with a terminator. If you need to use this for a // non-terminated block, just make the function able to do that too. -LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, +LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name, llvm::IRBuilder<>* b, bool emit_else = true); // Emits a compare operation between "lhs" and "rhs" with the given predicate, @@ -285,8 +287,7 @@ Status DumpIRToDirectory(const string& directory_name, llvm::Function* CreateFunction(llvm::FunctionType* function_type, llvm::GlobalValue::LinkageTypes linkage, bool enable_fast_math, bool optimize_for_size, - tensorflow::StringPiece name, - llvm::Module* module); + absl::string_view name, llvm::Module* module); // Extracts the xla_backend_extra_options from `config` and passes those that // don't start with xla_ to LLVM. diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 36f5fa195224c20e30a14f72b32eb42a681bb5e9..1553b4fc91eeeb69a94780b20e94e8a871cfab52 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -86,7 +86,7 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, } std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) { + absl::string_view loop_name, llvm::Type* index_type) { CHECK_NE(index_type, nullptr); if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. @@ -105,7 +105,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + /*suffix=*/absl::StrFormat("dim.%d", dimension)); array_index[dimension] = loop->GetIndVarValue(); } @@ -122,7 +122,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( return {array_index}; } -Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name, +Status LoopEmitter::EmitLoop(absl::string_view loop_name, llvm::Type* index_type) { if (index_type == nullptr) { index_type = b_->getInt64Ty(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index c4f5c82086ccfa233e0be118b1de10cce55a51b1..57d9d8bbc61014d423822ab5c1e4d251349df89c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -69,10 +69,10 @@ class LoopEmitter { } virtual std::vector EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type); + absl::string_view loop_name, llvm::Type* index_type); // Emits a complete loop nest for every element in the given shape. - Status EmitLoop(tensorflow::StringPiece loop_name = "", + Status EmitLoop(absl::string_view loop_name = "", llvm::Type* index_type = nullptr); protected: diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index c333311a7e406a44335bf2b9c540b7dc2fe7c284..00dd3f16389156afcf3824af0ce57763a82c0ad4 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -88,7 +88,7 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, const absl::optional& values_array, - tensorflow::StringPiece name, llvm::Value* xor_mask, + absl::string_view name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions) { const Shape& keys_shape = keys_array.GetShape(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 39fffea93115ae4d76b86af2a9fc95db96547a64..527ed10374ce9482045a8459e38fd041e0e83001 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -16,12 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -32,7 +32,7 @@ namespace llvm_ir { // the inner compare loop will not be parallelized. Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, const absl::optional& values_array, - tensorflow::StringPiece name, llvm::Value* xor_mask, + absl::string_view name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions); } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index b7cb782a7e1eac57ccba523e860866f9b94891c2..768105d9e11dbf4420494c4cb8796e4677e9dc4c 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" @@ -37,7 +39,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -149,7 +150,7 @@ StatusOr> LocalService::CompileExecutable( // Validate incoming layouts. if (argument_layouts.size() != program_shape.parameters_size()) { return InvalidArgument( - "Invalid number of arguments for computation: expected %d, got %zu.", + "Invalid number of arguments for computation: expected %d, got %u.", program_shape.parameters_size(), argument_layouts.size()); } @@ -167,16 +168,15 @@ StatusOr> LocalService::CompileExecutable( CHECK(metadata.value() != nullptr); const OpMetadata& m = *metadata.value(); if (!m.source_file().empty()) { - return tensorflow::strings::Printf( - " (%s:%d)", m.source_file().c_str(), m.source_line()); + return absl::StrFormat(" (%s:%d)", m.source_file(), m.source_line()); } return ""; }; return InvalidArgument( "Invalid argument shape for argument %d%s, expected %s, got %s.", i, - metadata_string().c_str(), - ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(argument_shape).c_str()); + metadata_string(), + ShapeUtil::HumanString(program_shape.parameters(i)), + ShapeUtil::HumanString(argument_shape)); } } if (build_options.result_layout() != nullptr) { @@ -214,7 +214,7 @@ StatusOr LocalService::GlobalDataToShapedBuffer( TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data)); if (replica_number >= buffers.size()) { return InvalidArgument( - "replica_number %d out of range; must be less than num_replicas = %zu.", + "replica_number %d out of range; must be less than num_replicas = %u.", replica_number, buffers.size()); } return buffers[replica_number]; diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc index c742d35a7bcafa66692195a513992c9cfbb39335..e1f56727bd209797c60f7b3f10c3e232992d01e0 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.cc +++ b/tensorflow/compiler/xla/service/logical_buffer.cc @@ -15,11 +15,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -34,11 +34,10 @@ LogicalBuffer::~LogicalBuffer() {} string LogicalBuffer::ToString() const { string color_string; if (has_color()) { - color_string = tensorflow::strings::StrCat(" @", color().value()); + color_string = absl::StrCat(" @", color().value()); } - return tensorflow::strings::StrCat(instruction_->name(), "[", - tensorflow::str_util::Join(index_, ","), - "](#", id(), color_string, ")"); + return absl::StrCat(instruction_->name(), "[", absl::StrJoin(index_, ","), + "](#", id(), color_string, ")"); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 6aa639a954d3a359ff3b3de69b454fc6c0ec1792..4c8cb7d379d4f82224ef5896fbd937d4aa482606 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -48,9 +48,7 @@ class MultiOutputFusion : public HloPassInterface { public: MultiOutputFusion(int64 fuel) : fuel_(fuel) {} - tensorflow::StringPiece name() const override { - return "multi_output_fusion"; - } + absl::string_view name() const override { return "multi_output_fusion"; } // Run multi-output fusion on the given module. Returns whether the module // was changed. diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index f6e7578a89551ec2f23d4d8c8b488c3c10e0bf1c..bd8fb17a235ea6eeb0e1809e8cb9ad83145fd8d6 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -52,8 +53,8 @@ NameUniquer::NameUniquer(const string& separator) { return result; } -string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { - string root = GetSanitizedName(prefix.empty() ? "name" : std::string(prefix)); +string NameUniquer::GetUniqueName(absl::string_view prefix) { + string root = GetSanitizedName(prefix.empty() ? "name" : string(prefix)); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. @@ -63,20 +64,22 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { if (separator_index != string::npos && (separator_index > 0) && (separator_index < root.size() - 1)) { string after_suffix = root.substr(separator_index + 1); - if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) { has_numeric_suffix = true; // Remove numeric suffix from root. root = root.substr(0, separator_index); + } else { + // absl::SimpleAtoi may modify numeric_suffix even if it returns false. + numeric_suffix = 0; } } SequentialIdGenerator& id_generator = generated_names_[root]; numeric_suffix = id_generator.RegisterId(numeric_suffix); if (numeric_suffix == 0) { - return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0) - : root; + return has_numeric_suffix ? absl::StrCat(root, separator_, 0) : root; } - tensorflow::strings::StrAppend(&root, separator_, numeric_suffix); + absl::StrAppend(&root, separator_, numeric_suffix); return root; } diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index 4423d6106920eaeab830bd9dc08529ff409a5161..6dd89c240f81c9f0ccac66e50c7f244bfd5429f1 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" @@ -38,7 +38,7 @@ class NameUniquer { // Get a sanitized unique name in a string, with an optional prefix for // convenience. - string GetUniqueName(tensorflow::StringPiece prefix = ""); + string GetUniqueName(absl::string_view prefix = ""); // Sanitizes and returns the name. Unallowed characters will be replaced with // '_'. The result will match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index ac6ea4c72f61a47726b3ae7dd000837d3fba1b93..ccc06ce613cb133d0be982bbb58bbc64d42a27c1 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -622,7 +622,7 @@ template class HloInstructionPatternNameImpl { public: explicit HloInstructionPatternNameImpl(const Previous& previous, - tensorflow::StringPiece name) + absl::string_view name) : previous_(previous), name_(name) {} bool Match(const ::xla::HloInstruction* inst) const { @@ -631,7 +631,7 @@ class HloInstructionPatternNameImpl { private: Previous previous_; - tensorflow::StringPiece name_; + absl::string_view name_; }; // An HloInstructionPattern implementation that matches only if the instruction @@ -784,7 +784,7 @@ class HloInstructionPattern { // Modifies the pattern to match only if the instruction has the given name. HloInstructionPattern> - WithName(tensorflow::StringPiece name) const { + WithName(absl::string_view name) const { return HloInstructionPattern>( HloInstructionPatternNameImpl(impl_, name), matched_inst_); diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 39fe3c7835d1c74c0f1e5bc0ebf5916ec734c24a..ae1e13d8a6c0ac6c1bce903e72a3f492fe126571 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -19,20 +19,19 @@ limitations under the License. #include #include +#include "absl/strings/ascii.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { -using tensorflow::str_util::Lowercase; - // Minimum supported CUDA compute capability is 3.5. constexpr int kMinCudaComputeCapabilityMajor = 3; constexpr int kMinCudaComputeCapabilityMinor = 5; @@ -43,7 +42,7 @@ constexpr char kInterpreter[] = "interpreter"; namespace { string CanonicalPlatformName(const string& name) { - string platform_str = Lowercase(name); + string platform_str = absl::AsciiStrToLower(name); // "cpu" and "host" mean the same thing. if (platform_str == "cpu") { platform_str = "host"; @@ -94,12 +93,12 @@ PlatformUtil::GetSupportedPlatforms() { } // Multiple platforms present and we can't pick a reasonable default. - string platforms_string = tensorflow::str_util::Join( + string platforms_string = absl::StrJoin( platforms, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "must specify platform because more than one platform found: %s", - platforms_string.c_str()); + platforms_string); } /* static */ StatusOr PlatformUtil::GetDefaultPlatform() { @@ -110,21 +109,21 @@ PlatformUtil::GetSupportedPlatforms() { return platforms[0]; } else if (platforms.size() == 2) { for (int i = 0; i < 2; i++) { - if (Lowercase(platforms[i]->Name()) == kInterpreter && - Lowercase(platforms[1 - i]->Name()) != kInterpreter) { + if (absl::AsciiStrToLower(platforms[i]->Name()) == kInterpreter && + absl::AsciiStrToLower(platforms[1 - i]->Name()) != kInterpreter) { return platforms[1 - i]; } } } // Multiple platforms present and we can't pick a reasonable default. - string platforms_string = tensorflow::str_util::Join( + string platforms_string = absl::StrJoin( platforms, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "must specify platform because more than one platform (except for the " "interpreter platform) found: %s", - platforms_string.c_str()); + platforms_string); } /*static*/ StatusOr PlatformUtil::GetPlatform( @@ -132,11 +131,11 @@ PlatformUtil::GetSupportedPlatforms() { string platform_str = CanonicalPlatformName(platform_name); TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); for (se::Platform* platform : platforms) { - if (Lowercase(platform->Name()) == platform_str) { + if (absl::AsciiStrToLower(platform->Name()) == platform_str) { return platform; } } - return InvalidArgument("platform %s not found", platform_name.c_str()); + return InvalidArgument("platform %s not found", platform_name); } /*static*/ StatusOr PlatformUtil::GetPlatformExceptFor( @@ -146,23 +145,23 @@ PlatformUtil::GetSupportedPlatforms() { TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); std::vector matched; for (se::Platform* platform : platforms) { - if (Lowercase(platform->Name()) != platform_name) { + if (absl::AsciiStrToLower(platform->Name()) != platform_name) { matched.push_back(platform); } } if (matched.empty()) { return InvalidArgument("unable to find platform that is not %s", - platform_name.c_str()); + platform_name); } if (matched.size() == 1) { return matched[0]; } - string matched_string = tensorflow::str_util::Join( + string matched_string = absl::StrJoin( matched, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "found multiple platforms %s, but expected one platform except for %s", - matched_string.c_str(), platform_name.c_str()); + matched_string, platform_name); } // Returns whether the device underlying the given StreamExecutor is supported @@ -193,7 +192,7 @@ static bool IsDeviceSupported(se::StreamExecutor* executor) { PlatformUtil::GetStreamExecutors(se::Platform* platform) { int device_count = platform->VisibleDeviceCount(); if (device_count <= 0) { - return NotFound("no %s devices found", platform->Name().c_str()); + return NotFound("no %s devices found", platform->Name()); } if (platform->id() == se::host::kHostPlatformId) { // On host "devices", StreamExecutor exports a device for each hardware @@ -232,7 +231,7 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { if (std::all_of(stream_executors.begin(), stream_executors.end(), [](se::StreamExecutor* s) { return s == nullptr; })) { return InternalError("no supported devices found for platform %s", - platform->Name().c_str()); + platform->Name()); } return stream_executors; } diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index afde3cf95c721b59a39b74b4e1ff3f15a335fe97..256b231e3af43a2ee85c97a5efab1f022d4de4b1 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -59,7 +59,7 @@ class ReducePrecisionInsertion : public HloPassInterface { ~ReducePrecisionInsertion() override{}; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "reduce-precision-insertion"; } diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h index 1f59e3b3147facb6f2fae00d6c810bf54d560e5c..1e86a0823a56a9e52421a5c8bd49e0adb98a2c70 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.h +++ b/tensorflow/compiler/xla/service/reshape_mover.h @@ -26,7 +26,7 @@ namespace xla { // them inputward also. class ReshapeMover : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "reshape-mover"; } + absl::string_view name() const override { return "reshape-mover"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 7534a3f7e32aa84e5b47695b3eef23a8e749ee63..a395dd5333f9b6b5f71a561b52cd9312a3faef2d 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -28,13 +28,18 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" - -namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using ReshapeMoverTest = HloVerifiedTestBase; + +namespace op = xla::testing::opcode_matchers; + +class ReshapeMoverTest : public HloVerifiedTestBase { + public: + ReshapeMoverTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { HloComputation::Builder builder(TestName()); diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index 338f0c09e9e7f59127023144ff30ac62aff55ee1..2077b57c05e225e17e89a6305eb829615f0f745f 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -291,7 +291,7 @@ StatusOr ScatterExpander::ExpandScatter( return Unimplemented( "Scatter operations with more than 2147483647 scatter indices are not " "supported. This error occurred for %s.", - scatter->ToString().c_str()); + scatter->ToString()); } // Canonicalize the scatter_indices, after which the size of its most-major diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h index 8f735e877d270c10b494e1cd974904c4e2d960c9..14f062c89cfd4657097c1a933621a3e945f89c53 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.h +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -22,7 +22,7 @@ namespace xla { class ScatterExpander : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "scatter_expander"; } + absl::string_view name() const override { return "scatter_expander"; } StatusOr Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 18d1b7732bb2f54eb4b1bf74e1eed1d96221913c..e10c1d9927edcc841d42f462a5b585e3d0fd1941 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -21,6 +21,8 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" @@ -46,8 +48,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -55,13 +55,12 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/ptr_util.h" -using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrCat; - namespace xla { - namespace { +using absl::StrCat; +using absl::StrFormat; + // Records the arguments used to invoke a computation in an HloSnapshot proto. Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, @@ -148,19 +147,19 @@ Service::Service(const ServiceOptions& options, CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas()) << "Requested more replicas than there are devices."; } - LOG(INFO) << Printf( + LOG(INFO) << StrFormat( "XLA service %p executing computations on platform %s. Devices:", this, - execute_backend_->platform()->Name().c_str()); + execute_backend_->platform()->Name()); for (int i = 0; i < execute_backend_->device_count(); ++i) { if (execute_backend_->device_ordinal_supported(i)) { se::StreamExecutor* executor = execute_backend_->stream_executor(i).ValueOrDie(); const auto& description = executor->GetDeviceDescription(); - LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i, - description.name().c_str(), - description.platform_version().c_str()); + LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i, + description.name(), + description.platform_version()); } else { - LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i); + LOG(INFO) << StrFormat(" StreamExecutor device (%d) not supported", i); } } } else { @@ -200,8 +199,8 @@ Status Service::ValidateResultShape(const Shape& client_shape, return InvalidArgument( "Shape used to set computation result layout %s is not compatible " "with result shape %s", - ShapeUtil::HumanStringWithLayout(client_shape).c_str(), - ShapeUtil::HumanString(result_shape).c_str()); + ShapeUtil::HumanStringWithLayout(client_shape), + ShapeUtil::HumanString(result_shape)); } return Status::OK(); } @@ -231,9 +230,9 @@ Service::ResolveAndValidateArguments( return InvalidArgument( "argument %lu is on device %s:%d but computation will be executed " "on device %s", - i, shaped_buffer->platform()->Name().c_str(), + i, shaped_buffer->platform()->Name(), shaped_buffer->device_ordinal(), - execute_backend_->device_name(replica_device_ordinal).c_str()); + execute_backend_->device_name(replica_device_ordinal)); } replicated_arguments[replica].push_back(shaped_buffer); } @@ -249,7 +248,7 @@ StatusOr> Service::CreateModuleConfig( ComputationLayout* computation_layout = config->mutable_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { - return InvalidArgument("computation takes %d parameters, but %zu given", + return InvalidArgument("computation takes %d parameters, but %u given", program_shape.parameters_size(), argument_shapes.size()); } @@ -261,8 +260,8 @@ StatusOr> Service::CreateModuleConfig( return InvalidArgument( "Argument does not match shape of computation parameter %d: want " "%s, got %s", - i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(*argument_shapes[i]).c_str()); + i, ShapeUtil::HumanString(program_shape.parameters(i)), + ShapeUtil::HumanString(*argument_shapes[i])); } TF_RETURN_IF_ERROR( computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( @@ -314,7 +313,7 @@ StatusOr>> Service::BuildExecutables( std::vector> module_configs, Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p", this); + VLOG(1) << StrFormat("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. std::vector> hlo_snapshots; @@ -329,9 +328,8 @@ StatusOr>> Service::BuildExecutables( auto hlo_snapshot = absl::make_unique(); *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i]; if (!directory_path.empty()) { - string filename = - Printf("computation_%lld__%s", module_protos[i]->id(), - module_protos[i]->entry_computation_name().c_str()); + string filename = StrFormat("computation_%d__%s", module_protos[i]->id(), + module_protos[i]->entry_computation_name()); TF_RETURN_IF_ERROR( Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); } @@ -454,8 +452,8 @@ Service::ExecuteParallelAndRegisterResult( for (int64 i = 0; i < streams.size(); ++i) { Status block_status = streams[i]->BlockHostUntilDone(); if (!block_status.ok()) { - return InternalError("failed to complete execution for stream %lld: %s", - i, block_status.error_message().c_str()); + return InternalError("failed to complete execution for stream %d: %s", i, + block_status.error_message()); } } @@ -580,7 +578,7 @@ StatusOr> Service::GetExecutors( if (requests_size > 1 && execution_options.device_handles_size() > 1) { return InvalidArgument( "Parallel requests with multiple device handles is not supported. " - "Found %lld parallel requests, with request %lld containing %d device " + "Found %d parallel requests, with request %d containing %d device " "handles.", requests_size, request_index, execution_options.device_handles_size()); } @@ -745,8 +743,8 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, } if (available_device_count < arg->device_count() * replica_count) { return ResourceExhausted( - "Requested device count (%lld) exceeds the number of available devices " - "on the target (%lld)", + "Requested device count (%d) exceeds the number of available devices " + "on the target (%d)", arg->device_count(), available_device_count); } @@ -796,9 +794,9 @@ StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf( + VLOG(1) << StrFormat( "BuildExecutable on service %p with serialized module proto: %s", this, - module_proto.name().c_str()); + module_proto.name()); // Dump computation proto state if flag is set. auto hlo_snapshot = absl::make_unique(); @@ -809,8 +807,8 @@ StatusOr> Service::BuildExecutable( if (!directory_path.empty() || !execution_directory_path.empty()) { *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto; if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s", module_proto.id(), - module_proto.entry_computation_name().c_str()); + string filename = StrFormat("computation_%d__%s", module_proto.id(), + module_proto.entry_computation_name()); TF_RETURN_IF_ERROR( Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); } @@ -1010,8 +1008,7 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, "%s", StrCat("The replica_id=", arg->replica_id(), " on TransferToInfeedRequest not in range [0, replica_count=", - replica_count, ").") - .c_str()); + replica_count, ").")); } se::StreamExecutor* executor; @@ -1037,8 +1034,7 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( - "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, " - "%lld)", + "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)", arg->replica_id(), replica_count); } diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index ec6aa6df55460fb9bb5d468dbc4fa69be34524b2..a04af8b0aac3fa343cb2be184d925053d3bf8b78 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -22,6 +22,10 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -29,32 +33,26 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" -using tensorflow::str_util::Join; -using tensorflow::strings::Printf; - namespace xla { - namespace { +using absl::StrFormat; +using absl::StrJoin; + // Returns true if no element is present in slice more than once. bool AllUnique(tensorflow::gtl::ArraySlice slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } -Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) { +Status ExpectArray(const Shape& shape, absl::string_view op_type) { if (!ShapeUtil::IsArray(shape)) { return InvalidArgument("Expected array argument for %s, but got %s.", - std::string(op_type).c_str(), - ShapeUtil::HumanString(shape).c_str()); + string(op_type), ShapeUtil::HumanString(shape)); } return Status::OK(); } @@ -66,7 +64,7 @@ Status VerifyReducerShape( int64 inputs) { if (reducer_shape.parameters_size() != inputs * 2) { return InvalidArgument( - "Reduction function must take %lld parameters, but " + "Reduction function must take %d parameters, but " "takes %d parameter(s).", inputs * 2, reducer_shape.parameters_size()); } @@ -76,7 +74,7 @@ Status VerifyReducerShape( if (ShapeUtil::IsArray(accumulator_shape)) { if (inputs != 1) { return InvalidArgument( - "Reduction function must produce a tuple with %lld elements, but " + "Reduction function must produce a tuple with %d elements, but " "produces a scalar", inputs); } @@ -84,8 +82,8 @@ Status VerifyReducerShape( } else if (ShapeUtil::IsTuple(accumulator_shape)) { if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) { return InvalidArgument( - "Reduction function must produce a tuple with %lld elements, but has " - "%lld elements", + "Reduction function must produce a tuple with %d elements, but has " + "%d elements", inputs, ShapeUtil::TupleElementCount(accumulator_shape)); } for (const Shape& element_shape : accumulator_shape.tuple_shapes()) { @@ -95,7 +93,7 @@ Status VerifyReducerShape( return InvalidArgument( "Reduction function must produce a scalar or tuple of scalars, but has " "shape: %s", - ShapeUtil::HumanString(accumulator_shape).c_str()); + ShapeUtil::HumanString(accumulator_shape)); } for (const Shape* element_shape : accumulator_subshapes) { @@ -103,7 +101,7 @@ Status VerifyReducerShape( return InvalidArgument( "Reduction function must return a scalar or tuple of scalars but " "returns shape: %s", - ShapeUtil::HumanString(accumulator_shape).c_str()); + ShapeUtil::HumanString(accumulator_shape)); } } @@ -114,19 +112,19 @@ Status VerifyReducerShape( if (!ShapeUtil::Compatible(*accumulator_subshapes[i], reducer_shape.parameters(i))) { return InvalidArgument( - "Reduction function's %lld-th parameter shape differs from the " + "Reduction function's %d-th parameter shape differs from the " "result shape: %s vs %s", - i, ShapeUtil::HumanString(reducer_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + i, ShapeUtil::HumanString(reducer_shape.parameters(i)), + ShapeUtil::HumanString(*accumulator_subshapes[i])); } // Check that init_value's shapes are suitable for reducer_shape. if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i], *init_value_shapes[i])) { return InvalidArgument( - "Reduction function's accumulator shape at index %lld differs from " + "Reduction function's accumulator shape at index %d differs from " "the init_value shape: %s vs %s", - i, ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str(), - ShapeUtil::HumanString(*init_value_shapes[i]).c_str()); + i, ShapeUtil::HumanString(*accumulator_subshapes[i]), + ShapeUtil::HumanString(*init_value_shapes[i])); } // Check that the inputs can be passed in as the non-accumulator arguments. const Shape input_element_shape = @@ -134,11 +132,11 @@ Status VerifyReducerShape( if (!ShapeUtil::CompatibleIgnoringFpPrecision( input_element_shape, reducer_shape.parameters(inputs + i))) { return InvalidArgument( - "Reduction function's %lld-th parameter shape differs from the " + "Reduction function's %d-th parameter shape differs from the " "input type element type: %s vs %s", inputs + i, - ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), - ShapeUtil::HumanString(input_element_shape).c_str()); + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)), + ShapeUtil::HumanString(input_element_shape)); } // Check that the accumulator and inputs to the reducer function match. // If the accumulator is scalar, it must have the same type as the inputs @@ -148,11 +146,11 @@ Status VerifyReducerShape( if (!ShapeUtil::CompatibleIgnoringFpPrecision( *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) { return InvalidArgument( - "Reduction function's %lld-th parameter shape must " + "Reduction function's %d-th parameter shape must " "match the result shape, but got %s vs %s.", inputs + i, - ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), - ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)), + ShapeUtil::HumanString(*accumulator_subshapes[i])); } } @@ -165,7 +163,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, bool allow_negative_padding) { if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) { return InvalidArgument( - "Window has dimension %d but base shape has dimension %lld.", + "Window has dimension %d but base shape has dimension %d.", window.dimensions_size(), ShapeUtil::Rank(base_shape)); } @@ -174,29 +172,29 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, const auto& dim = window.dimensions(i); if (dim.size() <= 0) { return InvalidArgument("Window %s has a non-positive dimension.", - window.DebugString().c_str()); + window.DebugString()); } if (dim.stride() <= 0) { return InvalidArgument("Window %s has a non-positive stride.", - window.DebugString().c_str()); + window.DebugString()); } if (!allow_negative_padding && dim.padding_low() < 0) { return InvalidArgument("Window %s has a negative low padding.", - window.DebugString().c_str()); + window.DebugString()); } if (!allow_negative_padding && dim.padding_high() < 0) { return InvalidArgument("Window %s has a negative high padding.", - window.DebugString().c_str()); + window.DebugString()); } if (dim.base_dilation() < 1) { return InvalidArgument( "Window %s has a non-positive base area dilation factor.", - window.DebugString().c_str()); + window.DebugString()); } if (dim.window_dilation() < 1) { return InvalidArgument( "Window %s has a non-positive window dilation factor.", - window.DebugString().c_str()); + window.DebugString()); } const int64 dilated_base = window_util::DilatedBound( @@ -234,11 +232,12 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, switch (opcode) { case HloOpcode::kFloor: case HloOpcode::kCeil: + case HloOpcode::kRoundNearestAfz: if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( - "Expected element type in shape to be floating for floor/ceil " - "operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + "Expected element type in shape to be floating for %s operation; " + "got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } return shape; case HloOpcode::kCos: @@ -251,9 +250,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (!ShapeUtil::ElementIsFloating(shape) && !ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( - "Expected element type in shape to be floating or complex for " - "sin/cos/exp/log/tanh operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + "Expected element type in shape to be floating or complex for %s " + "operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } return shape; case HloOpcode::kReal: @@ -265,19 +264,47 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } else { return InvalidArgument( "Expected element type in shape to be floating or complex for " - "real/imag operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + "%s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } case HloOpcode::kAbs: if (ShapeUtil::ElementIsComplex(shape)) { return ShapeUtil::ChangeElementType( shape, primitive_util::ComplexComponentType(shape.element_type())); + } else if (ShapeUtil::ElementIsSigned(shape)) { + return shape; + } else { + return InvalidArgument( + "Expected element type in shape to be floating or complex for " + "%s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } - return shape; case HloOpcode::kClz: + if (!ShapeUtil::ElementIsIntegral(shape)) { + return InvalidArgument( + "Expected an integral element type in argument to Clz " + "operation; got %s.", + PrimitiveType_Name(shape.element_type())); + } + return shape; case HloOpcode::kNegate: - case HloOpcode::kRoundNearestAfz: + if (!ShapeUtil::ElementIsIntegral(shape) && + !ShapeUtil::ElementIsFloating(shape) && + !ShapeUtil::ElementIsComplex(shape)) { + return InvalidArgument( + "Expected element type in shape to be integral, floating or " + "complex for %s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); + } + return shape; case HloOpcode::kSign: + if (!ShapeUtil::ElementIsSigned(shape) && + !ShapeUtil::ElementIsComplex(shape)) { + return InvalidArgument( + "Expected element type in shape to be signed or complex for " + "%s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); + } return shape; case HloOpcode::kNot: @@ -286,7 +313,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected pred or an integral element type in argument to Not " "operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return shape; @@ -296,14 +323,14 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, "Expected element type in shape to be floating " "point for IsFinite " "operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return ShapeUtil::ChangeElementType(shape, PRED); default: return InvalidArgument( "Unknown operation for unary shape inference: \"%s\".", - HloOpcodeString(opcode).c_str()); + HloOpcodeString(opcode)); } } @@ -314,7 +341,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument("Concatenate expects at least one argument."); } if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) { - return InvalidArgument("Concatenate dimension out of bounds: %lld.", + return InvalidArgument("Concatenate dimension out of bounds: %d.", dimension); } const Shape* arg_shape = nullptr; @@ -328,17 +355,16 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { return InvalidArgument( - "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld " + "Cannot concatenate arrays with different ranks: %d (%s) vs %d " "(%s).", - ShapeUtil::Rank(*arg_shape), - ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), - ShapeUtil::HumanString(*shape).c_str()); + ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape), + ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( "Cannot concatenate arrays with different element types: %s vs %s.", - PrimitiveType_Name(arg_shape->element_type()).c_str(), - PrimitiveType_Name(shape->element_type()).c_str()); + PrimitiveType_Name(arg_shape->element_type()), + PrimitiveType_Name(shape->element_type())); } for (int64 dimension_number = 0; dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) { @@ -351,9 +377,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Cannot concatenate arrays that differ in dimensions other than " "the one being concatenated (the other array dimensions must be " - "the same): %s vs %s in dimension %lld.", - ShapeUtil::HumanString(*arg_shape).c_str(), - ShapeUtil::HumanString(*shape).c_str(), dimension); + "the same): %s vs %s in dimension %d.", + ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape), + dimension); } } element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape); @@ -385,8 +411,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, !primitive_util::IsComplexType(new_element_type)) { return Unimplemented( "Conversion from complex to real type %s => %s is not implemented.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } if (!ShapeUtil::IsArray(operand_shape) || !primitive_util::IsArrayType(new_element_type)) { @@ -395,8 +421,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // are valid. For now we just reject them, though. return InvalidArgument( "Convert does not allow non-arrays, so cannot convert from %s to %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } return ShapeUtil::ChangeElementType(operand_shape, new_element_type); @@ -408,8 +434,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, if (primitive_util::IsComplexType(old_element_type) != primitive_util::IsComplexType(new_element_type)) { return InvalidArgument("Conversion from complex to real type %s => %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } if (!ShapeUtil::IsArray(operand_shape) || !primitive_util::IsArrayType(new_element_type)) { @@ -418,15 +444,15 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, // are valid. For now we just reject them, though. return InvalidArgument( "Cannot convert from or to tuple type; requested conversion: %s => %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } if (primitive_util::BitWidth(old_element_type) != primitive_util::BitWidth(new_element_type)) { return InvalidArgument( "Cannot bitcast types with different bit-widths: %s => %s.", - PrimitiveType_Name(old_element_type).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + PrimitiveType_Name(old_element_type), + PrimitiveType_Name(new_element_type)); } return ShapeUtil::ChangeElementType(operand_shape, new_element_type); @@ -439,7 +465,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected element type in shape to be floating point for " "ReducePrecision operation; got %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (exponent_bits < 1) { // One exponent bit is necessary to distinguish 0 from infinity. Having @@ -471,8 +497,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "The rank of the operand and the padding configuration do not match: " "%s vs %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - padding_config.ShortDebugString().c_str()); + ShapeUtil::HumanString(operand_shape), + padding_config.ShortDebugString()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, padding_value_shape)) { @@ -539,7 +565,7 @@ Status ValidateDotDimensionNumbers( !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, rhs_batch_dimensions)) { return InvalidArgument("A dimension number is out of range in Dot: %s.", - dimension_numbers.DebugString().c_str()); + dimension_numbers.DebugString()); } // Check that dimension numbers are unique. @@ -557,7 +583,7 @@ Status ValidateDotDimensionNumbers( if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) { return InvalidArgument("A dimension number is not unique in Dot: %s.", - dimension_numbers.DebugString().c_str()); + dimension_numbers.DebugString()); } // Check that the count of non-contracting-non-batch dimensions is in {0, 1}. @@ -602,14 +628,13 @@ Status ValidateDotDimensionNumbers( TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot")); auto fail = [lhs, rhs](const string& addendum) -> Status { - string message = tensorflow::strings::Printf( - "Cannot infer shape for dot operation: %s %s.", - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + string message = + StrFormat("Cannot infer shape for dot operation: %s %s.", + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs)); if (!addendum.empty()) { message += " " + addendum; } - return InvalidArgument("%s", message.c_str()); + return InvalidArgument("%s", message); }; // Check if both element types are the same. @@ -705,9 +730,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", - HloOpcodeString(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + HloOpcodeString(operation), ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs)); } } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), @@ -722,14 +746,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // the user to provide an explicit broadcast dimension in this case. // See b/25177275 for more details. return InvalidArgument("Automatic shape inference not supported: %s and %s", - ShapeUtil::HumanString(smaller_shape).c_str(), - ShapeUtil::HumanString(larger_shape).c_str()); + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape)); } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) { return InvalidArgument( "Size of broadcast_dimensions has to match lower-rank operand's " "rank; " - " lower-rank operand's rank is %lld, size of broadcast_dimensions is " - "%zu.", + " lower-rank operand's rank is %d, size of broadcast_dimensions is " + "%u.", ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size()); } @@ -779,12 +803,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, int64 dimension_to_match = broadcast_dimensions.at(i); if (dimension_to_match < 0) { return InvalidArgument( - "Broadcast dimension number (%lld) cannot be negative.", + "Broadcast dimension number (%d) cannot be negative.", dimension_to_match); } if (dimension_to_match >= larger_shape.dimensions_size()) { return InvalidArgument( - "Broadcast dimension number (%lld) too large; higher-rank " + "Broadcast dimension number (%d) too large; higher-rank " "operand has rank %d.", dimension_to_match, larger_shape.dimensions_size()); } @@ -796,16 +820,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (small_dimension_size != large_dimension_size && small_dimension_size != 1 && large_dimension_size != 1) { return InvalidArgument( - "Broadcast dimension %d mismatch: %lld != %lld; %s and %s.", i, + "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i, small_dimension_size, large_dimension_size, - ShapeUtil::HumanString(smaller_shape).c_str(), - ShapeUtil::HumanString(larger_shape).c_str()); + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape)); } // Make sure the broadcast dimensions are listed in a strictly increasing // order. if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) { return InvalidArgument( - "Broadcast dimensions order is wrong: %lld comes after %lld.", + "Broadcast dimensions order is wrong: %d comes after %d.", dimension_to_match, broadcast_dimensions.at(i - 1)); } @@ -824,8 +848,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Binary op %s with different element types: %s and %s.", - HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + HloOpcodeString(operation), ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs)); } if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { @@ -875,20 +899,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << StrFormat( "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", - HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str(), - Join(broadcast_dimensions, ", ").c_str()); + HloOpcodeString(opcode), ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs), StrJoin(broadcast_dimensions, ", ")); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); - TF_RETURN_IF_ERROR( - ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ", - HloOpcodeString(opcode)))); - TF_RETURN_IF_ERROR( - ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ", - HloOpcodeString(opcode)))); + TF_RETURN_IF_ERROR(ExpectArray( + lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode)))); + TF_RETURN_IF_ERROR(ExpectArray( + rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode)))); switch (opcode) { case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -910,7 +931,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected element type in shape to be floating for complex compose " "operation; got %s.", - PrimitiveType_Name(lhs.element_type()).c_str()); + PrimitiveType_Name(lhs.element_type())); } TF_ASSIGN_OR_RETURN(const Shape& shape, InferElementwiseBinaryOpShape(opcode, lhs, rhs, @@ -929,7 +950,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected pred or integral type in argument to and/or operation; " "got %s.", - PrimitiveType_Name(lhs.element_type()).c_str()); + PrimitiveType_Name(lhs.element_type())); } return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); @@ -947,8 +968,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, default: return Unimplemented( "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.", - HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(), - rhs.ShortDebugString().c_str()); + HloOpcodeString(opcode), lhs.ShortDebugString(), + rhs.ShortDebugString()); } } @@ -971,8 +992,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case HloOpcode::kTupleSelect: return InferTupleSelectShape(lhs, rhs, ehs); default: - return InvalidArgument("Unknown operation %s.", - HloOpcodeString(opcode).c_str()); + return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode)); } } @@ -1011,8 +1031,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Sort keys and values dimensions must match. " "Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(*operand_shapes[0]).c_str(), - ShapeUtil::HumanString(*operand_shapes[1]).c_str()); + ShapeUtil::HumanString(*operand_shapes[0]), + ShapeUtil::HumanString(*operand_shapes[1])); } return ShapeUtil::MakeTupleShape( {*operand_shapes[0], *operand_shapes[1]}); @@ -1020,8 +1040,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument("Unexpected number of operands for sort"); } default: - return InvalidArgument("Unknown operation %s.", - HloOpcodeString(opcode).c_str()); + return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode)); } } @@ -1059,7 +1078,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Map operation requires all operands to have the same shape; got: " "%s.", - Join(pieces, ", ").c_str()); + StrJoin(pieces, ", ")); } // Check that dimensions.size == arg_shape.dimensions_size() (we currently @@ -1067,7 +1086,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (dimensions.size() != arg_shape->dimensions_size()) { return InvalidArgument( "Map applied to a subset of dimensions currently not supported: " - "arg_dimension_size: %d, requested_map_dimensions_size: %zu.", + "arg_dimension_size: %d, requested_map_dimensions_size: %u.", arg_shape->dimensions_size(), dimensions.size()); } @@ -1076,7 +1095,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (dimensions[i] != i) { return InvalidArgument( "Map requires monotonically increasing dimension numbers; got: %s.", - Join(dimensions, ", ").c_str()); + StrJoin(dimensions, ", ")); } } @@ -1084,7 +1103,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (arg_shapes.size() != to_apply.parameters_size()) { return InvalidArgument( "Map applied function arity must match number of arguments; got: " - "arity: %d, arguments: %zu.", + "arity: %d, arguments: %u.", to_apply.parameters_size(), arg_shapes.size()); } @@ -1093,7 +1112,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::IsScalar(output_shape)) { return InvalidArgument( "Mapped computation's result has to be a scalar; got: %s.", - ShapeUtil::HumanString(output_shape).c_str()); + ShapeUtil::HumanString(output_shape)); } for (int i = 0; i < to_apply.parameters_size(); ++i) { @@ -1103,7 +1122,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Mapped computation's parameter has to be a scalar; " "got parameter %d shape: %s.", - i, ShapeUtil::HumanString(parameter_shape).c_str()); + i, ShapeUtil::HumanString(parameter_shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape, @@ -1111,8 +1130,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Mapped computation's parameter type has to match argument element " "type; got parameter %d shape: %s, argument shape: %s.", - i, ShapeUtil::HumanString(parameter_shape).c_str(), - ShapeUtil::HumanString(*arg_shape).c_str()); + i, ShapeUtil::HumanString(parameter_shape), + ShapeUtil::HumanString(*arg_shape)); } } @@ -1141,35 +1160,35 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected feature_index of batch-norm-training to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld.", + "got feature_index %d, and rank %d.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-training to " - "be a non-negative number, got %lld.", + "be a non-negative number, got %d.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-training to be at least 1; got %lld.", + "batch-norm-training to be at least 1; got %d.", ShapeUtil::Rank(operand_shape)); } if (ShapeUtil::Rank(offset_shape) != 1) { return InvalidArgument( "Offset input of batch-norm-training must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(offset_shape)); } if (ShapeUtil::Rank(scale_shape) != 1) { return InvalidArgument( "Scale input of batch-norm-training must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(scale_shape)); } @@ -1177,7 +1196,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The operand to batch-norm-training must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, @@ -1186,8 +1205,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-training, " "but the shape of offset factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(offset_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(offset_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, @@ -1196,8 +1215,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-training, " "but the shape of scale factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(scale_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(scale_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } const int64 feature_count = operand_shape.dimensions(feature_index); @@ -1207,16 +1226,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { return InvalidArgument( "The size of offset factor should be the same as feature count," - "but the size of offset factor is %lld " - "and the feature count is %lld.", + "but the size of offset factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { return InvalidArgument( "The size of scale factor should be the same as feature count," - "but the size of scale factor is %lld " - "and the feature count is %lld.", + "but the size of scale factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } @@ -1251,35 +1270,35 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected feature_index of batch-norm-inference to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld.", + "got feature_index %d, and rank %d.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-inference to " - "be a non-negative number, got %lld.", + "be a non-negative number, got %d.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-inference to be at least 1; got %lld.", + "batch-norm-inference to be at least 1; got %d.", ShapeUtil::Rank(operand_shape)); } if (ShapeUtil::Rank(offset_shape) != 1) { return InvalidArgument( "Offset input of batch-norm-inference must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(offset_shape)); } if (ShapeUtil::Rank(scale_shape) != 1) { return InvalidArgument( "Scale input of batch-norm-inference must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(scale_shape)); } @@ -1287,7 +1306,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The operand to batch-norm-inference must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, @@ -1297,8 +1316,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of offset factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(offset_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(offset_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, @@ -1308,8 +1327,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of scale factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(scale_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(scale_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, @@ -1319,8 +1338,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of mean is %s " "and the shape of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape, @@ -1330,8 +1349,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of variance is %s " "and the shape of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(variance_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(variance_shape.element_type())); } const int64 feature_count = operand_shape.dimensions(feature_index); @@ -1341,32 +1360,32 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { return InvalidArgument( "The size of offset factor should be the same as feature count," - "but the size of offset factor is %lld " - "and the feature count is %lld.", + "but the size of offset factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { return InvalidArgument( "The size of scale factor should be the same as feature count," - "but the size of scale factor is %lld " - "and the feature count is %lld.", + "but the size of scale factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) { return InvalidArgument( "The size of mean should be the same as feature count," - "but the size of mean is %lld " - "and the feature count is %lld.", + "but the size of mean is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) { return InvalidArgument( "The size of variance should be the same as feature count," - "but the size of variance is %lld " - "and the feature count is %lld.", + "but the size of variance is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(variance_shape, 0), feature_count); } @@ -1396,36 +1415,36 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected feature_index of batch-norm-grad to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld.", + "got feature_index %d, and rank %d.", feature_index, ShapeUtil::Rank(operand_shape)); } if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) { return InvalidArgument( "Expected operand_shape of batch-norm-grad to have the same rank as" - " output_grad_shape; got rank(oprand_shape) %lld, and" - " rank(output_grad_shape) %lld.", + " output_grad_shape; got rank(oprand_shape) %d, and" + " rank(output_grad_shape) %d.", ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape)); } if (ShapeUtil::Rank(mean_shape) != 1) { return InvalidArgument( "Mean input of batch-norm-grad must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(mean_shape)); } if (ShapeUtil::Rank(scale_shape) != 1) { return InvalidArgument( "Scale input of batch-norm-grad must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(scale_shape)); } if (ShapeUtil::Rank(var_shape) != 1) { return InvalidArgument( "Var input of batch-norm-grad must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(var_shape)); } @@ -1433,14 +1452,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The operand to batch-norm-grad must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::ElementIsFloating(output_grad_shape)) { return InvalidArgument( "The output_grad to batch-norm-grad must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(output_grad_shape.element_type()).c_str()); + PrimitiveType_Name(output_grad_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape, @@ -1449,8 +1468,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of output_grad is %s " "and the element type of operand is %s.", - PrimitiveType_Name(output_grad_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(output_grad_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, @@ -1459,8 +1478,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of scale factor is %s " "and the element type of operand is %s.", - PrimitiveType_Name(scale_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(scale_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, @@ -1469,8 +1488,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " "and the element type of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape, @@ -1479,8 +1498,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " "and the element type of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } const int64 feature_count = operand_shape.dimensions(feature_index); @@ -1491,24 +1510,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) { return InvalidArgument( "The size of mean should be the same as feature count," - "but the size of offset factor is %lld " - "and the feature count is %lld.", + "but the size of offset factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { return InvalidArgument( "The size of scale factor should be the same as feature count," - "but the size of scale factor is %lld " - "and the feature count is %lld.", + "but the size of scale factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) { return InvalidArgument( "The size of variance should be the same as feature count," - "but the size of variance is %lld " - "and the feature count is %lld.", + "but the size of variance is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(var_shape, 0), feature_count); } @@ -1518,8 +1537,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::GetDimension(output_grad_shape, i)) { return InvalidArgument( "The bounds of operand shape should be the same as output_grad's," - "but the bound of operand_shape at dimension %lld is %lld " - "and the bound of output_grad_shape is %lld.", + "but the bound of operand_shape at dimension %d is %d " + "and the bound of output_grad_shape is %d.", i, ShapeUtil::GetDimension(operand_shape, i), ShapeUtil::GetDimension(output_grad_shape, i)); } @@ -1538,15 +1557,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Convolution with different element types: %s and %s.", - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs)); } if (dnums.input_spatial_dimensions_size() != dnums.kernel_spatial_dimensions_size()) { return InvalidArgument( "Both arguments to convolution must have same number of dimensions.\n" "Window: %s", - window.DebugString().c_str()); + window.DebugString()); } const int num_spatial_dims = dnums.input_spatial_dimensions_size(); @@ -1554,19 +1572,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Window must have same number of dimensions as dimension numbers.\n" "Window: %s\nDimension numbers: %s.", - window.DebugString().c_str(), dnums.DebugString().c_str()); + window.DebugString(), dnums.DebugString()); } const int num_dims = num_spatial_dims + 2; if (ShapeUtil::Rank(lhs) != num_dims) { return InvalidArgument( "The LHS argument to a convolution should have rank %d; lhs: %s.", - num_dims, ShapeUtil::HumanString(lhs).c_str()); + num_dims, ShapeUtil::HumanString(lhs)); } if (ShapeUtil::Rank(rhs) != num_dims) { return InvalidArgument( "The RHS argument to a convolution should have rank %d; lhs: %s.", - num_dims, ShapeUtil::HumanString(lhs).c_str()); + num_dims, ShapeUtil::HumanString(lhs)); } TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); @@ -1603,26 +1621,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { return InvalidArgument( "A dimension number is out of range in convolution: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } if (input_dnums != expected_dnums) { return InvalidArgument( "Input dimensions of convolution must contain each dimension exactly " "once: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } if (window_dnums != expected_dnums) { return InvalidArgument( "Window dimensions of convolution must contain each dimension exactly " "once: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } if (output_dnums != expected_dnums) { return InvalidArgument( "Output dimensions of convolution must contain each dimension exactly " "once: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } std::vector input_spatial_dims(num_spatial_dims); @@ -1643,13 +1661,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (input_features != kernel_input_features * feature_group_count) { return InvalidArgument( - "Expected LHS feature dimension (value %lld) to match RHS " - "input feature dimension * feature_group_count (value %lld); " + "Expected LHS feature dimension (value %d) to match RHS " + "input feature dimension * feature_group_count (value %d); " "got (%s, %s)\n" "Dimension numbers: {%s}.", input_features, kernel_input_features * feature_group_count, - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str()); + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), + dnums.DebugString()); } std::vector window_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { @@ -1661,8 +1679,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "RHS shape: %s\n\t" "Window: {%s}\n\t" "Dimension numbers: {%s}.", - ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(), - dnums.ShortDebugString().c_str()); + ShapeUtil::HumanString(rhs), window.ShortDebugString(), + dnums.ShortDebugString()); } Shape base_shape = @@ -1688,29 +1706,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const tensorflow::gtl::ArraySlice fft_length) { const int64 fft_rank = fft_length.size(); if (fft_rank < 1 || fft_rank > 3) { - return InvalidArgument("FFT only supports ranks 1-3; got %lld.", fft_rank); + return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank); } -#define RET_CHECK_RANK(x) \ - if (x.dimensions_size() < fft_rank) { \ - return InvalidArgument( \ - "FFT of rank %lld requires input of at least " \ - "same rank; got input of rank %d", \ - fft_rank, x.dimensions_size()); \ +#define RET_CHECK_RANK(x) \ + if (x.dimensions_size() < fft_rank) { \ + return InvalidArgument( \ + "FFT of rank %d requires input of at least " \ + "same rank; got input of rank %d", \ + fft_rank, x.dimensions_size()); \ } switch (fft_type) { case FFT: case IFFT: if (in.element_type() != C64) { return InvalidArgument("%s requires C64 input type, found %s.", - FftType_Name(fft_type).c_str(), - PrimitiveType_Name(in.element_type()).c_str()); + FftType_Name(fft_type), + PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); return in; case RFFT: { if (in.element_type() != F32) { return InvalidArgument("RFFT requires F32 input type, found %s.", - PrimitiveType_Name(in.element_type()).c_str()); + PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); for (int i = 0; i < fft_rank; i++) { @@ -1718,7 +1736,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]) { return InvalidArgument( "RFFT requires innermost dimensions match fft_length but " - "dimension %lld is %lld and should be %lld.", + "dimension %d is %d and should be %d.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1732,7 +1750,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case IRFFT: { if (in.element_type() != C64) { return InvalidArgument("IRFFT requires C64 input type, found %s.", - PrimitiveType_Name(in.element_type()).c_str()); + PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); Shape result = ShapeUtil::ComplexComponentShape(in); @@ -1741,7 +1759,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]) { return InvalidArgument( "IRFFT requires all but one innermost dimensions match " - "fft_length, but dimension %lld is %lld and should be %lld.", + "fft_length, but dimension %d is %d and should be %d.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1751,7 +1769,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[fft_rank - 1] / 2 + 1) { return InvalidArgument( "IRFFT requires innermost dimension matches fft_length/2+1, but " - "dimension %d is %lld and should be %lld.", + "dimension %d is %d and should be %d.", in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1), fft_length[fft_rank - 1] / 2 + 1); } @@ -1787,18 +1805,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RET_CHECK(split_count > 0); if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) { return InvalidArgument( - "AllToAll split_dimension %lld is out-of-bounds in shape %s.", - split_dimension, ShapeUtil::HumanString(shape).c_str()); + "AllToAll split_dimension %d is out-of-bounds in shape %s.", + split_dimension, ShapeUtil::HumanString(shape)); } if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) { return InvalidArgument( - "AllToAll concat_dimension %lld is out-of-bounds in shape %s.", - concat_dimension, ShapeUtil::HumanString(shape).c_str()); + "AllToAll concat_dimension %d is out-of-bounds in shape %s.", + concat_dimension, ShapeUtil::HumanString(shape)); } if (shape.dimensions(split_dimension) % split_count != 0) { return InvalidArgument( - "AllToAll split dimension size %lld must be dividable by split_count " - "%lld.", + "AllToAll split dimension size %d must be dividable by split_count " + "%d.", shape.dimensions(split_dimension), split_count); } std::vector new_dimensions(shape.dimensions().begin(), @@ -1818,14 +1836,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "HLO all-to-all has operands with different shapes: the 0th " "operand shape %s, but the %dth operand has shape %s.", - ShapeUtil::HumanString(*operand_shapes[0]).c_str(), i, - ShapeUtil::HumanString(*operand_shapes[i]).c_str()); + ShapeUtil::HumanString(*operand_shapes[0]), i, + ShapeUtil::HumanString(*operand_shapes[i])); } } return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); } +/* static */ StatusOr ShapeInference::InferCollectivePermuteShape( + const Shape& shape) { + TF_RET_CHECK(ShapeUtil::IsArray(shape)); + return shape; +} + /* static */ StatusOr ShapeInference::InferReduceShape( tensorflow::gtl::ArraySlice arg_shapes, tensorflow::gtl::ArraySlice dimensions_to_reduce, @@ -1848,9 +1872,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) { return InvalidArgument( "All reduced tensors must have the sime dimension. Tensor 0 has " - "shape %s, Tensor %lld has shape %s", - ShapeUtil::HumanString(*reduced_args[0]).c_str(), i, - ShapeUtil::HumanString(*reduced_args[i]).c_str()); + "shape %s, Tensor %d has shape %s", + ShapeUtil::HumanString(*reduced_args[0]), i, + ShapeUtil::HumanString(*reduced_args[i])); } } @@ -1860,9 +1884,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& arg = *reduced_args[0]; for (int64 dimension : dimensions_to_reduce) { if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { - return InvalidArgument( - "Reducing out-of-bounds dimension %lld in shape %s.", dimension, - ShapeUtil::HumanString(arg).c_str()); + return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.", + dimension, ShapeUtil::HumanString(arg)); } } @@ -1935,16 +1958,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Select function's first parameter shape currently must " "match the operand element shape, but got %s vs %s.", - ShapeUtil::HumanString(select_shape.parameters(0)).c_str(), - ShapeUtil::HumanString(operand_element_shape).c_str()); + ShapeUtil::HumanString(select_shape.parameters(0)), + ShapeUtil::HumanString(operand_element_shape)); } if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, select_shape.parameters(1))) { return InvalidArgument( "Select function's second parameter shape currently must " "match the operand element shape, but got %s vs %s.", - ShapeUtil::HumanString(select_shape.parameters(1)).c_str(), - ShapeUtil::HumanString(operand_element_shape).c_str()); + ShapeUtil::HumanString(select_shape.parameters(1)), + ShapeUtil::HumanString(operand_element_shape)); } // Check if the scatter function has a proper shape as a reduction. @@ -1962,8 +1985,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Source shape does not match the shape of window-reduced operand: " "source(%s), window-reduced operand(%s).", - ShapeUtil::HumanString(source_shape).c_str(), - ShapeUtil::HumanString(window_result_shape).c_str()); + ShapeUtil::HumanString(source_shape), + ShapeUtil::HumanString(window_result_shape)); } return operand_shape; } @@ -1976,29 +1999,27 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "%s in slice operation; argument shape: %s; starts: {%s}; limits: " "{%s}; strides: {%s}.", - message.c_str(), ShapeUtil::HumanString(arg).c_str(), - Join(starts, ",").c_str(), Join(limits, ",").c_str(), - Join(strides, ",").c_str()); + message, ShapeUtil::HumanString(arg), StrJoin(starts, ","), + StrJoin(limits, ","), StrJoin(strides, ",")); }; TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice")); - VLOG(2) << tensorflow::strings::Printf( - "slicing shape %s starts={%s} limits={%s}", - ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(), - Join(limits, ", ").c_str()); + VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}", + ShapeUtil::HumanString(arg), StrJoin(starts, ", "), + StrJoin(limits, ", ")); if (starts.size() != limits.size()) { - return error(Printf("slice start and limit sizes differ: %zu vs %zu", - starts.size(), limits.size())); + return error(StrFormat("slice start and limit sizes differ: %u vs %u", + starts.size(), limits.size())); } if (starts.size() != strides.size()) { - return error(Printf("slice start and strides sizes differ: %zu vs %zu", - starts.size(), strides.size())); + return error(StrFormat("slice start and strides sizes differ: %u vs %u", + starts.size(), strides.size())); } if (starts.size() != ShapeUtil::Rank(arg)) { return InvalidArgument( - "Slice index count does not match argument rank: %zu vs %lld.", + "Slice index count does not match argument rank: %u vs %d.", starts.size(), ShapeUtil::Rank(arg)); } @@ -2008,27 +2029,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, int64 limit_index = limits[dimension]; int64 stride = strides[dimension]; if (start_index < 0) { - return InvalidArgument("Negative start index to slice: %lld.", - start_index); + return InvalidArgument("Negative start index to slice: %d.", start_index); } if (limit_index > arg.dimensions(dimension)) { return error( - Printf("limit index (%lld) must be less than or equal to dimension " - "size (%lld)", - limit_index, arg.dimensions(dimension))); - } - VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension, - start_index); - VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension, - limit_index); + StrFormat("limit index (%d) must be less than or equal to dimension " + "size (%d)", + limit_index, arg.dimensions(dimension))); + } + VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index); + VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index); if (start_index > limit_index) { return error( - Printf("limit index (%lld) must be greater or equal to " - "start index (%lld) in slice with positive stride", - limit_index, start_index)); + StrFormat("limit index (%d) must be greater or equal to " + "start index (%d) in slice with positive stride", + limit_index, start_index)); } if (stride <= 0) { - return InvalidArgument("Stride (%lld) must be positive.", stride); + return InvalidArgument("Stride (%d) must be positive.", stride); } sizes.push_back((limit_index - start_index + stride - 1) / stride); } @@ -2043,15 +2061,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR( ExpectArray(start_indices_shape, "start indices of dynamic slice")); - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << StrFormat( "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", - ShapeUtil::HumanString(operand_shape).c_str(), - ShapeUtil::HumanString(start_indices_shape).c_str(), - Join(slice_sizes, ", ").c_str()); + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), StrJoin(slice_sizes, ", ")); if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "Dynamic slice start indices of rank %lld must be rank1.", + "Dynamic slice start indices of rank %d must be rank1.", ShapeUtil::Rank(start_indices_shape)); } @@ -2063,16 +2080,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "Dynamic slice start number of dimensions %lld (%s) must match rank " - "%lld of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), - ShapeUtil::Rank(operand_shape), - ShapeUtil::HumanString(operand_shape).c_str()); + "Dynamic slice start number of dimensions %d (%s) must match rank " + "%d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); } if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( - "Dynamic slice index count does not match argument rank: %zu vs %lld.", + "Dynamic slice index count does not match argument rank: %u vs %d.", slice_sizes.size(), ShapeUtil::Rank(operand_shape)); } @@ -2080,16 +2096,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 input_dim_size = operand_shape.dimensions(dim); const int64 slice_dim_size = slice_sizes[dim]; if (slice_dim_size < 0) { - return InvalidArgument("Negative size index to dynamic slice: %lld.", + return InvalidArgument("Negative size index to dynamic slice: %d.", slice_dim_size); } if (slice_dim_size > input_dim_size) { return InvalidArgument( - "Slice dim size %lld greater than dynamic slice dimension: %lld.", + "Slice dim size %d greater than dynamic slice dimension: %d.", slice_dim_size, input_dim_size); } - VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim, - slice_dim_size); + VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size); } return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes); @@ -2105,16 +2120,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, "start indices of dynamic update slice")); - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << StrFormat( "updating slice of shape %s at dynamic start_indices %s with update " "shape %s", - ShapeUtil::HumanString(operand_shape).c_str(), - ShapeUtil::HumanString(start_indices_shape).c_str(), - ShapeUtil::HumanString(update_shape).c_str()); + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::HumanString(update_shape)); if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "Dynamic update slice start indices of rank %lld must be rank1.", + "Dynamic update slice start indices of rank %d must be rank1.", ShapeUtil::Rank(start_indices_shape)); } @@ -2126,17 +2141,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "Dynamic update slice start number of dimensions %lld (%s) must match " - "rank %lld of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), - ShapeUtil::Rank(operand_shape), - ShapeUtil::HumanString(operand_shape).c_str()); + "Dynamic update slice start number of dimensions %d (%s) must match " + "rank %d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); } if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( "Dynamic update slice update rank does not match argument rank: " - "%lld vs %lld.", + "%d vs %d.", ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); } @@ -2145,8 +2159,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Dynamic update slice update element type does not match argument. " "operand.element_type: %s vs update.element_type: %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str(), - PrimitiveType_Name(update_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type()), + PrimitiveType_Name(update_shape.element_type())); } for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) { @@ -2154,16 +2168,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 update_dim_size = update_shape.dimensions(dim); if (update_dim_size < 0) { return InvalidArgument( - "Size index %lld to dynamic update slice must be >= 0.", + "Size index %d to dynamic update slice must be >= 0.", update_dim_size); } if (update_dim_size > input_dim_size) { return InvalidArgument( - "Update dim size %lld greater than dynamic slice dimension: %lld.", + "Update dim size %d greater than dynamic slice dimension: %d.", update_dim_size, input_dim_size); } - VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim, - update_dim_size); + VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size); } return operand_shape; @@ -2178,8 +2191,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, for (int64 dimension : dimensions) { if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) { return InvalidArgument( - "One of the reverse dimensions (%lld) is out-of-bounds in shape %s.", - dimension, ShapeUtil::HumanString(operand_shape).c_str()); + "One of the reverse dimensions (%d) is out-of-bounds in shape %s.", + dimension, ShapeUtil::HumanString(operand_shape)); } } return operand_shape; @@ -2190,14 +2203,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::IsTuple(arg)) { return InvalidArgument( "Cannot infer shape: attempting to index into non-tuple: %s.", - ShapeUtil::HumanString(arg).c_str()); + ShapeUtil::HumanString(arg)); } if (index >= arg.tuple_shapes_size()) { return InvalidArgument( - "Cannot infer shape: attempt to index out of tuple bounds: %lld " + "Cannot infer shape: attempt to index out of tuple bounds: %d " ">= %d in shape %s.", - index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str()); + index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg)); } return arg.tuple_shapes(index); @@ -2217,17 +2230,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } auto shape_string = [&]() { - return tensorflow::strings::Printf( - "Condition: %s; body: %s; init: %s.", - ShapeUtil::HumanString(condition).c_str(), - ShapeUtil::HumanString(body).c_str(), - ShapeUtil::HumanString(init).c_str()); + return StrFormat( + "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition), + ShapeUtil::HumanString(body), ShapeUtil::HumanString(init)); }; // Check the shapes of computation parameters and return types. if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) { return InvalidArgument("Condition must return a boolean; got %s.", - shape_string().c_str()); + shape_string()); } if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) || !ShapeUtil::Compatible(body.result(), body.parameters(0)) || @@ -2235,7 +2246,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The parameter of condition and body, the result of the body, and init " "must all have the same shape; got %s.", - shape_string().c_str()); + shape_string()); } return init; @@ -2247,7 +2258,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const ProgramShape& false_computation) { if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { return InvalidArgument("Predicate must be a boolean; got %s.", - ShapeUtil::HumanString(predicate).c_str()); + ShapeUtil::HumanString(predicate)); } if (true_computation.parameters_size() != 1) { @@ -2256,15 +2267,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) { auto true_shape_string = [&]() { - return tensorflow::strings::Printf( - "true_operand: %s; true_computation: %s", - ShapeUtil::HumanString(true_operand).c_str(), - ShapeUtil::HumanString(true_computation).c_str()); + return StrFormat("true_operand: %s; true_computation: %s", + ShapeUtil::HumanString(true_operand), + ShapeUtil::HumanString(true_computation)); }; return InvalidArgument( "true_operand must match the shape of the only parameter of " "true_computation: got %s.", - true_shape_string().c_str()); + true_shape_string()); } if (false_computation.parameters_size() != 1) { @@ -2273,28 +2283,27 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) { auto false_shape_string = [&]() { - return tensorflow::strings::Printf( - "false_operand: %s; false_computation: %s", - ShapeUtil::HumanString(false_operand).c_str(), - ShapeUtil::HumanString(false_computation).c_str()); + return StrFormat("false_operand: %s; false_computation: %s", + ShapeUtil::HumanString(false_operand), + ShapeUtil::HumanString(false_computation)); }; return InvalidArgument( "false_operand must match the shape of the only parameter of " "false_computation: got %s.", - false_shape_string().c_str()); + false_shape_string()); } if (!ShapeUtil::Compatible(true_computation.result(), false_computation.result())) { auto shape_string = [&]() { - return tensorflow::strings::Printf( + return StrFormat( "true_computation result: %s; false_computation result: %s.", - ShapeUtil::HumanString(true_computation.result()).c_str(), - ShapeUtil::HumanString(false_computation.result()).c_str()); + ShapeUtil::HumanString(true_computation.result()), + ShapeUtil::HumanString(false_computation.result())); }; return InvalidArgument( "the result of true_computation and false_computation must have the " "same shape: got %s.", - shape_string().c_str()); + shape_string()); } return true_computation.result(); } @@ -2304,7 +2313,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); for (int64 size : broadcast_sizes) { if (size < 0) { - return InvalidArgument("Broadcast with negative dimension size %lld.", + return InvalidArgument("Broadcast with negative dimension size %d.", size); } } @@ -2329,11 +2338,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { return InvalidArgument( - "Reshape operation has mismatched element counts: from=%lld (%s) " - "to=%lld (%s).", - ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(), + "Reshape operation has mismatched element counts: from=%d (%s) " + "to=%d (%s).", + ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand), ShapeUtil::ElementsIn(inferred_shape), - ShapeUtil::HumanString(inferred_shape).c_str()); + ShapeUtil::HumanString(inferred_shape)); } std::vector indices(ShapeUtil::Rank(operand)); @@ -2344,7 +2353,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Reshape dimensions [%s] are not a permutation of the operand " "dimensions (operand shape is %s).", - Join(dimensions, ",").c_str(), ShapeUtil::HumanString(operand).c_str()); + StrJoin(dimensions, ","), ShapeUtil::HumanString(operand)); } return inferred_shape; @@ -2379,9 +2388,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { return InvalidArgument("Clamp with different operand types: %s, %s, %s.", - ShapeUtil::HumanString(min).c_str(), - ShapeUtil::HumanString(operand).c_str(), - ShapeUtil::HumanString(max).c_str()); + ShapeUtil::HumanString(min), + ShapeUtil::HumanString(operand), + ShapeUtil::HumanString(max)); } if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) || ShapeUtil::IsScalar(min)) && @@ -2398,9 +2407,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::ChangeElementType(min, operand.element_type()); } } - return Unimplemented( - "%s, %s %s is not implemented.", min.ShortDebugString().c_str(), - max.ShortDebugString().c_str(), operand.ShortDebugString().c_str()); + return Unimplemented("%s, %s %s is not implemented.", + min.ShortDebugString(), max.ShortDebugString(), + operand.ShortDebugString()); } // TODO(b/36794510): Make broadcast semantics more consistent, by supporting @@ -2411,13 +2420,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) { return InvalidArgument( "Operands to select must be the same shape; got %s and %s.", - ShapeUtil::HumanString(on_true).c_str(), - ShapeUtil::HumanString(on_false).c_str()); + ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false)); } if (pred.element_type() != PRED) { return InvalidArgument( "Select's pred operand must have PRED element type; got %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || ShapeUtil::IsScalar(pred)) { @@ -2430,7 +2438,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Select operation with non-scalar predicate with dimensionality " " different from the other operands: %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } } @@ -2441,18 +2449,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::Compatible(on_true, on_false)) { return InvalidArgument( "Operands to tuple-select must be the same shape; got %s and %s.", - ShapeUtil::HumanString(on_true).c_str(), - ShapeUtil::HumanString(on_false).c_str()); + ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false)); } if (pred.element_type() != PRED) { return InvalidArgument( "TupleSelect's pred operand must have PRED element type; got %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } if (!ShapeUtil::IsScalar(pred)) { return InvalidArgument( "TupleSelect operation with non-scalar predicate: %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } return on_true; } @@ -2464,15 +2471,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (arg_shapes.size() != to_apply.parameters_size()) { string computation_signature = ShapeUtil::HumanString(to_apply); string argument_shapes = - Join(arg_shapes, ", ", [](string* out, const Shape* shape) { - tensorflow::strings::StrAppend(out, ShapeUtil::HumanString(*shape)); + StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) { + absl::StrAppend(out, ShapeUtil::HumanString(*shape)); }); return InvalidArgument( "Call applied function arity must match number of arguments; got: " - "arity: %d, arguments: %zu; computation signature: %s; argument " + "arity: %d, arguments: %u; computation signature: %s; argument " "shapes: [%s].", - to_apply.parameters_size(), arg_shapes.size(), - computation_signature.c_str(), argument_shapes.c_str()); + to_apply.parameters_size(), arg_shapes.size(), computation_signature, + argument_shapes); } // All arguments must be compatible with the program shape. @@ -2483,8 +2490,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Call parameter must match argument; got parameter %d shape: %s, " "argument shape: %s.", - i, ShapeUtil::HumanString(param_shape).c_str(), - ShapeUtil::HumanString(arg_shape).c_str()); + i, ShapeUtil::HumanString(param_shape), + ShapeUtil::HumanString(arg_shape)); } } @@ -2498,14 +2505,14 @@ static Status ValidateGatherDimensionNumbers( if (!absl::c_is_sorted(dim_numbers.offset_dims())) { return InvalidArgument( "Output window dimensions in gather op must be ascending; got: %s.", - Join(dim_numbers.offset_dims(), ", ").c_str()); + StrJoin(dim_numbers.offset_dims(), ", ")); } if (absl::c_adjacent_find(dim_numbers.offset_dims()) != dim_numbers.offset_dims().end()) { return InvalidArgument( "Output window dimensions in gather op must not repeat; got: %s.", - Join(dim_numbers.offset_dims(), ", ").c_str()); + StrJoin(dim_numbers.offset_dims(), ", ")); } const int64 output_offset_dim_count = dim_numbers.offset_dims_size(); @@ -2516,9 +2523,9 @@ static Status ValidateGatherDimensionNumbers( int64 offset_dim = dim_numbers.offset_dims(i); if (offset_dim < 0 || offset_dim >= output_shape_rank) { return InvalidArgument( - "Offset dimension %d in gather op is out of bounds; got %lld, but " + "Offset dimension %d in gather op is out of bounds; got %d, but " "should " - "have been in [0,%lld).", + "have been in [0,%d).", i, offset_dim, output_shape_rank); } } @@ -2527,8 +2534,8 @@ static Status ValidateGatherDimensionNumbers( start_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( "Gather op has %d elements in start_index_map and the " - "bound of dimension index_vector_dim=%lld of start_indices is " - "%lld. These two numbers must be equal.", + "bound of dimension index_vector_dim=%d of start_indices is " + "%d. These two numbers must be equal.", dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(), start_indices_shape[dim_numbers.index_vector_dim()]); } @@ -2538,7 +2545,7 @@ static Status ValidateGatherDimensionNumbers( if (operand_dim_for_start_index_i < 0 || operand_dim_for_start_index_i >= input_shape.dimensions_size()) { return InvalidArgument( - "Invalid start_index_map; domain is [0, %d), got: %d->%lld.", + "Invalid start_index_map; domain is [0, %d), got: %d->%d.", input_shape.dimensions_size(), i, operand_dim_for_start_index_i); } } @@ -2554,14 +2561,14 @@ static Status ValidateGatherDimensionNumbers( return InvalidArgument( "Repeated dimensions are not allowed in start_index_map; " "got: %s.", - Join(dim_numbers.start_index_map(), ", ").c_str()); + StrJoin(dim_numbers.start_index_map(), ", ")); } for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) { if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) { return InvalidArgument( "Invalid collapsed_slice_dims set in gather op; valid range is [0, " - "%d), got: %lld.", + "%d), got: %d.", input_shape.dimensions_size(), collapsed_dim); } } @@ -2569,7 +2576,7 @@ static Status ValidateGatherDimensionNumbers( if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) { return InvalidArgument( "collapsed_slice_dims in gather op must be sorted; got: %s", - Join(dim_numbers.collapsed_slice_dims(), ", ").c_str()); + StrJoin(dim_numbers.collapsed_slice_dims(), ", ")); } if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) != @@ -2577,7 +2584,7 @@ static Status ValidateGatherDimensionNumbers( return InvalidArgument( "Repeated dimensions not allowed in collapsed_slice_dims in gather op; " "got: %s.", - Join(dim_numbers.collapsed_slice_dims(), ", ").c_str()); + StrJoin(dim_numbers.collapsed_slice_dims(), ", ")); } return Status::OK(); @@ -2595,7 +2602,7 @@ static Status ValidateGatherDimensionNumbers( if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { return InvalidArgument( "Gather indices parameter must be an integral tensor; got %s.", - ShapeUtil::HumanString(start_indices_shape).c_str()); + ShapeUtil::HumanString(start_indices_shape)); } // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if @@ -2608,7 +2615,7 @@ static Status ValidateGatherDimensionNumbers( return InvalidArgument( "Gather index leaf dimension must be within [0, rank(start_indices) + " "1). rank(start_indices) is %d and gather index leaf dimension is " - "%lld.", + "%d.", start_indices_shape.dimensions_size(), gather_dim_numbers.index_vector_dim()); } @@ -2639,8 +2646,8 @@ static Status ValidateGatherDimensionNumbers( "All components of the offset index in a gather op must either be a " "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, " "output_slice_sizes=%s, collapsed_slice_dims=%s.", - slice_sizes.size(), Join(gather_dim_numbers.offset_dims(), ",").c_str(), - Join(gather_dim_numbers.collapsed_slice_dims(), ",").c_str()); + slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","), + StrJoin(gather_dim_numbers.collapsed_slice_dims(), ",")); } for (int i = 0; i < slice_sizes.size(); i++) { @@ -2649,7 +2656,7 @@ static Status ValidateGatherDimensionNumbers( if (slice_size < 0 || slice_size > corresponding_input_size) { return InvalidArgument( "Slice size at index %d in gather op is out of range, must be " - "within [0, %lld), got %lld.", + "within [0, %d), got %d.", i, corresponding_input_size + 1, slice_size); } } @@ -2658,7 +2665,7 @@ static Status ValidateGatherDimensionNumbers( if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) { return InvalidArgument( "Gather op can only collapse slice dims with bound 1, but bound is " - "%lld for index %lld at position %d.", + "%d for index %d at position %d.", slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)], gather_dim_numbers.collapsed_slice_dims(i), i); } @@ -2703,20 +2710,20 @@ Status ValidateScatterDimensionNumbers( if (!absl::c_is_sorted(dim_numbers.update_window_dims())) { return InvalidArgument( "update_window_dims in scatter op must be sorted; got: %s.", - Join(dim_numbers.update_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.update_window_dims(), ", ")); } if (absl::c_adjacent_find(dim_numbers.update_window_dims()) != dim_numbers.update_window_dims().end()) { return InvalidArgument( "update_window_dims in scatter op must not repeat; got: %s.", - Join(dim_numbers.update_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.update_window_dims(), ", ")); } const int64 updates_rank = ShapeUtil::Rank(updates_shape); for (int64 window_dim : dim_numbers.update_window_dims()) { if (window_dim < 0 || window_dim >= updates_rank) { return InvalidArgument( "Invalid update_window_dims set in scatter op; valid range is [0, " - "%lld). got: %lld.", + "%d). got: %d.", updates_rank, window_dim); } } @@ -2725,19 +2732,19 @@ Status ValidateScatterDimensionNumbers( if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) { return InvalidArgument( "inserted_window_dims in scatter op must be sorted; got: %s.", - Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.inserted_window_dims(), ", ")); } if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) != dim_numbers.inserted_window_dims().end()) { return InvalidArgument( "inserted_window_dims in scatter op must not repeat; got: %s.", - Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.inserted_window_dims(), ", ")); } for (int64 inserted_dim : dim_numbers.inserted_window_dims()) { if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) { return InvalidArgument( "Invalid inserted_window_dims set in scatter op; valid range is [0, " - "%d), got: %lld.", + "%d), got: %d.", operand_shape.dimensions_size(), inserted_dim); } } @@ -2747,7 +2754,7 @@ Status ValidateScatterDimensionNumbers( scatter_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( "Scatter op has %d elements in scatter_dims_to_operand_dims and the " - "bound of dimension index_vector_dim=%lld of scatter_indices is %lld. " + "bound of dimension index_vector_dim=%d of scatter_indices is %d. " "These two numbers must be equal.", dim_numbers.scatter_dims_to_operand_dims_size(), dim_numbers.index_vector_dim(), @@ -2760,7 +2767,7 @@ Status ValidateScatterDimensionNumbers( scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) { return InvalidArgument( "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), " - "got: %d->%lld.", + "got: %d->%d.", operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim); } } @@ -2773,7 +2780,7 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Repeated dimensions not allowed in scatter_dims_to_operand_dims; " "got: %s.", - Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str()); + StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", ")); } return Status::OK(); @@ -2794,7 +2801,7 @@ Status ValidateScatterDimensionNumbers( if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) { return InvalidArgument( "Scatter indices parameter must be an integral tensor; got %s.", - ShapeUtil::HumanString(scatter_indices_shape).c_str()); + ShapeUtil::HumanString(scatter_indices_shape)); } if (scatter_indices_shape.dimensions_size() < @@ -2803,7 +2810,7 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Scatter index leaf dimension must be within [0, rank(scatter_indices)" " + 1). rank(scatter_indices) is %d and scatter index leaf dimension " - "is %lld.", + "is %d.", scatter_indices_shape.dimensions_size(), scatter_dim_numbers.index_vector_dim()); } @@ -2825,7 +2832,7 @@ Status ValidateScatterDimensionNumbers( int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 + scatter_dim_numbers.update_window_dims_size(); if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) { - return InvalidArgument("Updates tensor must be of rank %lld; got %lld.", + return InvalidArgument("Updates tensor must be of rank %d; got %d.", expected_updates_rank, ShapeUtil::Rank(updates_shape)); } @@ -2851,7 +2858,7 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Bounds of the window dimensions of updates must not exceed the " "bounds of the corresponding dimensions of operand. For dimension " - "%lld, updates bound is %lld, operand bound is %lld.", + "%d, updates bound is %d, operand bound is %d.", update_window_dim, updates_shape.dimensions(update_window_dim), max_update_slice_sizes[i]); } @@ -2872,8 +2879,8 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Bounds of the scatter dimensions of updates must be same as the " "bounds of the corresponding dimensions of scatter indices. For " - "scatter dimension %lld, updates bound is %lld, scatter_indices " - "bound is %lld.", + "scatter dimension %d, updates bound is %d, scatter_indices " + "bound is %d.", i, updates_shape.dimensions(i), expanded_scatter_indices_shape[scatter_dims_seen]); } diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 4974ac9916abaea25f8d455b24f7c0904277f5f7..235b1a4cf3f3506edadf3abb869e76a32f459cdc 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -136,6 +136,9 @@ class ShapeInference { static StatusOr InferAllToAllTupleShape( tensorflow::gtl::ArraySlice operand_shapes); + // Infers the shape of a collective permute operation. + static StatusOr InferCollectivePermuteShape(const Shape& shape); + // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. // diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 70714ffff06b4ba4c13aae22290eff049ed3385c..921a984589bb4fb64058a2a56adfe84fe14af69b 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -19,19 +19,18 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::strings::Appendf; - ShapedBuffer::ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, const se::Platform* platform, int device_ordinal) @@ -76,7 +75,7 @@ void ShapedBuffer::clear() { } string ShapedBuffer::ToString() const { - string s = tensorflow::strings::StrCat( + string s = absl::StrCat( "ShapedBuffer(", platform_->Name(), ":", device_ordinal(), "), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()), ", on-device shape=" + @@ -92,9 +91,9 @@ string ShapedBuffer::ToString() const { shape_str = ShapeUtil::HumanStringWithLayout(subshape); } const se::DeviceMemoryBase& memory = buffer(index); - Appendf(&s, " %s%p (%lld bytes) : %s\n", - string(index.size() * 2, ' ').c_str(), memory.opaque(), - memory.size(), shape_str.c_str()); + absl::StrAppendFormat(&s, " %s%p (%d bytes) : %s\n", + string(index.size() * 2, ' '), memory.opaque(), + memory.size(), shape_str); }); return s; } diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc index 8cbaac7b3760717bcacb57adc8782a5755c0aa6d..dd53c7531bea4273b5f8dc1c993e7720eb1afeb2 100644 --- a/tensorflow/compiler/xla/service/source_map_util.cc +++ b/tensorflow/compiler/xla/service/source_map_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/source_map_util.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -26,11 +27,10 @@ Status InvalidParameterArgumentV(const OpMetadata& op_metadata, string message; tensorflow::strings::Appendv(&message, format, args); if (!op_metadata.source_file().empty()) { - tensorflow::strings::Appendf(&message, " (%s:%d)", - op_metadata.source_file().c_str(), - op_metadata.source_line()); + absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(), + op_metadata.source_line()); } - return InvalidArgument("%s", message.c_str()); + return InvalidArgument("%s", message); } } // namespace diff --git a/tensorflow/compiler/xla/service/source_map_util.h b/tensorflow/compiler/xla/service/source_map_util.h index 84607cd012a9cff4eee5759b4235b2563692f84f..c5a7e17cb44c2b3b5ef145da0d66b4b3160f9531 100644 --- a/tensorflow/compiler/xla/service/source_map_util.h +++ b/tensorflow/compiler/xla/service/source_map_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_ +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/core/platform/macros.h" @@ -23,6 +24,19 @@ limitations under the License. namespace xla { namespace source_map_util { +// Creates an INVALID_ARGUMENT status with the given format string. +template +Status InvalidParameterArgument(const OpMetadata& op_metadata, + const absl::FormatSpec& format, + const Args&... args) { + string message = absl::StrFormat(format, args...); + if (!op_metadata.source_file().empty()) { + absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(), + op_metadata.source_line()); + } + return InvalidArgument("%s", message); +} + // Creates an INVALID_ARGUMENT status with the given format string. // // Also, attempts to extract the OpMetadata for parameter_number on executable @@ -30,15 +44,19 @@ namespace source_map_util { // // executable may be nullptr, but parameter_number should not be out of bounds // or a CHECK-failure may occur. +template Status InvalidParameterArgument(Executable* executable, int parameter_number, - const char* format, ...) - TF_PRINTF_ATTRIBUTE(3, 4); - -// As above, but takes the parameter metadata directly instead of extracting it -// from the executable. -Status InvalidParameterArgument(const OpMetadata& op_metadata, - const char* format, ...) - TF_PRINTF_ATTRIBUTE(2, 3); + const absl::FormatSpec& format, + const Args&... args) { + if (executable != nullptr && executable->has_module()) { + const HloModule& module = executable->module(); + const HloComputation& computation = *module.entry_computation(); + HloInstruction* param = computation.parameter_instruction(parameter_number); + const OpMetadata& metadata = param->metadata(); + return InvalidParameterArgument(metadata, format, args...); + } + return InvalidArgument(format, args...); +} } // namespace source_map_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index e0f995fd0d7cbabe5d1abd6af3d0c0005a8c9d48..b8d2d546e5d4dc67e3f314dfc6dcd4e8df5451c5 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -28,7 +29,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/notification.h" -using ::tensorflow::strings::StrCat; +using absl::StrCat; namespace xla { /* static */ tensorflow::mutex @@ -148,7 +149,7 @@ Status TransferManager::TransferArrayToDeviceAsync( if (dest.size() < GetByteSizeRequirement(on_device_shape)) { return FailedPrecondition( "Allocation on device not large enough for array: " - "%lld < %lld", + "%d < %d", dest.size(), GetByteSizeRequirement(on_device_shape)); } ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, @@ -165,12 +166,12 @@ void TransferManager::TransferArrayFromDevice( auto error = StrCat("Shape ", ShapeUtil::HumanString(shape), " has a differently shaped representation on-device: ", ShapeUtil::HumanString(HostShapeToDeviceShape(shape))); - return done(FailedPrecondition("%s", error.c_str())); + return done(FailedPrecondition("%s", error)); } if (source.size() < GetByteSizeRequirement(shape)) { return done( FailedPrecondition("Allocation on device not large enough for array: " - "%lld < %lld", + "%d < %d", source.size(), GetByteSizeRequirement(shape))); } ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, @@ -202,7 +203,7 @@ void TransferManager::TransferArrayFromDevice( return NotFound( "could not find registered transfer manager for platform %s -- check " "target linkage", - platform->Name().c_str()); + platform->Name()); } if (it->second.manager == nullptr) { @@ -253,7 +254,7 @@ Status TransferManager::TransferBufferFromDevice( if (source.size() < size) { return FailedPrecondition( "Source allocation on device not large enough for data tranfer: " - "%lld < %lld", + "%d < %d", source.size(), size); } stream->ThenMemcpy(destination, source, size); @@ -266,7 +267,7 @@ Status TransferManager::TransferBufferToDevice( if (destination->size() < size) { return FailedPrecondition( "Destination allocation on device not large enough for data tranfer: " - "%lld < %lld", + "%d < %d", destination->size(), size); } stream->ThenMemcpy(destination, source, size); @@ -277,9 +278,8 @@ StatusOr TransferManager::AllocateScopedShapedBuffer( const Shape& on_host_shape, DeviceMemoryAllocator* allocator, int device_ordinal) { if (!LayoutUtil::HasLayout(on_host_shape)) { - return InvalidArgument( - "Shape must have a layout: %s", - ShapeUtil::HumanStringWithLayout(on_host_shape).c_str()); + return InvalidArgument("Shape must have a layout: %s", + ShapeUtil::HumanStringWithLayout(on_host_shape)); } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape)); const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h index 71e8446452f072c22bb730cbda65a1743a95cd4c..3e5aa2db60ee31d9fbccf8f7256b15c1b8465335 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.h +++ b/tensorflow/compiler/xla/service/transpose_folding.h @@ -49,7 +49,7 @@ class TransposeFolding : public HloPassInterface { explicit TransposeFolding( TransposableGemmOperandsFn transposable_gemm_operands, TransposableConvOperandsFn transposable_conv_operands); - tensorflow::StringPiece name() const override { return "transpose-folding"; } + absl::string_view name() const override { return "transpose-folding"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 0c2f2112af5cdebe998f0d723528076b3c73d260..cf00ca102b1b4fd7e4953c6cff35f2b45a2caf2a 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -20,6 +20,9 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -27,17 +30,13 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { string BufferAlias::ToString() const { - return tensorflow::strings::StrCat("BufferAlias(", instruction_->name(), "[", - tensorflow::str_util::Join(index_, ","), - "])"); + return absl::StrCat("BufferAlias(", instruction_->name(), "[", + absl::StrJoin(index_, ","), "])"); } std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) { @@ -463,21 +462,20 @@ Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const { return FailedPrecondition( "LogicalBuffer %s is ill-defined: instruction %s does not define a " "buffer at that index", - buffer.ToString().c_str(), buffer.instruction()->name().c_str()); + buffer.ToString(), buffer.instruction()->name()); } } if (buffer.id() < 0 || buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) { - return FailedPrecondition( - "LogicalBuffer %s is ill-defined: invalid id %lld", - buffer.ToString().c_str(), buffer.id()); + return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d", + buffer.ToString(), buffer.id()); } if (GetBuffer(buffer.id()).instruction() != buffer.instruction() || GetBuffer(buffer.id()).index() != buffer.index()) { return FailedPrecondition( "LogicalBuffer %s is ill-defined: buffer with same id differs: %s", - buffer.ToString().c_str(), GetBuffer(buffer.id()).ToString().c_str()); + buffer.ToString(), GetBuffer(buffer.id()).ToString()); } return Status::OK(); @@ -496,8 +494,7 @@ StatusOr TuplePointsToAnalysis::GetBufferDefinedAt( if (buffers.size() != 1 || buffers[0]->instruction() != instruction) { return FailedPrecondition( "instruction %s does not define buffer at index {%s}", - instruction->name().c_str(), - tensorflow::str_util::Join(index, ",").c_str()); + instruction->name(), absl::StrJoin(index, ",")); } return buffers[0]; } @@ -558,13 +555,12 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet( } string TuplePointsToAnalysis::ToString() const { - string output = tensorflow::strings::Printf( - "TuplePointsToSet for module %s:\n", module_->name().c_str()); + string output = + absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name()); for (const auto* computation : module_->MakeNonfusionComputations()) { const char* entry = computation == module_->entry_computation() ? "entry " : ""; - tensorflow::strings::StrAppend(&output, entry, "computation ", - computation->name(), ":\n"); + absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n"); for (const HloInstruction* instruction : computation->MakeInstructionPostOrder()) { InstructionToString(instruction, &output); @@ -576,12 +572,11 @@ string TuplePointsToAnalysis::ToString() const { } } - tensorflow::strings::StrAppend(&output, "LogicalBuffers:\n"); + absl::StrAppend(&output, "LogicalBuffers:\n"); for (const auto& b : logical_buffer_analysis_->logical_buffers()) { - tensorflow::strings::StrAppend(&output, " buffer ", b->ToString(), ":\n"); + absl::StrAppend(&output, " buffer ", b->ToString(), ":\n"); for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) { - tensorflow::strings::StrAppend(&output, " alias ", alias.ToString(), - "\n"); + absl::StrAppend(&output, " alias ", alias.ToString(), "\n"); } } return output; @@ -590,20 +585,18 @@ string TuplePointsToAnalysis::ToString() const { void TuplePointsToAnalysis::InstructionToString( const HloInstruction* instruction, string* output) const { const string prefix = instruction->IsFused() ? " " : ""; - tensorflow::strings::StrAppend(output, prefix, " instruction ", - instruction->ToShortString(), ":\n"); + absl::StrAppend(output, prefix, " instruction ", + instruction->ToShortString(), ":\n"); const PointsToSet& points_to_set = GetPointsToSet(instruction); points_to_set.ForEachElement([&prefix, &output]( const ShapeIndex& index, const PointsToSet::BufferList& points_to) { - tensorflow::strings::StrAppend( - output, prefix, " {", tensorflow::str_util::Join(index, ","), "}: ", - tensorflow::str_util::Join( - points_to, ", ", - [](string* out, const LogicalBuffer* source) { - out->append(source->ToString()); - }), - "\n"); + absl::StrAppend(output, prefix, " {", absl::StrJoin(index, ","), "}: ", + absl::StrJoin(points_to, ", ", + [](string* out, const LogicalBuffer* source) { + out->append(source->ToString()); + }), + "\n"); }); } diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h index 750950188312c5077d487f2feef0606f07839432..8c91d6e69de637d58fa2ffc1a32ea65f09d3b6d8 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.h +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -30,7 +30,7 @@ class TupleSimplifier : public HloPassInterface { TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} explicit TupleSimplifier(bool exclude_entry_computation); ~TupleSimplifier() override {} - tensorflow::StringPiece name() const override { return "tuple-simplifier"; } + absl::string_view name() const override { return "tuple-simplifier"; } // Run tuple simplification on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h index 21fb8568a84985692026e145c363500a154a1599..2dba7d7f7574742a301e3503e353bbe57d72a203 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -54,7 +54,7 @@ class WhileLoopConstantSinking : public HloPassInterface { public: ~WhileLoopConstantSinking() override = default; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "while-loop-invariant-code-motion"; } diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 8e6cc8787576e4f041229da5cf8dd2b09194eb2a..2cdf20ce80362c0aeb9d8324573e7e9826cc018c 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -38,7 +38,7 @@ class WhileLoopInvariantCodeMotion : public HloPassInterface { : hoist_constants_(hoist_constants) {} ~WhileLoopInvariantCodeMotion() override = default; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "while-loop-invariant-code-motion"; } StatusOr Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 32e69c335b713c438bd7fcb2053709b0624f58ed..e14014b961d44cf723e1363e27c19c2e149c9057 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -28,6 +28,10 @@ namespace op = xla::testing::opcode_matchers; class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase { public: + WhileLoopInvariantCodeMotionTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + // Makes a computation which has one parameter, of the given shape, and always // returns PRED[]{true}. This is useful as a dummy loop condition. HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index a24e2b0116ef7b03eda9878c8ad684469e8b19e3..6a7bfe3f129d97866ccc54897d584fab0f7c683e 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -236,12 +236,11 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { << "Instruction " << user->ToString(print_no_metadata) << " should be unused (except by root of while body), but has " "users: {" - << tensorflow::str_util::Join( - user->users(), ", ", - [&](string* out, const HloInstruction* instr) { - tensorflow::strings::StrAppend( - out, instr->ToString(print_no_metadata)); - }) + << absl::StrJoin(user->users(), ", ", + [&](string* out, const HloInstruction* instr) { + absl::StrAppend( + out, instr->ToString(print_no_metadata)); + }) << "}"; replacements.emplace(user, nullptr); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index 3d3e1d60f294c3a2574513c1c2f071805a341ad1..78024f14dc89ff40a11bbc3602072fda1fe6f312 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -33,9 +33,7 @@ namespace xla { class WhileLoopSimplifier : public HloPassInterface { public: ~WhileLoopSimplifier() override {} - tensorflow::StringPiece name() const override { - return "simplify-while-loops"; - } + absl::string_view name() const override { return "simplify-while-loops"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 2e1571943e537f772ee7dcd95c80ba540445b76e..cfe4104f6d0afbb2a1c31aaf94ec53a0ba5e178e 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -15,11 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { @@ -27,6 +28,11 @@ namespace { namespace op = xla::testing::opcode_matchers; class WhileLoopSimplifierTest : public HloVerifiedTestBase { + public: + WhileLoopSimplifierTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: // Makes an HloModule that contains a loop with `num_iters` iteration. void MakeModuleWithSimpleLoop(int num_iters); @@ -64,10 +70,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { } )"; - string hlo_string = tensorflow::str_util::StringReplace( - hlo_string_template, "{{LOOP_BOUND}}", - tensorflow::strings::StrCat(42 + num_iters), - /*replace_all=*/true); + string hlo_string = absl::StrReplaceAll( + hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); ParseAndVerifyModule(hlo_string); } @@ -103,10 +107,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( } )"; - string hlo_string = tensorflow::str_util::StringReplace( - hlo_string_template, "{{LOOP_BOUND}}", - tensorflow::strings::StrCat(42 + num_iters), - /*replace_all=*/true); + string hlo_string = absl::StrReplaceAll( + hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); ParseAndVerifyModule(hlo_string); } diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 52d9c3e5ae71cc7d06acddd4717c16d3fbe9e8be..e8f76ff745a7871cd75294ff63c336cf1ce36f19 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -15,15 +15,15 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" #include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/tuple_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { -using tensorflow::strings::StrCat; +using absl::StrCat; static StatusOr WidenWhileCondition( HloComputation* narrow_condition, const Shape& wide_shape) { diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index 8763e588c484011ba2ccbc7cad8f29817347a605..a7f0e207eb5a81b04bb28977d6f5e38864ad2d6a 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -24,7 +24,7 @@ namespace xla { class ZeroSizedHloElimination : public HloPassInterface { public: StatusOr Run(HloModule* module) override; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "zero_sized_hlo_elimination"; } }; diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index caad31d6ce7ce35fa362ec364b0d7f1d95973715..d44db89d571891ecef554cd45c050017833982bb 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -25,8 +25,8 @@ namespace xla { Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { if (!ShapeUtil::Compatible(other_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", - ShapeUtil::HumanString(other_shape).c_str(), - ShapeUtil::HumanString(shape()).c_str()); + ShapeUtil::HumanString(other_shape), + ShapeUtil::HumanString(shape())); } shape_ = other_shape; return Status::OK(); @@ -35,8 +35,8 @@ Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", - ShapeUtil::HumanString(*to_shape).c_str(), - ShapeUtil::HumanString(shape()).c_str()); + ShapeUtil::HumanString(*to_shape), + ShapeUtil::HumanString(shape())); } *to_shape = shape_; return Status::OK(); diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 7244be80d9d53809398b2bf6e8b3fd14c86adb01..5477a78a9a44219eb9bce2ea56d31418c555a015 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -22,6 +22,13 @@ limitations under the License. #include #include +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -31,25 +38,22 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" namespace xla { -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); } string ShapeIndexView::ToString() const { - return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}"); + return StrCat("{", absl::StrJoin(indices_, ","), "}"); } bool ShapeIndexView::operator==(const ShapeIndexView& other) const { @@ -143,7 +147,7 @@ StatusOr MakeShapeWithLayoutInternal( } if (element_type == OPAQUE || element_type == TUPLE) { return InvalidArgument("Unsupported element type: %s", - PrimitiveType_Name(element_type).c_str()); + PrimitiveType_Name(element_type)); } Shape shape = ShapeUtil::MakeShape(element_type, dimensions); auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); @@ -449,14 +453,14 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( namespace { // Class to memoize the computation of -// tensorflow::str_util::Lowercase(PrimitiveType_Name(p)) +// absl::AsciiStrToLower(PrimitiveType_Name(p)) // for all PrimitiveType values "p" class PrimitiveTypeNameGenerator { public: PrimitiveTypeNameGenerator() { for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { if (PrimitiveType_IsValid(i)) { - lowercase_name_[i] = tensorflow::str_util::Lowercase( + lowercase_name_[i] = absl::AsciiStrToLower( PrimitiveType_Name(static_cast(i))); } } @@ -487,8 +491,7 @@ StatusOr StringToPrimitiveType(const string& name) { }(); auto found = name_to_type->find(name); if (found == name_to_type->end()) { - return InvalidArgument("Invalid element type string: \"%s\".", - name.c_str()); + return InvalidArgument("Invalid element type string: \"%s\".", name); } return found->second; } @@ -507,7 +510,7 @@ StatusOr StringToPrimitiveType(const string& name) { return text; } return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[", - tensorflow::str_util::Join(shape.dimensions(), ","), "]"); + absl::StrJoin(shape.dimensions(), ","), "]"); } /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { @@ -543,30 +546,29 @@ StatusOr StringToPrimitiveType(const string& name) { : "(unknown)", ": ", HumanString(shape))); } - return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ", + return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ", HumanString(program_shape.result())); } namespace { // Parses shapes with simple recursive descent structure -- consumes from the // front of s and passes that view recursively as required. -StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { - tensorflow::str_util::RemoveLeadingWhitespace(s); +StatusOr ParseShapeStringInternal(absl::string_view* s) { + *s = StripLeadingAsciiWhitespace(*s); - if (tensorflow::str_util::ConsumePrefix(s, "(")) { // Tuple. + if (absl::ConsumePrefix(s, "(")) { // Tuple. std::vector shapes; bool must_end = false; while (true) { - if (tensorflow::str_util::ConsumePrefix(s, ")")) { + if (absl::ConsumePrefix(s, ")")) { break; } else if (must_end) { - return InvalidArgument("Expected end of tuple; got: \"%s\"", - std::string(*s).c_str()); + return InvalidArgument("Expected end of tuple; got: \"%s\"", *s); } shapes.emplace_back(); TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); - tensorflow::str_util::RemoveLeadingWhitespace(s); - must_end = !tensorflow::str_util::ConsumePrefix(s, ","); + *s = StripLeadingAsciiWhitespace(*s); + must_end = !absl::ConsumePrefix(s, ","); } return ShapeUtil::MakeTupleShape(shapes); } @@ -575,9 +577,9 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { string dimensions_string; string format_string; string layout_string; - // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so + // absl::string_view is not compatible with internal RE2 StringPiece, so // we convert in to the RE2-consumable type and then consume the corresponding - // amount from our StringPiece type. + // amount from our string_view type. static LazyRE2 shape_pattern = { "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?"}; tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); @@ -585,12 +587,12 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { &dimensions_string, &format_string, &layout_string)) { size_t consumed = s->size() - s_consumable.size(); s->remove_prefix(consumed); - auto string_to_int64 = [&s](const string& input) -> StatusOr { + auto string_to_int64 = [&s](absl::string_view input) -> StatusOr { int64 element; - if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) { + if (!absl::SimpleAtoi(input, &element)) { return InvalidArgument( - "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", - input.c_str(), std::string(*s).c_str()); + "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", input, + *s); } return element; }; @@ -598,7 +600,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { auto comma_list_to_int64s = [string_to_int64](const string& input) -> StatusOr> { std::vector results; - for (const string& piece : tensorflow::str_util::Split(input, ',')) { + for (const auto& piece : absl::StrSplit(input, ',', absl::SkipEmpty())) { TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); results.push_back(element); } @@ -614,7 +616,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { StringToPrimitiveType(element_type_string)); if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) { return InvalidArgument("Invalid element type string: \"%s\".", - element_type_string.c_str()); + element_type_string); } Shape result; @@ -644,17 +646,14 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return std::move(result); } - return InvalidArgument("Invalid shape string to parse: \"%s\"", - std::string(*s).c_str()); + return InvalidArgument("Invalid shape string to parse: \"%s\"", *s); } } // namespace -/* static */ StatusOr ShapeUtil::ParseShapeString( - tensorflow::StringPiece s) { +/* static */ StatusOr ShapeUtil::ParseShapeString(absl::string_view s) { TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); if (!s.empty()) { - return InvalidArgument("Invalid shape string to parse: \"%s\"", - std::string(s).c_str()); + return InvalidArgument("Invalid shape string to parse: \"%s\"", s); } return shape; } @@ -819,7 +818,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { const Shape& shape) { if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { return InvalidArgument("shape has invalid element type: %s", - shape.ShortDebugString().c_str()); + shape.ShortDebugString()); } if (shape.element_type() == TUPLE) { if (shape.dimensions_size() != 0) { @@ -842,21 +841,21 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (shape.dimensions_size() != 0) { return InvalidArgument( "shape has %s element type, but has dimensions field: %s", - LowercasePrimitiveTypeName(shape.element_type()).c_str(), - shape.ShortDebugString().c_str()); + LowercasePrimitiveTypeName(shape.element_type()), + shape.ShortDebugString()); } if (shape.has_layout()) { return InvalidArgument( "shape has %s element type, but has layout field: %s", - LowercasePrimitiveTypeName(shape.element_type()).c_str(), - shape.ShortDebugString().c_str()); + LowercasePrimitiveTypeName(shape.element_type()), + shape.ShortDebugString()); } return Status::OK(); } if (Rank(shape) != shape.dimensions_size()) { return InvalidArgument( - "shape's rank is mismatched with dimension count; rank=%lld " + "shape's rank is mismatched with dimension count; rank=%d " "dimensions_size=%d", Rank(shape), shape.dimensions_size()); } @@ -864,9 +863,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { int64 dimension = shape.dimensions(i); if (dimension < 0) { return InvalidArgument( - "shape's dimensions must not be < 0; dimension at index %lld was " - "%lld", - i, dimension); + "shape's dimensions must not be < 0; dimension at index %d was %d", i, + dimension); } } @@ -931,7 +929,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (shape_size < 0) { return InvalidArgument("Shape %s size may overflow int64.", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } VLOG(3) << "Shape size is valid: " << shape_size; @@ -991,7 +989,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { i >= return_shape->tuple_shapes_size()) { return InvalidArgument( "Shape index %s not a valid subshape index for tuple with shape %s", - index.ToString().c_str(), shape.DebugString().c_str()); + index.ToString(), shape.DebugString()); } return_shape = &return_shape->tuple_shapes(i); } @@ -1172,8 +1170,7 @@ Status ForEachMutableSubshapeHelper( CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation))) << "shape=" << HumanStringWithLayout(shape) << ", new_shape=" << HumanStringWithLayout(new_shape) - << ", permutation={" << tensorflow::str_util::Join(permutation, ",") - << "}"; + << ", permutation={" << absl::StrJoin(permutation, ",") << "}"; } return new_shape; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index cb72fbbb0e2a289a23b61d3035df67442f96a792..83e58545bf9065aeb302328f296c416e7a7dd979 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -131,12 +131,12 @@ class ShapeIndexView { } ShapeIndexView ConsumeFront() const { ShapeIndexView result = *this; - result.indices_.pop_front(); + result.indices_.remove_prefix(1); return result; } ShapeIndexView ConsumeBack() const { ShapeIndexView result = *this; - result.indices_.pop_back(); + result.indices_.remove_suffix(1); return result; } ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); } @@ -228,7 +228,7 @@ class ShapeUtil { // Parses a ShapeUtil::HumanString-format shape string back into a shape // object. - static StatusOr ParseShapeString(tensorflow::StringPiece s); + static StatusOr ParseShapeString(absl::string_view s); // Returns whether the LHS and RHS shapes have the same dimensions; note: does // not check element type. diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index e5dd62ae9a3dd9b961a7ae03a99c19220dbd43e7..7549ba9c78025de06624f01d0e87956db27f4f9a 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" @@ -23,8 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { @@ -849,13 +849,13 @@ TEST(ShapeUtilTest, PermuteDimensionsLayout) { std::iota(layout.begin(), layout.end(), 0); do { Shape s = ShapeUtil::MakeShapeWithLayout(F32, {10, 100, 1000}, layout); - SCOPED_TRACE(tensorflow::strings::StrCat("s=", ShapeUtil::HumanString(s))); + SCOPED_TRACE(absl::StrCat("s=", ShapeUtil::HumanString(s))); std::vector permutation(3); std::iota(permutation.begin(), permutation.end(), 0); do { - SCOPED_TRACE(tensorflow::strings::StrCat( - "permutation=", tensorflow::str_util::Join(permutation, ","))); + SCOPED_TRACE( + absl::StrCat("permutation=", absl::StrJoin(permutation, ","))); // TransposeIsBitcast takes the inverse of the permutation that // PermuteDimensions takes. diff --git a/tensorflow/compiler/xla/status_macros.cc b/tensorflow/compiler/xla/status_macros.cc index a6b1f9004f096abb3b01d315938b0a23bea1ca48..b88fe367d7416a26c1147fd5e10fb20772814fe5 100644 --- a/tensorflow/compiler/xla/status_macros.cc +++ b/tensorflow/compiler/xla/status_macros.cc @@ -17,9 +17,8 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stacktrace.h" @@ -37,8 +36,7 @@ static void LogError(const Status& status, const char* filename, int line, if (TF_PREDICT_TRUE(log_severity != tensorflow::NUM_SEVERITIES)) { string stack_trace; if (should_log_stack_trace) { - stack_trace = - tensorflow::strings::StrCat("\n", tensorflow::CurrentStackTrace()); + stack_trace = absl::StrCat("\n", tensorflow::CurrentStackTrace()); } switch (log_severity) { case tensorflow::INFO: @@ -142,17 +140,15 @@ Status MakeErrorStream::Impl::GetStatus() { is_done_ = true; const string& stream_str = stream_.str(); - const string str = - prior_message_handling_ == kAppendToPriorMessage - ? tensorflow::strings::StrCat(prior_message_, stream_str) - : tensorflow::strings::StrCat(stream_str, prior_message_); + const string str = prior_message_handling_ == kAppendToPriorMessage + ? absl::StrCat(prior_message_, stream_str) + : absl::StrCat(stream_str, prior_message_); if (TF_PREDICT_FALSE(str.empty())) { - return MakeError(file_, line_, code_, - tensorflow::strings::StrCat( - str, "Error without message at ", file_, ":", line_), - true /* should_log */, - tensorflow::ERROR /* log_severity */, - should_log_stack_trace_); + return MakeError( + file_, line_, code_, + absl::StrCat(str, "Error without message at ", file_, ":", line_), + true /* should_log */, tensorflow::ERROR /* log_severity */, + should_log_stack_trace_); } else { return MakeError(file_, line_, code_, str, should_log_, log_severity_, should_log_stack_trace_); diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h index 8918350135fbb86973b228b35f5873fea8695b2f..3ede5e6e38a7a9e922fc0744f014c395dbd2324c 100644 --- a/tensorflow/compiler/xla/test_helpers.h +++ b/tensorflow/compiler/xla/test_helpers.h @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 6baf95d6317273c42f069b0ad7b5f8887160dd09..a0829b0d02562f97b957c0ff8ba536fff47b49c6 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -43,6 +43,7 @@ cc_library( "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], alwayslink = True, ) @@ -98,6 +99,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", ], ) @@ -205,6 +207,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -391,6 +394,7 @@ xla_test( "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -557,6 +561,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -671,6 +676,7 @@ xla_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -689,7 +695,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -697,6 +702,7 @@ xla_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -746,7 +752,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -754,6 +759,7 @@ xla_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -829,7 +835,10 @@ xla_test( timeout = "long", srcs = ["convolution_test.cc"], shard_count = 25, - deps = CONVOLUTION_TEST_DEPS + ["@com_google_absl//absl/memory"], + deps = CONVOLUTION_TEST_DEPS + [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], ) xla_test( @@ -839,7 +848,10 @@ xla_test( backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, backends = ["gpu"], shard_count = 25, - deps = CONVOLUTION_TEST_DEPS + ["@com_google_absl//absl/memory"], + deps = CONVOLUTION_TEST_DEPS + [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], ) xla_test( @@ -924,6 +936,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1001,6 +1014,8 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1110,7 +1125,6 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", @@ -1128,6 +1142,8 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1157,6 +1173,7 @@ xla_test_library( "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1231,12 +1248,12 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - ":client_library_test_base", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1247,12 +1264,12 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - ":client_library_test_base", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1413,7 +1430,6 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", @@ -1425,6 +1441,8 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1494,6 +1512,8 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1648,6 +1668,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1660,7 +1681,6 @@ xla_test( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:global_data", @@ -1671,6 +1691,7 @@ xla_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1851,13 +1872,9 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", - "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1866,6 +1883,7 @@ xla_test( "//tensorflow/core:lib", "//tensorflow/core:test", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -2026,6 +2044,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -2112,19 +2131,13 @@ xla_test( xla_test( name = "iota_test", srcs = ["iota_test.cc"], - blacklisted_backends = [ - "cpu", - "gpu", - ], + shard_count = 30, tags = [ "enable_for_xla_interpreter", ], deps = [ ":client_library_test_base", - ":literal_test_util", ":xla_internal_test_main", - "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:lib", - "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 316ab26a1fc59467fbcee31e6d4b3dbd9975045d..577fd1ab3b9268a66ea3f0c7e62b7d2644136d6e 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -296,6 +296,22 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { ComputeAndCompareR1(&b, expected, {lhs_data.get(), rhs_data.get()}); } +XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) { + XlaBuilder b(TestName()); + + std::vector lhs{static_cast(0x8000000000000000ULL)}; + std::unique_ptr lhs_literal = LiteralUtil::CreateR1({lhs}); + auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param"); + + std::vector rhs{static_cast(0x7FFFFFFFFFFFFFFFULL)}; + std::unique_ptr rhs_literal = LiteralUtil::CreateR1({rhs}); + auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param"); + + Lt(lhs_param, rhs_param); + + ComputeAndCompare(&b, {std::move(*lhs_literal), std::move(*rhs_literal)}); +} + TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { const int count = GetParam(); XlaBuilder builder(TestName()); @@ -498,8 +514,7 @@ XLA_TEST_F(IntegerDivideOpTest, DivS32s) { TestDivRem(dividends, divisors, quotients, remainders); } -XLA_TEST_F(IntegerDivideOpTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(SignedOverflow))) { +XLA_TEST_F(IntegerDivideOpTest, SignedOverflow) { std::vector dividends = {5, INT32_MIN}, divisors = {0, -1}, quotients = {-1, INT32_MIN}, remainders = {5, 0}; @@ -529,8 +544,7 @@ XLA_TEST_F(IntegerDivideOpTest, DivU32s) { TestDivRem(dividends, divisors, quotients, remainders); } -XLA_TEST_F(IntegerDivideOpTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(UnsignedOverflow))) { +XLA_TEST_F(IntegerDivideOpTest, UnsignedOverflow) { std::vector dividends = {5}, divisors = {0}, quotients = {-1}, remainders = {5}; diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 24b17b71007a1872462bed1f6b86ae1a5bb9922c..ac90a3adb6dbad30e3ef0b11438fb9a6fd6f8574 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" @@ -41,7 +42,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -382,7 +382,7 @@ struct BatchNormTestParam { friend ::std::ostream& operator<<(::std::ostream& os, const BatchNormTestParam& p) { - os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, "; + os << "bounds={" << absl::StrJoin(p.bounds, ", ") << "}, "; os << "feature_index=" << p.feature_index << ", "; os << "random_value_mean=" << p.random_value_mean << ", "; os << "random_value_var=" << p.random_value_var; diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 2cab3264a7ebe6ef515783a5df55ac5609cbe106..9cd974fd9bbb9f0f9bf316feb1c735106ed2bf07 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -196,8 +196,8 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( AsInt64Slice(expected.shape().dimensions()), minor_to_major); TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, &layout)); - verify_output(*actual, tensorflow::strings::StrCat( - "Test with output layout: ", + verify_output(*actual, + absl::StrCat("Test with output layout: ", ShapeUtil::HumanStringWithLayout(layout))); } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end())); return Status::OK(); @@ -258,7 +258,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( output_with_layout)); string error_message = "Test with input layouts: "; for (const auto& str : layout_strings) { - tensorflow::strings::StrAppend(&error_message, str, " "); + absl::StrAppend(&error_message, str, " "); } verify_output(*actual, error_message); return Status::OK(); @@ -391,7 +391,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } void ClientLibraryTestBase::ComputeAndCompareR1U8( - XlaBuilder* builder, tensorflow::StringPiece expected, + XlaBuilder* builder, absl::string_view expected, tensorflow::gtl::ArraySlice arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 24d0325929b66659f6b02ee5fd26ed6558b276e1..ac96d3e325b84a51201158906fe9342df736aec0 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -36,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" @@ -202,7 +202,7 @@ class ClientLibraryTestBase : public ::testing::Test { // Compare the result of the computation to a strings. In XLA strings are // represented using rank-1 U8 shapes. void ComputeAndCompareR1U8( - XlaBuilder* builder, tensorflow::StringPiece expected, + XlaBuilder* builder, absl::string_view expected, tensorflow::gtl::ArraySlice arguments); // Convenience method for running a built computation, transferring the diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 5a06d061f0d83fff547502495ff8ab13fb421b70..8226b6de3f780197bc0f1145b617dba99803927f 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/strings/match.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -145,8 +145,8 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); - EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(), - "depends on a parameter")) + EXPECT_TRUE( + absl::StrContains(value.status().ToString(), "depends on a parameter")) << value.status(); } } @@ -161,8 +161,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { EXPECT_FALSE(IsConstant(computation, &b)); auto value = ComputeConstantScalar(client, computation, &b); - EXPECT_TRUE(tensorflow::str_util::StrContains(value.status().ToString(), - "depends on a parameter")) + EXPECT_TRUE( + absl::StrContains(value.status().ToString(), "depends on a parameter")) << value.status(); } } diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc index b27c1044baf2c0002f166c53a81e4361c60d012a..25d10ab00af11b8ebb8147917e7cdbb21f9a42c4 100644 --- a/tensorflow/compiler/xla/tests/conditional_test.cc +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -642,5 +642,57 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) { test_swap(11.24f, 5.55f); } +// Test conditional that duplicates tuple elements in the then and else +// computations. This is a regression test for b/112550242. +XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) { + const Shape scalar = ShapeUtil::MakeShape(S32, {}); + const Shape tuple2 = ShapeUtil::MakeTupleShape({scalar, scalar}); + XlaComputation then_comp; + { + XlaBuilder builder(TestName() + ".then"); + auto p = Parameter(&builder, 0, tuple2, "then.p"); + auto e0 = GetTupleElement(p, 0); + auto e1 = GetTupleElement(p, 1); + Tuple(&builder, {e0, e1, e0}); + then_comp = builder.Build().ConsumeValueOrDie(); + } + XlaComputation else_comp; + { + XlaBuilder builder(TestName() + ".else"); + auto p = Parameter(&builder, 0, tuple2, "else.p"); + auto e0 = GetTupleElement(p, 0); + auto e1 = GetTupleElement(p, 1); + Tuple(&builder, {e0, e1, e1}); + else_comp = builder.Build().ConsumeValueOrDie(); + } + + { + // Pred is true case. + std::vector args; + args.push_back(std::move( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(123).get(), + LiteralUtil::CreateR0(-42).get()}))); + args.push_back(std::move(*LiteralUtil::CreateR0(true))); + XlaBuilder builder(TestName() + ".main"); + auto p = Parameter(&builder, 0, tuple2, "p0"); + auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); + Conditional(p_pred, p, then_comp, p, else_comp); + ComputeAndCompare(&builder, args); + } + { + // Pred is false case. + std::vector args; + args.push_back(std::move( + *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(123).get(), + LiteralUtil::CreateR0(-42).get()}))); + args.push_back(std::move(*LiteralUtil::CreateR0(false))); + XlaBuilder builder(TestName() + ".main"); + auto p = Parameter(&builder, 0, tuple2, "p0"); + auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1"); + Conditional(p_pred, p, then_comp, p, else_comp); + ComputeAndCompare(&builder, args); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 40658c3b775de0a38df4d6a629cab29b1fc83f2b..d2c6478b02423c93860244bc5eb91e652a3eac2e 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -35,8 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 0e9e92ed996fbb34826d19b670c7c4920a1aad13..5873516442fa63de47360acaa353abb3a97fe881 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -261,16 +262,14 @@ string PrintDotTestParam( const ::testing::TestParamInfo& test_param) { const DotTestParam& param = test_param.param; if (param.has_addend) { - return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, - "_MajorToMinor", - param.dot_lhs_row_major ? "T" : "F", - param.dot_rhs_row_major ? "T" : "F", - param.addend_row_major ? "T" : "F"); + return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor", + param.dot_lhs_row_major ? "T" : "F", + param.dot_rhs_row_major ? "T" : "F", + param.addend_row_major ? "T" : "F"); } else { - return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, - "_MajorToMinor", - param.dot_lhs_row_major ? "T" : "F", - param.dot_rhs_row_major ? "T" : "F"); + return absl::StrCat(param.m, "x", param.k, "x", param.n, "_MajorToMinor", + param.dot_lhs_row_major ? "T" : "F", + param.dot_rhs_row_major ? "T" : "F"); } } diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 39cc6c5927f1d416e31f689487efc10c20371abe..4a835a8e219d4b64fa144e12e9b4cbc41f45946f 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -16,13 +16,13 @@ limitations under the License. #include #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -39,8 +39,7 @@ class FloorCeilTest : public ClientLibraryTestBase { // Runs a computation and comparison on expected vs f(input) void TestR1F32(tensorflow::gtl::ArraySlice input, tensorflow::gtl::ArraySlice expected, Function f) { - LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ") - << "}"; + LOG(INFO) << "input: {" << absl::StrJoin(expected, ", ") << "}"; XlaBuilder builder(TestName()); auto c = ConstantR1(&builder, input); if (f == kCeil) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 5635c3fe86e87b1899d27d55d7231a793e00d425..93ea144438afa2d6f2f6c696f54d1ab1073081b8 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -43,7 +43,7 @@ namespace xla { namespace { using absl::optional; -using tensorflow::StringPiece; +using absl::string_view; using tensorflow::gtl::ArraySlice; constexpr char kInterpreter[] = "interpreter"; @@ -86,16 +86,20 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace -HloTestBase::HloTestBase(bool allow_mixed_precision_in_hlo_verifier) +HloTestBase::HloTestBase(bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier) : HloTestBase(GetTestPlatform(), GetReferencePlatform(), + verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier) {} HloTestBase::HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, + bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier) : test_runner_(test_platform), reference_runner_(reference_platform) { - hlo_verifier_ = - absl::make_unique(allow_mixed_precision_in_hlo_verifier); + hlo_verifier_ = absl::make_unique( + /*layout_sensitive=*/verifier_layout_sensitive, + /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier); } std::unique_ptr HloTestBase::CreateNewModule(const string& name) { @@ -239,7 +243,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompare( - const StringPiece hlo_string, const absl::optional& error, + string_view hlo_string, const absl::optional& error, const std::function& reference_preprocessor) { auto module_or_status = HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); @@ -252,7 +256,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( reference_preprocessor); } -::testing::AssertionResult HloTestBase::Run(const StringPiece hlo_string) { +::testing::AssertionResult HloTestBase::Run(string_view hlo_string) { auto module_or_status = HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); if (!module_or_status.ok()) { @@ -289,7 +293,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( - const StringPiece hlo_string, const absl::optional& error, + string_view hlo_string, const absl::optional& error, const std::function& reference_preprocessor) { auto module_or_status = HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); @@ -316,7 +320,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( } HloComputation* HloTestBase::FindComputation(HloModule* module, - tensorflow::StringPiece name) { + absl::string_view name) { auto computations = module->computations(); auto it = absl::c_find_if( computations, [&](HloComputation* c) { return c->name() == name; }); @@ -327,7 +331,7 @@ HloComputation* HloTestBase::FindComputation(HloModule* module, } HloInstruction* HloTestBase::FindInstruction(HloModule* module, - tensorflow::StringPiece name) { + absl::string_view name) { for (const HloComputation* c : module->computations()) { auto instructions = c->instructions(); auto it = absl::c_find_if( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index d88abf561a57b686eee309766fb7296ac42878e4..06bcc397417e0666c8c97f4286aba7d0b42a2d98 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -85,12 +85,14 @@ class HloTestBase : public ::testing::Test { // automatically finds another supported backend as the test backend. If the // interpreter is the only supported backend, it will be both the test backend // and the reference backend. - HloTestBase(bool allow_mixed_precision_in_hlo_verifier = true); + HloTestBase(bool verifier_layout_sensitive = false, + bool allow_mixed_precision_in_hlo_verifier = true); // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, + bool verifier_layout_sensitive = false, bool allow_mixed_precision_in_hlo_verifier = true); ~HloTestBase() override {} @@ -169,18 +171,18 @@ class HloTestBase : public ::testing::Test { // input. Module can be passed in directly, or parsed from an hlo_string, // or loaded from a file. ::testing::AssertionResult RunAndCompare( - const tensorflow::StringPiece hlo_string, + const absl::string_view hlo_string, const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; - ::testing::AssertionResult Run(const tensorflow::StringPiece hlo_string) + ::testing::AssertionResult Run(const absl::string_view hlo_string) TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareFromFile( const string& filename, const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; ::testing::AssertionResult RunAndCompareNoHloPasses( - const tensorflow::StringPiece hlo_string, + const absl::string_view hlo_string, const absl::optional& error, const std::function& reference_preprocessor = nullptr) TF_MUST_USE_RESULT; @@ -228,10 +230,8 @@ class HloTestBase : public ::testing::Test { // // This is useful for tests which create HLOs from a string and then want to // inspect a particular computation or instruction. - HloComputation* FindComputation(HloModule* module, - tensorflow::StringPiece name); - HloInstruction* FindInstruction(HloModule* module, - tensorflow::StringPiece name); + HloComputation* FindComputation(HloModule* module, absl::string_view name); + HloInstruction* FindInstruction(HloModule* module, absl::string_view name); // Return an HLO verifier constructed for the test backend. HloVerifier& verifier() const { return *hlo_verifier_; } diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index a509ee32078551c850232d0f36380e25321e00a0..8f86c528d0f346b0264948d592660911880f96d1 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -25,8 +25,11 @@ limitations under the License. namespace xla { -HloVerifiedTestBase::HloVerifiedTestBase() - : shape_verifier_(absl::make_unique()) {} +HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, + bool allow_mixed_precision) + : HloTestBase( + /*verifier_layout_sensitive=*/layout_sensitive, + /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {} HloVerifiedTestBase::~HloVerifiedTestBase() { // We can't call the ASSERT or EXPECT test macros in destructors, so we @@ -51,8 +54,7 @@ void HloVerifiedTestBase::TearDown() { } void HloVerifiedTestBase::VerifyModule(HloModule* module) { - HloVerifier verifier(/*allow_mixed_precision=*/true); - xla::StatusOr mutated = verifier.Run(module); + xla::StatusOr mutated = verifier().Run(module); if (!mutated.ok()) { ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); } else { @@ -73,7 +75,7 @@ HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { return modules_.back().get(); } -void HloVerifiedTestBase::ParseAndVerifyModule(tensorflow::StringPiece hlo_text, +void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text, const HloModuleConfig& config) { CHECK(!module_) << "Called ParseModule when test already has a module."; TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config)); diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index 5b28c01c369fa1ae1c7941f5c8139882c4dbed08..cc6967feed47b74846814454d550b38a474f3a04 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -29,7 +29,8 @@ namespace xla { // performs verification on that module on tear-down. class HloVerifiedTestBase : public HloTestBase { protected: - HloVerifiedTestBase(); + explicit HloVerifiedTestBase(bool layout_sensitive, + bool allow_mixed_precision); ~HloVerifiedTestBase() override; // Constructs a default shape verifier. @@ -44,32 +45,28 @@ class HloVerifiedTestBase : public HloTestBase { // Returns the default HloModule, lazily creating it if necessary via // HloTestBase::CreateNewModule(). HloModule& module(); - void ParseAndVerifyModule(tensorflow::StringPiece hlo_text, + void ParseAndVerifyModule(absl::string_view hlo_text, const HloModuleConfig& config = HloModuleConfig()); - // Sets the shape-size function used during hlo verification. If this isn't - // called, a default ShapeVerifier is used instead. - void SetShapeVerifier(std::unique_ptr shape_verifier) { - shape_verifier_ = std::move(shape_verifier); - } - // Creates a new module for a test, and stores it in modules_ so it can be // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent // creation of unverified modules. HloModule* CreateNewModule(const string& name = TestName()); + private: + void VerifyModule(HloModule* module); + // It is confusing to store modules created by module() and CreateNewModule() // in different fields, but it allows us to migrate tests to // HloVerifiedTestBase more easily, so it's a win because we can verify more // modules. See b/80488902. - private: + // // Lazily populated. Access via module(). std::unique_ptr module_; // Populated by calls to CreateNewModule. std::vector> modules_; - std::unique_ptr shape_verifier_; + bool tear_down_called_ = false; - static void VerifyModule(HloModule* module); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/iota_test.cc b/tensorflow/compiler/xla/tests/iota_test.cc index 17ac95ae0198d98490b25f7f2edd32d1e0495803..07c3c6b878866191b3e0a389b440e11ce7454bf6 100644 --- a/tensorflow/compiler/xla/tests/iota_test.cc +++ b/tensorflow/compiler/xla/tests/iota_test.cc @@ -23,40 +23,95 @@ limitations under the License. namespace xla { namespace { -class IotaTest : public ClientLibraryTestBase { - public: - explicit IotaTest(se::Platform* platform = nullptr) - : ClientLibraryTestBase(platform) {} - template - std::vector GetExpected(const int64 num_elements) { - std::vector result(num_elements); - std::iota(result.begin(), result.end(), 0); - return result; +template +std::vector GetR1Expected(const int64 num_elements) { + std::vector result(num_elements); + std::iota(result.begin(), result.end(), 0); + return result; +} + +class IotaR1Test + : public ClientLibraryTestBase, + public ::testing::WithParamInterface> {}; + +TEST_P(IotaR1Test, DoIt) { + const auto& spec = GetParam(); + const auto element_type = std::get<0>(spec); + const int64 num_elements = std::get<1>(spec); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + IotaGen(&builder, element_type, num_elements); + if (element_type == F32) { + ComputeAndCompareR1(&builder, GetR1Expected(num_elements), {}, + ErrorSpec{0.0001}); + } else if (element_type == U32) { + ComputeAndCompareR1(&builder, GetR1Expected(num_elements), + {}); + } else { + CHECK_EQ(element_type, S32); + ComputeAndCompareR1(&builder, GetR1Expected(num_elements), + {}); } -}; - -XLA_TEST_F(IotaTest, SimpleR1) { - for (int num_elements = 1; num_elements < 10000001; num_elements *= 10) { - { - XlaBuilder builder(TestName() + "_f32"); - IotaGen(&builder, F32, num_elements); - ComputeAndCompareR1(&builder, GetExpected(num_elements), {}, - ErrorSpec{0.0001}); - } - { - XlaBuilder builder(TestName() + "_u32"); - IotaGen(&builder, U32, num_elements); - ComputeAndCompareR1(&builder, GetExpected(num_elements), - {}); - } - { - XlaBuilder builder(TestName() + "_s32"); - IotaGen(&builder, S32, num_elements); - ComputeAndCompareR1(&builder, GetExpected(num_elements), - {}); - } +} + +INSTANTIATE_TEST_CASE_P(IotaR1TestInstantiation, IotaR1Test, + ::testing::Combine(::testing::Values(F32, U32, S32), + ::testing::Range(/*start=*/10, + /*end=*/10001, + /*step=*/10))); + +class IotaR2Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface< + std::tuple> {}; + +TEST_P(IotaR2Test, DoIt) { + const auto& spec = GetParam(); + const auto element_type = std::get<0>(spec); + const int64 num_elements = std::get<1>(spec); + const int64 iota_dim = std::get<2>(spec); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + std::vector dimensions = {42}; + dimensions.insert(dimensions.begin() + iota_dim, num_elements); + IotaGen(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); + if (primitive_util::IsFloatingPointType(element_type)) { + ComputeAndCompare(&builder, {}, ErrorSpec{0.0001}); + } else { + ComputeAndCompare(&builder, {}); } } +INSTANTIATE_TEST_CASE_P(IotaR2TestInstantiation, IotaR2Test, + ::testing::Combine(::testing::Values(F32, S32), + ::testing::Range(/*start=*/10, + /*end=*/1001, + /*step=*/10), + ::testing::Values(0, 1))); + +class IotaR3Test : public ClientLibraryTestBase, + public ::testing::WithParamInterface< + std::tuple> {}; + +TEST_P(IotaR3Test, DoIt) { + const auto& spec = GetParam(); + const auto element_type = std::get<0>(spec); + const int64 num_elements = std::get<1>(spec); + const int64 iota_dim = std::get<2>(spec); + XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type)); + std::vector dimensions = {42, 19}; + dimensions.insert(dimensions.begin() + iota_dim, num_elements); + IotaGen(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim); + if (primitive_util::IsFloatingPointType(element_type)) { + ComputeAndCompare(&builder, {}, ErrorSpec{0.0001}); + } else { + ComputeAndCompare(&builder, {}); + } +} + +INSTANTIATE_TEST_CASE_P(IotaR3TestInstantiation, IotaR3Test, + ::testing::Combine(::testing::Values(F32, S32), + ::testing::Range(/*start=*/10, + /*end=*/1001, + /*step=*/10), + ::testing::Values(0, 1, 2))); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index a4e3a998fc48c364b8a61169167039d1c1ed28de..554eb24d44168caa7d7252015e3d99f2d567df9b 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/literal_comparison.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -35,8 +35,7 @@ void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) { int64 now_usec = tensorflow::Env::Default()->NowMicros(); string filename = tensorflow::io::JoinPath( tensorflow::testing::TmpDir(), - tensorflow::strings::Printf("tempfile-%s-%llx-%s", get_hostname().c_str(), - now_usec, name.c_str())); + absl::StrFormat("tempfile-%s-%x-%s", get_hostname(), now_usec, name)); TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename, literal.ToProto())); LOG(ERROR) << "wrote to " << name << " file: " << filename; diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index f297b2b847f570d26e71ddcd8e34bc626f982e1f..4151bfae0332ffc706ba730d181c487eabab856f 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -80,7 +80,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { std::vector results; TF_CHECK_OK(env->GetMatchingPaths(pattern, &results)); - LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]"; + LOG(INFO) << "results: [" << absl::StrJoin(results, ", ") << "]"; EXPECT_EQ(3, results.size()); for (const string& result : results) { LiteralProto literal_proto; @@ -105,8 +105,10 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { auto actual = LiteralUtil::CreateR1({4, 5, 6}); ::testing::AssertionResult result = LiteralTestUtil::Equal(*expected, *actual); - EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}")); - EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}")); + EXPECT_THAT(result.message(), + ::testing::HasSubstr("Expected literal:\n{1, 2, 3}")); + EXPECT_THAT(result.message(), + ::testing::HasSubstr("Actual literal:\n{4, 5, 6}")); } TEST(LiteralTestUtilTest, NearComparatorR1) { diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index b6035a21a6709120c4b950382a6d248435f970c8..edb592f43ec778a3fe6e5ef936827dd612791760 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -32,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -135,8 +136,7 @@ class TestLinspaceMaxParametric MakeLinspaceArray2D(from, to, rows, cols); auto arhs = absl::make_unique>(rows, cols, static_cast(1.0f)); - XlaBuilder builder( - tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); + XlaBuilder builder(absl::StrFormat("max_%dx%d_linspace", rows, cols)); auto lhs = ConstantR2FromArray2D(&builder, *alhs); auto rhs = ConstantR2FromArray2D(&builder, *arhs); Max(lhs, rhs); @@ -158,7 +158,7 @@ class TestLinspaceMaxParametric string PrintTestLinspaceMaxParam( const ::testing::TestParamInfo& test_param) { const TestLinspaceMaxParam& param = test_param.param; - return tensorflow::strings::StrCat(param.rows, "r", param.cols, "c"); + return absl::StrCat(param.rows, "r", param.cols, "c"); } #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index cadf1c5523afdd61e4252185a123defdd8aa2c27..16b77e965d11fa136529e70796d11c520962ef28 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -52,12 +53,22 @@ class MultiOutputFusionTest : public HloTestBase { protected: MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; } + // Layout assignment assumes that there are no fusions in the input graph. + // Since the purpose of this test is to send pre-fused graphs to XLA, we have + // to do layout assignment ourselves. + DebugOptions GetDebugOptionsForTest() override { + auto opts = HloTestBase::GetDebugOptionsForTest(); + opts.add_xla_disable_hlo_passes("layout-assignment"); + return opts; + } + void RunTest2D(bool manual_fusion, int64 size) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - const Shape elem_shape0 = ShapeUtil::MakeShape(F32, {}); - const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size}); + const Shape elem_shape0 = ShapeUtil::MakeShapeWithLayout(F32, {}, {}); + const Shape elem_shape2 = + ShapeUtil::MakeShapeWithLayout(F32, {size, size}, {1, 0}); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(8.0f))); @@ -100,10 +111,10 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal arg1(ShapeUtil::MakeShape(F32, {size, size})); + Literal arg1(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); arg1.PopulateWithValue(2.5f); - Literal expect(ShapeUtil::MakeShape(F32, {size, size})); + Literal expect(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, size})); expect.PopulateWithValue(size * 1.5f * 3.5f); auto actual = ExecuteAndTransfer(std::move(hlo_module), @@ -115,8 +126,10 @@ class MultiOutputFusionTest : public HloTestBase { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); - const Shape elem_shape_F32 = ShapeUtil::MakeShape(F32, {size}); - const Shape elem_shape_U8 = ShapeUtil::MakeShape(F64, {size}); + const Shape elem_shape_F32 = + ShapeUtil::MakeShapeWithDescendingLayout(F32, {size}); + const Shape elem_shape_U8 = + ShapeUtil::MakeShapeWithDescendingLayout(F64, {size}); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, elem_shape_F32, "0")); auto param1 = builder.AddInstruction( @@ -136,12 +149,13 @@ class MultiOutputFusionTest : public HloTestBase { HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {size, 1}), add)); + ShapeUtil::MakeShapeWithDescendingLayout(F32, {size, 1}), add)); DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( - ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums)); + ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape, + dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -161,9 +175,9 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal input0(ShapeUtil::MakeShape(F32, {size})); + Literal input0(ShapeUtil::MakeShapeWithDescendingLayout(F32, {size})); input0.PopulateWithValue(2.5f); - Literal input1(ShapeUtil::MakeShape(F64, {size})); + Literal input1(ShapeUtil::MakeShapeWithDescendingLayout(F64, {size})); input1.PopulateWithValue(1.); Literal expect = @@ -291,7 +305,7 @@ const char* const kScalarOps = R"( XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -323,7 +337,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -355,7 +369,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -388,7 +402,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -422,7 +436,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -457,7 +471,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -494,7 +508,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) init1 = f32[] parameter(1) @@ -529,7 +543,7 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) { - const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + const string testcase = absl::StrCat(kScalarOps, R"( fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) { p0 = f16[2,2,2]{2,1,0} parameter(0) convert = f32[2,2,2]{2,1,0} convert(p0) diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index a080dd1732bde21712cf47b4b57538cf4040f30e..9af9ea4a2229bb6ca7c3561350f11837f5072a2c 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -15,11 +15,11 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -29,16 +29,13 @@ limitations under the License. namespace xla { namespace { -namespace str_util = tensorflow::str_util; -namespace strings = tensorflow::strings; - struct ReduceLayout { std::array input_minor_to_major; std::array output_minor_to_major; string ToString() const { - return strings::StrCat(str_util::Join(input_minor_to_major, "x"), "_", - str_util::Join(output_minor_to_major, "x")); + return absl::StrCat(absl::StrJoin(input_minor_to_major, "x"), "_", + absl::StrJoin(output_minor_to_major, "x")); } }; diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 531648fe3eb8e3941c5e3c012847ee68c616590f..0916a07f4fa99af6cf25441fa8558a558bfa032f 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -57,8 +58,8 @@ static const int mantissa_sizes[] = {23, 10, 23, 10}; string TestDataToString(const ::testing::TestParamInfo data) { int i = data.param; - return tensorflow::strings::StrCat(exponent_sizes[i], "_exponent_bits_", - mantissa_sizes[i], "_mantissa_bits"); + return absl::StrCat(exponent_sizes[i], "_exponent_bits_", mantissa_sizes[i], + "_mantissa_bits"); } // The FPVAL macro allows us to write out the binary representation of the diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 2065271a7f686c52c88df80b0efe8f2e1542d198..346f70248864306dada5276b309482a0dd65e63e 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -32,6 +32,8 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -556,12 +558,11 @@ struct BoundsLayout { }; void PrintTo(const BoundsLayout& spec, std::ostream* os) { - *os << tensorflow::strings::Printf( - "R%luToR%lu%s_%s_Reduce%s", spec.bounds.size(), - spec.bounds.size() - spec.reduce_dims.size(), - tensorflow::str_util::Join(spec.bounds, "x").c_str(), - tensorflow::str_util::Join(spec.layout, "").c_str(), - tensorflow::str_util::Join(spec.reduce_dims, "").c_str()); + *os << absl::StrFormat("R%uToR%u%s_%s_Reduce%s", spec.bounds.size(), + spec.bounds.size() - spec.reduce_dims.size(), + absl::StrJoin(spec.bounds, "x"), + absl::StrJoin(spec.layout, ""), + absl::StrJoin(spec.reduce_dims, "")); } // Add-reduces a broadcasted scalar matrix among dimension 1 and 0. diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index ebf7fa30be43016217eca781054f01f9c3f536b1..60167619a4eb89b3275cc728300c41419ce80c60 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -579,21 +581,20 @@ string R4ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // - "__window_bounds_", - tensorflow::str_util::Join(param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(param.strides, "x"), // - "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), // - "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), // - "__layout_", tensorflow::str_util::Join(param.layout, "_"), // + string str = absl::StrCat( + "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // + "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), // + "__strides_", absl::StrJoin(param.strides, "x"), // + "__pad_low_", absl::StrJoin(param.pad_low, "x"), // + "__pad_high_", absl::StrJoin(param.pad_high, "x"), // + "__layout_", absl::StrJoin(param.layout, "_"), // (param.reducer == kAdd) ? "_add" : "_max"); CHECK(param.reducer == kAdd || param.reducer == kMax); // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + str = absl::StrCat(str, "_bfloat16"); } return str; } @@ -935,15 +936,15 @@ string R3ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), - "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"), - "__strides_", tensorflow::str_util::Join(param.strides, "x"), - "__padding_", param.padding == Padding::kSame ? "same" : "valid", - "__layout_", param.layout[0], "_", param.layout[1], "_", param.layout[2], - "__reducer_", param.reducer == kAdd ? "add" : "max"); + string str = absl::StrCat( + "base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_", + absl::StrJoin(param.window_bounds, "x"), "__strides_", + absl::StrJoin(param.strides, "x"), "__padding_", + param.padding == Padding::kSame ? "same" : "valid", "__layout_", + param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_", + param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + str = absl::StrCat(str, "_bfloat16"); } return str; } @@ -1069,17 +1070,16 @@ string R2ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // - "__window_bounds_", - tensorflow::str_util::Join(param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(param.strides, "x"), // - "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), - "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), - "__layout_", param.layout[0], "_", param.layout[1], // + string str = absl::StrCat( + "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // + "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), // + "__strides_", absl::StrJoin(param.strides, "x"), // + "__pad_low_", absl::StrJoin(param.pad_low, "x"), "__pad_high_", + absl::StrJoin(param.pad_high, "x"), "__layout_", param.layout[0], "_", + param.layout[1], // "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + str = absl::StrCat(str, "_bfloat16"); } return str; } @@ -1274,15 +1274,15 @@ string R1ReduceWindowTestDataToString( const ::testing::TestParamInfo< ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); - string str = tensorflow::strings::StrCat( - "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), - "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"), - "__strides_", tensorflow::str_util::Join(param.strides, "x"), - "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), - "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), - "__reducer_", param.reducer == kAdd ? "add" : "max"); + string str = + absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"), + "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), + "__strides_", absl::StrJoin(param.strides, "x"), + "__pad_low_", absl::StrJoin(param.pad_low, "x"), + "__pad_high_", absl::StrJoin(param.pad_high, "x"), + "__reducer_", param.reducer == kAdd ? "add" : "max"); if (::testing::get<1>(data.param)) { - str = tensorflow::strings::StrCat(str, "_bfloat16"); + str = absl::StrCat(str, "_bfloat16"); } return str; } diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 41e49b4003236d55d85592315652a0ddefd5c485..c755ff63c904c893928ba08bd5e0fbedc4f2b70f 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -15,6 +15,8 @@ limitations under the License. #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -42,11 +44,9 @@ struct ReverseSpec { bool use_bfloat16; string ToTestCaseName() const { - return tensorflow::strings::Printf( - "reverse_%s_in_dims_%s_%s", - tensorflow::str_util::Join(input_dims, "x").c_str(), - tensorflow::str_util::Join(reversal, "x").c_str(), - use_bfloat16 ? "bf16" : "f32"); + return absl::StrFormat( + "reverse_%s_in_dims_%s_%s", absl::StrJoin(input_dims, "x"), + absl::StrJoin(reversal, "x"), use_bfloat16 ? "bf16" : "f32"); } }; diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index e42c71eb284deb2e50d6ea4b47fa707e4bc14ffc..cf2d453f43cda88ca05ab211a9b8be6e9c3e7c63 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index d865c414fd6ebeb98490278354c8f8a2c6571a23..69585ae39a72a87bd141d63c3926413ba05fe8c0 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -19,6 +19,9 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -27,15 +30,12 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { -using ::tensorflow::str_util::Join; - class SliceTest : public ClientLibraryTestBase {}; TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) { @@ -223,9 +223,8 @@ class SliceR1LargeTest : public SliceR1Test {}; string SliceR1TestDataToString(const ::testing::TestParamInfo& data) { const R1Spec& spec = data.param; - return ::tensorflow::strings::Printf("%lld_%lld_%lld_%lld", spec.input_dim0, - spec.slice_start, spec.slice_limit, - spec.slice_stride); + return absl::StrFormat("%d_%d_%d_%d", spec.input_dim0, spec.slice_start, + spec.slice_limit, spec.slice_stride); } XLA_TEST_P(SliceR1Test, DoIt_F32) { Run(GetParam()); } @@ -449,13 +448,11 @@ struct R4Spec { string R4SpecToString(const ::testing::TestParamInfo& data) { const R4Spec& spec = data.param; - return tensorflow::strings::StrCat( // - "input_", Join(spec.input_dims, "x"), // - "__layout_", Join(spec.input_layout, ""), // - "__starts_", Join(spec.slice_starts, "x"), // - "__limits_", Join(spec.slice_limits, "x"), // - "__strides_", Join(spec.slice_strides, "x") // - ); + return absl::StrCat("input_", absl::StrJoin(spec.input_dims, "x"), + "__layout_", absl::StrJoin(spec.input_layout, ""), + "__starts_", absl::StrJoin(spec.slice_starts, "x"), + "__limits_", absl::StrJoin(spec.slice_limits, "x"), + "__strides_", absl::StrJoin(spec.slice_strides, "x")); } class SliceR4Test : public ClientLibraryTestBase, diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc index be35ec6c6ee4c015755622b2dc9bb92e23af7c85..a9874a918659f1d7403ba0c5cb968e62d7091936 100644 --- a/tensorflow/compiler/xla/tests/test_macros.cc +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -20,7 +20,9 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/strings/str_util.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/regexp.h" @@ -44,7 +46,7 @@ ManifestT ReadManifest() { string contents((std::istreambuf_iterator(file_stream)), std::istreambuf_iterator()); - std::vector lines = tensorflow::str_util::Split(contents, '\n'); + std::vector lines = absl::StrSplit(contents, '\n'); for (string& line : lines) { auto comment = line.find("//"); if (comment != string::npos) { @@ -53,8 +55,8 @@ ManifestT ReadManifest() { if (line.empty()) { continue; } - tensorflow::str_util::StripTrailingWhitespace(&line); - std::vector pieces = tensorflow::str_util::Split(line, ' '); + absl::StripTrailingAsciiWhitespace(&line); + std::vector pieces = absl::StrSplit(line, ' '); CHECK_GE(pieces.size(), 1); auto& platforms = manifest[pieces[0]]; for (int64 i = 1; i < pieces.size(); ++i) { @@ -73,8 +75,7 @@ string PrependDisabledIfIndicated(const string& test_case_name, // First try full match: test_case_name.test_name // If that fails, try to find just the test_case_name; this would disable all // tests in the test case. - auto it = manifest.find( - tensorflow::strings::StrCat(test_case_name, ".", test_name)); + auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name)); if (it == manifest.end()) { it = manifest.find(test_case_name); if (it == manifest.end()) { diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 2f1d97b25d5c3e5116256a6303859bbcdb45218e..776f93d9f73430be34bce9e5b7e64c19fe53d07c 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -194,7 +194,7 @@ StatusOr> MakeFakeLiteralInternal( break; default: return Unimplemented("Unsupported type for fake literal generation: %s", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } return std::move(literal); } @@ -342,7 +342,7 @@ StatusOr> CreateLiteralForConstrainedUses( default: return Unimplemented( "Constrained operand generation not implemented for %s.", - use->ToString().c_str()); + use->ToString()); } } int constraint_count = 0; @@ -408,8 +408,12 @@ StatusOr>> MakeFakeArguments( return std::move(arguments); } -Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision) { - return HloVerifier(allow_mixed_precision).Run(module).status(); +Status VerifyHloModule(HloModule* const module, bool layout_sensitive, + bool allow_mixed_precision) { + return HloVerifier(/*layout_sensitive=*/layout_sensitive, + /*allow_mixed_precision=*/allow_mixed_precision) + .Run(module) + .status(); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 1aca1d8ef7e714c7ebb4d522f0d2dd28992fd16b..277d53d4231d471897d4f0c47d297653ff5561d3 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -95,8 +95,8 @@ StatusOr>> MakeFakeArguments( // Check that a given module satisfies various constraints before trying to // execute it. -Status VerifyHloModule(HloModule* const module, - bool allow_mixed_precision = false); +Status VerifyHloModule(HloModule* const module, bool layout_sensitive, + bool allow_mixed_precision); } // namespace xla diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc index 2bdbd08309a81b201fc224110805549f7fb5bb55..c7eb9e2dbe0e27b7933f5861280a3401cd268c08 100644 --- a/tensorflow/compiler/xla/tests/token_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -15,11 +15,10 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -67,7 +66,10 @@ XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); module->AddEntryComputation(builder.Build()); - Status status = HloVerifier().Run(module.get()).status(); + Status status = + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT( status.error_message(), @@ -84,7 +86,10 @@ XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { "param")); module->AddEntryComputation(builder.Build()); - Status status = HloVerifier().Run(module.get()).status(); + Status status = + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT( status.error_message(), @@ -101,7 +106,10 @@ XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); module->AddEntryComputation(builder.Build()); - Status status = HloVerifier().Run(module.get()).status(); + Status status = + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status(); ASSERT_IS_NOT_OK(status); EXPECT_THAT(status.error_message(), ::testing::HasSubstr( diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 20ae68ab74026936c43e5f525eb796eb402a19cb..8f80a9f3e466d73f2b718452d9a0d64a80c3b36f 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -190,25 +190,6 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR1) { SignAbsTestHelper(); } -XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { - XlaBuilder builder(TestName()); - auto arg = ConstantR1( - &builder, {2, 25, 0, 123, std::numeric_limits::max()}); - Abs(arg); - - ComputeAndCompareR1( - &builder, {2, 25, 0, 123, std::numeric_limits::max()}, {}); -} - -XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) { - XlaBuilder builder(TestName()); - auto arg = ConstantR1( - &builder, {2, 25, 0, 123, std::numeric_limits::max()}); - Sign(arg); - - ComputeAndCompareR1(&builder, {1, 1, 0, 1, 1}, {}); -} - XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { XlaBuilder builder(TestName()); auto arg = ConstantR2(&builder, {{1.0, -2.0}, {-3.0, 4.0}}); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index e12e095ecdef1d79d29e619f1cf88e91a577e0fd..6a7ddd9b55b8ff72a61df5f718f501f02b37302e 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -17,6 +17,9 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -30,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -82,8 +84,7 @@ struct ParsedProfileOutputLine { Status ParseOneProfileOutputLine( const string& line, bool expect_hlo, gtl::FlatMap* parsed_results, - tensorflow::gtl::ArraySlice opcodes_to_ignore = - {}) { + tensorflow::gtl::ArraySlice opcodes_to_ignore = {}) { string separator = "[^:]*:: +"; string match_percentage = R"(\d+\.\d*% +\d+Σ)"; string match_cycles = R"((\d+) cycles +\( *()" + match_percentage + R"()\))"; @@ -100,7 +101,7 @@ Status ParseOneProfileOutputLine( string match_opcode = expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" : "(\\[total\\])"; - string regexp_pattern = tensorflow::strings::StrCat( + string regexp_pattern = absl::StrCat( " +", match_cycles, separator, match_usecs, separator, match_flops, separator, match_trops, separator, match_bytes_per_sec, separator, match_bytes_per_cycle, separator, match_opcode); @@ -205,7 +206,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { rhs_shape); std::vector profile_output_lines = - tensorflow::str_util::Split(profile_output, '\n'); + absl::StrSplit(profile_output, '\n'); gtl::FlatMap parsed_profile_lines; @@ -292,22 +293,20 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { matrix_shape); std::vector profile_output_lines = - tensorflow::str_util::Split(profile_output, '\n'); + absl::StrSplit(profile_output, '\n'); auto while_body_profile_start = - absl::c_find_if(profile_output_lines, [](tensorflow::StringPiece s) { - return tensorflow::str_util::StartsWith(s, - "Execution profile for body"); + absl::c_find_if(profile_output_lines, [](absl::string_view s) { + return absl::StartsWith(s, "Execution profile for body"); }); ASSERT_NE(while_body_profile_start, profile_output_lines.cend()); - auto while_body_profile_end = - std::find_if(while_body_profile_start, profile_output_lines.end(), - [](tensorflow::StringPiece s) { - return tensorflow::str_util::StartsWith( - s, "********** microseconds report **********"); - }); + auto while_body_profile_end = std::find_if( + while_body_profile_start, profile_output_lines.end(), + [](absl::string_view s) { + return absl::StartsWith(s, "********** microseconds report **********"); + }); // We emit a blank line before the "********** microseconds report **********" // line. diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index a075195618c42aaa11f7b1c17730e67889a2c308..15603619b62d8f45cdce97ac7d83924a78f88cf3 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -32,16 +32,14 @@ GTEST_API_ int main(int argc, char** argv) { // If the --benchmarks flag is passed in then only run the benchmarks, not the // tests. for (int i = 1; i < argc; i++) { - tensorflow::StringPiece arg(argv[i]); - if (arg == "--benchmarks" || - tensorflow::str_util::StartsWith(arg, "--benchmarks=")) { + absl::string_view arg(argv[i]); + if (arg == "--benchmarks" || absl::StartsWith(arg, "--benchmarks=")) { const char* pattern = nullptr; - if (tensorflow::str_util::StartsWith(arg, "--benchmarks=")) { + if (absl::StartsWith(arg, "--benchmarks=")) { pattern = argv[i] + strlen("--benchmarks="); } else { // Handle flag of the form '--benchmarks foo' (no '='). - if (i + 1 >= argc || - tensorflow::str_util::StartsWith(argv[i + 1], "--")) { + if (i + 1 >= argc || absl::StartsWith(argv[i + 1], "--")) { LOG(ERROR) << "--benchmarks flag requires an argument."; return 2; } diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 7de2c39b3892dc40d09adfed1c39e4aca449039d..442e66321ee732f3d9cdfe4931433bd864b7fa82 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -21,24 +21,27 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" namespace xla { StatusOr> TextLiteralReader::ReadPath( - tensorflow::StringPiece path) { - CHECK(!tensorflow::str_util::EndsWith(path, ".gz")) + absl::string_view path) { + CHECK(!absl::EndsWith(path, ".gz")) << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr file; Status s = @@ -54,33 +57,6 @@ StatusOr> TextLiteralReader::ReadPath( TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file) : file_(file) {} -namespace { -// This is an optimized version of tensorflow::str_util::Split which uses -// StringPiece for the delimited strings and uses an out parameter for the -// result to avoid vector creation/destruction. -void SplitByDelimToStringPieces(tensorflow::StringPiece text, char delim, - std::vector* result) { - result->clear(); - - if (text.empty()) { - return; - } - - // The following loop is a little strange: its bound is text.size() + 1 - // instead of the more typical text.size(). - // The final iteration of the loop (when i is equal to text.size()) handles - // the trailing token. - size_t token_start = 0; - for (size_t i = 0; i < text.size() + 1; i++) { - if (i == text.size() || text[i] == delim) { - tensorflow::StringPiece token(text.data() + token_start, i - token_start); - result->push_back(token); - token_start = i + 1; - } - } -} -} // namespace - StatusOr> TextLiteralReader::ReadAllLines() { tensorflow::io::RandomAccessInputStream stream(file_.get()); tensorflow::io::BufferedInputStream buf(&stream, 65536); @@ -90,61 +66,55 @@ StatusOr> TextLiteralReader::ReadAllLines() { return s; } - tensorflow::StringPiece sp(shape_string); - if (tensorflow::str_util::RemoveWhitespaceContext(&sp) > 0) { - string tmp = std::string(sp); - shape_string = tmp; - } + absl::StripAsciiWhitespace(&shape_string); TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string)); if (shape.element_type() != F32) { return Unimplemented( "unsupported element type for text literal reading: %s", - ShapeUtil::HumanString(shape).c_str()); + ShapeUtil::HumanString(shape)); } auto result = absl::make_unique(shape); const float fill = std::numeric_limits::quiet_NaN(); result->PopulateWithValue(fill); - std::vector pieces; - std::vector coordinates; + std::vector pieces; + std::vector coordinates; std::vector coordinate_values; string line; while (buf.ReadLine(&line).ok()) { - SplitByDelimToStringPieces(line, ':', &pieces); - tensorflow::StringPiece coordinates_string = pieces[0]; - tensorflow::StringPiece value_string = pieces[1]; - tensorflow::str_util::RemoveWhitespaceContext(&coordinates_string); - tensorflow::str_util::RemoveWhitespaceContext(&value_string); - if (!tensorflow::str_util::ConsumePrefix(&coordinates_string, "(")) { + pieces = absl::StrSplit(line, ':'); + absl::string_view coordinates_string = + absl::StripAsciiWhitespace(pieces[0]); + absl::string_view value_string = absl::StripAsciiWhitespace(pieces[1]); + if (!absl::ConsumePrefix(&coordinates_string, "(")) { return InvalidArgument( - "expected '(' at the beginning of coordinates: \"%s\"", line.c_str()); + "expected '(' at the beginning of coordinates: \"%s\"", line); } - if (!tensorflow::str_util::ConsumeSuffix(&coordinates_string, ")")) { + if (!absl::ConsumeSuffix(&coordinates_string, ")")) { return InvalidArgument("expected ')' at the end of coordinates: \"%s\"", - line.c_str()); + line); } float value; - if (!tensorflow::strings::safe_strtof(std::string(value_string).c_str(), - &value)) { + if (!absl::SimpleAtof(value_string, &value)) { return InvalidArgument("could not parse value as float: \"%s\"", - std::string(value_string).c_str()); + value_string); } - SplitByDelimToStringPieces(coordinates_string, ',', &coordinates); + coordinates = absl::StrSplit(coordinates_string, ','); coordinate_values.clear(); - for (tensorflow::StringPiece piece : coordinates) { + for (absl::string_view piece : coordinates) { int64 coordinate_value; - if (!tensorflow::strings::safe_strto64(piece, &coordinate_value)) { + if (!absl::SimpleAtoi(piece, &coordinate_value)) { return InvalidArgument( "could not parse coordinate member as int64: \"%s\"", - std::string(piece).c_str()); + std::string(piece)); } coordinate_values.push_back(coordinate_value); } if (coordinate_values.size() != shape.dimensions_size()) { return InvalidArgument( - "line did not have expected number of coordinates; want %d got %zu: " + "line did not have expected number of coordinates; want %d got %u: " "\"%s\"", - shape.dimensions_size(), coordinate_values.size(), line.c_str()); + shape.dimensions_size(), coordinate_values.size(), line); } result->Set(coordinate_values, value); } diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h index 708e8c80d8b5c09454eb64d4e12df51a5b7ea628..b265640802c88847ce57e9f942f9f0859b873ae8 100644 --- a/tensorflow/compiler/xla/text_literal_reader.h +++ b/tensorflow/compiler/xla/text_literal_reader.h @@ -18,11 +18,11 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" @@ -41,8 +41,7 @@ class TextLiteralReader { public: // See class comment -- reads a file in its entirety (there must be only one // literal in the text file path provided). - static StatusOr> ReadPath( - tensorflow::StringPiece path); + static StatusOr> ReadPath(absl::string_view path); private: // Ownership of file is transferred. diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 24e0784741a4c9779b0adb7a7740c3d6e2fb033a..00147015a6b2bf41205a81dddd0b16f5ab434130 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -17,23 +17,23 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/types.h" namespace xla { -/* static */ Status TextLiteralWriter::WriteToPath( - const Literal& literal, tensorflow::StringPiece path) { +/* static */ Status TextLiteralWriter::WriteToPath(const Literal& literal, + absl::string_view path) { std::unique_ptr f; - auto s = tensorflow::Env::Default()->NewWritableFile(std::string(path), &f); + auto s = tensorflow::Env::Default()->NewWritableFile(string(path), &f); if (!s.ok()) { return s; } @@ -51,11 +51,10 @@ namespace xla { if (!status.ok()) { return; } - string coordinates = tensorflow::strings::StrCat( - "(", tensorflow::str_util::Join(indices, ", "), ")"); + string coordinates = + absl::StrCat("(", absl::StrJoin(indices, ", "), ")"); - status = f_ptr->Append( - tensorflow::strings::StrCat(coordinates, ": ", value, "\n")); + status = f_ptr->Append(absl::StrCat(coordinates, ": ", value, "\n")); }); auto ignored = f->Close(); return status; diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 159ac1b7e1b6f9c07dac795fb640cd0b2d284bcb..34de8572d638067b327711017ee173b16c8da21e 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ #define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -37,8 +37,7 @@ namespace xla { // This should be readable by xla::TextLiteralReader. class TextLiteralWriter { public: - static Status WriteToPath(const Literal& literal, - tensorflow::StringPiece path); + static Status WriteToPath(const Literal& literal, absl::string_view path); private: TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralWriter); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 40d28a57bfddd3403cad8252df985b746362631f..f23c5b3ef1f3eed1f03097d68d0a760ecc2d4a0f 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -24,6 +24,7 @@ tf_cc_binary( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/strings", ], ) @@ -191,6 +192,8 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index f20dcef382b86d27d7c176ae7e4132ad1db7b901..d15b71b7925d0e0f6c88e9484393c4a3239bb0b3 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -78,7 +78,7 @@ int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index f0af0580c1fbca455c6ed5f87f82971faee50a06..c446b27a040419059328def17b51fbfa2850ccff 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -30,8 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -44,16 +44,14 @@ class OperationDumper : public DfsHloVisitorWithDefault { explicit OperationDumper(const string& path) : path_(path) {} Status DefaultAction(HloInstruction* hlo) override { - string params = tensorflow::str_util::Join( + string params = absl::StrJoin( hlo->operands(), ", ", [](string* out, const HloInstruction* operand) { - tensorflow::strings::StrAppend( - out, ShapeUtil::HumanString(operand->shape())); + absl::StrAppend(out, ShapeUtil::HumanString(operand->shape())); }); // Spit `op_name(params...) -> result_type :: path` to stdout. - std::cout << tensorflow::strings::Printf( - "%s :: (%s) -> %s :: %s\n", HloOpcodeString(hlo->opcode()).c_str(), - params.c_str(), ShapeUtil::HumanString(hlo->shape()).c_str(), - path_.c_str()); + std::cout << absl::StrFormat("%s :: (%s) -> %s :: %s\n", + HloOpcodeString(hlo->opcode()), params, + ShapeUtil::HumanString(hlo->shape()), path_); return Status::OK(); } @@ -107,7 +105,7 @@ int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index f03e1b1f965af761c101555fd0275bc0425b9cf0..d86a4474b32f75a04fb398b13c2a34aa1b33df17 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -103,7 +103,7 @@ int main(int argc, char** argv) { QCHECK(argc > 1) << "\nERROR: must specify at least one module\n" << usage; tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args, compile); return 0; } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index dc5c106d02cb679f3e6f5b2bea40bbb42f8bd1cc..bd8b89542ff8863a015b1331be602adbdca49615 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -79,7 +79,7 @@ int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc index eb7bff053b1fc028fdb6930dbc496c3b6d9fae47..75b63c3b84c21005f64b770c44219d92ffce99df 100644 --- a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc +++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" #include "tensorflow/core/platform/env.h" @@ -67,7 +67,7 @@ int main(int argc, char** argv) { floats.push_back(value); } - tensorflow::StringPiece content( + tensorflow::StringPiece content( // non-absl ok tensorflow::bit_cast(floats.data()), floats.size() * sizeof(float)); TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 311a1bee8daa3a5d126f00dcabe0675f791adeaa..e826d6fa9361e9ea6f2fdbd6d70d0396d3849b29 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -250,7 +250,7 @@ StatusOr ParseInputFile(const string& filename, } fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n", filename.c_str()); - return InvalidArgument("Could not parse %s.", filename.c_str()); + return InvalidArgument("Could not parse %s.", filename); } int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { @@ -345,6 +345,6 @@ int main(int argc, char** argv) { } tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + args.remove_prefix(1); // Pop off the binary name, argv[0] return xla::tools::RealMain(args, opts); } diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc index 4e53fafcc97ff53afc5713e7ed8ee5222fac316b..10e7202acfbac2a3157007e129ead5502a779697 100644 --- a/tensorflow/compiler/xla/tools/show_signature.cc +++ b/tensorflow/compiler/xla/tools/show_signature.cc @@ -67,7 +67,7 @@ int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); tensorflow::gtl::ArraySlice args(argv, argc); - args.pop_front(); // Pop off the binary name, argv[0] + args.remove_prefix(1); // Pop off the binary name, argv[0] xla::tools::RealMain(args); return 0; } diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index e43498e381b8e63543e2ddda08ca7c0df91817e4..0f607a0c8afd0aa23053a15c3a274fe5d5fdfdbb 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -18,12 +18,13 @@ limitations under the License. #include #include +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stacktrace.h" @@ -54,108 +55,25 @@ ScopedLoggingTimer::~ScopedLoggingTimer() { } } -Status AddStatus(Status prior, tensorflow::StringPiece context) { +Status AddStatus(Status prior, absl::string_view context) { CHECK(!prior.ok()); - return Status{prior.code(), tensorflow::strings::StrCat( - context, ": ", prior.error_message())}; + return Status{prior.code(), + absl::StrCat(context, ": ", prior.error_message())}; } -Status AppendStatus(Status prior, tensorflow::StringPiece context) { +Status AppendStatus(Status prior, absl::string_view context) { CHECK(!prior.ok()); - return Status{prior.code(), tensorflow::strings::StrCat(prior.error_message(), - ": ", context)}; + return Status{prior.code(), + absl::StrCat(prior.error_message(), ": ", context)}; } -// Implementation note: we can't common these out (without using macros) because -// they all need to va_start/va_end their varargs in their frame. - -Status InvalidArgumentV(const char* format, va_list args) { - string message; - tensorflow::strings::Appendv(&message, format, args); - return WithLogBacktrace(tensorflow::errors::InvalidArgument(message)); -} - -Status InvalidArgument(const char* format, ...) { - va_list args; - va_start(args, format); - Status result = InvalidArgumentV(format, args); - va_end(args); - return result; -} - -Status Unimplemented(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Unimplemented(message)); -} - -Status InternalError(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Internal(message)); -} - -Status FailedPrecondition(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::FailedPrecondition(message)); -} - -Status Cancelled(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Cancelled(message)); -} - -Status ResourceExhausted(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::ResourceExhausted(message)); -} - -Status NotFound(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::NotFound(message)); -} - -Status Unavailable(const char* format, ...) { - string message; - va_list args; - va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); - va_end(args); - return WithLogBacktrace(tensorflow::errors::Unavailable(message)); -} - -string Reindent(tensorflow::StringPiece original, - const tensorflow::StringPiece indentation) { - std::vector pieces = tensorflow::str_util::Split( - tensorflow::StringPiece(original.data(), original.size()), '\n'); - return tensorflow::str_util::Join( - pieces, "\n", [indentation](string* out, string s) { - tensorflow::StringPiece piece(s); - tensorflow::str_util::RemoveWhitespaceContext(&piece); - tensorflow::strings::StrAppend(out, indentation, piece); - }); +string Reindent(absl::string_view original, + const absl::string_view indentation) { + std::vector pieces = + absl::StrSplit(absl::string_view(original.data(), original.size()), '\n'); + return absl::StrJoin(pieces, "\n", [indentation](string* out, string s) { + absl::StrAppend(out, indentation, absl::StripAsciiWhitespace(s)); + }); } bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank) { @@ -234,20 +152,20 @@ bool HasInteriorPadding(const PaddingConfig& config) { namespace { string HumanReadableNumOps(double flops, double nanoseconds, - tensorflow::StringPiece op_prefix) { + absl::string_view op_prefix) { if (nanoseconds == 0) { - return tensorflow::strings::StrCat("NaN ", op_prefix, "OP/s"); + return absl::StrCat("NaN ", op_prefix, "OP/s"); } double nano_flops = flops / nanoseconds; string throughput = tensorflow::strings::HumanReadableNum( static_cast(nano_flops * 1e9)); - tensorflow::StringPiece sp(throughput); + absl::string_view sp(throughput); // Use the more common "G(FLOPS)", rather than "B(FLOPS)" - if (tensorflow::str_util::EndsWith(sp, "B") || // Ends in 'B', ignoring case - tensorflow::str_util::EndsWith(sp, "b")) { + if (absl::EndsWith(sp, "B") || // Ends in 'B', ignoring case + absl::EndsWith(sp, "b")) { *throughput.rbegin() = 'G'; } - throughput += tensorflow::strings::StrCat(op_prefix, "OP/s"); + throughput += absl::StrCat(op_prefix, "OP/s"); return throughput; } } // namespace @@ -260,8 +178,7 @@ string HumanReadableNumTranscendentalOps(double trops, double nanoseconds) { return HumanReadableNumOps(trops, nanoseconds, "TR"); } -void LogLines(int sev, tensorflow::StringPiece text, const char* fname, - int lineno) { +void LogLines(int sev, absl::string_view text, const char* fname, int lineno) { const int orig_sev = sev; if (sev == tensorflow::FATAL) { sev = tensorflow::ERROR; @@ -275,7 +192,7 @@ void LogLines(int sev, tensorflow::StringPiece text, const char* fname, size_t cur = 0; while (cur < text.size()) { size_t eol = text.find('\n', cur); - if (eol == tensorflow::StringPiece::npos) { + if (eol == absl::string_view::npos) { eol = text.size(); } auto msg = text.substr(cur, eol - cur); diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index efeafbc53a28b46eff91568807ab7a8bf82b7b52..62f486369f1b7f402e69373ed1561f8213b459ab 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -26,16 +26,18 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" @@ -202,46 +204,76 @@ void StridedCopy(tensorflow::gtl::MutableArraySlice dest, int64 dest_base, // Adds some context information to the error message in a // Status. This is useful as Statuses are // propagated upwards. -Status AddStatus(Status prior, tensorflow::StringPiece context); -Status AppendStatus(Status prior, tensorflow::StringPiece context); - -// Status error shorthands -- printfs the arguments to be -// used as an error message and returns a status in the canonical -// error space. -Status InvalidArgument(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status Unimplemented(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status InternalError(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status FailedPrecondition(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status Cancelled(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status ResourceExhausted(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status NotFound(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); -Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); - -// Passed-varargs variant of the InvalidArgument factory above. -Status InvalidArgumentV(const char* format, va_list args); +Status AddStatus(Status prior, absl::string_view context); +Status AppendStatus(Status prior, absl::string_view context); + +// Status error shorthands -- StrFormat's the arguments to be used as an error +// message and returns a status in the canonical error space. +template +Status InvalidArgument(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::InvalidArgument(absl::StrFormat(format, args...))); +} +template +Status Unimplemented(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Unimplemented(absl::StrFormat(format, args...))); +} +template +Status InternalError(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Internal(absl::StrFormat(format, args...))); +} +template +Status FailedPrecondition(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::FailedPrecondition(absl::StrFormat(format, args...))); +} +template +Status Cancelled(const absl::FormatSpec& format, const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Cancelled(absl::StrFormat(format, args...))); +} +template +Status ResourceExhausted(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::ResourceExhausted(absl::StrFormat(format, args...))); +} +template +Status NotFound(const absl::FormatSpec& format, const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::NotFound(absl::StrFormat(format, args...))); +} +template +Status Unavailable(const absl::FormatSpec& format, + const Args&... args) { + return WithLogBacktrace( + tensorflow::errors::Unavailable(absl::StrFormat(format, args...))); +} template Status InvalidArgumentStrCat(Args&&... concat) { - return InvalidArgument( - "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); + return InvalidArgument("%s", absl::StrCat(std::forward(concat)...)); } template Status UnimplementedStrCat(Args&&... concat) { - return Unimplemented( - "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); + return Unimplemented("%s", absl::StrCat(std::forward(concat)...)); } template Status InternalErrorStrCat(Args&&... concat) { - return InternalError( - "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); + return InternalError("%s", absl::StrCat(std::forward(concat)...)); } template Status ResourceExhaustedStrCat(Args&&... concat) { - return ResourceExhausted( - "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); + return ResourceExhausted("%s", absl::StrCat(std::forward(concat)...)); } // Splits the lines of the original, replaces leading whitespace with the prefix @@ -250,8 +282,7 @@ Status ResourceExhaustedStrCat(Args&&... concat) { // // Note: even different amounts of leading whitespace on different lines will be // uniformly replaced with "indentation". -string Reindent(tensorflow::StringPiece original, - tensorflow::StringPiece indentation); +string Reindent(absl::string_view original, absl::string_view indentation); // Checks whether permutation is a permutation of the [0, rank) integer range. bool IsPermutation(tensorflow::gtl::ArraySlice permutation, int64 rank); @@ -313,7 +344,7 @@ string CommaSeparatedString(const Container& c, const char* prefix = "", string comma_separated = prefix; const char* separator = ""; for (const auto& entry : c) { - tensorflow::strings::StrAppend(&comma_separated, separator, entry); + absl::StrAppend(&comma_separated, separator, entry); separator = ", "; } comma_separated += suffix; @@ -395,8 +426,7 @@ string HumanReadableNumTranscendentalOps(double trops, double nanoseconds); // Split the text into multiple lines and log each line with the given // severity, filename, and line number. -void LogLines(int sev, tensorflow::StringPiece text, const char* fname, - int lineno); +void LogLines(int sev, absl::string_view text, const char* fname, int lineno); template inline bool IsPowerOfTwo(T x) { diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index f11123ca24849af1d9c4fd49809a986eb7202bd5..268dc5db01a3ebb8868444eccc71515ab04c7c97 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -17,11 +17,9 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace window_util { @@ -49,8 +47,8 @@ PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice sizes) { } /* static */ string ToString(const WindowDimension& dim) { - using tensorflow::strings::StrAppend; - using tensorflow::strings::StrCat; + using absl::StrAppend; + using absl::StrCat; string str = StrCat("(size=", dim.size()); if (dim.stride() != 1) { StrAppend(&str, ",stride=", dim.stride()); @@ -75,8 +73,8 @@ PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice sizes) { } string ToString(const Window& window) { - using tensorflow::strings::StrAppend; - using tensorflow::strings::StrCat; + using absl::StrAppend; + using absl::StrCat; string str; const auto add_field = diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 9451e0c315a882ce61af130e645198ba2fc7ca03..aaba5aa92e0c4247f02b8a2a24b4eeac37a08afd 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -570,6 +570,12 @@ message ReplicaGroup { repeated int64 replica_ids = 1; } +// Describes the source target pair in the collective permute op. +message SourceTarget { + int64 source = 1; + int64 target = 2; +} + // Used to indicate the precision configuration. It has backend specific // meaning. message PrecisionConfigProto { diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 222e66cebeeca9f50299925916a8732fcc8c6a4c..66983801bf81188f81b9d4149eec5f0d20a296b4 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -61,7 +61,6 @@ py_library( "//tensorflow/contrib/integrate:integrate_py", "//tensorflow/contrib/keras", "//tensorflow/contrib/kernel_methods", - "//tensorflow/contrib/kfac", "//tensorflow/contrib/labeled_tensor", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", @@ -182,6 +181,7 @@ cc_library( "//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib", "//tensorflow/contrib/coder:all_ops", "//tensorflow/contrib/data:dataset_ops_op_lib", + "//tensorflow/contrib/data:indexed_dataset_ops_op_lib", "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/framework:all_ops", "//tensorflow/contrib/hadoop:dataset_ops_op_lib", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 45a7680160251b37fbfb923eb23a5d68ccb2c5fb..5f477a79a3d960bc2cd2df2d288ae80e30671d75 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -51,7 +51,6 @@ from tensorflow.contrib import input_pipeline from tensorflow.contrib import integrate from tensorflow.contrib import keras from tensorflow.contrib import kernel_methods -from tensorflow.contrib import kfac from tensorflow.contrib import labeled_tensor from tensorflow.contrib import layers from tensorflow.contrib import learn diff --git a/tensorflow/contrib/android/asset_manager_filesystem.cc b/tensorflow/contrib/android/asset_manager_filesystem.cc index 513d519eabbd54f46fde9ec0f004247c02277732..d14b2126a0ff9b130ad5eaf3cb8dbdbe63ba1d68 100644 --- a/tensorflow/contrib/android/asset_manager_filesystem.cc +++ b/tensorflow/contrib/android/asset_manager_filesystem.cc @@ -28,7 +28,7 @@ string RemoveSuffix(const string& name, const string& suffix) { string output(name); StringPiece piece(output); str_util::ConsumeSuffix(&piece, suffix); - return piece.ToString(); + return string(piece); } // Closes the given AAsset when variable is destructed. @@ -231,7 +231,7 @@ string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) { string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) { StringPiece piece(name); str_util::ConsumePrefix(&piece, prefix_); - return piece.ToString(); + return string(piece); } bool AssetManagerFileSystem::DirectoryExists(const std::string& fname) { diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index f7dd3183b01095fb3bc43c73d3ad20b14bc53e1d..3530fbb2ecc5ac8de5ff8b3c94fdf6b84a4cd77b 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -217,7 +217,7 @@ class ControlFlowTransformer(converter.Base): cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE) cond_closure = set() - for s in cond_scope.referenced: + for s in cond_scope.used: for root in s.support_set: if root not in body_scope.created: cond_closure.add(root) @@ -250,6 +250,7 @@ class ControlFlowTransformer(converter.Base): node_body = ast_util.rename_symbols(node.body, ssf_map) test = ast_util.rename_symbols(node.test, ssf_map) + # TODO(b/113118541) investigate the need-for and correctness-of extra_deps template = """ def test_name(state_ssf): return test @@ -310,7 +311,9 @@ class ControlFlowTransformer(converter.Base): template = """ def extra_test_name(state_ssf): return extra_test_expr - def body_name(iterate, state_ssf): + def body_name(loop_vars, state_ssf): + # Workaround for PEP-3113 + iterate = loop_vars body return state_ssf, state_ast_tuple = ag__.for_stmt( diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index 02bc00dbc8607070e353acfe383136c79bcfec51..1d04ba3ba610ff1694e8ef9a7f52cfda06571184 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -48,6 +48,24 @@ class ControlFlowTest(converter_testing.TestCase): self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5)) + def test_while_nested(self): + + def test_fn(n): + i = 0 + j = 0 + s = 0 + while i < n: + while j < i: + j += 3 + u = i + j # 'u' is not defined within the inner loop + s += u + i += 1 + j = 0 + return s, i, j, n + + self.assertTransformedResult(test_fn, constant_op.constant(5), + (25, 5, 0, 5)) + def test_while_single_output(self): def test_fn(n): @@ -217,5 +235,13 @@ class ControlFlowTest(converter_testing.TestCase): with self.assertRaises(transformer.AutographParseError): control_flow.transform(node, ctx) + def test_for_tuple_unpacking(self): + def test_fn(x_list): + z = tf.constant(0) # pylint:disable=undefined-variable + for i, x in enumerate(x_list): + z = z + x + i + return z + + self.assertTransformedResult(test_fn, [3, 3], 7) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/testing/BUILD b/tensorflow/contrib/autograph/pyct/testing/BUILD index 9ef1ac9663eac8febffd697d7164425716b65d9d..29a92444bbc911a4f3c4afbc64410d9fe802801c 100644 --- a/tensorflow/contrib/autograph/pyct/testing/BUILD +++ b/tensorflow/contrib/autograph/pyct/testing/BUILD @@ -34,8 +34,10 @@ py_test( srcs = ["codegen_test.py"], srcs_version = "PY2AND3", tags = [ + "manual", "no_windows", "nomsan", + "notap", ], deps = [ ":testing", diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md index b9abfa8295f9013cd8e92f87466a73952ccceb10..f33eaf7e3df356e10939f591ef75cb4f17978144 100644 --- a/tensorflow/contrib/bigtable/README.md +++ b/tensorflow/contrib/bigtable/README.md @@ -324,8 +324,14 @@ If you encounter a log line that includes the following: "filename":"/usr/share/grpc/roots.pem" ``` -you likely need to copy the [gRPC `roots.pem` file][grpcPem] to -`/usr/share/grpc/roots.pem` on your local machine. +you can solve it via either of the following approaches: + +* copy the [gRPC `roots.pem` file][grpcPem] to + `/usr/share/grpc/roots.pem` on your local machine, which is the default + location where gRPC will look for this file +* export the environment variable `GRPC_DEFAULT_SSL_ROOTS_FILE_PATH` to point to + the full path of the gRPC `roots.pem` file on your file system if it's in a + different location [grpcPem]: https://github.com/grpc/grpc/blob/master/etc/roots.pem diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 64349cfca39eba0abbab3865fbcdf4b8e7add4cb..d0fd39fa30e48d7031f3f386910703eee05da0d2 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= +#include #include #include #include @@ -325,13 +326,21 @@ class BuildDenseInequalitySplitsOp : public OpKernel { } float best_gain = std::numeric_limits::lowest(); - int64 best_bucket_idx = 0; + int64 best_bucket_id = 0; std::vector best_right_node_stats(num_elements, NodeStats(0)); std::vector best_left_node_stats(num_elements, NodeStats(0)); std::vector current_left_node_stats(num_elements, NodeStats(0)); std::vector current_right_node_stats(num_elements, NodeStats(0)); - int64 current_bucket_id = 0; + int64 current_bucket_id = std::numeric_limits::max(); int64 last_bucket_id = -1; + // Find the lowest bucket id, this is going to be the first bucket id to + // try. + for (int root_idx = 0; root_idx < num_elements; root_idx++) { + const int start_index = partition_boundaries[root_idx]; + if (bucket_ids(start_index, 0) < current_bucket_id) { + current_bucket_id = bucket_ids(start_index, 0); + } + } // Indexes offsets for each of the partitions that can be used to access // gradients of a partition for a current bucket we consider. std::vector current_layer_offsets(num_elements, 0); @@ -373,6 +382,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel { best_gain = gain_of_split; best_left_node_stats = current_left_node_stats; best_right_node_stats = current_right_node_stats; + best_bucket_id = current_bucket_id; } current_bucket_id = next_bucket_id; } @@ -387,8 +397,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel { oblivious_split_info.mutable_split_node() ->mutable_oblivious_dense_float_binary_split(); oblivious_dense_split->set_feature_column(state->feature_column_group_id()); - oblivious_dense_split->set_threshold( - bucket_boundaries(bucket_ids(best_bucket_idx, 0))); + oblivious_dense_split->set_threshold(bucket_boundaries(best_bucket_id)); (*gains)(0) = best_gain; for (int root_idx = 0; root_idx < num_elements; root_idx++) { @@ -400,6 +409,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel { const int start_index = partition_boundaries[root_idx]; (*output_partition_ids)(root_idx) = partition_ids(start_index); + oblivious_split_info.add_children_parent_id(partition_ids(start_index)); } oblivious_split_info.SerializeToString(&(*output_splits)(0)); } diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index bb5ae78d9bfdc7ea17bd3f7b312360d4e70d97c4..ab2853352a70073648f47e9835f8a66852ff584f 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= +#include + #include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h" #include "tensorflow/contrib/boosted_trees/proto/learner.pb.h" #include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h" @@ -772,20 +774,32 @@ class GrowTreeEnsembleOp : public OpKernel { // The number of new children. int num_children = 1 << (depth + 1); auto split_info = split->oblivious_split_info; - CHECK(num_children == split_info.children_size()) - << "Wrong number of new children: " << num_children - << " != " << split_info.children_size(); - for (int idx = 0; idx < num_children; idx += 2) { - // Old leaf is at position depth + idx / 2. + CHECK(num_children >= split_info.children_size()) + << "Too many new children, expected <= " << num_children << " and got " + << split_info.children_size(); + std::vector new_leaves; + new_leaves.reserve(num_children); + int next_id = 0; + for (int idx = 0; idx < num_children / 2; idx++) { trees::Leaf old_leaf = - *tree_config->mutable_nodes(depth + idx / 2)->mutable_leaf(); - // Update left leaf. - *split_info.mutable_children(idx) = - *MergeLeafWeights(old_leaf, split_info.mutable_children(idx)); - // Update right leaf. - *split_info.mutable_children(idx + 1) = - *MergeLeafWeights(old_leaf, split_info.mutable_children(idx + 1)); + *tree_config->mutable_nodes(depth + idx)->mutable_leaf(); + // Check if a split was made for this leaf. + if (next_id < split_info.children_parent_id_size() && + depth + idx == split_info.children_parent_id(next_id)) { + // Add left leaf. + new_leaves.push_back(*MergeLeafWeights( + old_leaf, split_info.mutable_children(2 * next_id))); + // Add right leaf. + new_leaves.push_back(*MergeLeafWeights( + old_leaf, split_info.mutable_children(2 * next_id + 1))); + next_id++; + } else { + // If there is no split for this leaf, just duplicate it. + new_leaves.push_back(old_leaf); + new_leaves.push_back(old_leaf); + } } + CHECK(next_id == split_info.children_parent_id_size()); TreeNodeMetadata* split_metadata = split_info.mutable_split_node()->mutable_node_metadata(); split_metadata->set_gain(split->gain); @@ -804,11 +818,10 @@ class GrowTreeEnsembleOp : public OpKernel { if (idx + depth + 1 < nodes_size) { // Update leaves that were already there. *tree_config->mutable_nodes(idx + depth + 1)->mutable_leaf() = - *split_info.mutable_children(idx); + new_leaves[idx]; } else { // Add new leaves. - *tree_config->add_nodes()->mutable_leaf() = - *split_info.mutable_children(idx); + *tree_config->add_nodes()->mutable_leaf() = new_leaves[idx]; } } } diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index d9caebb645ab95cdee6bbd40a1e03b8fd64ff195..5532bd026ab695d166bc2e2872ecc551920978d5 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -186,14 +186,15 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): with self.test_session() as sess: # The data looks like the following: # Example | Gradients | Partition | Dense Quantile | - # i0 | (0.2, 0.12) | 0 | 2 | - # i1 | (-0.5, 0.07) | 0 | 2 | - # i2 | (1.2, 0.2) | 0 | 0 | - # i3 | (4.0, 0.13) | 1 | 1 | - dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52]) + # i0 | (0.2, 0.12) | 1 | 3 | + # i1 | (-0.5, 0.07) | 1 | 3 | + # i2 | (1.2, 0.2) | 1 | 1 | + # i3 | (4.0, 0.13) | 2 | 2 | + dense_column = array_ops.placeholder( + dtypes.float32, shape=(4, 1), name="dense_column") gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0]) hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) - partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) + partition_ids = array_ops.constant([1, 1, 1, 2], dtype=dtypes.int32) class_id = -1 gradient_shape = tensor_shape.scalar() @@ -230,31 +231,35 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_1]): are_splits_ready = split_handler.make_splits( np.int64(0), np.int64(1), class_id)[0] + # Forcing the creation of four buckets. + are_splits_ready = sess.run( + [are_splits_ready], + feed_dict={dense_column: [[0.2], [0.62], [0.3], [0.52]]})[0] - with ops.control_dependencies([are_splits_ready]): - update_2 = split_handler.update_stats_sync( - 1, - partition_ids, - gradients, - hessians, - empty_gradients, - empty_hessians, - example_weights, - is_active=array_ops.constant([True, True])) + update_2 = split_handler.update_stats_sync( + 1, + partition_ids, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( split_handler.make_splits(np.int64(1), np.int64(2), class_id)) - are_splits_ready, are_splits_ready2, partitions, gains, splits = ( - sess.run([ - are_splits_ready, are_splits_ready2, partitions, gains, splits - ])) + # Only using the last three buckets. + are_splits_ready2, partitions, gains, splits = ( + sess.run( + [are_splits_ready2, partitions, gains, splits], + feed_dict={dense_column: [[0.62], [0.62], [0.3], [0.52]]})) # During the first iteration, inequality split handlers are not going to # have any splits. Make sure that we return not_ready in that case. self.assertFalse(are_splits_ready) self.assertTrue(are_splits_ready2) - self.assertAllEqual([0, 1], partitions) + self.assertAllEqual([1, 2], partitions) oblivious_split_info = split_info_pb2.ObliviousSplitInfo() oblivious_split_info.ParseFromString(splits[0]) @@ -263,52 +268,57 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertAllClose(0.3, split_node.threshold, 0.00001) self.assertEqual(0, split_node.feature_column) - # Check the split on partition 0. + # Check the split on partition 1. # -(1.2 - 0.1) / (0.2 + 1) - expected_left_weight_0 = -0.9166666666666666 + expected_left_weight_1 = -0.9166666666666666 - # expected_left_weight_0 * -(1.2 - 0.1) - expected_left_gain_0 = 1.008333333333333 + # expected_left_weight_1 * -(1.2 - 0.1) + expected_left_gain_1 = 1.008333333333333 # (-0.5 + 0.2 + 0.1) / (0.19 + 1) - expected_right_weight_0 = 0.1680672 + expected_right_weight_1 = 0.1680672 - # expected_right_weight_0 * -(-0.5 + 0.2 + 0.1)) - expected_right_gain_0 = 0.033613445378151252 + # expected_right_weight_1 * -(-0.5 + 0.2 + 0.1)) + expected_right_gain_1 = 0.033613445378151252 # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1) - expected_bias_gain_0 = 0.46043165467625896 + expected_bias_gain_1 = 0.46043165467625896 left_child = oblivious_split_info.children[0].vector right_child = oblivious_split_info.children[1].vector - self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001) + self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001) - self.assertAllClose([expected_right_weight_0], right_child.value, 0.00001) + self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001) - # Check the split on partition 1. - expected_left_weight_1 = 0 - expected_left_gain_1 = 0 + # Check the split on partition 2. + expected_left_weight_2 = 0 + expected_left_gain_2 = 0 # -(4 - 0.1) / (0.13 + 1) - expected_right_weight_1 = -3.4513274336283186 - # expected_right_weight_1 * -(4 - 0.1) - expected_right_gain_1 = 13.460176991150442 + expected_right_weight_2 = -3.4513274336283186 + # expected_right_weight_2 * -(4 - 0.1) + expected_right_gain_2 = 13.460176991150442 # (-4 + 0.1) ** 2 / (0.13 + 1) - expected_bias_gain_1 = 13.460176991150442 + expected_bias_gain_2 = 13.460176991150442 left_child = oblivious_split_info.children[2].vector right_child = oblivious_split_info.children[3].vector - self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001) + self.assertAllClose([expected_left_weight_2], left_child.value, 0.00001) - self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001) + self.assertAllClose([expected_right_weight_2], right_child.value, 0.00001) # The layer gain is the sum of the gains of each partition layer_gain = ( - expected_left_gain_0 + expected_right_gain_0 - expected_bias_gain_0) + ( - expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1) + expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1) + ( + expected_left_gain_2 + expected_right_gain_2 - expected_bias_gain_2) self.assertAllClose(layer_gain, gains[0], 0.00001) + # We have examples in both partitions, then we get both ids. + self.assertEqual(2, len(oblivious_split_info.children_parent_id)) + self.assertEqual(1, oblivious_split_info.children_parent_id[0]) + self.assertEqual(2, oblivious_split_info.children_parent_id[1]) + def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self): with self.test_session() as sess: # The data looks like the following: diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h index 69bb8fd4ada861a42a0ccc3f287a47d91be5c879..8d71a6cdbc495aab9c29b3b1f3b70d32c04573ec 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h @@ -36,12 +36,6 @@ class WeightedQuantilesSummary { struct SummaryEntry { SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min, const WeightType& max) { - // Explicitly initialize all of memory (including padding from memory - // alignment) to allow the struct to be msan-resistant "plain old data". - // - // POD = http://en.cppreference.com/w/cpp/concept/PODType - memset(this, 0, sizeof(*this)); - value = v; weight = w; min_rank = min; @@ -49,8 +43,6 @@ class WeightedQuantilesSummary { } SummaryEntry() { - memset(this, 0, sizeof(*this)); - value = ValueType(); weight = 0; min_rank = 0; diff --git a/tensorflow/contrib/boosted_trees/proto/split_info.proto b/tensorflow/contrib/boosted_trees/proto/split_info.proto index 65448996bffae02f077616db6467957b951131c0..784977af39501af247526619af8ab0cb29422ab7 100644 --- a/tensorflow/contrib/boosted_trees/proto/split_info.proto +++ b/tensorflow/contrib/boosted_trees/proto/split_info.proto @@ -21,4 +21,8 @@ message SplitInfo { message ObliviousSplitInfo { tensorflow.boosted_trees.trees.TreeNode split_node = 1; repeated tensorflow.boosted_trees.trees.Leaf children = 2; + // For each child, children_parent_id stores the node_id of its parent when it + // was a leaf. For the idx-th child it corresponds the idx/2-th + // children_parent_id. + repeated int32 children_parent_id = 3; } diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index 278dc1f7560639d658408c83aa1321fee995bb2c..b3e4c2e5f7a907892d66ad4181eb6ed8589bab6e 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -91,7 +91,8 @@ def _gen_dense_split_info(fc, threshold, left_weight, right_weight): return split.SerializeToString() -def _gen_dense_oblivious_split_info(fc, threshold, leave_weights): +def _gen_dense_oblivious_split_info(fc, threshold, leave_weights, + children_parent_id): split_str = """ split_node { oblivious_dense_float_binary_split { @@ -107,6 +108,9 @@ def _gen_dense_oblivious_split_info(fc, threshold, leave_weights): } }""" % ( weight) + for x in children_parent_id: + split_str += """ + children_parent_id: %d""" % (x) split = split_info_pb2.ObliviousSplitInfo() text_format.Merge(split_str, split) return split.SerializeToString() @@ -432,14 +436,18 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): handler1_partitions = np.array([0], dtype=np.int32) handler1_gains = np.array([7.62], dtype=np.float32) handler1_split = [ - _gen_dense_oblivious_split_info(0, 0.52, [-4.375, 7.143]) + _gen_dense_oblivious_split_info(0, 0.52, [-4.375, 7.143], [0]) ] handler2_partitions = np.array([0], dtype=np.int32) handler2_gains = np.array([0.63], dtype=np.float32) - handler2_split = [_gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24])] + handler2_split = [ + _gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24], [0]) + ] handler3_partitions = np.array([0], dtype=np.int32) handler3_gains = np.array([7.62], dtype=np.float32) - handler3_split = [_gen_dense_oblivious_split_info(0, 7, [-4.375, 7.143])] + handler3_split = [ + _gen_dense_oblivious_split_info(0, 7, [-4.375, 7.143], [0]) + ] # Grow tree ensemble. grow_op = training_ops.grow_tree_ensemble( @@ -1675,17 +1683,20 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): handler1_partitions = np.array([0], dtype=np.int32) handler1_gains = np.array([1.4], dtype=np.float32) handler1_split = [ - _gen_dense_oblivious_split_info(0, 0.21, [-6.0, 1.65, 1.0, -0.5]) + _gen_dense_oblivious_split_info(0, 0.21, [-6.0, 1.65, 1.0, -0.5], + [1, 2]) ] handler2_partitions = np.array([0], dtype=np.int32) handler2_gains = np.array([2.7], dtype=np.float32) handler2_split = [ - _gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24, 0.3, 0.4]), + _gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24, 0.3, 0.4], + [1, 2]) ] handler3_partitions = np.array([0], dtype=np.int32) handler3_gains = np.array([1.7], dtype=np.float32) handler3_split = [ - _gen_dense_oblivious_split_info(0, 3, [-0.75, 1.93, 0.2, -0.1]) + _gen_dense_oblivious_split_info(0, 3, [-0.75, 1.93, 0.2, -0.1], + [1, 2]) ] # Grow tree ensemble layer by layer. @@ -1797,6 +1808,528 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): self.assertEqual(stats.attempted_layers, 2) self.assertProtoEquals(expected_result, tree_ensemble_config) + def testGrowEnsembleWithEmptyNodesMiddleCase(self): + """Test case: The middle existing leaves don't have examples.""" + with self.test_session() as session: + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge( + """ + trees { + nodes { + oblivious_dense_float_binary_split { + feature_column: 4 + threshold: 7 + } + node_metadata { + gain: 7.62 + original_oblivious_leaves { + } + } + } + nodes { + oblivious_dense_float_binary_split { + feature_column: 1 + threshold: 0.23 + } + node_metadata { + gain: 2.7 + original_oblivious_leaves { + vector { + value: 7.143 + } + } + original_oblivious_leaves { + vector { + value: -4.375 + } + } + } + } + nodes { + leaf { + vector { + value: 6.543 + } + } + } + nodes { + leaf { + vector { + value: 7.5 + } + } + } + nodes { + leaf { + vector { + value: -4.075 + } + } + } + nodes { + leaf { + vector { + value: -3.975 + } + } + } + } + tree_weights: 0.1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 2 + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 2 + } + """, tree_ensemble_config) + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare learner config. + learner_config = _gen_learner_config( + num_classes=2, + l1_reg=0, + l2_reg=0, + tree_complexity=0, + max_depth=6, + min_node_weight=0, + pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, + growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER) + + # Prepare handler inputs. + handler1_partitions = np.array([0], dtype=np.int32) + handler1_gains = np.array([1.8], dtype=np.float32) + handler1_split = [ + _gen_dense_oblivious_split_info(0, 0.9, [1.0, 2.0, 3.0, 4.0], [2, 5]) + ] + # The tree currently has depth 2, so the ids for the four leaves are in + # the range [2, 6). In this test case we are assuming that our examples + # only fall in leaves 2 and 5. + + # Grow tree ensemble layer by layer. + grow_op = training_ops.grow_tree_ensemble( + tree_ensemble_handle, + stamp_token=0, + next_stamp_token=1, + learning_rate=0.1, + partition_ids=[handler1_partitions], + gains=[handler1_gains], + splits=[handler1_split], + learner_config=learner_config.SerializeToString(), + dropout_seed=123, + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE) + session.run(grow_op) + + new_stamp, serialized = session.run( + model_ops.tree_ensemble_serialize(tree_ensemble_handle)) + stats = session.run( + training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1)) + tree_ensemble_config.ParseFromString(serialized) + expected_result = """ + trees { + nodes { + oblivious_dense_float_binary_split { + feature_column: 4 + threshold: 7 + } + node_metadata { + gain: 7.62 + original_oblivious_leaves { + } + } + } + nodes { + oblivious_dense_float_binary_split { + feature_column: 1 + threshold: 0.23 + } + node_metadata { + gain: 2.7 + original_oblivious_leaves { + vector { + value: 7.143 + } + } + original_oblivious_leaves { + vector { + value: -4.375 + } + } + } + } + nodes { + oblivious_dense_float_binary_split { + feature_column: 0 + threshold: 0.9 + } + node_metadata { + gain: 1.8 + original_oblivious_leaves { + vector { + value: 6.543 + } + } + original_oblivious_leaves { + vector { + value: 7.5 + } + } + original_oblivious_leaves { + vector { + value: -4.075 + } + } + original_oblivious_leaves { + vector { + value: -3.975 + } + } + } + } + nodes { + leaf { + vector { + value: 7.543 + } + } + } + nodes { + leaf { + vector { + value: 8.543 + } + } + } + nodes { + leaf { + vector { + value: 7.5 + } + } + } + nodes { + leaf { + vector { + value: 7.5 + } + } + } + nodes { + leaf { + vector { + value: -4.075 + } + } + } + nodes { + leaf { + vector { + value: -4.075 + } + } + } + nodes { + leaf { + vector { + value: -0.975 + } + } + } + nodes { + leaf { + vector { + value: 0.025 + } + } + } + } + tree_weights: 0.1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 3 + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 3 + } + """ + self.assertEqual(new_stamp, 1) + self.assertEqual(stats.num_trees, 0) + self.assertEqual(stats.num_layers, 3) + self.assertEqual(stats.active_tree, 1) + self.assertEqual(stats.active_layer, 3) + self.assertEqual(stats.attempted_trees, 1) + self.assertEqual(stats.attempted_layers, 3) + self.assertProtoEquals(expected_result, tree_ensemble_config) + + def testGrowEnsembleWithEmptyNodesBorderCase(self): + """Test case: The first and last existing leaves don't have examples.""" + with self.test_session() as session: + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge( + """ + trees { + nodes { + oblivious_dense_float_binary_split { + feature_column: 4 + threshold: 7 + } + node_metadata { + gain: 7.62 + original_oblivious_leaves { + } + } + } + nodes { + oblivious_dense_float_binary_split { + feature_column: 1 + threshold: 0.23 + } + node_metadata { + gain: 2.7 + original_oblivious_leaves { + vector { + value: 7.143 + } + } + original_oblivious_leaves { + vector { + value: -4.375 + } + } + } + } + nodes { + leaf { + vector { + value: 6.543 + } + } + } + nodes { + leaf { + vector { + value: 7.5 + } + } + } + nodes { + leaf { + vector { + value: -4.075 + } + } + } + nodes { + leaf { + vector { + value: -3.975 + } + } + } + } + tree_weights: 0.1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 2 + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 2 + } + """, tree_ensemble_config) + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare learner config. + learner_config = _gen_learner_config( + num_classes=2, + l1_reg=0, + l2_reg=0, + tree_complexity=0, + max_depth=6, + min_node_weight=0, + pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, + growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER) + + # Prepare handler inputs. + handler1_partitions = np.array([0], dtype=np.int32) + handler1_gains = np.array([1.8], dtype=np.float32) + handler1_split = [ + _gen_dense_oblivious_split_info(0, 0.9, [1.0, 2.0, 3.0, 4.0], [3, 4]) + ] + # The tree currently has depth 2, so the ids for the four leaves are in + # the range [2, 6). In this test case we are assuming that our examples + # only fall in leaves 3 and 4. + + # Grow tree ensemble layer by layer. + grow_op = training_ops.grow_tree_ensemble( + tree_ensemble_handle, + stamp_token=0, + next_stamp_token=1, + learning_rate=0.1, + partition_ids=[handler1_partitions], + gains=[handler1_gains], + splits=[handler1_split], + learner_config=learner_config.SerializeToString(), + dropout_seed=123, + center_bias=True, + max_tree_depth=learner_config.constraints.max_tree_depth, + weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE) + session.run(grow_op) + + new_stamp, serialized = session.run( + model_ops.tree_ensemble_serialize(tree_ensemble_handle)) + stats = session.run( + training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1)) + tree_ensemble_config.ParseFromString(serialized) + expected_result = """ + trees { + nodes { + oblivious_dense_float_binary_split { + feature_column: 4 + threshold: 7 + } + node_metadata { + gain: 7.62 + original_oblivious_leaves { + } + } + } + nodes { + oblivious_dense_float_binary_split { + feature_column: 1 + threshold: 0.23 + } + node_metadata { + gain: 2.7 + original_oblivious_leaves { + vector { + value: 7.143 + } + } + original_oblivious_leaves { + vector { + value: -4.375 + } + } + } + } + nodes { + oblivious_dense_float_binary_split { + feature_column: 0 + threshold: 0.9 + } + node_metadata { + gain: 1.8 + original_oblivious_leaves { + vector { + value: 6.543 + } + } + original_oblivious_leaves { + vector { + value: 7.5 + } + } + original_oblivious_leaves { + vector { + value: -4.075 + } + } + original_oblivious_leaves { + vector { + value: -3.975 + } + } + } + } + nodes { + leaf { + vector { + value: 6.543 + } + } + } + nodes { + leaf { + vector { + value: 6.543 + } + } + } + nodes { + leaf { + vector { + value: 8.5 + } + } + } + nodes { + leaf { + vector { + value: 9.5 + } + } + } + nodes { + leaf { + vector { + value: -1.075 + } + } + } + nodes { + leaf { + vector { + value: -0.075 + } + } + } + nodes { + leaf { + vector { + value: -3.975 + } + } + } + nodes { + leaf { + vector { + value: -3.975 + } + } + } + } + tree_weights: 0.1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 3 + } + growing_metadata { + num_trees_attempted: 1 + num_layers_attempted: 3 + } + """ + self.assertEqual(new_stamp, 1) + self.assertEqual(stats.num_trees, 0) + self.assertEqual(stats.num_layers, 3) + self.assertEqual(stats.active_tree, 1) + self.assertEqual(stats.active_layer, 3) + self.assertEqual(stats.attempted_trees, 1) + self.assertEqual(stats.attempted_layers, 3) + self.assertProtoEquals(expected_result, tree_ensemble_config) + def testGrowExistingEnsembleTreeFinalizedWithDropout(self): """Test growing an existing ensemble with the last tree finalized.""" with self.cached_session() as session: diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc index 58fadffce32f9a8fec047d1e99f9f4eb5a710d91..e57a66b99f6c8e9451a81d920da96e729d02c684 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc @@ -33,7 +33,7 @@ bool IsPartitionEmpty(const BigQueryTablePartition& partition) { Status ParseJson(StringPiece json, Json::Value* result) { Json::Reader reader; - if (!reader.parse(json.ToString(), *result)) { + if (!reader.parse(string(json), *result)) { return errors::Internal("Couldn't parse JSON response from BigQuery."); } return Status::OK(); diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 07934ef3247c6e05323de8ccea70e0264561441f..fb871acae9963978485afef52dbba089aea4fd40 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -247,10 +247,6 @@ tensorflow/contrib/kernel_methods/python tensorflow/contrib/kernel_methods/python/mappers tensorflow/contrib/kinesis/python tensorflow/contrib/kinesis/python/ops -tensorflow/contrib/kfac -tensorflow/contrib/kfac/examples -tensorflow/contrib/kfac/python -tensorflow/contrib/kfac/python/ops tensorflow/contrib/labeled_tensor tensorflow/contrib/labeled_tensor/python tensorflow/contrib/labeled_tensor/python/ops diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD index 855c824ead2f7de4c37db2d2a3648a9ee00fb9e9..4bfd753bb1d1fc254c66a4f7eb1d6ac83a40cb70 100644 --- a/tensorflow/contrib/coder/BUILD +++ b/tensorflow/contrib/coder/BUILD @@ -3,6 +3,7 @@ package(default_visibility = [ "//learning/brain:__subpackages__", + "//research/vision/piedpiper:__subpackages__", "//tensorflow:__subpackages__", ]) diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD index bcee0b04c8430588c2dcbc199504bede0436f8f1..d7583be6d8ed996ac894d3a8601f716cc27bdd86 100644 --- a/tensorflow/contrib/compiler/BUILD +++ b/tensorflow/contrib/compiler/BUILD @@ -8,6 +8,7 @@ package_group( packages = ["//tensorflow/..."], ) +load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( @@ -46,3 +47,36 @@ cuda_py_test( ], xla_enabled = True, ) + +py_library( + name = "xla", + srcs = ["xla.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//tensorflow/python/estimator:model_fn", + ], +) + +tf_py_test( + name = "xla_test", + srcs = ["xla_test.py"], + additional_deps = [ + ":xla", + "@six_archive//:six", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:control_flow_util", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + ], + tags = ["no_pip"], +) diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py new file mode 100644 index 0000000000000000000000000000000000000000..60f5af166234ba69e21a4a64cd3b3c102f66aef4 --- /dev/null +++ b/tensorflow/contrib/compiler/xla.py @@ -0,0 +1,208 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""xla provides experimental xla support API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat + +_XLA_COMPILE_ATTR = '_xla_compile_id' +_MAX_WARNING_LINES = 5 + +# Operations that indicate some error in the users graph. For example, XLA +# computation should not have any Placeholder op. +_BLACKLISTED_OPS = set([ + 'Placeholder', +]) + +# XLA doesn't currently support reading of intermediate tensors, thus some ops +# are not supported. +_UNSUPPORTED_OPS = set([ + 'AudioSummary', + 'AudioSummaryV2', + 'HistogramSummary', + 'ImageSummary', + 'MergeSummary', + 'Print', + 'ScalarSummary', + 'TensorSummary', + 'TensorSummaryV2', +]) + + +class XLACompileContext(control_flow_ops.XLAControlFlowContext): + """A `ControlFlowContext` for nodes inside an XLA computation cluster. + + THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY. + + The primary role of `XLACompileContext` is to mark operators inside a + xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is + a unique name. + + `ControlFlowContext` is used to perform the annotation since it integrates + with Tensorflow constructs like ResourceVariables. For example, if a + `ResourceVariable` is constructed inside a xla.compile() block, the + `ResourceVariable` implementation can use + `with ops.control_dependencies(None)` to build the variable's definition + outside the compiled computation. + """ + + def __init__(self, name, pivot): + """Builds a new XLACompileContext. + + Args: + name: a unique name for the context, used to populate the + `_xla_compile_id` attribute. + pivot: a pivot node. Nodes in the XLACompileContext that do not have any + inputs will have a control dependency on the pivot node. This ensures + that nodes are correctly included in any enclosing control flow + contexts. + """ + super(XLACompileContext, self).__init__() + self._name = name + self._name_as_bytes = compat.as_bytes(name) + self._unsupported_ops = [] + self._pivot = pivot + + def report_unsupported_operations(self): + if self._unsupported_ops: + op_str = '\n'.join([ + ' %s (%s)' % (op.type, op.name) + for op in self._unsupported_ops[:_MAX_WARNING_LINES] + ]) + logging.warning('%d unsupported operations found: \n%s', + len(self._unsupported_ops), op_str) + if len(self._unsupported_ops) > _MAX_WARNING_LINES: + logging.warning('... and %d more', + len(self._unsupported_ops) - _MAX_WARNING_LINES) + + def AddOp(self, op): + """Create op in XLACompileContext and notifies outer context recursively.""" + # pylint: disable=protected-access + if op.type in _BLACKLISTED_OPS: + logging.error( + 'Operation of type %s (%s) is not supported in XLA. Execution will ' + 'fail if this op is used in the graph. ', op.type, op.name) + + # TODO(ycao): Automatically disable summaries instead of reporting them. + if op.type in _UNSUPPORTED_OPS: + self._unsupported_ops.append(op) + + if any(x.dtype._is_ref_dtype for x in op.inputs): + raise NotImplementedError( + 'Non-resource Variables are not supported inside XLA computations ' + '(operator name: %s)' % op.name) + + if _XLA_COMPILE_ATTR in op.node_def.attr: + raise ValueError('XLA compiled computations cannot be nested, (operator ' + 'name: %s)' % op.name) + + op._set_attr( + _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes)) + + op.graph.prevent_feeding(op) + op.graph.prevent_fetching(op) + + # Remove any control edges from outer control flow contexts. These may cause + # mismatched frame errors. An example is when one of op's inputs is + # generated in a different While control flow context. + (internal_control_inputs, + external_control_inputs) = self._RemoveExternalControlEdges(op) + + if not op.inputs: + # Add a control edge from the control pivot to this op. + if not internal_control_inputs: + # pylint: disable=protected-access + op._add_control_input(self._pivot) + # pylint: enable=protected-access + else: + for index in xrange(len(op.inputs)): + x = op.inputs[index] + real_x = self.AddValue(x) + if real_x != x: + op._update_input(index, real_x) # pylint: disable=protected-access + + if external_control_inputs: + # Use an identity to pull control inputs as data inputs. Note that we + # ignore ops which don't have outputs. TODO(phawkins): fix that. + with ops.control_dependencies(None): + self.Enter() + external_control_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_control_inputs + if x.outputs + ] + self.Exit() + # pylint: disable=protected-access + op._add_control_inputs(external_control_inputs) + # pylint: enable=protected-access + + # Mark op's outputs as seen by this context and any outer contexts. + output_names = [x.name for x in op.outputs] + context = self + while context is not None: + # pylint: disable=protected-access + context._values.update(output_names) + context = context._outer_context + # pylint: enable=protected-access + + if self._outer_context: + self._outer_context.AddInnerOp(op) + + def AddValue(self, val): + """Add `val` to the current context and its outer context recursively.""" + if val.name in self._values: + # Use the real value if it comes from outer context. + result = self._external_values.get(val.name) + return val if result is None else result + + result = val + self._values.add(val.name) + if self._outer_context: + result = self._outer_context.AddValue(val) + self._values.add(result.name) + + self._external_values[val.name] = result + + return result + + def AddInnerOp(self, op): + self.AddOp(op) + if self._outer_context: + self._outer_context.AddInnerOp(op) + + @property + def grad_state(self): + # Define the gradient loop state associated with the XLACompileContext to + # be None as the XLACompileContext does not get nested nor does the + # grad_state outside the XLACompileContext affect the graph inside so the + # grad_state should be as if this is the top-level gradient state. + return None + + @property + def back_prop(self): + """Forwards to the enclosing while context, if any.""" + if self.GetWhileContext(): + return self.GetWhileContext().back_prop + return False diff --git a/tensorflow/contrib/compiler/xla_test.py b/tensorflow/contrib/compiler/xla_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a306b56f63bd3b135b0231da89fb2e3445570740 --- /dev/null +++ b/tensorflow/contrib/compiler/xla_test.py @@ -0,0 +1,180 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for contrib.compiler.xla.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.compiler import xla +from tensorflow.python import summary +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import summary_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class XLACompileContextTest(test.TestCase): + + def create_test_xla_compile_context(self): + computation_name = ops.get_default_graph().unique_name('computation') + pivot = control_flow_ops.no_op(name=computation_name + '/pivot') + return xla.XLACompileContext(name=computation_name, pivot=pivot) + + def test_report_unsupported_operations(self): + """Tests that unsupported operations are detected.""" + context = self.create_test_xla_compile_context() + context.Enter() + dummy_tensor = constant_op.constant(1.1) + audio_summary = summary.audio('audio_summary', dummy_tensor, 0.5) + histogram_summary = summary.histogram('histogram_summary', dummy_tensor) + image_summary = summary.image('image_summary', dummy_tensor) + scalar_summary = summary.scalar('scalar_summary', dummy_tensor) + tensor_summary = summary_ops.tensor_summary('tensor_summary', dummy_tensor) + summary.merge( + [ + audio_summary, histogram_summary, image_summary, scalar_summary, + tensor_summary + ], + name='merge_summary') + logging_ops.Print(dummy_tensor, [dummy_tensor], name='print_op') + context.Exit() + + unsupported_ops_names = [op.name for op in context._unsupported_ops] + self.assertEqual(unsupported_ops_names, [ + u'audio_summary', u'histogram_summary', u'image_summary', + u'scalar_summary', u'tensor_summary', u'merge_summary/merge_summary', + u'print_op' + ]) + + def test_resource_variable(self): + """Tests that resource variable usage is allowed.""" + a = variable_scope.get_variable( + name='variable_a', shape=(1), use_resource=True) + + context = self.create_test_xla_compile_context() + context.Enter() + state_ops.assign(a, a + 1) + context.Exit() + + def test_non_resource_variable_error(self): + """Tests that non-resource variable usage is disallowed.""" + a = variable_scope.get_variable( + name='variable_a', shape=(1), use_resource=False) + + context = self.create_test_xla_compile_context() + context.Enter() + with self.assertRaisesRegexp( + NotImplementedError, 'Non-resource Variables are not supported inside ' + r'XLA computations \(operator name: Assign\)'): + state_ops.assign(a, a + 1) + context.Exit() + + def test_nested_xla_compile_error(self): + """Tests that nested XLA computation leads to fatal error.""" + context1 = self.create_test_xla_compile_context() + context1.Enter() + + context2 = self.create_test_xla_compile_context() + context2.Enter() + with self.assertRaisesRegexp(ValueError, + 'XLA compiled computations cannot be nested'): + constant_op.constant(1) + context2.Exit() + context1.Exit() + + def test_xla_compile_attr(self): + """Tests that ops are tagged with XLA compile ID attribute.""" + context = self.create_test_xla_compile_context() + context.Enter() + op = constant_op.constant(1) + context.Exit() + self.assertIn('_xla_compile_id', op.op.node_def.attr) + + def test_op_without_input(self): + """Tests that ops without inputs depend on pivot correctly.""" + context = self.create_test_xla_compile_context() + context.Enter() + op = constant_op.constant(1) + context.Exit() + + self.assertIn(context._pivot, op.op.control_inputs) + + def test_external_control_edges(self): + """Tests that external control edges are handled correctly.""" + i = constant_op.constant(1) + op1 = constant_op.constant(1) + + with ops.control_dependencies([op1]): + op2 = constant_op.constant(1) + self.assertIn(op1.op, op2.op.control_inputs) + + def while_body(i): + del i # unused + context = self.create_test_xla_compile_context() + context.Enter() + with ops.control_dependencies([op1]): + op3 = constant_op.constant(1) + context.Exit() + self.assertNotIn(op1.op, op3.op.control_inputs) + return op3 + + control_flow_ops.while_loop( + cond=lambda i: math_ops.less(i, 10), body=while_body, loop_vars=[i]) + + def test_op_output_marked_as_seen(self): + """Tests that any op output is marked as seen in context.""" + context = self.create_test_xla_compile_context() + context.Enter() + op = constant_op.constant(1) + context.Exit() + + self.assertIn(op.name, context._values) + + def testOpIsInContext(self): + """Tests that XLACompileContext is recognized as an XLA context.""" + op1 = constant_op.constant(1) + context = self.create_test_xla_compile_context() + context.Enter() + op2 = constant_op.constant(2) + context.Exit() + self.assertFalse(control_flow_util.IsInXLAContext(op1.op)) + self.assertTrue(control_flow_util.IsInXLAContext(op2.op)) + + def testOpPreventFeeding(self): + """Tests that ops created inside XLACompileContext can not be fed.""" + context = self.create_test_xla_compile_context() + context.Enter() + op = constant_op.constant(1) + context.Exit() + self.assertFalse(op.graph.is_feedable(op.op)) + + def testOpPreventFetching(self): + """Tests that ops created inside XLACompileContext can not be fetched.""" + context = self.create_test_xla_compile_context() + context.Enter() + op = constant_op.constant(1) + context.Exit() + self.assertFalse(op.graph.is_fetchable(op.op)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 8bdbba83ef6a8541158d956e36caf6a9be435c5b..9f710613dd0d549d4f93bae8780427f7878234a6 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -33,14 +33,22 @@ cc_library( tf_custom_op_library( name = "_dataset_ops.so", - srcs = ["ops/dataset_ops.cc"], - deps = ["//tensorflow/contrib/data/kernels:dataset_kernels"] + - if_static( - extra_deps = [":lib_proto_parsing_for_dataset_ops"], - otherwise = [], - ), + srcs = [ + "ops/dataset_ops.cc", + "ops/indexed_dataset_ops.cc", + ], + deps = [ + "//tensorflow/contrib/data/kernels:dataset_kernels", + "//tensorflow/contrib/data/kernels:indexed_dataset", + ] + if_static( + extra_deps = [":lib_proto_parsing_for_dataset_ops"], + otherwise = [], + ), ) tf_gen_op_libs( - op_lib_names = ["dataset_ops"], + op_lib_names = [ + "dataset_ops", + "indexed_dataset_ops", + ], ) diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 5821d51bca491b1e5c5388c0c82088ca0eb8fed3..5e6c1520a2fc1c21678625c9d4aae04164b198f6 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -25,6 +25,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@Counter @@CheckpointInputPipelineHook @@CsvDataset +@@LMDBDataset @@RandomDataset @@Reducer @@SqlDataset @@ -49,6 +50,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@map_and_batch @@padded_batch_and_drop_remainder @@parallel_interleave +@@parse_example_dataset @@prefetch_to_device @@read_batch_features @@rejection_resample @@ -89,10 +91,12 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator +from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device from tensorflow.contrib.data.python.ops.random_ops import RandomDataset from tensorflow.contrib.data.python.ops.readers import CsvDataset +from tensorflow.contrib.data.python.ops.readers import LMDBDataset from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset from tensorflow.contrib.data.python.ops.readers import make_csv_dataset from tensorflow.contrib.data.python.ops.readers import read_batch_features diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index 2e249f5c14ab111ae412ff3288acc25de8d7aa11..ec6cb37193cdfbc888df5dc6787854241daea621 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -6,6 +6,31 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) +cc_library( + name = "indexed_dataset_headers", + hdrs = ["indexed_dataset.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], +) + +cc_library( + name = "indexed_dataset", + srcs = [ + "identity_indexed_dataset.cc", + "indexed_dataset.cc", + ], + deps = [ + ":indexed_dataset_headers", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + cc_library( name = "prefetching_kernels", srcs = ["prefetching_kernels.cc"], @@ -51,6 +76,17 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "lmdb_dataset_op", + srcs = ["lmdb_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@lmdb", + "@protobuf_archive//:protobuf_headers", + ], +) + cc_library( name = "threadpool_dataset_op", srcs = ["threadpool_dataset_op.cc"], @@ -91,6 +127,8 @@ cc_library( ":csv_dataset_op", ":directed_interleave_dataset_op", ":ignore_errors_dataset_op", + ":indexed_dataset", + ":lmdb_dataset_op", ":prefetching_kernels", ":threadpool_dataset_op", ":unique_dataset_op", diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index d242cfdf4911ee43051b8aa2f7b960916b40374a..0ba905b92e2d9a14128b540028687955bd96f2f0 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -713,7 +713,7 @@ class CSVDatasetOp : public DatasetOpKernel { component.scalar()() = dataset()->record_defaults_[output_idx].flat()(0); } else { - component.scalar()() = field.ToString(); + component.scalar()() = string(field); } break; } diff --git a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc b/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc new file mode 100644 index 0000000000000000000000000000000000000000..4718c1c8b9d77b5dbac2a8caf11d9a0604af94c2 --- /dev/null +++ b/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc @@ -0,0 +1,153 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/data/kernels/indexed_dataset.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel { + public: + using IndexedDatasetOpKernel::IndexedDatasetOpKernel; + + void MakeIndexedDataset(OpKernelContext* ctx, + IndexedDataset** output) override { + uint64 size = -1; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "size", &size)); + OP_REQUIRES(ctx, size > 0, errors::InvalidArgument("`size` must be > 0")); + *output = new Dataset(ctx, size); + } + + class Dataset : public IndexedDataset { + public: + Dataset(OpKernelContext* ctx, uint64 size) + : IndexedDataset(DatasetContext(ctx)), size_(size) {} + + Status MaterializeDataset( + std::shared_ptr* materialized) override { + materialized->reset(new Materialized(this)); + return Status::OK(); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_UINT64}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::IdentityIndexedDataset")})); + } + + string DebugString() const override { + return "IdentityIndexedDataset::Dataset"; + } + + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const override { + return errors::Unimplemented( + "identity_indexed_dataset.AsGraphDefInternal"); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (cur_ < dataset()->size_) { + Tensor result_tensor(ctx->allocator({}), DT_UINT64, {}); + result_tensor.scalar()() = cur_++; + out_tensors->emplace_back(std::move(result_tensor)); + *end_of_sequence = false; + return Status::OK(); + } + *end_of_sequence = true; + return Status::OK(); + } + + private: + mutex mu_; + uint64 cur_ GUARDED_BY(mu_); + }; + + class Materialized : public MaterializedIndexedDataset { + public: + explicit Materialized(Dataset* dataset) : dataset_(dataset) { + dataset->Ref(); + } + + ~Materialized() override { + // TODO(saeta): Pull this into MaterializedIndexedDataset + dataset_->Unref(); + } + + const DataTypeVector& output_dtypes() const override { + return dataset_->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return dataset_->output_shapes(); + } + + Status Get(IteratorContext&& ctx, uint64 index, + std::vector* out_tensors) const override { + LOG(INFO) << "Materialized(" << dataset_->size_ << ")::Get(" << index + << ")"; + if (index >= dataset_->size_) { + // Note: use InvalidArgument instead of OutOfRange error because many + // things consider OutOfRange to be a "clean termination" error. + return errors::InvalidArgument( + "Index ", index, + " is out of range for this dataset. (Size is: ", dataset_->size_, + ".)"); + } + Tensor result_tensor(ctx.allocator({}), DT_UINT64, {}); + result_tensor.scalar()() = index; + out_tensors->emplace_back(std::move(result_tensor)); + return Status::OK(); + } + + Status Size(uint64* size) const override { + *size = dataset_->size_; + return Status::OK(); + } + + private: + const Dataset* const dataset_; // Not owned. + }; + + const uint64 size_; + std::shared_ptr materialized_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("IdentityIndexedDataset").Device(DEVICE_CPU), + IdentityIndexedDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.cc b/tensorflow/contrib/data/kernels/indexed_dataset.cc new file mode 100644 index 0000000000000000000000000000000000000000..c69564a31bbc3a07ff56e0da564e7e1b8323f464 --- /dev/null +++ b/tensorflow/contrib/data/kernels/indexed_dataset.cc @@ -0,0 +1,372 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/data/kernels/indexed_dataset.h" + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" + +namespace tensorflow { + +namespace { + +Status VerifyTypesMatch(const DataTypeVector& expected, + const DataTypeVector& received) { + if (expected.size() != received.size()) { + return errors::InvalidArgument( + "Number of components does not match: expected ", expected.size(), + " types but got ", received.size(), "."); + } + for (size_t i = 0; i < expected.size(); ++i) { + if (expected[i] != received[i]) { + return errors::InvalidArgument("Data type mismatch at component ", i, + ": expected ", DataTypeString(expected[i]), + " but got ", DataTypeString(received[i]), + "."); + } + } + return Status::OK(); +} + +Status VerifyShapesCompatible(const std::vector& expected, + const std::vector& received) { + if (expected.size() != received.size()) { + return errors::InvalidArgument( + "Number of components does not match: expected ", expected.size(), + " shapes but got ", received.size(), "."); + } + for (size_t i = 0; i < expected.size(); ++i) { + if (!expected[i].IsCompatibleWith(received[i])) { + return errors::InvalidArgument("Incompatible shapes at component ", i, + ": expected ", expected[i].DebugString(), + " but got ", received[i].DebugString(), + "."); + } + } + + return Status::OK(); +} + +class MaterializedDatasetResource : public ResourceBase { + public: + MaterializedDatasetResource( + const DataTypeVector& output_dtypes, + const std::vector& output_shapes) + : output_dtypes_(output_dtypes), output_shapes_(output_shapes) {} + + string DebugString() override { + return "Materialized IndexedDataset resource"; + } + + Status Get(IteratorContext&& ctx, uint64 index, + std::vector* out_tensors) { + std::shared_ptr captured(materialized_); + if (captured) { + return captured->Get(std::move(ctx), index, out_tensors); + } else { + return errors::FailedPrecondition( + "Get() failed because the MaterializedIndexedDataset has not been " + "initialized. Ensure that you have run the materialization operation " + "for this MaterializedIndexedDataset before retrieving elements."); + } + } + + // TODO(saeta): Implement Save and Restore + + const DataTypeVector& output_dtypes() const { return output_dtypes_; } + const std::vector& output_shapes() const { + return output_shapes_; + } + + Status set_materialized_dataset( + const std::shared_ptr& dataset) { + if (dataset) { + TF_RETURN_IF_ERROR( + VerifyTypesMatch(output_dtypes_, dataset->output_dtypes())); + TF_RETURN_IF_ERROR( + VerifyShapesCompatible(output_shapes_, dataset->output_shapes())); + } + materialized_ = dataset; + return Status::OK(); + } + + private: + std::shared_ptr materialized_; + const DataTypeVector output_dtypes_; + const std::vector output_shapes_; +}; + +// A wrapper class for storing an `IndexedDataset` instance in a DT_VARIANT +// tensor. Objects of the wrapper class own a reference on an instance of an +// `IndexedTensor` and the wrapper's copy constructor and desctructor take care +// of managing the reference count. +// +// NOTE: This is not a feature-complete implementation of the DT_VARIANT +// specification. In particular, we cannot currently serialize an arbitrary +// `IndexedDataset` object, so the `Encode()` and `Decode()` methods are not +// implemented. +// +// NOTE(saeta): When `IndexedDataset`s get merged into core, we can instead just +// use `tensorflow::DatasetVariantWrapper`. +class IndexedDatasetVariantWrapper { + public: + IndexedDatasetVariantWrapper() : dataset_(nullptr) {} + + // Transfers ownership of `dataset` to `*this`. + explicit IndexedDatasetVariantWrapper(IndexedDataset* dataset) + : dataset_(dataset) {} + + IndexedDatasetVariantWrapper(const IndexedDatasetVariantWrapper& other) + : dataset_(other.dataset_) { + if (dataset_) dataset_->Ref(); + } + + ~IndexedDatasetVariantWrapper() { + if (dataset_) dataset_->Unref(); + } + + IndexedDataset* get() const { return dataset_; } + + string TypeName() const { return "tensorflow::IndexedDatasetVariantWrapper"; } + string DebugString() const { + if (dataset_) { + return dataset_->DebugString(); + } else { + return ""; + } + } + + void Encode(VariantTensorData* data) const { + LOG(ERROR) << "The Encode() method is not implemented for " + "IndexedDatasetVariantWrapper objects."; + } + + bool Decode(const VariantTensorData& data) { + LOG(ERROR) << "The Decode() method is not implemented for " + "IndexedDatasetVariantWrapper objects."; + return false; + } + + private: + IndexedDataset* const dataset_; // Owns one reference. +}; + +} // namespace + +Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor, + IndexedDataset** out_dataset) { + if (!(tensor.dtype() == DT_VARIANT || + TensorShapeUtils::IsScalar(tensor.shape()))) { + return errors::InvalidArgument( + "IndexedDataset tensor must be a scalar of dtype DT_VARIANT."); + } + const Variant& variant = tensor.scalar()(); + const IndexedDatasetVariantWrapper* wrapper = + variant.get(); + if (wrapper == nullptr) { + return errors::InvalidArgument("Tensor must be an IndexedDataset object."); + } + *out_dataset = wrapper->get(); + if (*out_dataset == nullptr) { + return errors::Internal("Read uninitialized IndexedDataset variant."); + } + return Status::OK(); +} + +Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset, + Tensor* tensor) { + if (!(tensor->dtype() == DT_VARIANT || + TensorShapeUtils::IsScalar(tensor->shape()))) { + return errors::InvalidArgument( + "Dataset tensor must be a scalar of dtype DT_VARIANT."); + } + tensor->scalar()() = IndexedDatasetVariantWrapper(dataset); + return Status::OK(); +} + +void IndexedDatasetOpKernel::Compute(OpKernelContext* ctx) { + IndexedDataset* dataset = nullptr; + MakeIndexedDataset(ctx, &dataset); + + if (ctx->status().ok()) { + OP_REQUIRES(ctx, dataset != nullptr, + errors::Internal("MakeIndexedDataset did not correctly " + "construct the IndexedDataset")); + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); + OP_REQUIRES_OK(ctx, StoreIndexedDatasetInVariantTensor(dataset, output)); + } +} + +namespace { + +class MaterializedHandleOp : public OpKernel { + public: + explicit MaterializedHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + ~MaterializedHandleOp() override { + if (resource_ != nullptr) { + resource_->Unref(); + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->template Delete( + cinfo_.container(), cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + // Note: cargo-culted from $tf/core/framework/resource_op_kernel.h + } + } + } + } + + void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + if (resource_ == nullptr) { + ResourceMgr* mgr = context->resource_manager(); + OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); + + MaterializedDatasetResource* resource; + OP_REQUIRES_OK(context, + mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this](MaterializedDatasetResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new MaterializedDatasetResource( + output_dtypes_, output_shapes_); + return Status::OK(); + })); + Status s = VerifyResource(resource); + if (TF_PREDICT_FALSE(!s.ok())) { + resource->Unref(); + context->SetStatus(s); + return; + } + + resource_ = resource; + } + } + OP_REQUIRES_OK(context, MakeResourceHandleToOutput( + context, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + // During the first Compute(), resource is either created or looked up using + // shared_name. In the latter case, the resource found should be verified if + // it is compatible with this op's configuration. The verification may fail in + // cases such as two graphs asking queues of the same shared name to have + // inconsistent capacities. + Status VerifyResource(MaterializedDatasetResource* resource) { + TF_RETURN_IF_ERROR( + VerifyTypesMatch(output_dtypes_, resource->output_dtypes())); + TF_RETURN_IF_ERROR( + VerifyShapesCompatible(output_shapes_, resource->output_shapes())); + return Status::OK(); + } + + mutex mu_; + ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. + MaterializedDatasetResource* resource_ GUARDED_BY(mu_) = nullptr; + DataTypeVector output_dtypes_; + std::vector output_shapes_; +}; + +// TODO(saeta): Make async. +class MaterializeDatasetOp : public OpKernel { + public: + explicit MaterializeDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + IndexedDataset* dataset; + OP_REQUIRES_OK(ctx, + GetIndexedDatasetFromVariantTensor(ctx->input(0), &dataset)); + + MaterializedDatasetResource* materialized_resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), + &materialized_resource)); + core::ScopedUnref unref(materialized_resource); + std::shared_ptr materialized; + OP_REQUIRES_OK(ctx, dataset->MaterializeDataset(&materialized)); + OP_REQUIRES_OK( + ctx, materialized_resource->set_materialized_dataset(materialized)); + } +}; + +// TODO(saeta): Make async +class IndexedDatasetGet : public OpKernel { + public: + explicit IndexedDatasetGet(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + MaterializedDatasetResource* materialized_resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), + &materialized_resource)); + auto cleanup = gtl::MakeCleanup([materialized_resource] { + materialized_resource->Unref(); // Note: can't use core::ScopedUnref. + }); + + const Tensor* index_t; + OP_REQUIRES_OK(ctx, ctx->input("index", &index_t)); + // TODO(saeta): Support batch reads (indexes should be non-scalar!) + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(index_t->shape()), + errors::InvalidArgument("index must be a scalar")); + const uint64 index = index_t->scalar()(); + + std::vector out_tensors; + Status s = + materialized_resource->Get(IteratorContext(ctx), index, &out_tensors); + + // Note: Unref materialized_resource to avoid destruction races. (Important + // in a [future] async op implementation.) + cleanup.release()(); + + if (!s.ok()) { + ctx->SetStatus(s); + } else { + auto expected_shapes = materialized_resource->output_shapes(); + auto expected_types = materialized_resource->output_dtypes(); + for (size_t i = 0; i < out_tensors.size(); ++i) { + OP_REQUIRES( + ctx, expected_shapes[i].IsCompatibleWith(out_tensors[i].shape()), + errors::Internal( + "Materialized dataset output at index ", i, + " is incompatible with the expected shape. (Expected: ", + expected_shapes[i], ", got: ", out_tensors[i].shape(), ")")); + OP_REQUIRES(ctx, out_tensors[i].dtype() == expected_types[i], + errors::Internal("Materialized dataset output at index ", i, + " was not the expected dtype. (Expected: ", + expected_types[i], + ", got: ", out_tensors[i].dtype(), ")")); + ctx->set_output(i, out_tensors[i]); + } + } + } +}; + +REGISTER_KERNEL_BUILDER( + Name("MaterializedIndexDatasetHandle").Device(DEVICE_CPU), + MaterializedHandleOp); +REGISTER_KERNEL_BUILDER(Name("IndexedDatasetMaterialize").Device(DEVICE_CPU), + MaterializeDatasetOp); +REGISTER_KERNEL_BUILDER(Name("IndexedDatasetGet").Device(DEVICE_CPU), + IndexedDatasetGet); +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.h b/tensorflow/contrib/data/kernels/indexed_dataset.h new file mode 100644 index 0000000000000000000000000000000000000000..6149de888cc0a966ead48c790074d63ca028f1e8 --- /dev/null +++ b/tensorflow/contrib/data/kernels/indexed_dataset.h @@ -0,0 +1,117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_ +#define TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +// TODO(saeta): Urgh, this is ugly. +class MaterializedIndexedDataset { + public: + virtual ~MaterializedIndexedDataset() = default; + + // Retrieve the element at a given index. The output tensors are stored in + // out_tensors. + // + // If `index` is greater than `Size()`, tensorflow::errors::OutOfRangeError is + // returned. + // + // Get is thread-safe. + virtual Status Get(IteratorContext&& ctx, uint64 index, + std::vector* out_tensors) const = 0; + + // Size determines the number of elements in this IndexedDataset. + // + // Size is thread-safe. + virtual Status Size(uint64* size) const = 0; + + // Returns a vector of DataType values, representing the respective + // element types of each tuple component in the outputs of this dataset. + virtual const DataTypeVector& output_dtypes() const = 0; + + // Returns a vector of tensor shapes, representing the respective + // (and possibly partially defined) shapes of each tuple component + // in the outputs of this dataset. + virtual const std::vector& output_shapes() const = 0; +}; + +// IndexedDataset represents a dataset that supports random access in addition +// to iterator-based sequential access. +// +// Note: IndexedDatasets are HIGHLY experimental at this time. Expect +// significant (backwards incompatible) changes! +class IndexedDataset : public DatasetBase { + public: + IndexedDataset(DatasetContext&& ctx) : DatasetBase(std::move(ctx)) {} + + // Materialize (if necessary) the dataset, and return a pointer. + // TODO(saeta): Add in `IteratorContext* ctx` when materializing. + virtual Status MaterializeDataset( + std::shared_ptr* materialized) = 0; +}; + +// IndexedDatasetOpKernel abstracts away interfacing IndexedDatasets with the +// rest of the TensorFlow runtime. +// +// Most IndexedDataset's will be private members of classes inheriting from this +// class. +class IndexedDatasetOpKernel : public OpKernel { + public: + IndexedDatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) final; + + protected: + // Subclasses should implement this method. It will be called during Compute + // execution. + virtual void MakeIndexedDataset(OpKernelContext* ctx, + IndexedDataset** output) = 0; + + template + Status ParseScalarArgument(OpKernelContext* ctx, + const StringPiece& argument_name, T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); + } +}; + +// Validates and extracts an `IndexedDataset` object from `tensor`. +// +// `tensor` must have been written by a call to +// `StoreIndexedDatasetInVariantTensor` +// +// The retrieved pointer isa borrowed reference to the dataset, which is owned +// by the tensor. The consumer must either acquire its own reference to the +// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not +// destroyed or mutated while the retrieved pointer is in use. +Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor, + IndexedDataset** out_dataset); + +// Stores an `IndexedDataset` object in `tensor.` +// +// The ownership of `dataset` is transferred to `tensor`. +Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset, + Tensor* tensor); + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_ diff --git a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..80f39992fbb1ff1395c308f00a5d02903d368891 --- /dev/null +++ b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc @@ -0,0 +1,215 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow/core/platform/file_system.h" + +#include "lmdb.h" // NOLINT(build/include) + +namespace tensorflow { +namespace { + +class LMDBDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* filenames_tensor; + OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); + OP_REQUIRES( + ctx, filenames_tensor->dims() <= 1, + errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + + std::vector filenames; + filenames.reserve(filenames_tensor->NumElements()); + for (int i = 0; i < filenames_tensor->NumElements(); ++i) { + filenames.push_back(filenames_tensor->flat()(i)); + } + + *output = new Dataset(ctx, filenames); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const std::vector& filenames) + : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {} + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::LMDB")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = + new DataTypeVector({DT_STRING, DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}, {}}); + return *shapes; + } + + string DebugString() const override { return "LMDBDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* filenames = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + do { + if (mdb_cursor_) { + Tensor key_tensor(ctx->allocator({}), DT_STRING, {}); + key_tensor.scalar()() = string( + static_cast(mdb_key_.mv_data), mdb_key_.mv_size); + out_tensors->emplace_back(std::move(key_tensor)); + + Tensor value_tensor(ctx->allocator({}), DT_STRING, {}); + value_tensor.scalar()() = + string(static_cast(mdb_value_.mv_data), + mdb_value_.mv_size); + out_tensors->emplace_back(std::move(value_tensor)); + + int val; + val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT); + if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { + return errors::InvalidArgument(mdb_strerror(val)); + } + if (val == MDB_NOTFOUND) { + ResetStreamsLocked(); + ++current_file_index_; + } + *end_of_sequence = false; + return Status::OK(); + } + if (current_file_index_ == dataset()->filenames_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + return errors::Unimplemented( + "Checkpointing is currently not supported for LMDBDataset."); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented( + "Checkpointing is currently not supported for LMDBDataset."); + } + + private: + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_file_index_ >= dataset()->filenames_.size()) { + return errors::InvalidArgument( + "current_file_index_:", current_file_index_, + " >= filenames_.size():", dataset()->filenames_.size()); + } + const string& filename = dataset()->filenames_[current_file_index_]; + + int val = mdb_env_create(&mdb_env_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK; + + struct stat source_stat; + if (stat(filename.c_str(), &source_stat) == 0 && + (source_stat.st_mode & S_IFREG)) { + flags |= MDB_NOSUBDIR; + } + val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST); + if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { + return errors::InvalidArgument(mdb_strerror(val)); + } + if (val == MDB_NOTFOUND) { + ResetStreamsLocked(); + } + return Status::OK(); + } + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (mdb_env_ != nullptr) { + if (mdb_cursor_) { + mdb_cursor_close(mdb_cursor_); + mdb_cursor_ = nullptr; + } + mdb_dbi_close(mdb_env_, mdb_dbi_); + mdb_txn_abort(mdb_txn_); + mdb_env_close(mdb_env_); + mdb_txn_ = nullptr; + mdb_dbi_ = 0; + mdb_env_ = nullptr; + } + } + mutex mu_; + size_t current_file_index_ GUARDED_BY(mu_) = 0; + MDB_env* mdb_env_ GUARDED_BY(mu_) = nullptr; + MDB_txn* mdb_txn_ GUARDED_BY(mu_) = nullptr; + MDB_dbi mdb_dbi_ GUARDED_BY(mu_) = 0; + MDB_cursor* mdb_cursor_ GUARDED_BY(mu_) = nullptr; + + MDB_val mdb_key_ GUARDED_BY(mu_); + MDB_val mdb_value_ GUARDED_BY(mu_); + }; + + const std::vector filenames_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index 74df1e42a8fbca9b6a65aa4800424d27aa90de24..725f8933c94cb42339556f63982d69d1bf0bb504 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -548,7 +548,9 @@ class MultiDeviceIterator : public ResourceBase { devices_(devices), flib_def_(std::move(flib_def)), pflr_(std::move(pflr)), - lib_(lib) {} + lib_(lib) { + CHECK_NOTNULL(lib_); + } string DebugString() override { return strings::StrCat("MultiDeviceIterator for ", devices_.size(), @@ -600,6 +602,11 @@ class MultiDeviceIterator : public ResourceBase { return lib_def_; } + FunctionLibraryRuntime* const lib() { + tf_shared_lock l(mu_); + return lib_; + } + private: // A private class that uses a background thread to keep a per device buffer // full. @@ -930,8 +937,10 @@ class MultiDeviceIteratorInitOp : public OpKernel { core::ScopedUnref unref(resource); std::unique_ptr iterator; - OP_REQUIRES_OK(ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator", - &iterator)); + IteratorContext iter_ctx(ctx); + iter_ctx.set_lib(resource->lib()); + OP_REQUIRES_OK( + ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator)); int64 incarnation_id; OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size, &incarnation_id)); diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index cc5e250ea15bf89be2db9aba14e3b29b72512a73..ae104d55bd813fdbc9829ccbc274612a112c8e1d 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -266,4 +266,13 @@ REGISTER_OP("AssertNextDataset") return shape_inference::ScalarShape(c); }); +REGISTER_OP("LMDBDataset") + .Input("filenames: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + } // namespace tensorflow diff --git a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc b/tensorflow/contrib/data/ops/indexed_dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..cd9b7c68a04a33ca6dec1e9088c3606deebdb7f4 --- /dev/null +++ b/tensorflow/contrib/data/ops/indexed_dataset_ops.cc @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("IdentityIndexedDataset") + .Input("size: uint64") + .Output("handle: variant") + .SetIsStateful() + .SetShapeFn( + shape_inference::ScalarShape); // TODO(saeta): check input shapes. + +/////////////////////////////////////////////////////////////////////////////// +// IndexedDataset Internals +/////////////////////////////////////////////////////////////////////////////// + +// Creates the handle. +REGISTER_OP("MaterializedIndexDatasetHandle") + .Output("handle: resource") + .Attr("container: string") + .Attr("shared_name: string") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + +// Actually materialize the materialize handle. +REGISTER_OP("IndexedDatasetMaterialize") + .Input("dataset: variant") + .Input("materialized: resource") + .SetShapeFn(shape_inference::NoOutputs); + +namespace { + +Status GetShapeFn(shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + std::vector output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as `output_types` (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast(i), output_shape_handle); + } + return Status::OK(); +} + +} // namespace + +REGISTER_OP("IndexedDatasetGet") + .Input("materialized: resource") + .Input("index: uint64") + .Output("components: output_types") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(GetShapeFn) + .Doc(R"doc( +Gets the element at `index` from `materialized` IndexedDataset. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index cd46e382ebb1e6174be427a2c51f3492aeabf805..b86a543fc3f9504059dde3717ce0492441cd434a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -4,7 +4,8 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_test") py_test( name = "batch_dataset_op_test", @@ -133,6 +134,21 @@ py_test( ], ) +py_test( + name = "indexed_dataset_ops_test", + srcs = ["indexed_dataset_ops_test.py"], + deps = [ + "//tensorflow/contrib/data/python/ops:contrib_op_loader", + "//tensorflow/contrib/data/python/ops:gen_dataset_ops", + "//tensorflow/contrib/data/python/ops:indexed_dataset_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + py_test( name = "interleave_dataset_op_test", size = "medium", @@ -178,6 +194,31 @@ py_test( ], ) +py_test( + name = "lmdb_dataset_op_test", + size = "medium", + srcs = ["lmdb_dataset_op_test.py"], + data = ["//tensorflow/core:lmdb_testdata"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "no_windows", + ], + deps = [ + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + "//third_party/py/numpy", + ], +) + py_test( name = "map_dataset_op_test", size = "medium", @@ -204,6 +245,25 @@ py_test( ], ) +py_test( + name = "filter_dataset_op_test", + size = "medium", + srcs = ["filter_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + py_test( name = "map_defun_op_test", size = "small", @@ -229,19 +289,32 @@ py_test( srcs = ["optimize_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ - ":stats_dataset_test_base", - ":test_utils", "//tensorflow/contrib/data/python/ops:optimization", - "//tensorflow/contrib/data/python/ops:stats_ops", - "//tensorflow/python:check_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "parsing_ops_test", + size = "small", + srcs = ["parsing_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:parsing_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", "//tensorflow/python/data/ops:dataset_ops", - "@absl_py//absl/testing:parameterized", + "//tensorflow/python/data/util:nest", + "//third_party/py/numpy", ], ) @@ -331,6 +404,7 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:string_ops", "//tensorflow/python/data/ops:readers", + "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6d01bf585c077ba7b24212c6f8e5f603b00d64cc --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py @@ -0,0 +1,76 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmarks FilterDataset input pipeline op.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np + +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class FilterBenchmark(test.Benchmark): + + # This benchmark compares the performance of pipeline with multiple chained + # filter with and without filter fusion. + def benchmarkFilters(self): + chain_lengths = [0, 1, 2, 5, 10, 20, 50] + for chain_length in chain_lengths: + self._benchmarkFilters(chain_length, False) + self._benchmarkFilters(chain_length, True) + + def _benchmarkFilters(self, chain_length, optimize_dataset): + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors(5).repeat(None) + for _ in range(chain_length): + dataset = dataset.filter(lambda x: math_ops.greater_equal(x - 5, 0)) + if optimize_dataset: + dataset = dataset.apply(optimization.optimize(["filter_fusion"])) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(10): + sess.run(next_element.op) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + opt_mark = "opt" if optimize_dataset else "no-opt" + print("Filter dataset {} chain length: {} Median wall time: {}".format( + opt_mark, chain_length, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_filter_dataset_chain_latency_{}_{}".format( + opt_mark, chain_length)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..db2ab815eeebb77c159ca8c7d0d9920f2bdcdabd --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py @@ -0,0 +1,78 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for experimental indexed dataset ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.contrib.data.python.ops import indexed_dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class IndexedDatasetOpsTest(test.TestCase): + + def testLowLevelIndexedDatasetOps(self): + identity = gen_dataset_ops.identity_indexed_dataset( + ops.convert_to_tensor(16, dtype=dtypes.uint64)) + handle = gen_dataset_ops.materialized_index_dataset_handle( + container="", + shared_name="", + output_types=[dtypes.uint64], + output_shapes=[[]]) + materialize = gen_dataset_ops.indexed_dataset_materialize(identity, handle) + index = array_ops.placeholder(dtypes.uint64) + get_op = gen_dataset_ops.indexed_dataset_get( + handle, index, output_types=[dtypes.uint64], output_shapes=[[]]) + + with self.test_session() as sess: + sess.run(materialize) + self.assertEqual([3], sess.run(get_op, feed_dict={index: 3})) + + def testIdentityIndexedDataset(self): + ds = indexed_dataset_ops.IdentityIndexedDataset(16) + materialized = ds.materialize() + with self.test_session() as sess: + sess.run(materialized.initializer) + placeholder = array_ops.placeholder(dtypes.uint64, shape=[]) + for i in range(16): + output = sess.run( + materialized.get(placeholder), feed_dict={placeholder: i}) + self.assertEqual([i], output) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(materialized.get(placeholder), feed_dict={placeholder: 16}) + + @unittest.skip("Requisite functionality currently unimplemented.") + def testIdentityIndexedDatasetIterator(self): + ds = indexed_dataset_ops.IdentityIndexedDataset(16) + itr = ds.make_initializable_iterator() + n = itr.get_next() + with self.test_session() as sess: + sess.run(itr.initializer) + for i in range(16): + output = sess.run(n) + self.assertEqual(i, output) + with self.assertRaises(errors.OutOfRangeError): + sess.run(n) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7bc582ebaa50c7418e7624a1a389f002f2cea395 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for LMDBDatasetOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.platform import test +from tensorflow.python.util import compat + +prefix_path = "tensorflow/core/lib" + + +class LMDBDatasetTest(test.TestCase): + + def setUp(self): + super(LMDBDatasetTest, self).setUp() + # Copy database out because we need the path to be writable to use locks. + path = os.path.join(prefix_path, "lmdb", "testdata", "data.mdb") + self.db_path = os.path.join(self.get_temp_dir(), "data.mdb") + shutil.copy(path, self.db_path) + + def testReadFromFile(self): + filename = self.db_path + + filenames = constant_op.constant([filename], dtypes.string) + num_repeats = 2 + + dataset = readers.LMDBDataset(filenames).repeat(num_repeats) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for _ in range(num_repeats): # Dataset is repeated. + for i in range(10): # 10 records. + k = compat.as_bytes(str(i)) + v = compat.as_bytes(str(chr(ord("a") + i))) + self.assertEqual((k, v), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..b299e0736fb29d0936680e5905172b0fa95ac586 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD @@ -0,0 +1,61 @@ +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_test( + name = "map_vectorization_test", + size = "small", + srcs = ["map_vectorization_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/kernel_tests:test_utils", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:session", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "map_and_filter_fusion_test", + size = "medium", + srcs = ["map_and_filter_fusion_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python/data/ops:dataset_ops", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "latency_all_edges_test", + size = "small", + srcs = ["latency_all_edges_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1850b6921af0aae8d26fbdfd165fd0e087134e6d --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the LatencyAllEdges optimization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.contrib.data.python.ops import stats_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): + + def testLatencyStatsOptimization(self): + + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.from_tensors(1).apply( + optimization.assert_next( + ["LatencyStats", "Map", "LatencyStats", "Prefetch", + "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply( + optimization.optimize(["latency_all_edges"])).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + get_next = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + self.assertEqual(1 * 1, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, + "record_latency_TensorDataset/_1", 1) + self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4", + 1) + self._assertSummaryHasCount(summary_str, + "record_latency_PrefetchDataset/_6", 1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py new file mode 100644 index 0000000000000000000000000000000000000000..586b4bee5fcb1d8de44e8bc5e78cc21e15870a5c --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py @@ -0,0 +1,224 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapAndFilterFusion optimization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): + + @staticmethod + def map_functions(): + identity = lambda x: x + increment = lambda x: x + 1 + + def increment_and_square(x): + y = x + 1 + return y * y + + functions = [identity, increment, increment_and_square] + tests = [] + for i, fun1 in enumerate(functions): + for j, fun2 in enumerate(functions): + tests.append(( + "test_{}_{}".format(i, j), + [fun1, fun2], + )) + for k, fun3 in enumerate(functions): + tests.append(( + "test_{}_{}_{}".format(i, j, k), + [fun1, fun2, fun3], + )) + + swap = lambda x, n: (n, x) + tests.append(( + "swap1", + [lambda x: (x, 42), swap], + )) + tests.append(( + "swap2", + [lambda x: (x, 42), swap, swap], + )) + return tuple(tests) + + @parameterized.named_parameters(*map_functions.__func__()) + def testMapFusion(self, functions): + dataset = dataset_ops.Dataset.range(5).apply( + optimization.assert_next(["Map", "Prefetch"])) + for function in functions: + dataset = dataset.map(function) + + dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.test_session() as sess: + for x in range(5): + result = sess.run(get_next) + r = x + for function in functions: + if isinstance(r, tuple): + r = function(*r) # Pass tuple as multiple arguments. + else: + r = function(r) + self.assertAllEqual(r, result) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + @staticmethod + def map_and_filter_functions(): + identity = lambda x: x + increment = lambda x: x + 1 + minus_five = lambda x: x - 5 + + def increment_and_square(x): + y = x + 1 + return y * y + + take_all = lambda x: constant_op.constant(True) + is_zero = lambda x: math_ops.equal(x, 0) + is_odd = lambda x: math_ops.equal(x % 2, 0) + greater = lambda x: math_ops.greater(x + 5, 0) + + functions = [identity, increment, minus_five, increment_and_square] + filters = [take_all, is_zero, is_odd, greater] + tests = [] + + for x, fun in enumerate(functions): + for y, predicate in enumerate(filters): + tests.append(("mixed_{}_{}".format(x, y), fun, predicate)) + + # Multi output + tests.append(("multiOne", lambda x: (x, x), + lambda x, y: constant_op.constant(True))) + tests.append( + ("multiTwo", lambda x: (x, 2), + lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0))) + return tuple(tests) + + @parameterized.named_parameters(*map_and_filter_functions.__func__()) + def testMapFilterFusion(self, function, predicate): + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", + "FilterByLastComponent"])).map(function).filter(predicate).apply( + optimization.optimize(["map_and_filter_fusion"])) + self._testMapAndFilter(dataset, function, predicate) + + def _testMapAndFilter(self, dataset, function, predicate): + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.test_session() as sess: + for x in range(10): + r = function(x) + if isinstance(r, tuple): + b = predicate(*r) # Pass tuple as multiple arguments. + else: + b = predicate(r) + if sess.run(b): + result = sess.run(get_next) + self.assertAllEqual(r, result) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testAdditionalInputs(self): + a = constant_op.constant(3, dtype=dtypes.int64) + b = constant_op.constant(4, dtype=dtypes.int64) + some_tensor = math_ops.mul(a, b) + function = lambda x: x * x + + def predicate(y): + return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor) + + # We are currently not supporting functions with additional inputs. + dataset = dataset_ops.Dataset.range(10).apply( + optimization.assert_next( + ["Map", "Filter"])).map(function).filter(predicate).apply( + optimization.optimize(["map_and_filter_fusion"])) + + self._testMapAndFilter(dataset, function, predicate) + + @staticmethod + def filter_functions(): + take_all = lambda x: constant_op.constant(True) + is_zero = lambda x: math_ops.equal(x, 0) + greater = lambda x: math_ops.greater(x + 5, 0) + + tests = [] + filters = [take_all, is_zero, greater] + identity = lambda x: x + for x, predicate_1 in enumerate(filters): + for y, predicate_2 in enumerate(filters): + tests.append(("mixed_{}_{}".format(x, y), identity, + [predicate_1, predicate_2])) + for z, predicate_3 in enumerate(filters): + tests.append(("mixed_{}_{}_{}".format(x, y, z), identity, + [predicate_1, predicate_2, predicate_3])) + + take_all_multiple = lambda x, y: constant_op.constant(True) + # Multi output + tests.append(("multiOne", lambda x: (x, x), + [take_all_multiple, take_all_multiple])) + tests.append(("multiTwo", lambda x: (x, 2), [ + take_all_multiple, + lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0) + ])) + return tuple(tests) + + @parameterized.named_parameters(*filter_functions.__func__()) + def testFilterFusion(self, map_function, predicates): + dataset = dataset_ops.Dataset.range(5).apply( + optimization.assert_next(["Map", "Filter", + "Prefetch"])).map(map_function) + for predicate in predicates: + dataset = dataset.filter(predicate) + + dataset = dataset.prefetch(0).apply( + optimization.optimize(["filter_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.test_session() as sess: + for x in range(5): + r = map_function(x) + filtered = False + for predicate in predicates: + if isinstance(r, tuple): + b = predicate(*r) # Pass tuple as multiple arguments. + else: + b = predicate(r) + if not sess.run(b): + filtered = True + break + + if not filtered: + result = sess.run(get_next) + self.assertAllEqual(r, result) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e2c9bc82dfb27c68cf780b77d43a90203af602f2 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py @@ -0,0 +1,219 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the MapVectorization optimization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from absl.testing import parameterized +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests import test_utils +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): + + def _get_test_datasets(self, + base_dataset, + map_fn, + num_parallel_calls=None, + expect_optimized=True): + """Given base dataset and map fn, creates test datasets. + + Returns a tuple of (unoptimized, dataset, optimized dataset). The + unoptimized dataset has the assertion that Batch follows Map. The optimized + dataset has the assertion that Map follows Batch, and has the + "map_vectorization" optimization applied. + + Args: + base_dataset: Input dataset to map->batch + map_fn: Map function to use + num_parallel_calls: (Optional.) num_parallel_calls argument for map + expect_optimized: (Optional.) Whether we expect the optimization to take + place, in which case we will assert that Batch is followed by Map, + otherwise Map followed by Batch. Defaults to True. + + Returns: + Tuple of (unoptimized dataset, optimized dataset). + """ + map_node_name = "Map" if num_parallel_calls is None else "ParallelMap" + batch_size = 100 + + def _make_dataset(node_names): + return base_dataset.apply(optimization.assert_next(node_names)).map( + map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size) + + unoptimized = _make_dataset([map_node_name, "Batch"]) + optimized = _make_dataset(["Batch", map_node_name] if expect_optimized else + [map_node_name, "Batch"]).apply( + optimization.optimize(["map_vectorization"])) + + return unoptimized, optimized + + @parameterized.named_parameters( + ("Basic", lambda x: (x, x + 1), None), + ("Parallel", lambda x: (x, x + 1), 12), + ("Gather", lambda x: array_ops.gather(x, 0), 12), + ) + def testOptimization(self, map_fn, num_parallel_calls): + base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], + [3, 4]]).repeat(5) + unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn, + num_parallel_calls) + self._assert_datasets_equal(unoptimized, optimized) + + def testOptimizationBadMapFn(self): + # Test map functions that give an error + def map_fn(x): + # x has leading dimension 5, this will raise an error + return array_ops.gather(x, 10) + + base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch( + 5, drop_remainder=True) + _, optimized = self._get_test_datasets(base_dataset, map_fn) + nxt = optimized.make_one_shot_iterator().get_next() + with self.assertRaisesRegexp(errors.InvalidArgumentError, + r"indices = 10 is not in \[0, 5\)"): + self.evaluate(nxt) + + def testOptimizationWithCapturedInputs(self): + # Tests that vectorization works with captured inputs + def map_fn(x): + return x + y + + y = constant_op.constant(1, shape=(2,)) + base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], + [3, 4]]).repeat(5) + # TODO(rachelim): when this optimization works, turn on expect_optimized + unoptimized, optimized = self._get_test_datasets( + base_dataset, map_fn, expect_optimized=False) + self._assert_datasets_equal(optimized, unoptimized) + + def testOptimizationIgnoreStateful(self): + + def map_fn(x): + with ops.control_dependencies([check_ops.assert_equal(x, 0)]): + return array_ops.identity(x) + + base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], + [3, 4]]).repeat(5) + unoptimized, optimized = self._get_test_datasets( + base_dataset, map_fn, expect_optimized=False) + self._assert_datasets_raise_same_error( + unoptimized, optimized, errors.InvalidArgumentError, + [("OneShotIterator", "OneShotIterator_1", 1), + ("IteratorGetNext", "IteratorGetNext_1", 1)]) + + def testOptimizationIgnoreRagged(self): + # Make sure we ignore inputs that might not be uniformly sized + def map_fn(x): + return array_ops.gather(x, 0) + + # output_shape = (?,) + base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False) + unoptimized, optimized = self._get_test_datasets( + base_dataset, map_fn, expect_optimized=False) + self._assert_datasets_equal(unoptimized, optimized) + + def testOptimizationIgnoreRaggedMap(self): + # Don't optimize when the output of the map fn shapes are unknown. + def map_fn(x): + return array_ops.tile(x, x) + + base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True) + unoptimized, optimized = self._get_test_datasets( + base_dataset, map_fn, expect_optimized=False) + self._assert_datasets_raise_same_error( + unoptimized, optimized, errors.InvalidArgumentError, + [("OneShotIterator", "OneShotIterator_1", 1), + ("IteratorGetNext", "IteratorGetNext_1", 1)]) + + +class MapVectorizationBenchmark(test.Benchmark): + # TODO(rachelim): Add a benchmark for more expensive transformations, such as + # vgg_preprocessing. + + def _run(self, x, num_iters=100, name=None): + deltas = [] + with session.Session() as sess: + for _ in range(5): + # Warm up session... + sess.run(x) + for _ in range(num_iters): + start = time.time() + sess.run(x) + end = time.time() + deltas.append(end - start) + median_time = np.median(deltas) + self.report_benchmark(iters=num_iters, wall_time=median_time, name=name) + return median_time + + def benchmark_CheapFns(self): + + input_sizes = [(10, 10, 3), (10, 100, 300)] + batch_size = 1000 + for input_size in input_sizes: + input_dataset = dataset_ops.Dataset.from_tensor_slices( + (np.random.rand(*input_size), np.random.rand(*input_size))).repeat() + for map_fn, str_id in self._get_known_cheap_fns(): + self._compare(input_dataset, map_fn, batch_size, input_size, str_id) + + def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id): + num_elems = np.prod(input_size) + name_template = "{}__batch_size_{}_input_size_{}_{}" + unoptimized = input_dataset.map(map_fn).batch(batch_size) + unoptimized_op = unoptimized.make_one_shot_iterator().get_next() + + optimized = unoptimized.apply(optimization.optimize(["map_vectorization"])) + optimized_op = optimized.make_one_shot_iterator().get_next() + + unoptimized_time = self._run( + unoptimized_op, + name=name_template.format(str_id, batch_size, num_elems, "unoptimized")) + optimized_time = self._run( + optimized_op, + name=name_template.format(str_id, batch_size, num_elems, "optimized")) + + print("Batch size: {}\n" + "Input size: {}\n" + "Transformation: {}\n" + "Speedup: {}\n".format(batch_size, input_size, str_id, + (unoptimized_time / optimized_time))) + + def _get_known_cheap_fns(self): + return [ + (lambda *args: [array_ops.identity(x) for x in args], "identity"), + (lambda *args: [x + 1 for x in args], "add_const"), + (lambda *args: args[0], "select"), + (lambda *args: [math_ops.cast(x, dtypes.float64) for x in args], + "cast"), + ] + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py index 76aa1c3cfdf4f7b4acfc67d2dcd64168ba51fb03..446bf8d7497880307270d1b1f495becdadd15684 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -19,18 +19,10 @@ from __future__ import print_function from absl.testing import parameterized -from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base -from tensorflow.contrib.data.python.kernel_tests import test_utils from tensorflow.contrib.data.python.ops import optimization -from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -109,295 +101,17 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testFunctionLibraryDefinitionModification(self): - dataset = dataset_ops.Dataset.from_tensors(0).map(lambda x: x).apply( - optimization.optimize(["_test_only_function_rename"])) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - with self.assertRaisesRegexp(errors.NotFoundError, - "Function .* is not defined."): - sess.run(get_next) - - @staticmethod - def map_functions(): - identity = lambda x: x - increment = lambda x: x + 1 - - def increment_and_square(x): - y = x + 1 - return y * y - - functions = [identity, increment, increment_and_square] - tests = [] - for i, fun1 in enumerate(functions): - for j, fun2 in enumerate(functions): - tests.append(( - "test_{}_{}".format(i, j), - [fun1, fun2], - )) - for k, fun3 in enumerate(functions): - tests.append(( - "test_{}_{}_{}".format(i, j, k), - [fun1, fun2, fun3], - )) - - swap = lambda x, n: (n, x) - tests.append(( - "swap1", - [lambda x: (x, 42), swap], - )) - tests.append(( - "swap2", - [lambda x: (x, 42), swap, swap], - )) - return tuple(tests) - - @parameterized.named_parameters(*map_functions.__func__()) - def testMapFusion(self, functions): - dataset = dataset_ops.Dataset.range(5).apply( - optimization.assert_next(["Map", "Prefetch"])) - for function in functions: - dataset = dataset.map(function) - - dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"])) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - with self.test_session() as sess: - for x in range(5): - result = sess.run(get_next) - r = x - for function in functions: - if isinstance(r, tuple): - r = function(*r) # Pass tuple as multiple arguments. - else: - r = function(r) - self.assertAllEqual(r, result) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - @staticmethod - def map_and_filter_functions(): - identity = lambda x: x - increment = lambda x: x + 1 - minus_five = lambda x: x - 5 - - def increment_and_square(x): - y = x + 1 - return y * y - - take_all = lambda x: constant_op.constant(True) - is_zero = lambda x: math_ops.equal(x, 0) - is_odd = lambda x: math_ops.equal(x % 2, 0) - greater = lambda x: math_ops.greater(x + 5, 0) - - functions = [identity, increment, minus_five, increment_and_square] - filters = [take_all, is_zero, is_odd, greater] - tests = [] - - for x, fun in enumerate(functions): - for y, predicate in enumerate(filters): - tests.append(("mixed_{}_{}".format(x, y), fun, predicate)) - - # Multi output - tests.append(("multiOne", lambda x: (x, x), - lambda x, y: constant_op.constant(True))) - tests.append( - ("multiTwo", lambda x: (x, 2), - lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0))) - return tuple(tests) - - @parameterized.named_parameters(*map_and_filter_functions.__func__()) - def testMapFilterFusion(self, function, predicate): + def testStatefulFunctionOptimization(self): dataset = dataset_ops.Dataset.range(10).apply( - optimization.assert_next( - ["Map", - "FilterByLastComponent"])).map(function).filter(predicate).apply( - optimization.optimize(["map_and_filter_fusion"])) - self._testMapAndFilter(dataset, function, predicate) - - def _testMapAndFilter(self, dataset, function, predicate): + optimization.assert_next([ + "MapAndBatch" + ])).map(lambda _: random_ops.random_uniform([])).batch(10).apply( + optimization.optimize(["map_and_batch_fusion"])) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() - with self.test_session() as sess: - for x in range(10): - r = function(x) - if isinstance(r, tuple): - b = predicate(*r) # Pass tuple as multiple arguments. - else: - b = predicate(r) - if sess.run(b): - result = sess.run(get_next) - self.assertAllEqual(r, result) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testAdditionalInputs(self): - a = constant_op.constant(3, dtype=dtypes.int64) - b = constant_op.constant(4, dtype=dtypes.int64) - some_tensor = math_ops.mul(a, b) - function = lambda x: x * x - - def predicate(y): - return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor) - - # We are currently not supporting functions with additional inputs. - dataset = dataset_ops.Dataset.range(10).apply( - optimization.assert_next( - ["Map", "Filter"])).map(function).filter(predicate).apply( - optimization.optimize(["map_and_filter_fusion"])) - - self._testMapAndFilter(dataset, function, predicate) - - -class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): - - def testLatencyStatsOptimization(self): - - stats_aggregator = stats_ops.StatsAggregator() - dataset = dataset_ops.Dataset.from_tensors(1).apply( - optimization.assert_next( - ["LatencyStats", "Map", "LatencyStats", "Prefetch", - "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply( - optimization.optimize(["latency_all_edges"])).apply( - stats_ops.set_stats_aggregator(stats_aggregator)) - iterator = dataset.make_initializable_iterator() - get_next = iterator.get_next() - summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run(iterator.initializer) - self.assertEqual(1 * 1, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - summary_str = sess.run(summary_t) - self._assertSummaryHasCount(summary_str, - "record_latency_TensorDataset/_1", 1) - self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4", - 1) - self._assertSummaryHasCount(summary_str, - "record_latency_PrefetchDataset/_6", 1) - - -class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): - - def _get_test_datasets(self, - base_dataset, - map_fn, - num_parallel_calls=None, - expect_optimized=True): - """Given base dataset and map fn, creates test datasets. - - Returns a tuple of (unoptimized, dataset, optimized dataset). The - unoptimized dataset has the assertion that Batch follows Map. The optimized - dataset has the assertion that Map follows Batch, and has the - "map_vectorization" optimization applied. - - Args: - base_dataset: Input dataset to map->batch - map_fn: Map function to use - num_parallel_calls: (Optional.) num_parallel_calls argument for map - expect_optimized: (Optional.) Whether we expect the optimization to take - place, in which case we will assert that Batch is followed by Map, - otherwise Map followed by Batch. Defaults to True. - - Returns: - Tuple of (unoptimized dataset, optimized dataset). - """ - map_node_name = "Map" if num_parallel_calls is None else "ParallelMap" - batch_size = 100 - - def _make_dataset(node_names): - return base_dataset.apply(optimization.assert_next(node_names)).map( - map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size) - - unoptimized = _make_dataset([map_node_name, "Batch"]) - optimized = _make_dataset(["Batch", map_node_name] if expect_optimized else - [map_node_name, "Batch"]).apply( - optimization.optimize(["map_vectorization"])) - - return unoptimized, optimized - - @parameterized.named_parameters( - ("Basic", lambda x: (x, x + 1), None), - ("Parallel", lambda x: (x, x + 1), 12), - ("Gather", lambda x: array_ops.gather(x, 0), 12), - ) - def testOptimization(self, map_fn, num_parallel_calls): - base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], - [3, 4]]).repeat(5) - unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn, - num_parallel_calls) - self._assert_datasets_equal(unoptimized, optimized) - - def testOptimizationBadMapFn(self): - # Test map functions that give an error - def map_fn(x): - # x has leading dimension 5, this will raise an error - return array_ops.gather(x, 10) - - base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch( - 5, drop_remainder=True) - _, optimized = self._get_test_datasets(base_dataset, map_fn) - nxt = optimized.make_one_shot_iterator().get_next() - with self.assertRaisesRegexp(errors.InvalidArgumentError, - r"indices = 10 is not in \[0, 5\)"): - self.evaluate(nxt) - - def testOptimizationWithCapturedInputs(self): - # Tests that vectorization works with captured inputs - def map_fn(x): - return x + y - - y = constant_op.constant(1, shape=(2,)) - base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], - [3, 4]]).repeat(5) - # TODO(rachelim): when this optimization works, turn on expect_optimized - unoptimized, optimized = self._get_test_datasets( - base_dataset, map_fn, expect_optimized=False) - self._assert_datasets_equal(optimized, unoptimized) - - def testOptimizationIgnoreStateful(self): - - def map_fn(x): - with ops.control_dependencies([check_ops.assert_equal(x, 0)]): - return array_ops.identity(x) - - base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2], - [3, 4]]).repeat(5) - _, optimized = self._get_test_datasets( - base_dataset, map_fn, expect_optimized=False) - nxt = optimized.make_one_shot_iterator().get_next() - - # NOTE: Right now, it raises an error because we can't save datasets that - # are stateful, and we rely on this saving mechanism to optimize datasets, - # so stateful functions can't be optimized. - with self.assertRaisesRegexp(errors.InvalidArgumentError, "[Ss]tateful"): - self.evaluate(nxt) - - def testOptimizationIgnoreRagged(self): - # Make sure we ignore inputs that might not be uniformly sized - def map_fn(x): - return array_ops.gather(x, 0) - - # output_shape = (?,) - base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False) - unoptimized, optimized = self._get_test_datasets( - base_dataset, map_fn, expect_optimized=False) - self._assert_datasets_equal(unoptimized, optimized) - - def testOptimizationIgnoreRaggedMap(self): - # Don't optimize when the output of the map fn shapes are unknown. - def map_fn(x): - return array_ops.tile(x, x) - - base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True) - unoptimized, optimized = self._get_test_datasets( - base_dataset, map_fn, expect_optimized=False) - self._assert_datasets_raise_same_error(unoptimized, optimized, - errors.InvalidArgumentError) + sess.run(get_next) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f6c4a984b8608b408bc1b1bb4a712ef1c3792696 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py @@ -0,0 +1,850 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.parsing_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +import numpy as np + +from tensorflow.contrib.data.python.ops import parsing_ops as contrib_parsing_ops +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + +# Helpers for creating Example objects +example = example_pb2.Example +feature = feature_pb2.Feature +features = lambda d: feature_pb2.Features(feature=d) +bytes_feature = lambda v: feature(bytes_list=feature_pb2.BytesList(value=v)) +int64_feature = lambda v: feature(int64_list=feature_pb2.Int64List(value=v)) +float_feature = lambda v: feature(float_list=feature_pb2.FloatList(value=v)) +# Helpers for creating SequenceExample objects +feature_list = lambda l: feature_pb2.FeatureList(feature=l) +feature_lists = lambda d: feature_pb2.FeatureLists(feature_list=d) +sequence_example = example_pb2.SequenceExample + + +def _compare_output_to_expected(tester, dict_tensors, expected_tensors, + flat_output): + tester.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys())) + + i = 0 # Index into the flattened output of session.run() + for k, v in sorted(dict_tensors.items()): + # TODO(shivaniagrawal): flat_output is same as v. + expected_v = expected_tensors[k] + tf_logging.info("Comparing key: %s", k) + print("i", i, "flat_output", flat_output[i], "expected_v", expected_v) + if sparse_tensor.is_sparse(v): + # Three outputs for SparseTensor : indices, values, shape. + tester.assertEqual([k, len(expected_v)], [k, 3]) + print("i", i, "flat_output", flat_output[i].indices, "expected_v", + expected_v[0]) + tester.assertAllEqual(expected_v[0], flat_output[i].indices) + tester.assertAllEqual(expected_v[1], flat_output[i].values) + tester.assertAllEqual(expected_v[2], flat_output[i].dense_shape) + else: + # One output for standard Tensor. + tester.assertAllEqual(expected_v, flat_output[i]) + i += 1 + + +class ParseExampleTest(test.TestCase): + + def _test(self, + input_tensor, + feature_val, + expected_values=None, + expected_err=None): + + with self.test_session() as sess: + if expected_err: + with self.assertRaisesWithPredicateMatch(expected_err[0], + expected_err[1]): + dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply( + contrib_parsing_ops.parse_example_dataset(feature_val)) + get_next = dataset.make_one_shot_iterator().get_next() + sess.run(get_next) + return + else: + # Returns dict w/ Tensors and SparseTensors. + # Check values. + dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply( + contrib_parsing_ops.parse_example_dataset(feature_val)) + get_next = dataset.make_one_shot_iterator().get_next() + result = sess.run(get_next) + flattened = nest.flatten(result) + print("result", result, "expected_values", expected_values) + _compare_output_to_expected(self, result, expected_values, flattened) + + # Check shapes; if serialized is a Tensor we need its size to + # properly check. + batch_size = ( + input_tensor.eval().size if isinstance(input_tensor, ops.Tensor) else + np.asarray(input_tensor).size) + for k, f in feature_val.items(): + print("output_shapes as list ", + tuple(dataset.output_shapes[k].as_list())) + if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None: + self.assertEqual(dataset.output_shapes[k].as_list()[0], batch_size) + elif isinstance(f, parsing_ops.VarLenFeature): + self.assertEqual(dataset.output_shapes[k].as_list()[1], None) + + def testEmptySerializedWithAllDefaults(self): + sparse_name = "st_a" + a_name = "a" + b_name = "b" + c_name = "c:has_a_tricky_name" + a_default = [0, 42, 0] + b_default = np.random.rand(3, 3).astype(bytes) + c_default = np.random.rand(2).astype(np.float32) + + expected_st_a = ( # indices, values, shape + np.empty( + (0, 2), dtype=np.int64), # indices + np.empty( + (0,), dtype=np.int64), # sp_a is DT_INT64 + np.array( + [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0 + + expected_output = { + sparse_name: expected_st_a, + a_name: np.array(2 * [[a_default]]), + b_name: np.array(2 * [b_default]), + c_name: np.array(2 * [c_default]), + } + + self._test( + ops.convert_to_tensor(["", ""]), { + sparse_name: + parsing_ops.VarLenFeature(dtypes.int64), + a_name: + parsing_ops.FixedLenFeature( + (1, 3), dtypes.int64, default_value=a_default), + b_name: + parsing_ops.FixedLenFeature( + (3, 3), dtypes.string, default_value=b_default), + c_name: + parsing_ops.FixedLenFeature( + (2,), dtypes.float32, default_value=c_default), + }, + expected_values=expected_output) + + def testEmptySerializedWithoutDefaultsShouldFail(self): + input_features = { + "st_a": + parsing_ops.VarLenFeature(dtypes.int64), + "a": + parsing_ops.FixedLenFeature( + (1, 3), dtypes.int64, default_value=[0, 42, 0]), + "b": + parsing_ops.FixedLenFeature( + (3, 3), + dtypes.string, + default_value=np.random.rand(3, 3).astype(bytes)), + # Feature "c" is missing a default, this gap will cause failure. + "c": + parsing_ops.FixedLenFeature( + (2,), dtype=dtypes.float32), + } + + # Edge case where the key is there but the feature value is empty + original = example(features=features({"c": feature()})) + self._test( + [original.SerializeToString()], + input_features, + expected_err=(errors_impl.InvalidArgumentError, + "Feature: c \\(data type: float\\) is required")) + + # Standard case of missing key and value. + self._test( + ["", ""], + input_features, + expected_err=(errors_impl.InvalidArgumentError, + "Feature: c \\(data type: float\\) is required")) + + def testDenseNotMatchingShapeShouldFail(self): + original = [ + example(features=features({ + "a": float_feature([1, 1, 3]), + })), example(features=features({ + "a": float_feature([-1, -1]), + })) + ] + + serialized = [m.SerializeToString() for m in original] + + self._test( + ops.convert_to_tensor(serialized), + {"a": parsing_ops.FixedLenFeature((1, 3), dtypes.float32)}, + expected_err=(errors_impl.InvalidArgumentError, + "Key: a, Index: 1. Number of float values")) + + def testDenseDefaultNoShapeShouldFail(self): + original = [example(features=features({"a": float_feature([1, 1, 3]),})),] + + serialized = [m.SerializeToString() for m in original] + + self._test( + ops.convert_to_tensor(serialized), + {"a": parsing_ops.FixedLenFeature(None, dtypes.float32)}, + expected_err=(ValueError, "Missing shape for feature a")) + + def testSerializedContainingSparse(self): + original = [ + example(features=features({ + "st_c": float_feature([3, 4]) + })), + example(features=features({ + "st_c": float_feature([]), # empty float list + })), + example(features=features({ + "st_d": feature(), # feature with nothing in it + })), + example(features=features({ + "st_c": float_feature([1, 2, -1]), + "st_d": bytes_feature([b"hi"]) + })) + ] + + serialized = [m.SerializeToString() for m in original] + + expected_st_c = ( # indices, values, shape + np.array( + [[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64), np.array( + [3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32), np.array( + [4, 3], dtype=np.int64)) # batch == 2, max_elems = 3 + + expected_st_d = ( # indices, values, shape + np.array( + [[3, 0]], dtype=np.int64), np.array( + ["hi"], dtype=bytes), np.array( + [4, 1], dtype=np.int64)) # batch == 2, max_elems = 1 + + expected_output = { + "st_c": expected_st_c, + "st_d": expected_st_d, + } + + self._test( + ops.convert_to_tensor(serialized), { + "st_c": parsing_ops.VarLenFeature(dtypes.float32), + "st_d": parsing_ops.VarLenFeature(dtypes.string) + }, + expected_values=expected_output) + + def testSerializedContainingSparseFeature(self): + original = [ + example(features=features({ + "val": float_feature([3, 4]), + "idx": int64_feature([5, 10]) + })), + example(features=features({ + "val": float_feature([]), # empty float list + "idx": int64_feature([]) + })), + example(features=features({ + "val": feature(), # feature with nothing in it + # missing idx feature + })), + example(features=features({ + "val": float_feature([1, 2, -1]), + "idx": + int64_feature([0, 9, 3]) # unsorted + })) + ] + + serialized = [m.SerializeToString() for m in original] + + expected_sp = ( # indices, values, shape + np.array( + [[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64), + np.array( + [3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), np.array( + [4, 13], dtype=np.int64)) # batch == 4, max_elems = 13 + + expected_output = {"sp": expected_sp,} + + self._test( + ops.convert_to_tensor(serialized), + {"sp": parsing_ops.SparseFeature(["idx"], "val", dtypes.float32, [13])}, + expected_values=expected_output) + + def testSerializedContainingSparseFeatureReuse(self): + original = [ + example(features=features({ + "val1": float_feature([3, 4]), + "val2": float_feature([5, 6]), + "idx": int64_feature([5, 10]) + })), + example(features=features({ + "val1": float_feature([]), # empty float list + "idx": int64_feature([]) + })), + ] + + serialized = [m.SerializeToString() for m in original] + + expected_sp1 = ( # indices, values, shape + np.array( + [[0, 5], [0, 10]], dtype=np.int64), np.array( + [3.0, 4.0], dtype=np.float32), np.array( + [2, 13], dtype=np.int64)) # batch == 2, max_elems = 13 + + expected_sp2 = ( # indices, values, shape + np.array( + [[0, 5], [0, 10]], dtype=np.int64), np.array( + [5.0, 6.0], dtype=np.float32), np.array( + [2, 7], dtype=np.int64)) # batch == 2, max_elems = 13 + + expected_output = { + "sp1": expected_sp1, + "sp2": expected_sp2, + } + + self._test( + ops.convert_to_tensor(serialized), { + "sp1": + parsing_ops.SparseFeature("idx", "val1", dtypes.float32, 13), + "sp2": + parsing_ops.SparseFeature( + "idx", "val2", dtypes.float32, size=7, already_sorted=True) + }, + expected_values=expected_output) + + def testSerializedContaining3DSparseFeature(self): + original = [ + example(features=features({ + "val": float_feature([3, 4]), + "idx0": int64_feature([5, 10]), + "idx1": int64_feature([0, 2]), + })), + example(features=features({ + "val": float_feature([]), # empty float list + "idx0": int64_feature([]), + "idx1": int64_feature([]), + })), + example(features=features({ + "val": feature(), # feature with nothing in it + # missing idx feature + })), + example(features=features({ + "val": float_feature([1, 2, -1]), + "idx0": int64_feature([0, 9, 3]), # unsorted + "idx1": int64_feature([1, 0, 2]), + })) + ] + + serialized = [m.SerializeToString() for m in original] + + expected_sp = ( + # indices + np.array( + [[0, 5, 0], [0, 10, 2], [3, 0, 1], [3, 3, 2], [3, 9, 0]], + dtype=np.int64), + # values + np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), + # shape batch == 4, max_elems = 13 + np.array([4, 13, 3], dtype=np.int64)) + + expected_output = {"sp": expected_sp,} + + self._test( + ops.convert_to_tensor(serialized), { + "sp": + parsing_ops.SparseFeature(["idx0", "idx1"], "val", + dtypes.float32, [13, 3]) + }, + expected_values=expected_output) + + def testSerializedContainingDense(self): + aname = "a" + bname = "b*has+a:tricky_name" + original = [ + example(features=features({ + aname: float_feature([1, 1]), + bname: bytes_feature([b"b0_str"]), + })), example(features=features({ + aname: float_feature([-1, -1]), + bname: bytes_feature([b""]), + })) + ] + + serialized = [m.SerializeToString() for m in original] + + expected_output = { + aname: + np.array( + [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1), + bname: + np.array( + ["b0_str", ""], dtype=bytes).reshape(2, 1, 1, 1, 1), + } + + # No defaults, values required + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32), + bname: + parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string), + }, + expected_values=expected_output) + + # This test is identical as the previous one except + # for the creation of 'serialized'. + def testSerializedContainingDenseWithConcat(self): + aname = "a" + bname = "b*has+a:tricky_name" + # TODO(lew): Feature appearing twice should be an error in future. + original = [ + (example(features=features({ + aname: float_feature([10, 10]), + })), example(features=features({ + aname: float_feature([1, 1]), + bname: bytes_feature([b"b0_str"]), + }))), + ( + example(features=features({ + bname: bytes_feature([b"b100"]), + })), + example(features=features({ + aname: float_feature([-1, -1]), + bname: bytes_feature([b"b1"]), + })),), + ] + + serialized = [ + m.SerializeToString() + n.SerializeToString() for (m, n) in original + ] + + expected_output = { + aname: + np.array( + [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1), + bname: + np.array( + ["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1), + } + + # No defaults, values required + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32), + bname: + parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string), + }, + expected_values=expected_output) + + def testSerializedContainingDenseScalar(self): + original = [ + example(features=features({ + "a": float_feature([1]), + })), example(features=features({})) + ] + + serialized = [m.SerializeToString() for m in original] + + expected_output = { + "a": + np.array( + [[1], [-1]], dtype=np.float32) # 2x1 (column vector) + } + + self._test( + ops.convert_to_tensor(serialized), { + "a": + parsing_ops.FixedLenFeature( + (1,), dtype=dtypes.float32, default_value=-1), + }, + expected_values=expected_output) + + def testSerializedContainingDenseWithDefaults(self): + original = [ + example(features=features({ + "a": float_feature([1, 1]), + })), + example(features=features({ + "b": bytes_feature([b"b1"]), + })), + example(features=features({ + "b": feature() + })), + ] + + serialized = [m.SerializeToString() for m in original] + + expected_output = { + "a": + np.array( + [[1, 1], [3, -3], [3, -3]], dtype=np.float32).reshape(3, 1, 2, + 1), + "b": + np.array( + ["tmp_str", "b1", "tmp_str"], dtype=bytes).reshape(3, 1, 1, 1, + 1), + } + + self._test( + ops.convert_to_tensor(serialized), { + "a": + parsing_ops.FixedLenFeature( + (1, 2, 1), dtype=dtypes.float32, default_value=[3.0, -3.0]), + "b": + parsing_ops.FixedLenFeature( + (1, 1, 1, 1), dtype=dtypes.string, default_value="tmp_str"), + }, + expected_values=expected_output) + + def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self): + expected_st_a = ( # indices, values, shape + np.empty( + (0, 2), dtype=np.int64), # indices + np.empty( + (0,), dtype=np.int64), # sp_a is DT_INT64 + np.array( + [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0 + expected_sp = ( # indices, values, shape + np.array( + [[0, 0], [0, 3], [1, 7]], dtype=np.int64), np.array( + ["a", "b", "c"], dtype="|S"), np.array( + [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13 + + original = [ + example(features=features({ + "c": float_feature([3, 4]), + "val": bytes_feature([b"a", b"b"]), + "idx": int64_feature([0, 3]) + })), example(features=features({ + "c": float_feature([1, 2]), + "val": bytes_feature([b"c"]), + "idx": int64_feature([7]) + })) + ] + + serialized = [m.SerializeToString() for m in original] + + a_default = [1, 2, 3] + b_default = np.random.rand(3, 3).astype(bytes) + expected_output = { + "st_a": expected_st_a, + "sp": expected_sp, + "a": np.array(2 * [[a_default]]), + "b": np.array(2 * [b_default]), + "c": np.array( + [[3, 4], [1, 2]], dtype=np.float32), + } + + self._test( + ops.convert_to_tensor(serialized), + { + "st_a": + parsing_ops.VarLenFeature(dtypes.int64), + "sp": + parsing_ops.SparseFeature("idx", "val", dtypes.string, 13), + "a": + parsing_ops.FixedLenFeature( + (1, 3), dtypes.int64, default_value=a_default), + "b": + parsing_ops.FixedLenFeature( + (3, 3), dtypes.string, default_value=b_default), + # Feature "c" must be provided, since it has no default_value. + "c": + parsing_ops.FixedLenFeature((2,), dtypes.float32), + }, + expected_values=expected_output) + + def testSerializedContainingSparseAndSparseFeatureWithReuse(self): + expected_idx = ( # indices, values, shape + np.array( + [[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64), + np.array([0, 3, 7, 1]), np.array( + [2, 2], dtype=np.int64)) # batch == 4, max_elems = 2 + + expected_sp = ( # indices, values, shape + np.array( + [[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64), np.array( + ["a", "b", "d", "c"], dtype="|S"), np.array( + [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13 + + original = [ + example(features=features({ + "val": bytes_feature([b"a", b"b"]), + "idx": int64_feature([0, 3]) + })), example(features=features({ + "val": bytes_feature([b"c", b"d"]), + "idx": int64_feature([7, 1]) + })) + ] + + serialized = [m.SerializeToString() for m in original] + + expected_output = { + "idx": expected_idx, + "sp": expected_sp, + } + + self._test( + ops.convert_to_tensor(serialized), { + "idx": + parsing_ops.VarLenFeature(dtypes.int64), + "sp": + parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]), + }, + expected_values=expected_output) + + def _testSerializedContainingVarLenDenseLargerBatch(self, batch_size): + # During parsing, data read from the serialized proto is stored in buffers. + # For small batch sizes, a buffer will contain one minibatch entry. + # For larger batch sizes, a buffer may contain several minibatch + # entries. This test identified a bug where the code that copied + # data out of the buffers and into the output tensors assumed each + # buffer only contained one minibatch entry. The bug has since been fixed. + truth_int = [i for i in range(batch_size)] + truth_str = [[("foo%d" % i).encode(), ("bar%d" % i).encode()] + for i in range(batch_size)] + + expected_str = copy.deepcopy(truth_str) + + # Delete some intermediate entries + for i in range(batch_size): + col = 1 + if np.random.rand() < 0.25: + # w.p. 25%, drop out the second entry + expected_str[i][col] = b"default" + col -= 1 + truth_str[i].pop() + if np.random.rand() < 0.25: + # w.p. 25%, drop out the second entry (possibly again) + expected_str[i][col] = b"default" + truth_str[i].pop() + + expected_output = { + # Batch size batch_size, 1 time step. + "a": np.array(truth_int, dtype=np.int64).reshape(batch_size, 1), + # Batch size batch_size, 2 time steps. + "b": np.array(expected_str, dtype="|S").reshape(batch_size, 2), + } + + original = [ + example(features=features( + {"a": int64_feature([truth_int[i]]), + "b": bytes_feature(truth_str[i])})) + for i in range(batch_size) + ] + + serialized = [m.SerializeToString() for m in original] + + self._test( + ops.convert_to_tensor(serialized, dtype=dtypes.string), { + "a": + parsing_ops.FixedLenSequenceFeature( + shape=(), + dtype=dtypes.int64, + allow_missing=True, + default_value=-1), + "b": + parsing_ops.FixedLenSequenceFeature( + shape=[], + dtype=dtypes.string, + allow_missing=True, + default_value="default"), + }, + expected_values=expected_output) + + def testSerializedContainingVarLenDenseLargerBatch(self): + np.random.seed(3456) + for batch_size in (1, 10, 20, 100, 256): + self._testSerializedContainingVarLenDenseLargerBatch(batch_size) + + def testSerializedContainingVarLenDense(self): + aname = "a" + bname = "b" + cname = "c" + dname = "d" + original = [ + example(features=features({ + cname: int64_feature([2]), + })), + example(features=features({ + aname: float_feature([1, 1]), + bname: bytes_feature([b"b0_str", b"b1_str"]), + })), + example(features=features({ + aname: float_feature([-1, -1, 2, 2]), + bname: bytes_feature([b"b1"]), + })), + example(features=features({ + aname: float_feature([]), + cname: int64_feature([3]), + })), + ] + + serialized = [m.SerializeToString() for m in original] + + expected_output = { + aname: + np.array( + [ + [0, 0, 0, 0], + [1, 1, 0, 0], + [-1, -1, 2, 2], + [0, 0, 0, 0], + ], + dtype=np.float32).reshape(4, 2, 2, 1), + bname: + np.array( + [["", ""], ["b0_str", "b1_str"], ["b1", ""], ["", ""]], + dtype=bytes).reshape(4, 2, 1, 1, 1), + cname: + np.array([2, 0, 0, 3], dtype=np.int64).reshape(4, 1), + dname: + np.empty(shape=(4, 0), dtype=bytes), + } + + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenSequenceFeature( + (2, 1), dtype=dtypes.float32, allow_missing=True), + bname: + parsing_ops.FixedLenSequenceFeature( + (1, 1, 1), dtype=dtypes.string, allow_missing=True), + cname: + parsing_ops.FixedLenSequenceFeature( + shape=[], dtype=dtypes.int64, allow_missing=True), + dname: + parsing_ops.FixedLenSequenceFeature( + shape=[], dtype=dtypes.string, allow_missing=True), + }, + expected_values=expected_output) + + # Test with padding values. + expected_output_custom_padding = dict(expected_output) + expected_output_custom_padding[aname] = np.array( + [ + [-2, -2, -2, -2], + [1, 1, -2, -2], + [-1, -1, 2, 2], + [-2, -2, -2, -2], + ], + dtype=np.float32).reshape(4, 2, 2, 1) + + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenSequenceFeature( + (2, 1), + dtype=dtypes.float32, + allow_missing=True, + default_value=-2.0), + bname: + parsing_ops.FixedLenSequenceFeature( + (1, 1, 1), dtype=dtypes.string, allow_missing=True), + cname: + parsing_ops.FixedLenSequenceFeature( + shape=[], dtype=dtypes.int64, allow_missing=True), + dname: + parsing_ops.FixedLenSequenceFeature( + shape=[], dtype=dtypes.string, allow_missing=True), + }, expected_output_custom_padding) + + # Change number of required values so the inputs are not a + # multiple of this size. + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenSequenceFeature( + (2, 1), dtype=dtypes.float32, allow_missing=True), + bname: + parsing_ops.FixedLenSequenceFeature( + (2, 1, 1), dtype=dtypes.string, allow_missing=True), + }, + expected_err=( + errors_impl.OpError, "Key: b, Index: 2. " + "Number of bytes values is not a multiple of stride length.")) + + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenSequenceFeature( + (2, 1), + dtype=dtypes.float32, + allow_missing=True, + default_value=[]), + bname: + parsing_ops.FixedLenSequenceFeature( + (2, 1, 1), dtype=dtypes.string, allow_missing=True), + }, + expected_err=(ValueError, + "Cannot reshape a tensor with 0 elements to shape")) + + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenFeature((None, 2, 1), dtype=dtypes.float32), + bname: + parsing_ops.FixedLenSequenceFeature( + (2, 1, 1), dtype=dtypes.string, allow_missing=True), + }, + expected_err=(ValueError, + "First dimension of shape for feature a unknown. " + "Consider using FixedLenSequenceFeature.")) + + self._test( + ops.convert_to_tensor(serialized), { + cname: + parsing_ops.FixedLenFeature( + (1, None), dtype=dtypes.int64, default_value=[[1]]), + }, + expected_err=(ValueError, + "All dimensions of shape for feature c need to be known " + r"but received \(1, None\).")) + + self._test( + ops.convert_to_tensor(serialized), { + aname: + parsing_ops.FixedLenSequenceFeature( + (2, 1), dtype=dtypes.float32, allow_missing=True), + bname: + parsing_ops.FixedLenSequenceFeature( + (1, 1, 1), dtype=dtypes.string, allow_missing=True), + cname: + parsing_ops.FixedLenSequenceFeature( + shape=[], dtype=dtypes.int64, allow_missing=False), + dname: + parsing_ops.FixedLenSequenceFeature( + shape=[], dtype=dtypes.string, allow_missing=True), + }, + expected_err=(ValueError, + "Unsupported: FixedLenSequenceFeature requires " + "allow_missing to be True.")) + + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD index 7b9ea191a4524891d1b589e1e228e29241fda7f8..4881f63ab96cb4797e6e071bf3e310c73bc85f3d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD @@ -317,6 +317,19 @@ py_test( ], ) +py_test( + name = "parse_example_dataset_serialization_test", + size = "medium", + srcs = ["parse_example_dataset_serialization_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test_base", + "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "prefetch_dataset_serialization_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py index 9fdbcb66bf7f361c2b3e14ed7162b853fceaf88e..595cecef4de488d795cd9e5ebb433636026e51fc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py @@ -510,7 +510,6 @@ class DatasetSerializationTestBase(test.TestCase): else: init_op, get_next_op, saver = self._build_graph( ds_fn, sparse_tensors=sparse_tensors) - get_next_op = remove_variants(get_next_op) return init_op, get_next_op, saver for i in range(len(break_points) + 1): @@ -616,29 +615,40 @@ class DatasetSerializationTestBase(test.TestCase): # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections # do not support tuples we flatten the tensors and restore the shape in # `_get_iterator_ops_from_collection`. - - # TODO(shivaniagrwal): `output_classes` is a nested structure of classes, - # this base class is specific to current test cases. Update when tests are - # added with `output_classes` as a nested structure with at least one of the - # component being `tf.SparseTensor`. - if (sparse_tensors or - self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor): + if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. ops.add_to_collection("iterator_ops", get_next.indices) ops.add_to_collection("iterator_ops", get_next.values) ops.add_to_collection("iterator_ops", get_next.dense_shape) - else: - for el in nest.flatten(get_next): - ops.add_to_collection("iterator_ops", el) + return + + get_next_list = nest.flatten(get_next) + for i, output_class in enumerate( + nest.flatten(self._get_output_classes(ds_fn))): + if output_class is sparse_tensor.SparseTensor: + ops.add_to_collection("iterator_ops", get_next_list[i].indices) + ops.add_to_collection("iterator_ops", get_next_list[i].values) + ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape) + else: + ops.add_to_collection("iterator_ops", get_next_list[i]) def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False): all_ops = ops.get_collection("iterator_ops") - if (sparse_tensors or - self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor): + if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`. init_op, indices, values, dense_shape = all_ops return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape) - else: - return all_ops[0], nest.pack_sequence_as( - self._get_output_types(ds_fn), all_ops[1:]) + get_next_list = [] + i = 1 + for output_class in nest.flatten(self._get_output_classes(ds_fn)): + if output_class is sparse_tensor.SparseTensor: + indices, values, dense_shape = all_ops[i:i + 3] + i += 3 + get_next_list.append( + sparse_tensor.SparseTensor(indices, values, dense_shape)) + else: + get_next_list.append(all_ops[i]) + i += 1 + return all_ops[0], nest.pack_sequence_as( + self._get_output_types(ds_fn), get_next_list) def _get_output_types(self, ds_fn): with ops.Graph().as_default(): diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d3fa84e74cf25cd82014e459b3a2ee0bff5602e3 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py @@ -0,0 +1,50 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the ParseExampleDataset serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base +from tensorflow.python.platform import test + + +class ParseExampleDatasetSerializationTest( + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def ParseExampleDataset(self, num_repeat, batch_size): + return self.make_batch_feature( + filenames=self.test_filenames, + num_epochs=num_repeat, + batch_size=batch_size, + reader_num_threads=5, + parser_num_threads=10) + + def testSerializationCore(self): + num_repeat = 5 + batch_size = 2 + num_outputs = self._num_records * self._num_files * num_repeat // batch_size + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self.ParseExampleDataset( + num_repeat=num_repeat, batch_size=batch_size), + lambda: self.ParseExampleDataset(num_repeat=10, batch_size=4), + num_outputs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index a41d21f8c14ed6bec7626599a5aa7f365765ce8b..53c22628c79b22d9bb02e884ef51db00e7d76bf3 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -190,7 +190,7 @@ class FeatureStatsDatasetTest( batch_size=batch_size, shuffle=True, shuffle_seed=5, - drop_final_batch=True).apply( + drop_final_batch=False).apply( stats_ops.set_stats_aggregator(stats_aggregator)) iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() @@ -198,7 +198,8 @@ class FeatureStatsDatasetTest( with self.test_session() as sess: sess.run(iterator.initializer) - for _ in range(total_records // batch_size): + for _ in range(total_records // batch_size + 1 if total_records % + batch_size else total_records // batch_size): sess.run(next_element) with self.assertRaises(errors.OutOfRangeError): diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py index 1b962b3418a7195f927fe79c949383a475108e0a..1d70b16041e902a5d08383887cbf647eac2e816c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py +++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import re + from tensorflow.python.data.util import nest from tensorflow.python.framework import errors from tensorflow.python.platform import test @@ -45,7 +47,11 @@ class DatasetTestBase(test.TestCase): for i in range(len(op1)): self.assertAllEqual(op1[i], op2[i]) - def _assert_datasets_raise_same_error(self, dataset1, dataset2, exc_class): + def _assert_datasets_raise_same_error(self, + dataset1, + dataset2, + exception_class, + replacements=None): next1 = dataset1.make_one_shot_iterator().get_next() next2 = dataset2.make_one_shot_iterator().get_next() with self.test_session() as sess: @@ -53,8 +59,12 @@ class DatasetTestBase(test.TestCase): sess.run(next1) raise ValueError( "Expected dataset to raise an error of type %s, but it did not." % - repr(exc_class)) - except exc_class as e: + repr(exception_class)) + except exception_class as e: + expected_message = e.message + for old, new, count in replacements: + expected_message = expected_message.replace(old, new, count) # Check that the first segment of the error messages are the same. - with self.assertRaisesRegexp(exc_class, e.message.split(". ")[0]): + with self.assertRaisesRegexp(exception_class, + re.escape(expected_message)): sess.run(next2) diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index ad9378dfb9d938c826f994da9bbb89101cfbd872..4b45cc7e36d14e99d1132b919dfc175a1217f8b9 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -80,17 +80,14 @@ py_library( ":batching", ":gen_dataset_ops", ":interleave_ops", + ":parsing_ops", ":shuffle_ops", - ":stats_ops", "//tensorflow/python:constant_op", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", - "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", @@ -210,6 +207,22 @@ py_library( ], ) +py_library( + name = "parsing_ops", + srcs = ["parsing_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + py_library( name = "map_defun", srcs = ["map_defun.py"], @@ -331,7 +344,10 @@ py_library( tf_gen_op_wrapper_py( name = "gen_dataset_ops", out = "gen_dataset_ops.py", - deps = ["//tensorflow/contrib/data:dataset_ops_op_lib"], + deps = [ + "//tensorflow/contrib/data:dataset_ops_op_lib", + "//tensorflow/contrib/data:indexed_dataset_ops_op_lib", + ], ) tf_kernel_library( @@ -349,6 +365,7 @@ tf_custom_op_py_library( dso = ["//tensorflow/contrib/data:_dataset_ops.so"], kernels = [ ":dataset_ops_kernels", + "//tensorflow/contrib/data:indexed_dataset_ops_op_lib", "//tensorflow/contrib/data:dataset_ops_op_lib", ], srcs_version = "PY2AND3", @@ -359,6 +376,19 @@ tf_custom_op_py_library( ], ) +py_library( + name = "indexed_dataset_ops", + srcs = ["indexed_dataset_ops.py"], + deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + py_library( name = "prefetching_ops", srcs = ["prefetching_ops.py"], @@ -380,6 +410,7 @@ py_library( ":error_ops", ":get_single_element", ":grouping", + ":indexed_dataset_ops", ":interleave_ops", ":map_defun", ":optimization", diff --git a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a0932b40810972fd017230e2dfacaaddc0e1d1bf --- /dev/null +++ b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py @@ -0,0 +1,173 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrappers for indexed datasets.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape + + +class MaterializedIndexedDataset(object): + """MaterializedIndexedDataset is highly experimental! + """ + + def __init__(self, materialized_resource, materializer, output_classes, + output_types, output_shapes): + self._materialized_resource = materialized_resource + self._materializer = materializer + self._output_classes = output_classes + self._output_types = output_types + self._output_shapes = output_shapes + + @property + def initializer(self): + if self._materializer is not None: + return self._materializer + raise ValueError("MaterializedDataset does not have a materializer") + + def get(self, index): + """Get retrieves a value (or set of values) from the IndexedDataset. + + Args: + index: A uint64 scalar or vector tensor with the indices to retrieve. + + Returns: + A tensor containing the values corresponding to `index`. + """ + # TODO(saeta): nest.pack_sequence_as(...) + return gen_dataset_ops.indexed_dataset_get( + self._materialized_resource, + index, + output_types=nest.flatten( + sparse.as_dense_types(self._output_types, self._output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_types(self._output_shapes, self._output_classes))) + + +class IndexedDataset(dataset_ops.Dataset): + """IndexedDataset is highly experimental! + """ + + def __init__(self): + pass + + def materialize(self, shared_name=None, container=None): + """Materialize creates a MaterializedIndexedDataset. + + IndexedDatasets can be combined through operations such as TBD. Therefore, + they are only materialized when absolutely required. + + Args: + shared_name: a string for the shared name to use for the resource. + container: a string for the container to store the resource. + + Returns: + A MaterializedIndexedDataset. + """ + if container is None: + container = "" + if shared_name is None: + shared_name = "" + materialized_resource = gen_dataset_ops.materialized_index_dataset_handle( + container=container, + shared_name=shared_name, + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_types(self.output_shapes, self.output_classes))) + + with ops.colocate_with(materialized_resource): + materializer = gen_dataset_ops.indexed_dataset_materialize( + self._as_variant_tensor(), materialized_resource) + return MaterializedIndexedDataset(materialized_resource, materializer, + self.output_classes, self.output_types, + self.output_shapes) + + @abc.abstractproperty + def output_types(self): + """Returns the type of each component of an element of this IndexedDataset. + + Returns: + A nested structure of `tf.DType` objects corresponding to each component + of an element of this IndexedDataset. + """ + raise NotImplementedError("IndexedDataset.output_types") + + @abc.abstractproperty + def output_classes(self): + """Returns the class of each component of an element of this IndexedDataset. + + The expected values are `tf.Tensor` and `tf.SparseTensor`. + + Returns: + A nested structure of Python `type` objects corresponding to each + component of an element of this IndexedDataset. + """ + raise NotImplementedError("IndexedDataset.output_classes") + + @abc.abstractproperty + def output_shapes(self): + """Returns the shape of each component of an element of this IndexedDataset. + + Returns: + A nested structure of `tf.TensorShape` objects corresponding to each + component of an element of this IndexedDataset. + """ + raise NotImplementedError("IndexedDataset.output_shapes") + + @abc.abstractmethod + def _as_variant_tensor(self): + """Creates a `tf.variant` `tf.Tensor` representing this IndexedDataset. + + Returns: + A scalar `tf.Tensor` of `tf.variant` type, which represents this + IndexedDataset. + """ + raise NotImplementedError("IndexedDataset._as_variant_tensor") + + +class IdentityIndexedDataset(IndexedDataset): + """IdentityIndexedDataset is a trivial indexed dataset used for testing. + """ + + def __init__(self, size): + super(IdentityIndexedDataset, self).__init__() + # TODO(saeta): Verify _size is a scalar! + self._size = ops.convert_to_tensor(size, dtype=dtypes.uint64, name="size") + + @property + def output_types(self): + return dtypes.uint64 + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + def _as_variant_tensor(self): + return gen_dataset_ops.identity_indexed_dataset(self._size) diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 5a1a35199abecc3890d5733ddf678af8d4098f33..54a92ab1855f41367d25023c7f7f5dcab330d46c 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -163,7 +163,7 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset): for data_input in data_inputs[1:]: if (data_input.output_types != data_inputs[0].output_types or data_input.output_classes != data_inputs[0].output_classes): - raise TypeError("All datasets must have the same type.") + raise TypeError("All datasets must have the same type and class.") def _as_variant_tensor(self): # pylint: disable=protected-access @@ -216,25 +216,46 @@ def sample_from_datasets(datasets, weights=None, seed=None): length of the `datasets` element. """ num_datasets = len(datasets) - if weights is None: - weights = dataset_ops.Dataset.from_tensors([1.0] * num_datasets).repeat() - elif not isinstance(weights, dataset_ops.Dataset): - weights = ops.convert_to_tensor(weights, name="weights") - if weights.dtype not in (dtypes.float32, dtypes.float64): - raise TypeError("`weights` must be convertible to a tensor of " - "`tf.float32` or `tf.float64` elements.") - if not weights.shape.is_compatible_with([num_datasets]): - raise ValueError("`weights` must be a vector of length `len(datasets)`.") - weights = dataset_ops.Dataset.from_tensors(weights).repeat() - - # The `stateless_multinomial()` op expects log-probabilities, as opposed to - # weights. - logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits")) - def select_dataset(logits, seed): - return array_ops.squeeze( - stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) - selector_input = dataset_ops.Dataset.zip( - (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset) + if not isinstance(weights, dataset_ops.Dataset): + if weights is None: + # Select inputs with uniform probability. + logits = [[1.0] * num_datasets] + else: + # Use the given `weights` as the probability of choosing the respective + # input. + weights = ops.convert_to_tensor(weights, name="weights") + if weights.dtype not in (dtypes.float32, dtypes.float64): + raise TypeError("`weights` must be convertible to a tensor of " + "`tf.float32` or `tf.float64` elements.") + if not weights.shape.is_compatible_with([num_datasets]): + raise ValueError( + "`weights` must be a vector of length `len(datasets)`.") + + # The `stateless_multinomial()` op expects log-probabilities, as opposed + # to weights. + logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0) + + def select_dataset_constant_logits(seed): + return array_ops.squeeze( + stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) + + selector_input = random_ops.RandomDataset(seed).batch(2).map( + select_dataset_constant_logits) + else: + # Use each element of the given `weights` dataset as the probability of + # choosing the respective input. + + # The `stateless_multinomial()` op expects log-probabilities, as opposed to + # weights. + logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits")) + + def select_dataset_varying_logits(logits, seed): + return array_ops.squeeze( + stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) + + selector_input = dataset_ops.Dataset.zip( + (logits_ds, random_ops.RandomDataset(seed).batch(2) + )).map(select_dataset_varying_logits) return _DirectedInterleaveDataset(selector_input, datasets) diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..2701605e641b190852bb9934ce83f7fc3e90ff15 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/parsing_ops.py @@ -0,0 +1,150 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental `dataset` API for parsing example.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import parsing_ops + + +class _ParseExampleDataset(dataset_ops.Dataset): + """A `Dataset` that parses `example` dataset into a `dict` dataset.""" + + def __init__(self, input_dataset, features, num_parallel_calls): + super(_ParseExampleDataset, self).__init__() + self._input_dataset = input_dataset + if not all(types == dtypes.string + for types in nest.flatten(input_dataset.output_types)): + raise TypeError("Input dataset should be a dataset of vectors of strings") + self._num_parallel_calls = num_parallel_calls + # pylint: disable=protected-access + self._features = parsing_ops._prepend_none_dimension(features) + # sparse_keys and dense_keys come back sorted here. + (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults, + dense_shapes) = parsing_ops._features_to_raw_params( + self._features, [ + parsing_ops.VarLenFeature, parsing_ops.SparseFeature, + parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature + ]) + # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature. + (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes, + dense_shape_as_shape) = parsing_ops._process_raw_parameters( + None, dense_defaults, sparse_keys, sparse_types, dense_keys, + dense_types, dense_shapes) + # pylint: enable=protected-access + self._sparse_keys = sparse_keys + self._sparse_types = sparse_types + self._dense_keys = dense_keys + self._dense_defaults = dense_defaults_vec + self._dense_shapes = dense_shapes + self._dense_types = dense_types + dense_output_shapes = [ + self._input_dataset.output_shapes.concatenate(shape) + for shape in dense_shape_as_shape + ] + sparse_output_shapes = [ + self._input_dataset.output_shapes.concatenate([None]) + for _ in range(len(sparse_keys)) + ] + + self._output_shapes = dict( + zip(self._dense_keys + self._sparse_keys, + dense_output_shapes + sparse_output_shapes)) + self._output_types = dict( + zip(self._dense_keys + self._sparse_keys, + self._dense_types + self._sparse_types)) + self._output_classes = dict( + zip(self._dense_keys + self._sparse_keys, + [ops.Tensor for _ in range(len(self._dense_defaults))] + + [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys)) + ])) + + def _as_variant_tensor(self): + return gen_dataset_ops.parse_example_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._num_parallel_calls, + self._dense_defaults, + self._sparse_keys, + self._dense_keys, + self._sparse_types, + self._dense_shapes, + **dataset_ops.flat_structure(self)) + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + @property + def output_classes(self): + return self._output_classes + + +# TODO(b/111553342): add arguments names and example names as well. +def parse_example_dataset(features, num_parallel_calls=1): + """A transformation that parses `Example` protos into a `dict` of tensors. + + Parses a number of serialized `Example` protos given in `serialized`. We refer + to `serialized` as a batch with `batch_size` many entries of individual + `Example` protos. + + This op parses serialized examples into a dictionary mapping keys to `Tensor` + and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`, + `SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature` + and `SparseFeature` is mapped to a `SparseTensor`, and each + `FixedLenFeature` is mapped to a `Tensor`. See `tf.parse_example` for more + details about feature dictionaries. + + Args: + features: A `dict` mapping feature keys to `FixedLenFeature`, + `VarLenFeature`, and `SparseFeature` values. + num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, + representing the number of parsing processes to call in parallel. + + Returns: + A dataset transformation function, which can be passed to + `tf.data.Dataset.apply`. + + Raises: + ValueError: if features argument is None. + """ + if features is None: + raise ValueError("Missing: features was %s." % features) + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls) + if any([ + isinstance(feature, parsing_ops.SparseFeature) + for _, feature in features.items() + ]): + # pylint: disable=protected-access + # pylint: disable=g-long-lambda + out_dataset = out_dataset.map( + lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features( + features, x), num_parallel_calls=num_parallel_calls) + return out_dataset + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 3882d4bfdbe899c2ce92f829cb331b32d3d50398..29005859d75514294defb36943756228af3b4402 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -25,8 +25,8 @@ import numpy as np from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.contrib.data.python.ops import parsing_ops from tensorflow.contrib.data.python.ops import shuffle_ops -from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.util import convert @@ -37,7 +37,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import file_io from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.util import deprecation @@ -326,7 +325,6 @@ def make_csv_dataset( shuffle_seed=None, prefetch_buffer_size=1, num_parallel_reads=1, - num_parallel_parser_calls=2, sloppy=False, num_rows_for_inference=100, compression_type=None, @@ -393,8 +391,6 @@ def make_csv_dataset( batches consumed per training step. num_parallel_reads: Number of threads used to read CSV records from files. If >1, the results will be interleaved. - num_parallel_parser_calls: Number of parallel invocations of the CSV parsing - function on CSV records. sloppy: If `True`, reading performance will be improved at the cost of non-deterministic ordering. If `False`, the order of elements produced is deterministic prior to shuffling (elements are still @@ -503,7 +499,7 @@ def make_csv_dataset( # indefinitely, and all batches will be full-sized. dataset = dataset.batch(batch_size=batch_size, drop_remainder=num_epochs is None) - dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls) + dataset = dataset.map(map_fn) dataset = dataset.prefetch(prefetch_buffer_size) return dataset @@ -778,8 +774,6 @@ def make_batched_features_dataset(file_pattern, dataset = _maybe_shuffle_and_repeat( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) - dataset = dataset.apply(stats_ops.feature_stats("record_stats")) - # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to # improve the shape inference, because it makes the batch dimension static. # It is safe to do this because in that case we are repeating the input @@ -788,9 +782,9 @@ def make_batched_features_dataset(file_pattern, batch_size, drop_remainder=drop_final_batch or num_epochs is None) # Parse `Example` tensors to a dictionary of `Feature` tensors. - dataset = dataset.map( - lambda x: parsing_ops.parse_example(x, features), - num_parallel_calls=parser_num_threads) + dataset = dataset.apply( + parsing_ops.parse_example_dataset( + features, num_parallel_calls=parser_num_threads)) # TODO(rachelim): Add an optional label_name argument for extracting the label # from the features dictionary, to comply with the type expected by the @@ -974,3 +968,49 @@ class SqlDataset(dataset_ops.Dataset): @property def output_types(self): return self._output_types + + +class LMDBDataset(dataset_ops.Dataset): + """A LMDB Dataset that reads the lmdb file.""" + + def __init__(self, filenames): + """Create a `LMDBDataset`. + + `LMDBDataset` allows a user to read data from a mdb file as + (key value) pairs sequentially. + For example: + ```python + dataset = tf.contrib.lmdb.LMDBDataset("/foo/bar.mdb") + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + # Prints the (key, value) pairs inside a lmdb file. + while True: + try: + print(sess.run(next_element)) + except tf.errors.OutOfRangeError: + break + ``` + Args: + filenames: A `tf.string` tensor containing one or more filenames. + """ + super(LMDBDataset, self).__init__() + self._filenames = ops.convert_to_tensor( + filenames, dtype=dtypes.string, name="filenames") + + def _as_variant_tensor(self): + return contrib_gen_dataset_ops.lmdb_dataset( + self._filenames, + output_types=nest.flatten(self.output_types), + output_shapes=nest.flatten(self.output_shapes)) + + @property + def output_classes(self): + return ops.Tensor, ops.Tensor + + @property + def output_shapes(self): + return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) + + @property + def output_types(self): + return dtypes.string, dtypes.string diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD index c16f1d6035d9fb4c5ffe29a713edfeaff299affc..02feeafb60a6e182f7061c981c9239881433381b 100644 --- a/tensorflow/contrib/distribute/BUILD +++ b/tensorflow/contrib/distribute/BUILD @@ -35,5 +35,6 @@ py_library( "//tensorflow/contrib/distribute/python:tpu_strategy", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/distribute:distribute_config", ], ) diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 2f5dd10550d0771d0cd3c2501d0456dc95077386..ba92ea0b124e2db86eec67fe736f17a36724c5e5 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -1,6 +1,6 @@ # Distribution Strategy -> *NOTE*: This is a experimental feature. The API and performance +> *NOTE*: This is an experimental feature. The API and performance > characteristics are subject to change. ## Overview @@ -9,7 +9,7 @@ API is an easy way to distribute your training across multiple devices/machines. Our goal is to allow users to use existing models and training code with minimal changes to enable distributed training. -Moreover, we've design the API in such a way that it works with both eager and +Moreover, we've designed the API in such a way that it works with both eager and graph execution. Currently we support one type of strategy, called diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py index 588a4f2898b2b7d818898990e4ce7bd343a32bfe..bf763215ba2db00cf4d1e28f938302cfb0184aab 100644 --- a/tensorflow/contrib/distribute/__init__.py +++ b/tensorflow/contrib/distribute/__init__.py @@ -27,6 +27,7 @@ from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceSt from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy from tensorflow.contrib.distribute.python.step_fn import * from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy +from tensorflow.python.distribute.distribute_config import DistributeConfig from tensorflow.python.training.distribute import * from tensorflow.python.training.distribution_strategy_context import * @@ -37,6 +38,7 @@ _allowed_symbols = [ 'AllReduceCrossTowerOps', 'CollectiveAllReduceStrategy', 'CrossTowerOps', + 'DistributeConfig', 'DistributionStrategy', 'MirroredStrategy', 'Monitor', diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 59efd17746d98ba4fd736e4e3b7772f52c2f5bd7..94deb2a432c5e64dfc6d01269a50bd99d506e110 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -23,8 +23,6 @@ py_library( deps = [ ":input_ops", ":prefetching_ops_v2", - "//tensorflow/contrib/data/python/ops:batching", - "//tensorflow/contrib/eager/python:datasets", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:device_util", @@ -85,6 +83,7 @@ py_library( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/distribute:multi_worker_util", "//tensorflow/python/eager:context", "//tensorflow/python/eager:tape", ], @@ -105,6 +104,38 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/eager:context", + ], +) + +cuda_py_test( + name = "parameter_server_strategy_test", + srcs = ["parameter_server_strategy_test.py"], + additional_deps = [ + ":combinations", + ":multi_worker_test_base", + ":parameter_server_strategy", + ":values", + "@absl_py//absl/testing:parameterized", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:layers", + "//tensorflow/python:session", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/distribute:multi_worker_util", + "//tensorflow/python/eager:context", + "//tensorflow/python/estimator:estimator_py", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", ], ) @@ -138,6 +169,7 @@ py_library( "//tensorflow/python:collective_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:training", + "//tensorflow/python/distribute:multi_worker_util", "//tensorflow/python/eager:context", ], ) @@ -237,40 +269,12 @@ py_test( ], ) -py_test( - name = "parameter_server_strategy_test", - srcs = ["parameter_server_strategy_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], - deps = [ - ":combinations", - ":multi_worker_test_base", - ":parameter_server_strategy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:layers", - "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/eager:context", - "//tensorflow/python/estimator:estimator_py", - "@absl_py//absl/testing:parameterized", - ], -) - cuda_py_test( name = "mirrored_strategy_multigpu_test", srcs = ["mirrored_strategy_multigpu_test.py"], additional_deps = [ ":mirrored_strategy", + ":multi_worker_test_base", ":values", ":strategy_test_lib", "//tensorflow/python:distribute", @@ -339,19 +343,17 @@ py_library( ], ) -py_test( +cuda_py_test( name = "collective_all_reduce_strategy_test", srcs = ["collective_all_reduce_strategy_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], - deps = [ + additional_deps = [ ":collective_all_reduce_strategy", ":combinations", ":cross_tower_utils", ":multi_worker_test_base", ":strategy_test_lib", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -365,8 +367,10 @@ py_test( "//tensorflow/python:variables", "//tensorflow/python/eager:context", "//tensorflow/python/estimator:estimator_py", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", ], ) @@ -446,6 +450,35 @@ cuda_py_test( ], ) +cuda_py_test( + name = "estimator_training_test", + size = "large", + srcs = ["estimator_training_test.py"], + additional_deps = [ + ":combinations", + ":mirrored_strategy", + ":multi_worker_test_base", + ":parameter_server_strategy", + "//third_party/py/numpy", + "//tensorflow/contrib/optimizer_v2:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute", + "//tensorflow/python/eager:test", + "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/feature_column", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + ], + tags = [ + "manual", + "multi_and_single_gpu", + "no_pip", + "nogpu", + "notap", + ], +) + py_library( name = "single_loss_example", srcs = ["single_loss_example.py"], @@ -601,6 +634,7 @@ cuda_py_test( ":combinations", ":cross_tower_ops", ":multi_worker_test_base", + ":mirrored_strategy", ":values", "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py index 95b824e51ab23b15afa64cb6cbfa7d12574bd427..865dba803f562e0ab98341dd8343e3c72b03d39b 100644 --- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py +++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py @@ -48,7 +48,7 @@ class CheckpointUtilsWithDistributionStrategyTest( mode=["graph"])) def testInitFromCheckpoint(self, distribution, in_tower_mode): checkpoint_dir = self.get_temp_dir() - with self.test_session() as session: + with self.cached_session() as session: v1_value, v2_value, _, _ = checkpoint_utils_test._create_checkpoints( session, checkpoint_dir) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index 9afcaecf78844b011a9dbc30bb95fa3bfeda8470..23314442614590632947fe89f7185ca04706a1fb 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -18,30 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import json -import os - from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import values -from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.distribute import multi_worker_util from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops -from tensorflow.python.training import server_lib - - -# TODO(yuefengz): move this function to a common util file. -def _normalize_cluster_spec(cluster_spec): - if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): - return server_lib.ClusterSpec(cluster_spec) - elif not isinstance(cluster_spec, server_lib.ClusterSpec): - raise ValueError( - "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " - "`tf.train.ClusterDef` object") - return cluster_spec # TODO(yuefengz): shard the dataset. @@ -52,51 +37,45 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): """Distribution strategy that uses collective ops for all-reduce. It is similar to the MirroredStrategy but it uses collective ops for - reduction. It currently only works for between-graph replication and its - reduction will reduce across all workers. + reduction. + + When `cluster_spec` is given by the `configure` method, it turns into the + mulit-worker version that works on multiple workers with between-graph + replication. + + Note: `configure` will be called by higher-level APIs if running in + distributed environment. """ - def __init__(self, - num_gpus_per_worker=0, - cluster_spec=None, - task_type="worker", - task_id=0): + def __init__(self, num_gpus_per_worker=0): """Initializes the object. Args: num_gpus_per_worker: number of local GPUs or GPUs per worker. - cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the - cluster configurations. - task_type: the current task type, such as "worker". - task_id: the current task id. - - Raises: - ValueError: if `task_type` is not in the `cluster_spec`. """ self._num_gpus_per_worker = num_gpus_per_worker - self._initialize(cluster_spec, task_type, task_id) + self._initialize(None, None, None) def _initialize(self, cluster_spec, task_type, task_id): - if task_type not in ["chief", "worker"]: - raise ValueError( - "Unrecognized task_type: %r, valid task types are: \"chief\", " - "\"worker\"." % task_type) if cluster_spec: - self._cluster_spec = _normalize_cluster_spec(cluster_spec) + if task_type is None or task_id is None: + raise ValueError("When `cluster_spec` is given, you must also specify " + "`task_type` and `task_id`") + if task_type not in ["chief", "worker"]: + raise ValueError( + "Unrecognized task_type: %r, valid task types are: \"chief\", " + "\"worker\"." % task_type) + self._cluster_spec = multi_worker_util.normalize_cluster_spec( + cluster_spec) worker_device = "/job:%s/task:%d" % (task_type, task_id) - num_workers = len(self._cluster_spec.as_dict().get(task_type, [])) - if "chief" in self._cluster_spec.as_dict(): - num_workers += 1 + num_workers = len(self._cluster_spec.as_dict().get("worker", [])) + len( + self._cluster_spec.as_dict().get("chief", [])) if not num_workers: - raise ValueError("`task_type` shoud be in `cluster_spec`.") + raise ValueError("No `worker` or `chief` tasks can be found in " + "`cluster_spec`.") - # TODO(yuefengz): create a utility to infer chief. - if "chief" in self._cluster_spec.as_dict() and task_type == "chief": - assert task_id == 0 - self._is_chief = True - else: - assert task_type == "worker" - self._is_chief = task_id == 0 + self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, + task_id) else: self._cluster_spec = None self._is_chief = True @@ -187,19 +166,41 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): return mirrored_strategy._create_mirrored_variable( devices, _real_mirrored_creator, *args, **kwargs) - def configure(self, session_config=None): - # Use TF_CONFIG to get the cluster spec and the current job. - if not self._cluster_spec: - tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) - cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {})) + def configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): + """Configures the object. - task_env = tf_config.get("task", {}) - if task_env: - task_type = task_env.get("type", "worker") - task_id = int(task_env.get("index", "0")) - else: - task_type = "worker" - task_id = 0 + Args: + session_config: a @{tf.ConfigProto} + cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the + cluster configurations. + task_type: the current task type, such as "worker". + task_id: the current task id. - if cluster_spec: - self._initialize(cluster_spec, task_type, task_id) + Raises: + ValueError: if `task_type` is not in the `cluster_spec`. + """ + # TODO(yuefengz): we'll need to mutate the session_config to add + # configurations for collective ops. + del session_config + if not self._cluster_spec and cluster_spec: + self._initialize(cluster_spec, task_type, task_id) + + @property + def between_graph(self): + return True + + @property + def should_init(self): + return True + + @property + def should_checkpoint(self): + return self._is_chief + + @property + def should_save_summary(self): + return self._is_chief diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index b5e54e3b7d7156e87731e6f79aa66262d127232c..e284969b1a4781a1654beb12b885618fcdd94634 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -25,10 +25,8 @@ from tensorflow.contrib.distribute.python import collective_all_reduce_strategy from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import cross_tower_utils from tensorflow.contrib.distribute.python import multi_worker_test_base -from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import context -from tensorflow.python.estimator import run_config from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -41,53 +39,43 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -class DistributedCollectiveAllReduceStrategyTest( - multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase): +class CollectiveAllReduceStrategyTestBase( + multi_worker_test_base.MultiWorkerTestBase): collective_key_base = 0 - @classmethod - def setUpClass(cls): - """Create a local cluster with 2 workers.""" - cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( - num_workers=3, num_ps=0) - cls._cluster_spec = { - run_config.TaskType.WORKER: [ - 'fake_worker_0', 'fake_worker_1', 'fake_worker_2' - ] - } - def setUp(self): self._run_options = config_pb2.RunOptions() self._run_options.experimental.collective_graph_key = 6 self._sess_config = config_pb2.ConfigProto() - self._sess_config.experimental.collective_group_leader = ( - '/job:worker/replica:0/task:0') # We use a different key_base for each test so that collective keys won't be # reused. # TODO(yuefengz, tucker): enable it to reuse collective keys in different # tests. - DistributedCollectiveAllReduceStrategyTest.collective_key_base += 100000 - super(DistributedCollectiveAllReduceStrategyTest, self).setUp() + CollectiveAllReduceStrategyTestBase.collective_key_base += 100000 + super(CollectiveAllReduceStrategyTestBase, self).setUp() def _get_test_object(self, task_type, task_id, num_gpus=0): distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( - num_gpus_per_worker=num_gpus, - cluster_spec=self._cluster_spec, - task_type=task_type, - task_id=task_id) + num_gpus_per_worker=num_gpus) + if task_type and task_id is not None: + distribution.configure( + cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id) collective_keys = cross_tower_utils.CollectiveKeys( group_key_start=10 * num_gpus + - DistributedCollectiveAllReduceStrategyTest.collective_key_base, + CollectiveAllReduceStrategyTestBase.collective_key_base, instance_key_start=num_gpus * 100 + - DistributedCollectiveAllReduceStrategyTest.collective_key_base, + CollectiveAllReduceStrategyTestBase.collective_key_base, instance_key_with_id_start=num_gpus * 10000 + - DistributedCollectiveAllReduceStrategyTest.collective_key_base) + CollectiveAllReduceStrategyTestBase.collective_key_base) distribution._collective_keys = collective_keys distribution._cross_tower_ops._collective_keys = collective_keys - return distribution, self._workers[task_id].target + if task_type and task_id is not None: + return distribution, 'grpc://' + self._cluster_spec[task_type][task_id] + else: + return distribution, '' def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): d, master_target = self._get_test_object(task_type, task_id, num_gpus) @@ -155,12 +143,6 @@ class DistributedCollectiveAllReduceStrategyTest( self.assertLess(error_after, error_before) return error_after < error_before - @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testMinimizeLossGraph(self, num_gpus): - self._run_between_graph_clients(self._test_minimize_loss_graph, - self._cluster_spec, num_gpus) - def _test_variable_initialization(self, task_type, task_id, num_gpus): distribution, master_target = self._get_test_object(task_type, task_id, num_gpus) @@ -182,16 +164,74 @@ class DistributedCollectiveAllReduceStrategyTest( distribution.reduce( variable_scope.VariableAggregation.MEAN, x, destinations='/cpu:0'))[0] + x = distribution.unwrap(x)[0] sess.run( variables.global_variables_initializer(), options=self._run_options) + x_value, reduced_x_value = sess.run( [x, reduced_x], options=self._run_options) - self.assertTrue(np.array_equal(x_value, reduced_x_value)) - return np.array_equal(x_value, reduced_x_value) + self.assertTrue( + np.allclose(x_value, reduced_x_value, atol=1e-5), + msg=('x_value = %r, reduced_x_value = %r' % (x_value, + reduced_x_value))) + return np.allclose(x_value, reduced_x_value, atol=1e-5) + + +class DistributedCollectiveAllReduceStrategyTest( + CollectiveAllReduceStrategyTestBase, parameterized.TestCase): + + @classmethod + def setUpClass(cls): + """Create a local cluster with 3 workers.""" + cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0) + + def setUp(self): + super(DistributedCollectiveAllReduceStrategyTest, self).setUp() + self._sess_config.experimental.collective_group_leader = ( + '/job:worker/replica:0/task:0') + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testVariableInitialization(self, num_gpus): + if context.num_gpus() < num_gpus: + return + self._run_between_graph_clients( + self._test_variable_initialization, + self._cluster_spec, + num_gpus=num_gpus) + + +class DistributedCollectiveAllReduceStrategyTestWithChief( + CollectiveAllReduceStrategyTestBase, parameterized.TestCase): + + @classmethod + def setUpClass(cls): + """Create a local cluster with 3 workers and 1 chief.""" + cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=0, has_chief=True) + + def setUp(self): + super(DistributedCollectiveAllReduceStrategyTestWithChief, self).setUp() + self._run_options.experimental.collective_graph_key = 7 + self._sess_config.experimental.collective_group_leader = ( + '/job:chief/replica:0/task:0') @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) def testVariableInitialization(self, num_gpus): if context.num_gpus() < num_gpus: return @@ -201,16 +241,14 @@ class DistributedCollectiveAllReduceStrategyTest( num_gpus=num_gpus) -class LocalCollectiveAllReduceStrategy(strategy_test_lib.DistributionTestBase, - parameterized.TestCase): +class LocalCollectiveAllReduceStrategy( + CollectiveAllReduceStrategyTestBase, parameterized.TestCase): def testMinimizeLossGraph(self, num_gpus=2): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: return - distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy( - num_gpus_per_worker=num_gpus) - self._test_minimize_loss_graph(distribution) + self._test_minimize_loss_graph(None, None, num_gpus) if __name__ == '__main__': diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 2fbadfe0f5ad9ef0a4255f51abe4aad5a0646efe..2301ba9233d29a1e5d054e71e4d9383af8bd48fd 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -341,33 +341,6 @@ mirrored_strategy_with_two_gpus = NamedDistribution( ["/gpu:0", "/gpu:1"], prefetch_on_device=False), required_gpus=2) -multi_worker_strategy_with_cpu = NamedDistribution( - "MultiWorkerCPU", - lambda: mirrored_lib.MirroredStrategy( - cluster_spec={ - "worker": [ - "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" - ] - }, - num_gpus=0), 0) -multi_worker_strategy_with_one_gpu = NamedDistribution( - "MultiWorker1GPU", - lambda: mirrored_lib.MirroredStrategy( - cluster_spec={ - "worker": [ - "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" - ] - }, - num_gpus=1), 1) -multi_worker_strategy_with_two_gpus = NamedDistribution( - "MultiWorker2GPUs", - lambda: mirrored_lib.MirroredStrategy( - cluster_spec={ - "worker": [ - "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" - ] - }, - num_gpus=2), 2) adam_optimizer_v1_fn = NamedObject( "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index 163559587da3b8b6f175e295602f767c08468a28..2a653b0f10c89b4938a5d3cf3802afe28cfb9387 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -62,7 +62,44 @@ def validate_destinations(destinations): raise ValueError("destinations can not be empty") +def _make_tensor_into_per_device(input_tensor): + """Converts a single tensor into a PerDevice object.""" + if isinstance(input_tensor, (tuple, list)): + raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object, " + "got %r but expected a object that is not a tuple or list." + % (input_tensor,)) + if isinstance(input_tensor, value_lib.PerDevice): + return input_tensor + + try: + device = input_tensor.device + except AttributeError: + raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object " + "because it doesn't have device set.") + + return value_lib.PerDevice({device: input_tensor}) + + +def _normalize_value_destination_pairs(value_destination_pairs): + """Converts each tensor into a PerDevice object in the input list.""" + result = [] + if not isinstance(value_destination_pairs, (list, tuple)): + raise ValueError("`value_destination_pairs` should be a list or tuple") + for pair in value_destination_pairs: + if not isinstance(pair, tuple): + raise ValueError( + "Each element of `value_destination_pairs` should be a tuple.") + if len(pair) != 2: + raise ValueError("Each element of `value_destination_pairs` should be a " + "tuple of size 2.") + + per_device = _make_tensor_into_per_device(pair[0]) + result.append((per_device, pair[1])) + return result + + def _validate_value_destination_pairs(value_destination_pairs): + # TODO(yuefengz): raise exceptions instead of returning False. # pylint: disable=g-missing-docstring if not value_destination_pairs: return False if not isinstance(value_destination_pairs, (list, tuple)): return False @@ -83,8 +120,10 @@ def get_devices_from(destinations): return [destinations.device] elif isinstance(destinations, six.string_types): return [device_util.resolve(destinations)] - else: + elif isinstance(destinations, (list, tuple)): return [device_util.resolve(destination) for destination in destinations] + else: + return [destinations.device] def _devices_match(left, right): @@ -159,7 +198,7 @@ class CrossTowerOps(object): Args: aggregation: Indicates how a variable will be aggregated. Accepted values are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. - per_device_value: a PerDevice object. + per_device_value: a PerDevice object or a tensor with device set. destinations: the reduction destinations. Returns: @@ -169,7 +208,8 @@ class CrossTowerOps(object): ValueError: if per_device_value is not a PerDevice object. """ if not isinstance(per_device_value, value_lib.PerDevice): - raise ValueError("`per_device_value` must be a `PerDevice` object.") + per_device_value = _make_tensor_into_per_device(per_device_value) + if destinations is not None: validate_destinations(destinations) return self._reduce(aggregation, per_device_value, destinations) @@ -184,8 +224,9 @@ class CrossTowerOps(object): aggregation: Indicates how a variable will be aggregated. Accepted values are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. value_destination_pairs: a list or a tuple of tuples of PerDevice objects - and destinations. If a destination is None, then the destinations - are set to match the devices of the input PerDevice object. + (or tensors with device set if there is one tower) and destinations. If + a destination is None, then the destinations are set to match the + devices of the input PerDevice object. Returns: a list of Mirrored objects. @@ -195,8 +236,11 @@ class CrossTowerOps(object): tuples of PerDevice objects and destinations """ if not _validate_value_destination_pairs(value_destination_pairs): - raise ValueError("`value_destination_pairs` must be a list or a tuple of " - "tuples of PerDevice objects and destinations") + # If the first element of each pair is a tensor, we try to turn it into a + # PerDevice object. + value_destination_pairs = _normalize_value_destination_pairs( + value_destination_pairs) + for _, d in value_destination_pairs: if d is not None: validate_destinations(d) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index 3508c9d5997070ef1350d4f08f98bf2d9c8b6837..2ad91d56e92fd8b4b847af5ed7a27b8e228b4694 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -26,12 +26,12 @@ import numpy as np from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import test -from tensorflow.python.estimator import run_config from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -40,9 +40,17 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import device_util -def _make_per_device(values, devices): +def _make_per_device(values, devices, regroup=False): devices = cross_tower_ops_lib.get_devices_from(devices) assert len(values) == len(devices) + + # We simulate the result of regroup called on PerDevice which strips the + # PerDevice wrapper if it has only one value. + if len(values) == 1 and regroup: + with ops.device(devices[0]): + placed_v = array_ops.identity(values[0]) + return placed_v + index = {} for d, v in zip(devices, values): with ops.device(d): @@ -368,14 +376,27 @@ class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase, ("xring", 2, -1)], 0, 0, 0)), ], distribution=[ - combinations.multi_worker_strategy_with_cpu, - combinations.multi_worker_strategy_with_one_gpu, - combinations.multi_worker_strategy_with_two_gpus + combinations.NamedDistribution( + "MirroredCPU", + lambda: mirrored_strategy.MirroredStrategy(num_gpus=0), + required_gpus=0), + combinations.NamedDistribution( + "Mirrored1GPU", + lambda: mirrored_strategy.MirroredStrategy(num_gpus=1), + required_gpus=1), + combinations.NamedDistribution( + "Mirrored2GPUs", + lambda: mirrored_strategy.MirroredStrategy(num_gpus=2), + required_gpus=2), ], mode=["graph"]) @combinations.generate(multi_worker_allreduce_combinations) def testReductionAndBroadcast(self, cross_tower_ops, distribution): + distribution.configure(cluster_spec={ + "worker": + ["/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"] + }) with distribution.scope(): self._testReductionAndBroadcast(cross_tower_ops, distribution) @@ -388,13 +409,8 @@ class MultiWorkerCollectiveAllReduceTest( @classmethod def setUpClass(cls): """Create a local cluster with 2 workers.""" - cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( + cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=0) - cls._cluster_spec = { - run_config.TaskType.WORKER: [ - "fake_worker_0", "fake_worker_1", "fake_worker_2" - ] - } def setUp(self): super(MultiWorkerCollectiveAllReduceTest, self).setUp() @@ -428,7 +444,8 @@ class MultiWorkerCollectiveAllReduceTest( ] else: devices = ["/job:%s/task:%d" % (task_type, task_id)] - return collective_all_reduce_ops, devices, self._workers[task_id].target + return (collective_all_reduce_ops, devices, + "grpc://" + self._cluster_spec[task_type][task_id]) def _assert_values_equal(self, left, right, sess): if isinstance(left, list): @@ -455,7 +472,8 @@ class MultiWorkerCollectiveAllReduceTest( num_workers = 1 worker_device = None else: - num_workers = len(self._workers) + num_workers = len(self._cluster_spec.get("chief", [])) + len( + self._cluster_spec.get("worker", [])) worker_device = "/job:%s/task:%d" % (task_type, task_id) with ops.Graph().as_default(), \ ops.device(worker_device), \ @@ -463,7 +481,7 @@ class MultiWorkerCollectiveAllReduceTest( # Collective ops doesn't support scalar tensors, so we have to construct # 1-d tensors. values = [constant_op.constant([float(d)]) for d in range(len(devices))] - per_device = _make_per_device(values, devices) + per_device = _make_per_device(values, devices, regroup=True) mean = np.array([(len(devices) - 1.) / 2.]) values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))] @@ -533,7 +551,7 @@ class MultiWorkerCollectiveAllReduceTest( return True @combinations.generate( - combinations.combine(mode=["graph"], num_gpus=[0, 1, 2])) + combinations.combine(mode=["graph"], num_gpus=[0, 1, 2], required_gpus=1)) def testReductionDistributed(self, num_gpus): if context.num_gpus() < num_gpus: return diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5348512016efc504f92e5a956d627698b93b209a --- /dev/null +++ b/tensorflow/contrib/distribute/python/estimator_training_test.py @@ -0,0 +1,659 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests that show Distribute Coordinator works with Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import json +import os +import sys +import tempfile +import threading +from absl.testing import parameterized +import numpy as np +import six + +_portpicker_import_error = None +try: + import portpicker # pylint: disable=g-import-not-at-top +except ImportError as _error: # pylint: disable=invalid-name + _portpicker_import_error = _error + portpicker = None + +# pylint: disable=g-import-not-at-top +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import parameter_server_strategy +from tensorflow.contrib.optimizer_v2 import adagrad +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribute_coordinator as dc +from tensorflow.python.distribute import estimator_training as dc_training +from tensorflow.python.distribute.distribute_config import DistributeConfig +from tensorflow.python.eager import context +from tensorflow.python.estimator import exporter as exporter_lib +from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.estimator import training as estimator_training +from tensorflow.python.estimator.canned import dnn_linear_combined +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export as export_lib +from tensorflow.python.feature_column import feature_column +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary import summary_iterator +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import server_lib + +BATCH_SIZE = 10 +LABEL_DIMENSION = 2 +DATA = np.linspace( + 0., 2., BATCH_SIZE * LABEL_DIMENSION, dtype=np.float32).reshape( + BATCH_SIZE, LABEL_DIMENSION) +EVAL_NAME = "foo" +EXPORTER_NAME = "saved_model_exporter" +MAX_STEPS = 10 + +CHIEF = dc._TaskType.CHIEF +EVALUATOR = dc._TaskType.EVALUATOR +WORKER = dc._TaskType.WORKER +PS = dc._TaskType.PS + +original_run_distribute_coordinator = dc.run_distribute_coordinator + + +# TODO(yuefengz): merge this method back to test_util. +def _create_local_cluster(num_workers, + num_ps, + has_eval=False, + protocol="grpc", + worker_config=None, + ps_config=None): + if _portpicker_import_error: + raise _portpicker_import_error # pylint: disable=raising-bad-type + worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] + ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] + + cluster_dict = { + "worker": ["localhost:%s" % port for port in worker_ports], + "ps": ["localhost:%s" % port for port in ps_ports] + } + if has_eval: + cluster_dict["evaluator"] = ["localhost:%s" % portpicker.pick_unused_port()] + + cs = server_lib.ClusterSpec(cluster_dict) + + workers = [ + server_lib.Server( + cs, + job_name="worker", + protocol=protocol, + task_index=ix, + config=worker_config, + start=True) for ix in range(num_workers) + ] + ps_servers = [ + server_lib.Server( + cs, + job_name="ps", + protocol=protocol, + task_index=ix, + config=ps_config, + start=True) for ix in range(num_ps) + ] + if has_eval: + evals = [ + server_lib.Server( + cs, + job_name="evaluator", + protocol=protocol, + task_index=0, + config=worker_config, + start=True) + ] + else: + evals = [] + + return workers, ps_servers, evals + + +def _create_in_process_cluster(num_workers, num_ps, has_eval=False): + """Create an in-process cluster that consists of only standard server.""" + # Leave some memory for cuda runtime. + if has_eval: + gpu_mem_frac = 0.7 / (num_workers + 1) + else: + gpu_mem_frac = 0.7 / num_workers + + worker_config = config_pb2.ConfigProto() + worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac + + # Enable collective ops which has no impact on non-collective ops. + # TODO(yuefengz, tucker): removing this after we move the initialization of + # collective mgr to the session level. + worker_config.experimental.collective_group_leader = ( + "/job:worker/replica:0/task:0") + + ps_config = config_pb2.ConfigProto() + ps_config.device_count["GPU"] = 0 + + return _create_local_cluster( + num_workers, + num_ps=num_ps, + has_eval=has_eval, + worker_config=worker_config, + ps_config=ps_config, + protocol="grpc") + + +def _create_cluster_spec(has_chief=False, + num_workers=1, + num_ps=0, + has_eval=False): + if _portpicker_import_error: + raise _portpicker_import_error # pylint: disable=raising-bad-type + + cluster_spec = {} + if has_chief: + cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()] + if num_workers: + cluster_spec[WORKER] = [ + "localhost:%s" % portpicker.pick_unused_port() + for _ in range(num_workers) + ] + if num_ps: + cluster_spec[PS] = [ + "localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps) + ] + if has_eval: + cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()] + return cluster_spec + + +def _bytes_to_str(maybe_bytes): + if isinstance(maybe_bytes, six.string_types): + return maybe_bytes + else: + return str(maybe_bytes, "utf-8") + + +def _strip_protocol(target): + # cluster_spec expects "host:port" strings. + if "//" in target: + return target.split("//")[1] + else: + return target + + +class DistributeCoordinatorIntegrationTest(test.TestCase, + parameterized.TestCase): + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + cls._workers, cls._ps, cls._evals = _create_in_process_cluster( + num_workers=3, num_ps=2, has_eval=True) + cls._cluster_spec = { + "worker": [ + _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers + ], + "ps": [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps], + "evaluator": [ + _strip_protocol(_bytes_to_str(e.target)) for e in cls._evals + ] + } + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + self._event = threading.Event() + super(DistributeCoordinatorIntegrationTest, self).setUp() + + def dataset_input_fn(self, x, y, batch_size, shuffle): + + def input_fn(): + dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) + if shuffle: + dataset = dataset.shuffle(batch_size) + dataset = dataset.repeat(100).batch(batch_size) + return dataset + + return input_fn + + def _get_exporter(self, name, fc): + feature_spec = feature_column.make_parse_example_spec(fc) + serving_input_receiver_fn = ( + export_lib.build_parsing_serving_input_receiver_fn(feature_spec)) + return exporter_lib.LatestExporter( + name, serving_input_receiver_fn=serving_input_receiver_fn) + + def _extract_loss_and_global_step(self, event_folder): + """Returns the loss and global step in last event.""" + event_paths = glob.glob(os.path.join(event_folder, "events*")) + + loss = None + global_step_count = None + + for e in summary_iterator.summary_iterator(event_paths[-1]): + current_loss = None + for v in e.summary.value: + if v.tag == "loss": + current_loss = v.simple_value + + # If loss is not found, global step is meaningless. + if current_loss is None: + continue + + current_global_step = e.step + if global_step_count is None or current_global_step > global_step_count: + global_step_count = current_global_step + loss = current_loss + + return (loss, global_step_count) + + def _get_estimator(self, + train_distribute, + eval_distribute, + remote_cluster=None): + input_dimension = LABEL_DIMENSION + linear_feature_columns = [ + feature_column.numeric_column("x", shape=(input_dimension,)) + ] + dnn_feature_columns = [ + feature_column.numeric_column("x", shape=(input_dimension,)) + ] + + return dnn_linear_combined.DNNLinearCombinedRegressor( + linear_feature_columns=linear_feature_columns, + dnn_hidden_units=(2, 2), + dnn_feature_columns=dnn_feature_columns, + label_dimension=LABEL_DIMENSION, + model_dir=self._model_dir, + dnn_optimizer=adagrad.AdagradOptimizer(0.001), + linear_optimizer=adagrad.AdagradOptimizer(0.001), + config=run_config_lib.RunConfig( + experimental_distribute=DistributeConfig( + train_distribute=train_distribute, + eval_distribute=eval_distribute, + remote_cluster=remote_cluster))) + + def _complete_flow(self, + train_distribute, + eval_distribute, + remote_cluster=None): + estimator = self._get_estimator(train_distribute, eval_distribute, + remote_cluster) + + input_dimension = LABEL_DIMENSION + train_input_fn = self.dataset_input_fn( + x={"x": DATA}, + y=DATA, + batch_size=BATCH_SIZE // len(train_distribute.worker_devices), + shuffle=True) + if eval_distribute: + eval_batch_size = BATCH_SIZE // len(eval_distribute.worker_devices) + else: + eval_batch_size = BATCH_SIZE + eval_input_fn = self.dataset_input_fn( + x={"x": DATA}, y=DATA, batch_size=eval_batch_size, shuffle=False) + + linear_feature_columns = [ + feature_column.numeric_column("x", shape=(input_dimension,)) + ] + dnn_feature_columns = [ + feature_column.numeric_column("x", shape=(input_dimension,)) + ] + feature_columns = linear_feature_columns + dnn_feature_columns + + estimator_training.train_and_evaluate( + estimator, + estimator_training.TrainSpec(train_input_fn, max_steps=MAX_STEPS), + estimator_training.EvalSpec( + name=EVAL_NAME, + input_fn=eval_input_fn, + steps=None, + exporters=self._get_exporter(EXPORTER_NAME, feature_columns), + start_delay_secs=0, + throttle_secs=1)) + return estimator + + def _inspect_train_and_eval_events(self, estimator): + # Make sure nothing is stuck in limbo. + writer_cache.FileWriterCache.clear() + + # Examine the training events. Use a range to check global step to avoid + # flakyness due to global step race condition. + training_loss, _ = self._extract_loss_and_global_step(self._model_dir) + self.assertIsNotNone(training_loss) + + # Examine the eval events. The global step should be accurate. + eval_dir = os.path.join(self._model_dir, "eval_" + EVAL_NAME) + eval_loss, eval_global_step = self._extract_loss_and_global_step( + event_folder=eval_dir) + self.assertIsNotNone(eval_loss) + self.assertGreaterEqual(eval_global_step, MAX_STEPS) + + # Examine the export folder. + export_dir = os.path.join( + os.path.join(self._model_dir, "export"), EXPORTER_NAME) + self.assertTrue(gfile.Exists(export_dir)) + + # Examine the ckpt for predict. + def predict_input_fn(): + return dataset_ops.Dataset.from_tensor_slices({ + "x": DATA + }).batch(BATCH_SIZE) + + predicted_proba = np.array([ + x[prediction_keys.PredictionKeys.PREDICTIONS] + for x in estimator.predict(predict_input_fn) + ]) + self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape) + + @combinations.generate( + combinations.combine( + mode=["graph"], + train_distribute_cls=[ + mirrored_strategy.MirroredStrategy, + parameter_server_strategy.ParameterServerStrategy + ], + eval_distribute_cls=[ + None, mirrored_strategy.MirroredStrategy, + parameter_server_strategy.ParameterServerStrategy + ], + required_gpus=1)) + def test_complete_flow_standalone_client(self, train_distribute_cls, + eval_distribute_cls): + try: + train_distribute = train_distribute_cls(num_gpus=context.num_gpus()) + except TypeError: + train_distribute = train_distribute_cls(num_gpus_per_worker=2) + + if eval_distribute_cls: + eval_distribute = eval_distribute_cls() + else: + eval_distribute = None + + estimator = self._complete_flow( + train_distribute, eval_distribute, remote_cluster=self._cluster_spec) + self._inspect_train_and_eval_events(estimator) + + def _mock_run_distribute_coordinator( + self, + worker_fn, + strategy, + eval_fn, + eval_strategy, + mode=dc.CoordinatorMode.STANDALONE_CLIENT, + cluster_spec=None, + session_config=None): + # Calls the origial `run_distribute_coordinator` method but gets task config + # from environment variables and then signals the caller. + task_type = None + task_id = None + if not cluster_spec: + cluster_spec = None + tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) + if not cluster_spec: + cluster_spec = tf_config.get("cluster", {}) + task_env = tf_config.get("task", {}) + if task_env: + task_type = task_env.get("type", task_type) + task_id = int(task_env.get("index", task_id)) + self._event.set() + original_run_distribute_coordinator( + worker_fn, + strategy, + eval_fn, + eval_strategy, + mode=mode, + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id, + session_config=session_config) + + def _task_thread(self, train_distribute, eval_distribute): + with test.mock.patch.object(dc, "run_distribute_coordinator", + self._mock_run_distribute_coordinator): + self._complete_flow(train_distribute, eval_distribute) + + def _run_task_in_thread(self, cluster_spec, task_type, task_id, + train_distribute, eval_distribute): + if task_type: + tf_config = { + "cluster": cluster_spec, + "task": { + "type": task_type, + "index": task_id + } + } + else: + tf_config = { + "cluster": cluster_spec, + "task": { + "type": task_type, + "index": task_id + } + } + self._event.clear() + t = threading.Thread( + target=self._task_thread, args=(train_distribute, eval_distribute)) + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(tf_config)}): + t.start() + self._event.wait() + return t + + def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute, + eval_distribute): + threads = {} + for task_type in cluster_spec.keys(): + threads[task_type] = [] + for task_id in range(len(cluster_spec[task_type])): + t = self._run_task_in_thread(cluster_spec, task_type, task_id, + train_distribute, eval_distribute) + threads[task_type].append(t) + return threads + + @combinations.generate( + combinations.combine( + mode=["graph"], + train_distribute_cls=[ + parameter_server_strategy.ParameterServerStrategy, + ], + eval_distribute_cls=[ + None, mirrored_strategy.MirroredStrategy, + parameter_server_strategy.ParameterServerStrategy + ], + required_gpus=1)) + def test_complete_flow_indepedent_worker_between_graph( + self, train_distribute_cls, eval_distribute_cls): + train_distribute = train_distribute_cls( + num_gpus_per_worker=context.num_gpus()) + + if eval_distribute_cls: + eval_distribute = eval_distribute_cls() + else: + eval_distribute = None + + cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True) + threads = self._run_multiple_tasks_in_threads( + cluster_spec, train_distribute, eval_distribute) + for task_type, ts in threads.items(): + if task_type == PS: + continue + for t in ts: + t.join() + + estimator = self._get_estimator(train_distribute, eval_distribute) + self._inspect_train_and_eval_events(estimator) + + @combinations.generate( + combinations.combine( + mode=["graph"], + train_distribute_cls=[mirrored_strategy.MirroredStrategy], + eval_distribute_cls=[None, mirrored_strategy.MirroredStrategy], + required_gpus=1)) + def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls, + eval_distribute_cls): + train_distribute = train_distribute_cls(num_gpus=context.num_gpus()) + + if eval_distribute_cls: + eval_distribute = eval_distribute_cls() + else: + eval_distribute = None + + cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True) + threads = self._run_multiple_tasks_in_threads( + cluster_spec, train_distribute, eval_distribute) + threads[WORKER][0].join() + threads[EVALUATOR][0].join() + + estimator = self._get_estimator(train_distribute, eval_distribute) + self._inspect_train_and_eval_events(estimator) + + +TF_CONFIG_WITH_CHIEF = { + "cluster": { + "chief": ["fake_chief"], + }, + "task": { + "type": "chief", + "index": 0 + } +} + +TF_CONFIG_WITH_MASTER = { + "cluster": { + "master": ["fake_master"], + }, + "task": { + "type": "master", + "index": 0 + } +} + +TF_CONFIG_WITHOUT_TASK = {"cluster": {"chief": ["fake_worker"]}} + + +class RunConfigTest(test.TestCase): + + def test_previously_unexpected_cluster_spec(self): + with test.mock.patch.dict( + "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITHOUT_TASK)}): + run_config_lib.RunConfig( + experimental_distribute=DistributeConfig( + train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + + def test_should_run_distribute_coordinator(self): + """Tests that should_run_distribute_coordinator return a correct value.""" + # We don't use distribute coordinator for local training. + self.assertFalse( + dc_training.should_run_distribute_coordinator( + run_config_lib.RunConfig())) + + # When `train_distribute` is not specified, don't use distribute + # coordinator. + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): + self.assertFalse( + dc_training.should_run_distribute_coordinator( + run_config_lib.RunConfig())) + + # When `train_distribute` is specified and TF_CONFIG is detected, use + # distribute coordinator. + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): + config_with_train_distribute = run_config_lib.RunConfig( + experimental_distribute=DistributeConfig( + train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + config_with_eval_distribute = run_config_lib.RunConfig( + experimental_distribute=DistributeConfig( + eval_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + self.assertTrue( + dc_training.should_run_distribute_coordinator( + config_with_train_distribute)) + self.assertFalse( + dc_training.should_run_distribute_coordinator( + config_with_eval_distribute)) + + # With a master in the cluster, don't run distribute coordinator. + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}): + config = run_config_lib.RunConfig( + experimental_distribute=DistributeConfig( + train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2))) + self.assertFalse(dc_training.should_run_distribute_coordinator(config)) + + def test_init_run_config_duplicate_distribute(self): + with self.assertRaises(ValueError): + run_config_lib.RunConfig( + train_distribute=mirrored_strategy.MirroredStrategy(), + experimental_distribute=DistributeConfig( + train_distribute=mirrored_strategy.MirroredStrategy())) + + with self.assertRaises(ValueError): + run_config_lib.RunConfig( + eval_distribute=mirrored_strategy.MirroredStrategy(), + experimental_distribute=DistributeConfig( + eval_distribute=mirrored_strategy.MirroredStrategy())) + + def test_init_run_config_none_distribute_coordinator_mode(self): + # We don't use distribute coordinator for local training. + config = run_config_lib.RunConfig( + train_distribute=mirrored_strategy.MirroredStrategy()) + dc_training.init_run_config(config, {}) + self.assertIsNone(config._distribute_coordinator_mode) + + # With a master in the cluster, don't run distribute coordinator. + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}): + config = run_config_lib.RunConfig( + train_distribute=mirrored_strategy.MirroredStrategy()) + self.assertIsNone(config._distribute_coordinator_mode) + + # When `train_distribute` is not specified, don't use distribute + # coordinator. + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): + config = run_config_lib.RunConfig() + self.assertFalse(hasattr(config, "_distribute_coordinator_mode")) + + def test_init_run_config_independent_worker(self): + # When `train_distribute` is specified and TF_CONFIG is detected, use + # distribute coordinator with INDEPENDENT_WORKER mode. + with test.mock.patch.dict("os.environ", + {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}): + config = run_config_lib.RunConfig( + train_distribute=mirrored_strategy.MirroredStrategy()) + self.assertEqual(config._distribute_coordinator_mode, + dc.CoordinatorMode.INDEPENDENT_WORKER) + + def test_init_run_config_standalone_client(self): + # When `train_distribute` is specified, TF_CONFIG is detected and + # `experimental.remote_cluster` is set use distribute coordinator with + # STANDALONE_CLIENT mode. + config = run_config_lib.RunConfig( + train_distribute=mirrored_strategy.MirroredStrategy(), + experimental_distribute=DistributeConfig( + remote_cluster={"chief": ["fake_worker"]})) + self.assertEqual(config._distribute_coordinator_mode, + dc.CoordinatorMode.STANDALONE_CLIENT) + + +if __name__ == "__main__": + with test.mock.patch.object(sys, "exit", os._exit): + test.main() diff --git a/tensorflow/contrib/distribute/python/examples/BUILD b/tensorflow/contrib/distribute/python/examples/BUILD index cbfd17850212a1c007e2edb9dd3986b3109f040d..84b106545e1326fddd3ed299462534af982dc102 100644 --- a/tensorflow/contrib/distribute/python/examples/BUILD +++ b/tensorflow/contrib/distribute/python/examples/BUILD @@ -19,9 +19,20 @@ py_binary( ) py_binary( - name = "simple_tfkeras_example", + name = "keras_model_with_estimator", srcs = [ - "simple_tfkeras_example.py", + "keras_model_with_estimator.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) + +py_binary( + name = "keras_mnist", + srcs = [ + "keras_mnist.py", ], deps = [ "//tensorflow:tensorflow_py", diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..a20069c4fe4713897ba9543cd56615db7a2fc3cb --- /dev/null +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -0,0 +1,126 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An example training a Keras Model using MirroredStrategy and native APIs.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +NUM_CLASSES = 10 + + +def get_input_datasets(): + """Downloads the MNIST dataset and creates train and eval dataset objects. + + Returns: + Train dataset, eval dataset and input shape. + + """ + # input image dimensions + img_rows, img_cols = 28, 28 + + # the data, split between train and test sets + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() + + if tf.keras.backend.image_data_format() == 'channels_first': + x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) + x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) + input_shape = (1, img_rows, img_cols) + else: + x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) + x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) + input_shape = (img_rows, img_cols, 1) + + x_train = x_train.astype('float32') + x_test = x_test.astype('float32') + x_train /= 255 + x_test /= 255 + + # convert class vectors to binary class matrices + y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES) + y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES) + + # train dataset + train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) + train_ds = train_ds.repeat() + train_ds = train_ds.shuffle(100) + train_ds = train_ds.batch(64) + + # eval dataset + eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + eval_ds = eval_ds.repeat() + eval_ds = eval_ds.shuffle(100) + eval_ds = eval_ds.batch(64) + + return train_ds, eval_ds, input_shape + + +def get_model(input_shape): + """Builds a Sequential CNN model to recognize MNIST digits. + + Args: + input_shape: Shape of the input depending on the `image_data_format`. + + Returns: + a Keras model + + """ + # Define a CNN model to recognize MNIST digits. + model = tf.keras.models.Sequential() + model.add(tf.keras.layers.Conv2D(32, kernel_size=(3, 3), + activation='relu', + input_shape=input_shape)) + model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu')) + model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2))) + model.add(tf.keras.layers.Dropout(0.25)) + model.add(tf.keras.layers.Flatten()) + model.add(tf.keras.layers.Dense(128, activation='relu')) + model.add(tf.keras.layers.Dropout(0.5)) + model.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')) + return model + + +def main(_): + # Build the train and eval datasets from the MNIST data. Also return the + # input shape which is constructed based on the `image_data_format` + # i.e channels_first or channels_last. + train_ds, eval_ds, input_shape = get_input_datasets() + model = get_model(input_shape) + + # Instantiate the MirroredStrategy object. If we don't specify `num_gpus` or + # the `devices` argument then all the GPUs available on the machine are used. + strategy = tf.contrib.distribute.MirroredStrategy() + + # Compile the model by passing the distribution strategy object to the + # `distribute` argument. `fit`, `evaluate` and `predict` will be distributed + # based on the strategy instantiated. + model.compile(loss=tf.keras.losses.categorical_crossentropy, + optimizer=tf.train.RMSPropOptimizer(learning_rate=0.001), + metrics=['accuracy'], + distribute=strategy) + + # Train the model with the train dataset. + model.fit(x=train_ds, epochs=20, steps_per_epoch=310) + + # Evaluate the model with the eval dataset. + score = model.evaluate(eval_ds, steps=10, verbose=0) + print('Test loss:', score[0]) + print('Test accuracy:', score[1]) + + +if __name__ == '__main__': + tf.app.run() diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/keras_model_with_estimator.py similarity index 91% rename from tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py rename to tensorflow/contrib/distribute/python/examples/keras_model_with_estimator.py index 518ec9c4232465c3ecd0e4161f707dac499430c7..8d117eb7e8f5463a0a1c7e9814829d65c6111289 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py +++ b/tensorflow/contrib/distribute/python/examples/keras_model_with_estimator.py @@ -42,19 +42,19 @@ def main(args): model_dir = args[1] print('Using %s to store checkpoints.' % model_dir) - # Define tf.keras Model. + # Define a Keras Model. model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,))) model.add(tf.keras.layers.Dense(1, activation='sigmoid')) - # Compile tf.keras Model. + # Compile the model. optimizer = tf.train.GradientDescentOptimizer(0.2) model.compile(loss='binary_crossentropy', optimizer=optimizer) model.summary() tf.keras.backend.set_learning_phase(True) - # Define a DistributionStrategy and convert the tf.keras Model to a - # tf.Estimator that utilizes the DistributionStrategy. + # Define a DistributionStrategy and convert the Keras Model to an + # Estimator that utilizes the DistributionStrategy. strategy = tf.contrib.distribute.MirroredStrategy( ['/device:GPU:0', '/device:GPU:1']) config = tf.estimator.RunConfig( @@ -62,7 +62,7 @@ def main(args): keras_estimator = tf.keras.estimator.model_to_estimator( keras_model=model, config=config, model_dir=model_dir) - # Train and evaluate the tf.Estimator. + # Train and evaluate the model. keras_estimator.train(input_fn=input_fn, steps=10) eval_result = keras_estimator.evaluate(input_fn=input_fn) print('Eval result: {}'.format(eval_result)) diff --git a/tensorflow/contrib/distribute/python/input_ops_test.py b/tensorflow/contrib/distribute/python/input_ops_test.py index 16179c3a4903c8149800d411853af734c1633466..c5acb7ced4bcb58cf327398f04fb37675a944e97 100644 --- a/tensorflow/contrib/distribute/python/input_ops_test.py +++ b/tensorflow/contrib/distribute/python/input_ops_test.py @@ -91,7 +91,7 @@ class AutoShardDatasetTest(test.TestCase): def _verifySimpleShardingOutput(self, dataset, record_fn): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): self.assertAllEqual(record_fn(r, f), sess.run(next_element)) @@ -150,7 +150,7 @@ class AutoShardDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: actual, expected = [], [] for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): @@ -182,7 +182,7 @@ class AutoShardDatasetTest(test.TestCase): # Verify output. iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: actual = [] num_iterations = (self._num_files * self._num_records * num_epochs) // ( self._num_shards * batch_size) @@ -218,7 +218,7 @@ class AutoShardDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for f in range(self._shard_index, self._num_files, self._num_shards): for r in range(self._num_records): self.assertAllEqual(self._record(r, f), sess.run(next_element)) diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index a262d7666e7be2c28857b7b38ad0ccbd1b053463..d39fd57294a67a4a98a528f2aa99f0436f245847 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -116,7 +116,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): model_dir=self._base_dir, train_distribute=dist, eval_distribute=dist) - with self.test_session(): + with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=config) before_eval_results = est_keras.evaluate( @@ -139,7 +139,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, train_distribute=dist) - with self.test_session(): + with self.cached_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=config) before_eval_results = est_keras.evaluate( @@ -163,7 +163,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir, train_distribute=dist) - with self.test_session(): + with self.cached_session(): est_keras = keras_lib.model_to_estimator(keras_model=keras_model, config=config) with self.assertRaisesRegexp(ValueError, @@ -178,7 +178,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase): class TestWithDistributionStrategy(test.TestCase): def test_validating_dataset_input_tensors_with_shape_mismatch(self): - with self.test_session(): + with self.cached_session(): strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', '/device:CPU:0']) a = constant_op.constant([1, 2], shape=(1, 2)) @@ -197,7 +197,7 @@ class TestWithDistributionStrategy(test.TestCase): strategy, x, y) def test_validating_dataset_input_tensors_with_dtype_mismatch(self): - with self.test_session(): + with self.cached_session(): strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0', '/device:CPU:0']) a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32) @@ -216,7 +216,7 @@ class TestWithDistributionStrategy(test.TestCase): strategy, x, y) def test_calling_model_on_same_dataset(self): - with self.test_session(): + with self.cached_session(): x = keras.layers.Input(shape=(3,), name='input') y = keras.layers.Dense(4, name='dense')(x) model = keras.Model(x, y) @@ -242,7 +242,7 @@ class TestWithDistributionStrategy(test.TestCase): model.predict(dataset, steps=2) def test_fit_with_tuple_and_dict_dataset_inputs(self): - with self.test_session(): + with self.cached_session(): a = keras.layers.Input(shape=(3,), name='input_a') b = keras.layers.Input(shape=(3,), name='input_b') @@ -283,7 +283,7 @@ class TestWithDistributionStrategy(test.TestCase): model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) def test_fit_eval_and_predict_methods_on_dataset(self): - with self.test_session(): + with self.cached_session(): x = keras.layers.Input(shape=(3,), name='input') y = keras.layers.Dense(4, name='dense')(x) model = keras.Model(x, y) @@ -320,7 +320,7 @@ class TestWithDistributionStrategy(test.TestCase): def __call__(self, y_true, y_pred): return y_pred - y_true - with self.test_session(): + with self.cached_session(): x = keras.layers.Input(shape=(3,), name='input') y = keras.layers.Dense(4, name='dense')(x) model = keras.Model(x, y) @@ -336,7 +336,7 @@ class TestWithDistributionStrategy(test.TestCase): model.compile(optimizer, loss, metrics=metrics, distribute=strategy) def test_unsupported_features(self): - with self.test_session(): + with self.cached_session(): x = keras.layers.Input(shape=(3,), name='input') y = keras.layers.Dense(4, name='dense')(x) model = keras.Model(x, y) @@ -389,7 +389,7 @@ class TestWithDistributionStrategy(test.TestCase): model.predict(dataset, verbose=0) def test_calling_with_unsupported_predefined_callbacks(self): - with self.test_session(): + with self.cached_session(): x = keras.layers.Input(shape=(3,), name='input') y = keras.layers.Dense(4, name='dense')(x) model = keras.Model(x, y) @@ -428,7 +428,7 @@ class TestWithDistributionStrategy(test.TestCase): callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)]) def test_dataset_input_shape_validation(self): - with self.test_session(): + with self.cached_session(): x = keras.layers.Input(shape=(3,), name='input') y = keras.layers.Dense(4, name='dense')(x) model = keras.Model(x, y) @@ -465,7 +465,7 @@ class TestWithDistributionStrategy(test.TestCase): # TODO(anjalisridhar): Modify this test to use Lambdas since we can compare # meaningful values. Currently we don't pass the learning phase if the # Lambda layer uses the learning phase. - with self.test_session(): + with self.cached_session(): x = keras.layers.Input(shape=(16,), name='input') y = keras.layers.Dense(16)(x) z = keras.layers.Dropout(0.9999)(y) @@ -498,7 +498,7 @@ class TestWithDistributionStrategy(test.TestCase): class LossMaskingWithDistributionStrategyTest(test.TestCase): def test_masking(self): - with self.test_session(): + with self.cached_session(): np.random.seed(1337) x = np.array([[[1], [1]], [[0], [0]]]) model = keras.models.Sequential() @@ -523,7 +523,7 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase): class NormalizationLayerWithDistributionStrategyTest(test.TestCase): def test_batchnorm_correctness(self): - with self.test_session(): + with self.cached_session(): model = keras.models.Sequential() norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) model.add(norm) @@ -550,7 +550,7 @@ class NormalizationLayerWithDistributionStrategyTest(test.TestCase): class CorrectnessWithDistributionStrategyTest(test.TestCase): def test_correctness(self): - with self.test_session(): + with self.cached_session(): keras.backend.set_image_data_format('channels_last') num_samples = 10000 x_train = np.random.rand(num_samples, 1) @@ -565,8 +565,7 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase): dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) dataset_with = dataset_with.batch(32) strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0', - '/device:GPU:0'], - prefetch_on_device=False) + '/device:GPU:0']) model.compile(loss=keras.losses.mean_squared_error, optimizer=gradient_descent.GradientDescentOptimizer(0.5), diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index 516ede7ade7d8c9d09198993f919f15377b1c565..bdac4fb58c2ca8c4f6a322a6f477a9e3657b8f93 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -71,7 +71,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): self.evaluate(distribution.initialize()) if not context.executing_eagerly(): - with self.test_session() as sess: + with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) @@ -108,7 +108,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): model_fn, iterator.get_next(), run_concurrently=layer.built)) if not context.executing_eagerly(): - with self.test_session() as sess: + with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) @@ -168,7 +168,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): self.evaluate(distribution.initialize()) if not context.executing_eagerly(): - with self.test_session() as sess: + with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) @@ -249,7 +249,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): self.evaluate(distribution.initialize()) if not context.executing_eagerly(): - with self.test_session() as sess: + with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) @@ -343,7 +343,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): self.evaluate(distribution.initialize()) if not context.executing_eagerly(): - with self.test_session() as sess: + with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) @@ -466,7 +466,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): self.evaluate(distribution.initialize()) if not context.executing_eagerly(): - with self.test_session() as sess: + with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 6981449a4cc9d15ebc3a0edd145fa5766e9b6503..e87b48ba4182476f182afc123f44c547fc7d3321 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -25,8 +25,8 @@ import threading from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib from tensorflow.contrib.distribute.python import shared_variable_creator from tensorflow.contrib.distribute.python import values -from tensorflow.core.protobuf import cluster_pb2 from tensorflow.python import pywrap_tensorflow +from tensorflow.python.distribute import multi_worker_util from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op @@ -39,7 +39,6 @@ from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import coordinator from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib -from tensorflow.python.training import server_lib from tensorflow.python.util import nest @@ -277,6 +276,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): else: result = values.MirroredVariable(index, index[devices[0]], aggregation) + # Add the wrapped variable to the requested collections. + # The handling of eager mode and the global step matches + # ResourceVariable._init_from_args(). if not context.executing_eagerly(): g = ops.get_default_graph() # If "trainable" is True, next_creator() will add the member variables @@ -290,6 +292,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): for v in index.values(): l.remove(v) g.add_to_collections(collections, result) + elif ops.GraphKeys.GLOBAL_STEP in collections: + ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result) + return result @@ -299,8 +304,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): This strategy uses one tower per device and sync replication for its multi-GPU version. - When `cluster_spec` is given, it turns into the mulit-worker version that - works on multiple workers with in-graph replication. + When `cluster_spec` is given by the `configure` method., it turns into the + mulit-worker version that works on multiple workers with in-graph replication. + Note: `configure` will be called by higher-level APIs if running in + distributed environment. There are several important concepts for distributed TensorFlow, e.g. `client`, `job`, 'task', `cluster`, `in-graph replication` and @@ -330,8 +337,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): num_gpus: number of GPUs. For local training, either specify `devices` or `num_gpus`. In distributed training, this must be specified as number of GPUs on each worker. - cluster_spec: if this is set, it turns into the multi-worker version and - `devices` must not be set but `num_gpus` must be set. cross_tower_ops: optional, a descedant of `CrossTowerOps`. If this is not set, the `configure` method will try to find the best one. prefetch_on_device: optional boolean to specify whether to prefetch input @@ -341,65 +346,76 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def __init__(self, devices=None, num_gpus=None, - cluster_spec=None, cross_tower_ops=None, prefetch_on_device=None): super(MirroredStrategy, self).__init__() - if cluster_spec: - if devices is not None: - raise ValueError("Specifying devices when `cluster_spec` is also given " - "is not supported in MirroredStrategy.") - - # TODO(yuefengz): use the utility method to normalize cluster_spec. - if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): - cluster_spec = server_lib.ClusterSpec(cluster_spec) - elif not isinstance(cluster_spec, server_lib.ClusterSpec): - raise ValueError( - "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " - "`tf.train.ClusterDef` object") - self._cluster_spec = cluster_spec - - self._workers = [] - for job in sorted(cluster_spec.jobs): - for task in range(cluster_spec.num_tasks(job)): - self._workers.append("/job:%s/task:%d" % (job, task)) + self._cross_tower_ops = cross_tower_ops + self._prefetch_on_device = prefetch_on_device + # Rememeber num GPUs which might be needed by `configure` method. + self._num_gpus = num_gpus + self._initialize_local(num_gpus, devices) + + def _initialize_local(self, num_gpus, devices): + """Initializes the object for local training.""" + self._cluster_spec = None + # Convert `num_gpus` into `devices`, shouldn't specify both. + if devices is None: if num_gpus is None: - raise ValueError("`num_gpus` is required if `cluster_spec` is given.") - self._num_gpus = num_gpus - if num_gpus > 0: - self._worker_device_map = { - worker: [ - device_util.canonicalize(worker + "/device:GPU:%d" % gpu) - for gpu in range(num_gpus) - ] for worker in self._workers - } + num_gpus = context.num_gpus() + if num_gpus == 0: + devices = ["/device:CPU:0"] else: - self._worker_device_map = { - worker: [device_util.canonicalize(worker, "/device:CPU:0")] - for worker in self._workers - } - devices = nest.flatten(self._worker_device_map) - - # Setting `_default_device` will add a device scope in the - # distribution.scope. We set the default device to the first worker. When - # users specify device under distribution.scope by - # with tf.device("/cpu:0"): - # ... - # their ops will end up on the cpu device of its first worker, e.g. - # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode. - self._default_device = self._workers[0] - else: - self._cluster_spec = None - # Convert `num_gpus` into `devices`, shouldn't specify both. - if devices is None: - if num_gpus is None: - num_gpus = context.num_gpus() devices = ["/device:GPU:%d" % d for d in range(num_gpus)] - elif num_gpus is not None: - raise ValueError("Must only specify one of `devices` and `num_gpus`.") - # TODO(yuefengz): consider setting the default device. + elif num_gpus is not None: + raise ValueError("Must only specify one of `devices` and `num_gpus`.") + self._num_gpus = num_gpus + # TODO(yuefengz): consider setting the default device. + + assert devices, "Must specify at least one device." + assert len(set(devices)) == len(devices), ( + "No duplicates allowed in `devices` argument.") + # TODO(josh11b): Require at least 2 devices? + self._devices = [device_util.resolve(d) for d in devices] + self._canonical_device_set = set(self._devices) + self._device_index = values.PerDevice({d: i for i, d in enumerate(devices)}) + + def _initialize_multi_worker(self, num_gpus, cluster_spec): + """Initializes the object for multi-worker training.""" + cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) + self._cluster_spec = cluster_spec + + self._workers = [] + for job in ["chief", "worker"]: + for task in range(len(cluster_spec.as_dict().get(job, []))): + self._workers.append("/job:%s/task:%d" % (job, task)) + + if num_gpus is None: + raise ValueError("`num_gpus` is required if `cluster_spec` is given.") + if num_gpus > 0: + self._worker_device_map = { + worker: [ + device_util.canonicalize(worker + "/device:GPU:%d" % gpu) + for gpu in range(num_gpus) + ] for worker in self._workers + } + else: + self._worker_device_map = { + worker: [device_util.canonicalize(worker, "/device:CPU:0")] + for worker in self._workers + } + + devices = nest.flatten(self._worker_device_map) + + # Setting `_default_device` will add a device scope in the + # distribution.scope. We set the default device to the first worker. When + # users specify device under distribution.scope by + # with tf.device("/cpu:0"): + # ... + # their ops will end up on the cpu device of its first worker, e.g. + # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode. + self._default_device = self._workers[0] assert devices, "Must specify at least one device." assert len(set(devices)) == len(devices), ( @@ -409,8 +425,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self._canonical_device_set = set(self._devices) self._device_index = values.PerDevice( {d: i for i, d in enumerate(devices)}) - self._cross_tower_ops = cross_tower_ops - self._prefetch_on_device = prefetch_on_device def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" @@ -544,7 +558,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): cluster_spec=None, task_type=None, task_id=None): - del cluster_spec, task_type, task_id + del task_type, task_id + if cluster_spec: + self._initialize_multi_worker(self._num_gpus, cluster_spec) + if self._cross_tower_ops is None: if self._cluster_spec: self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce( @@ -636,6 +653,22 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def parameter_devices(self): return list(self._devices) + @property + def between_graph(self): + return False + + @property + def should_init(self): + return True + + @property + def should_checkpoint(self): + return True + + @property + def should_save_summary(self): + return True + def non_slot_devices(self, var_list): del var_list return list(self._devices) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 9a4cc0a8975c39cf82e474d660968afc17991db0..a12ff662db2c9314b7fa86ba017661a556388926 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import sys from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 @@ -41,6 +42,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import device_util from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.training import server_lib GPU_TEST = "test_gpu" in sys.argv[0] @@ -886,8 +888,18 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertIsInstance(mirrored_var, values.MirroredVariable) self.evaluate(variables.global_variables_initializer()) self.assertEquals(1.0, self.evaluate(mirrored_var)) - mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0)) + + # read_value == True + mirrored_var_result = self.evaluate( + mirrored_var.assign_add(6.0, read_value=True)) self.assertEquals(7.0, mirrored_var_result) + self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + + # read_value == False + self.evaluate(mirrored_var.assign_add(2.0, read_value=False)) + self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) + self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) @test_util.run_in_graph_and_eager_modes(config=config) def testAssignAddMirroredVarTowerContext(self): @@ -954,6 +966,8 @@ class MirroredVariableUpdateTest(test.TestCase): self.assertEquals(5.0, self.evaluate(mirrored_var)) mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0)) self.assertEquals(3.0, mirrored_var_result) + self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:GPU:0"))) + self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:CPU:0"))) @test_util.run_in_graph_and_eager_modes(config=config) def testAssignSubMirroredVarTowerContext(self): @@ -1244,5 +1258,39 @@ class MirroredStrategyDefunTest(test.TestCase): self._call_and_check(fn1, [factors], expected_result, [fn1]) +class MultiWorkerMirroredStrategyTest( + multi_worker_test_base.MultiWorkerTestBase, + strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + cluster_spec = server_lib.ClusterSpec({ + "worker": ["/job:worker/task:0", "/job:worker/task:1"] + }) + strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus()) + strategy.configure(cluster_spec=cluster_spec) + return strategy + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy(), + learning_rate=0.05) + + +class MultiWorkerMirroredStrategyTestWithChief( + multi_worker_test_base.MultiWorkerTestBase, + strategy_test_lib.DistributionTestBase): + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers and 1 chief.""" + cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( + num_workers=2, num_ps=0, has_chief=True) + cls._default_target = "grpc://" + cls._cluster_spec["chief"][0] + + def testMinimizeLossGraph(self): + strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus()) + strategy.configure(cluster_spec=self._cluster_spec) + self._test_minimize_loss_graph(strategy, learning_rate=0.05) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py index 55d59adc078ad546e4fe0a3acb88741e8666b562..969e1269560e52736d05e6b14ce320d9bd4fcac0 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import strategy_test_lib from tensorflow.python.eager import context from tensorflow.python.eager import test @@ -28,7 +27,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import variable_scope from tensorflow.python.training import distribution_strategy_context -from tensorflow.python.training import server_lib class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): @@ -64,6 +62,7 @@ class VariableCreatorStackTest(test.TestCase): def model_fn(device_id): assert isinstance(device_id, int) + def thread_creator_fn(next_creator, *args, **kwargs): return next_creator(*args, **kwargs) + ":thread_" + str(device_id) @@ -90,32 +89,20 @@ class VariableCreatorStackTest(test.TestCase): self.assertEquals(expected, result) -class MultiWorkerMirroredStrategyTest( - multi_worker_test_base.MultiWorkerTestBase, - strategy_test_lib.DistributionTestBase): - - def _get_distribution_strategy(self): - return mirrored_strategy.MirroredStrategy( - cluster_spec=server_lib.ClusterSpec({ - 'worker': ['/job:worker/task:0', '/job:worker/task:1'] - }), - num_gpus=context.num_gpus()) - - def testMinimizeLossGraph(self): - self._test_minimize_loss_graph(self._get_distribution_strategy()) +class MultiWorkerMirroredStrategyTest(test.TestCase): def testDeviceScope(self): """Test the device scope of multi-worker MirroredStrategy.""" with context.graph_mode(): - strategy = mirrored_strategy.MirroredStrategy( - cluster_spec={'worker': ['/job:worker/task:0', '/job:worker/task:1']}, - num_gpus=context.num_gpus()) + strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus()) + strategy.configure( + cluster_spec={"worker": ["/job:worker/task:0", "/job:worker/task:1"]}) with strategy.scope(): a = constant_op.constant(1.) - with ops.device('/cpu:0'): + with ops.device("/cpu:0"): b = constant_op.constant(1.) - self.assertEqual(a.device, '/job:worker/task:0') - self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0') + self.assertEqual(a.device, "/job:worker/task:0") + self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0") if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py index 2892ce439494320a115b8eae0025a132841c4a8f..16be839e1d155003b9490fbe3da6ab85b7d2d78a 100644 --- a/tensorflow/contrib/distribute/python/monitor_test.py +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -45,7 +45,7 @@ class MonitorTest(test.TestCase, parameterized.TestCase): if context.executing_eagerly(): monitor = monitor_lib.Monitor(single_loss_step, None) else: - with self.test_session() as sess: + with self.cached_session() as sess: monitor = monitor_lib.Monitor(single_loss_step, sess) monitor.run_steps(1) diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py index 249de01f0880b02d603687db99692088480f7136..18b4503eff4c7e83e8b98a6d71893dee15c19898 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -23,26 +23,105 @@ import copy import threading import numpy as np +_portpicker_import_error = None +try: + import portpicker # pylint: disable=g-import-not-at-top +except ImportError as _error: # pylint: disable=invalid-name + _portpicker_import_error = _error + portpicker = None + +# pylint: disable=g-import-not-at-top from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.estimator import run_config from tensorflow.python.platform import test -from tensorflow.python.framework import test_util - - -def create_in_process_cluster(num_workers, num_ps): +from tensorflow.python.training import server_lib + + +def _create_cluster(num_workers, + num_ps, + has_chief=False, + has_eval=False, + protocol='grpc', + worker_config=None, + ps_config=None): + """Creates and starts local servers and returns the cluster_spec dict.""" + if _portpicker_import_error: + raise _portpicker_import_error # pylint: disable=raising-bad-type + worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] + ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)] + + cluster_dict = {} + if num_workers > 0: + cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports] + if num_ps > 0: + cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports] + if has_eval: + cluster_dict['evaluator'] = ['localhost:%s' % portpicker.pick_unused_port()] + if has_chief: + cluster_dict['chief'] = ['localhost:%s' % portpicker.pick_unused_port()] + + cs = server_lib.ClusterSpec(cluster_dict) + + for i in range(num_workers): + server_lib.Server( + cs, + job_name='worker', + protocol=protocol, + task_index=i, + config=worker_config, + start=True) + + for i in range(num_ps): + server_lib.Server( + cs, + job_name='ps', + protocol=protocol, + task_index=i, + config=ps_config, + start=True) + + if has_chief: + server_lib.Server( + cs, + job_name='chief', + protocol=protocol, + task_index=0, + config=worker_config, + start=True) + + if has_eval: + server_lib.Server( + cs, + job_name='evaluator', + protocol=protocol, + task_index=0, + config=worker_config, + start=True) + + return cluster_dict + + +def create_in_process_cluster(num_workers, + num_ps, + has_chief=False, + has_eval=False): """Create an in-process cluster that consists of only standard server.""" # Leave some memory for cuda runtime. - gpu_mem_frac = 0.7 / num_workers + gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval)) worker_config = config_pb2.ConfigProto() worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac # Enable collective ops which has no impact on non-collective ops. # TODO(yuefengz, tucker): removing this after we move the initialization of # collective mgr to the session level. - worker_config.experimental.collective_group_leader = ( - '/job:worker/replica:0/task:0') + if has_chief: + worker_config.experimental.collective_group_leader = ( + '/job:chief/replica:0/task:0') + else: + worker_config.experimental.collective_group_leader = ( + '/job:worker/replica:0/task:0') ps_config = config_pb2.ConfigProto() ps_config.device_count['GPU'] = 0 @@ -56,9 +135,10 @@ def create_in_process_cluster(num_workers, num_ps): # 2) there is something global in CUDA such that if we initialize CUDA in the # parent process, the child process cannot initialize it again and thus cannot # use GPUs (https://stackoverflow.com/questions/22950047). - return test_util.create_local_cluster( + return _create_cluster( num_workers, num_ps=num_ps, + has_chief=has_chief, worker_config=worker_config, ps_config=ps_config, protocol='grpc') @@ -70,7 +150,8 @@ class MultiWorkerTestBase(test.TestCase): @classmethod def setUpClass(cls): """Create a local cluster with 2 workers.""" - cls._workers, cls._ps = create_in_process_cluster(num_workers=2, num_ps=0) + cls._cluster_spec = create_in_process_cluster(num_workers=2, num_ps=0) + cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0] def setUp(self): # We only cache the session in one test because another test may have a @@ -111,17 +192,17 @@ class MultiWorkerTestBase(test.TestCase): config.graph_options.rewrite_options.constant_folding = ( rewriter_config_pb2.RewriterConfig.OFF) + if target is None: + target = self._default_target if graph is None: if getattr(self._thread_local, 'cached_session', None) is None: self._thread_local.cached_session = session.Session( - graph=None, config=config, target=target or self._workers[0].target) + graph=None, config=config, target=target) sess = self._thread_local.cached_session with sess.graph.as_default(), sess.as_default(): yield sess else: - with session.Session( - graph=graph, config=config, target=target or - self._workers[0].target) as sess: + with session.Session(graph=graph, config=config, target=target) as sess: yield sess def _run_client(self, client_fn, task_type, task_id, num_gpus, *args, diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index a2d736e42271ab1627240949b99088ed3f0746f6..6e9ba37a198fc8038c086d2672251adfac30fdcf 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -51,7 +51,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): model_fn, iterator.get_next(), run_concurrently=layer.built))) if not context.executing_eagerly(): - with self.test_session() as sess: + with self.cached_session() as sess: run_step = sess.make_callable(run_step()) self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 96b6519bc4d0a280746632fef57c54a9b1e82fe8..361c8be5903d63fe7e126e441d0e56b552f41bce 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -22,6 +22,7 @@ from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import values from tensorflow.python.distribute import multi_worker_util +from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -94,11 +95,18 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): cluster configurations. task_type: the current task type. task_id: the current task id. + + Raises: + ValueError: if `cluster_spec` is given but `task_type` or `task_id` is + not. """ super(ParameterServerStrategy, self).__init__() self._num_gpus_per_worker = num_gpus_per_worker if cluster_spec: cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) + if task_type is None or task_id is None: + raise ValueError("When `cluster_spec` is given, must also specify " + "`task_type` and `task_id`.") self._cluster_spec = cluster_spec # We typically don't need to do all-reduce in this strategy. @@ -233,8 +241,35 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): " for variable: " + kwargs["name"]) def var_creator(*args, **kwargs): + # Record what collections this variable should be added to. + collections = kwargs.pop("collections", None) + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + kwargs["collections"] = [] + + # Create and wrap the variable. v = next_creator(*args, **kwargs) - return values.AggregatingVariable(v, aggregation) + wrapped = values.AggregatingVariable(v, aggregation) + + # Add the wrapped variable to the requested collections. + # The handling of eager mode and the global step matches + # ResourceVariable._init_from_args(). + if not context.executing_eagerly(): + g = ops.get_default_graph() + # If "trainable" is True, next_creator() will add the contained + # variable to the TRAINABLE_VARIABLES collection, so we manually + # remove it and replace with the wrapper. We can't set "trainable" + # to False for next_creator() since that causes functions like + # implicit_gradients to skip those variables. + if kwargs.get("trainable", True): + collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) + l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) + l.remove(v) + g.add_to_collections(collections, wrapped) + elif ops.GraphKeys.GLOBAL_STEP in collections: + ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped) + + return wrapped else: var_creator = next_creator @@ -345,6 +380,10 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): cluster configurations. task_type: the current task type. task_id: the current task id. + + Raises: + ValueError: if `cluster_spec` is given but `task_type` or `task_id` is + not. """ del session_config @@ -353,6 +392,9 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): if not self._cluster_spec and cluster_spec: self._cluster_spec = multi_worker_util.normalize_cluster_spec( cluster_spec) + if task_type is None or task_id is None: + raise ValueError("When `cluster_spec` is given, must also specify " + "`task_type` and `task_id`.") self._initialize_devices(self._num_gpus_per_worker, self._cluster_spec, task_type, task_id) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index adfe3e8b020521d9c2c409da7c6d79e0ba060330..0e2bfcec5f6bcf0eeaa163ebd276666763bc68a6 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -24,6 +24,8 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import parameter_server_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.python.distribute import multi_worker_util from tensorflow.python.eager import context from tensorflow.python.estimator import run_config from tensorflow.python.framework import constant_op @@ -37,21 +39,15 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import device_util from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.training import training_util +CHIEF = run_config.TaskType.CHIEF +WORKER = run_config.TaskType.WORKER +PS = run_config.TaskType.PS -class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, - parameterized.TestCase): - @classmethod - def setUpClass(cls): - cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster( - num_workers=3, num_ps=2) - cls._cluster_spec = { - run_config.TaskType.WORKER: [ - 'fake_worker_0', 'fake_worker_1', 'fake_worker_2' - ], - run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] - } +class ParameterServerStrategyTestBase( + multi_worker_test_base.MultiWorkerTestBase): def setUp(self): self._result = 0 @@ -60,7 +56,7 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, self._init_reached = 0 self._finish_condition = threading.Condition() self._finish_reached = 0 - super(ParameterServerStrategyTest, self).setUp() + super(ParameterServerStrategyTestBase, self).setUp() def _get_test_objects(self, task_type, task_id, num_gpus): distribution = parameter_server_strategy.ParameterServerStrategy( @@ -70,13 +66,13 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, distribution.configure( cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id) - return distribution, self._workers[task_id].target + return distribution, 'grpc://' + self._cluster_spec[WORKER][task_id] def _test_device_assignment_distributed(self, task_type, task_id, num_gpus): worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id) d, _ = self._get_test_objects(task_type, task_id, num_gpus) with ops.Graph().as_default(), \ - self.test_session(target=self._workers[0].target) as sess, \ + self.test_session(target=self._default_target) as sess, \ d.scope(): # Define a variable outside the call_for_each_tower scope. This is not @@ -172,18 +168,13 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, self.assertEqual(z_val, 43.0) self.assertEqual(f_val, 46.0) - @combinations.generate( - combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) - def testDeviceAssignmentDistributed(self, num_gpus): - self._test_device_assignment_distributed('worker', 1, num_gpus) - def _test_device_assignment_local(self, d, compute_device='CPU', variable_device='CPU', num_gpus=0): with ops.Graph().as_default(), \ - self.test_session(target=self._workers[0].target) as sess, \ + self.test_session(target=self._default_target) as sess, \ d.scope(): def model_fn(): @@ -276,29 +267,12 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, self.assertEqual(z_val, 43.0) self.assertEqual(f_val, 46.0) - def testDeviceAssignmentLocalCPU(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=0) - self._test_device_assignment_local( - distribution, compute_device='CPU', variable_device='CPU', num_gpus=0) - - def testDeviceAssignmentLocalOneGPU(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=1) - self._test_device_assignment_local( - distribution, compute_device='GPU', variable_device='GPU', num_gpus=1) - - def testDeviceAssignmentLocalTwoGPUs(self): - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=2) - self._test_device_assignment_local( - distribution, compute_device='GPU', variable_device='CPU', num_gpus=2) - def _test_simple_increment(self, task_type, task_id, num_gpus): d, master_target = self._get_test_objects(task_type, task_id, num_gpus) if hasattr(d, '_cluster_spec') and d._cluster_spec: - num_workers = len(d._cluster_spec.as_dict().get('worker', - ['dummy_worker'])) + num_workers = len(d._cluster_spec.as_dict().get(WORKER)) + if 'chief' in d._cluster_spec.as_dict(): + num_workers += 1 else: num_workers = 1 with ops.Graph().as_default(), \ @@ -357,6 +331,11 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): d, master_target = self._get_test_objects(task_type, task_id, num_gpus) + assert hasattr(d, '_cluster_spec') and d._cluster_spec + num_workers = len(d._cluster_spec.as_dict().get(WORKER)) + if CHIEF in d._cluster_spec.as_dict(): + num_workers += 1 + with ops.Graph().as_default(), \ self.test_session(target=master_target) as sess, \ d.scope(): @@ -405,13 +384,13 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, if context.num_gpus() < d._num_gpus_per_worker: return True - if task_id == 0: + if multi_worker_util.is_chief(d._cluster_spec, task_type, task_id): variables.global_variables_initializer().run() # Workers waiting for chief worker's initializing variables. self._init_condition.acquire() self._init_reached += 1 - while self._init_reached != 3: + while self._init_reached != num_workers: self._init_condition.wait() self._init_condition.notify_all() self._init_condition.release() @@ -428,9 +407,42 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, self.assertLess(error_after, error_before) return error_after < error_before + +class ParameterServerStrategyTest(ParameterServerStrategyTestBase, + parameterized.TestCase): + + @classmethod + def setUpClass(cls): + cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=2) + cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0] + + def testDeviceAssignmentLocalCPU(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=0) + self._test_device_assignment_local( + distribution, compute_device='CPU', variable_device='CPU', num_gpus=0) + + def testDeviceAssignmentLocalOneGPU(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=1) + self._test_device_assignment_local( + distribution, compute_device='GPU', variable_device='GPU', num_gpus=1) + + def testDeviceAssignmentLocalTwoGPUs(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + self._test_device_assignment_local( + distribution, compute_device='GPU', variable_device='CPU', num_gpus=2) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testDeviceAssignmentDistributed(self, num_gpus): + self._test_device_assignment_distributed('worker', 1, num_gpus) + def testSimpleBetweenGraph(self): self._run_between_graph_clients(self._test_simple_increment, - self._cluster_spec, 0) + self._cluster_spec, context.num_gpus()) @combinations.generate( combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) @@ -444,5 +456,38 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, self._cluster_spec, num_gpus) +class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, + parameterized.TestCase): + + @classmethod + def setUpClass(cls): + cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( + num_workers=3, num_ps=2, has_chief=True) + cls._default_target = 'grpc://' + cls._cluster_spec[CHIEF][0] + + def testSimpleBetweenGraph(self): + self._run_between_graph_clients(self._test_simple_increment, + self._cluster_spec, context.num_gpus()) + + @combinations.generate( + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) + + def testGlobalStepIsWrapped(self): + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=2) + with ops.Graph().as_default(), distribution.scope(): + created_step = training_util.create_global_step() + get_step = training_util.get_global_step() + self.assertEqual(created_step, get_step, + msg=('created_step %s type %s vs. get_step %s type %s' % + (id(created_step), created_step.__class__.__name__, + id(get_step), get_step.__class__.__name__))) + self.assertIs(values.AggregatingVariable, type(created_step)) + self.assertIs(values.AggregatingVariable, type(get_step)) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py index a68dbce6c7d03f6a1695ebfcd00178e21ac1cda0..bb10b546a1907bba26cd0d7e7c5308420adbaf3f 100644 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py @@ -37,7 +37,7 @@ class PrefetchingOpsV2Test(test.TestCase): iterator = device_dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: for i in range(10): self.assertEqual(i, sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): @@ -55,7 +55,7 @@ class PrefetchingOpsV2Test(test.TestCase): next_element = iterator.get_next() output = [] - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(5): result = sess.run(next_element) self.assertEqual(2, len(result)) @@ -75,7 +75,7 @@ class PrefetchingOpsV2Test(test.TestCase): iterator = device_dataset.make_initializable_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(iterator.initializer) for _ in range(5): sess.run(next_element) diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py index 8605ab1f7daeb81e778577ad3c4a18b39c57d743..f1ada49fa378358f112fb75a4bcdbe9a8a09cd13 100644 --- a/tensorflow/contrib/distribute/python/step_fn_test.py +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -49,7 +49,7 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): if context.executing_eagerly(): run_step = single_loss_step else: - with self.test_session() as sess: + with self.cached_session() as sess: run_step = sess.make_callable(single_loss_step()) self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py index 371b97ba96a826194a6469ba63e485fc67639585..6ee26e19acc71a64952da89080354c83986e44e5 100644 --- a/tensorflow/contrib/distribute/python/strategy_test_lib.py +++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py @@ -130,7 +130,8 @@ class DistributionTestBase(test.TestCase): # Error should go down self.assertLess(error_after, error_before) - def _test_minimize_loss_graph(self, d, soft_placement=False): + def _test_minimize_loss_graph(self, d, soft_placement=False, + learning_rate=0.2): config = config_pb2.ConfigProto() config.allow_soft_placement = soft_placement config.gpu_options.per_process_gpu_memory_fraction = 0.3 @@ -150,7 +151,7 @@ class DistributionTestBase(test.TestCase): grad_fn = backprop.implicit_grad(loss) def update(v, g): - return v.assign_sub(0.2 * g) + return v.assign_sub(learning_rate * g) one = d.broadcast(constant_op.constant([[1.]])) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index a4860030769fab92ec946c5a436240e7c88af1bf..6202a0750a9140e9ac449b081b28dc42049d79a3 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -59,7 +59,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver): class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Experimental TPU distribution strategy implementation.""" - def __init__(self, tpu_cluster_resolver, steps_per_run): + def __init__(self, tpu_cluster_resolver, steps_per_run, num_cores=None): """Initializes the TPUStrategy object. Args: @@ -70,6 +70,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): metrics, summaries etc. This parameter is only used when Distribution Strategy is used with estimator or keras. + num_cores: Number of cores to use on the TPU. If None specified, then + auto-detect the cores and topology of the TPU system. """ # TODO(isaprykin): Generalize the defaults. They are currently tailored for # the unit test. @@ -77,13 +79,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) + self._num_cores_override = num_cores - # TODO(priyag): This should not be hardcoded here. - self._host = '/device:CPU:0' # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run + # TODO(frankchn): This should not be hardcoded here for pod purposes. + self._host = self.tpu_host_cpu_device(0) + def distribute_dataset(self, dataset_fn): # TODO(priyag): Perhaps distribute across cores here. return self._call_dataset_fn(dataset_fn) @@ -106,6 +110,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Enqueue ops for one iteration.""" control_deps = [] sharded_inputs = [] + # TODO(sourabhbajaj): Add support for TPU pods with ops.device(self._host): for _ in range(self.num_towers): # Use control dependencies to ensure a deterministic ordering. @@ -258,4 +263,10 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): @property def num_towers(self): - return self._tpu_metadata.num_of_cores_per_host + return self._num_cores_override or self._tpu_metadata.num_cores + + def tpu_host_cpu_device(self, host_id): + if self._tpu_cluster_resolver.get_master() in ('', 'local'): + return '/replica:0/task:0/device:CPU:0' + return '/job:%s/task:%d/device:CPU:0' % ('tpu_worker', host_id) + diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index a58bb3a8492a372d29089db0943e2e993ba47ad3..3ccaa2690e84807cb66f10726e636b614a9d4a41 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -183,6 +183,14 @@ class Mirrored(DistributedDelegate): return self._index[device] return list(self._index.values())[0] + def _as_graph_element(self): + obj = self.get() + # pylint: disable=protected-access + conv_fn = getattr(obj, "_as_graph_element", None) + if conv_fn and callable(conv_fn): + return conv_fn() + return obj + def _assign_on_device(device, variable, tensor): with ops.device(device): @@ -296,6 +304,10 @@ class DistributedVariable(DistributedDelegate): self._primary_var.op.type) return self.get().op + @property + def _in_graph_mode(self): + return self._primary_var._in_graph_mode # pylint: disable=protected-access + def read_value(self): return distribution_strategy_context.get_distribution_strategy().read_var( self) @@ -354,8 +366,19 @@ class MirroredVariable(DistributedVariable, Mirrored, # We are calling assign on the mirrored variable in cross tower context, # use update to update the variable. - return distribution_strategy_context.get_distribution_strategy().update( - self, f, *args, **kwargs) + strategy = distribution_strategy_context.get_distribution_strategy() + updates = strategy.update(self, f, *args, **kwargs) + grouped = strategy.group(updates) + if isinstance(updates, DistributedValues) and updates.is_tensor_like: + # Make sure we run all updates. Without this, something like + # session.run(mirrored_var.assign*(...)) may only update one tower. + index = {} + for d in updates.devices: + with ops.device(d), ops.control_dependencies([grouped]): + index[d] = array_ops.identity(updates.get(d)) + return Mirrored(index) + else: + return grouped else: _assert_tower_context() # We are calling an assign function on the mirrored variable in tower @@ -1180,6 +1203,10 @@ class AggregatingVariable(checkpointable.CheckpointableBase): def __repr__(self): return repr(self._v) + def _should_act_as_resource_variable(self): + """Pass resource_variable_ops.is_resource_variable check.""" + pass + # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py index 042c8ebd51c47facfc5c942cae56bd56be9df7c5..372b7e37b74066e86b2c6ec9875249afe9a54e00 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py @@ -31,7 +31,7 @@ class AbsoluteValueTest(test.TestCase): """Tests correctness of the absolute value bijector.""" def testBijectorVersusNumpyRewriteOfBasicFunctionsEventNdims0(self): - with self.test_session() as sess: + with self.cached_session() as sess: bijector = AbsoluteValue(validate_args=True) self.assertEqual("absolute_value", bijector.name) x = array_ops.constant([[0., 1., -1], [0., -5., 3.]]) # Shape [2, 3] @@ -54,13 +54,13 @@ class AbsoluteValueTest(test.TestCase): y, event_ndims=0))) def testNegativeYRaisesForInverseIfValidateArgs(self): - with self.test_session() as sess: + with self.cached_session() as sess: bijector = AbsoluteValue(validate_args=True) with self.assertRaisesOpError("y was negative"): sess.run(bijector.inverse(-1.)) def testNegativeYRaisesForILDJIfValidateArgs(self): - with self.test_session() as sess: + with self.cached_session() as sess: bijector = AbsoluteValue(validate_args=True) with self.assertRaisesOpError("y was negative"): sess.run(bijector.inverse_log_det_jacobian(-1., event_ndims=0)) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py index 1e4ad724d00f751a55370ef9aa6dde0003a2098c..a7bd51430e384c199ca8abd06ef9887e998cc380 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py @@ -28,7 +28,7 @@ from tensorflow.python.platform import test class AffineLinearOperatorTest(test.TestCase): def testIdentity(self): - with self.test_session(): + with self.cached_session(): affine = AffineLinearOperator( validate_args=True) x = np.array([[1, 0, -1], [2, 3, 4]], dtype=np.float32) @@ -45,7 +45,7 @@ class AffineLinearOperatorTest(test.TestCase): affine.forward_log_det_jacobian(x, event_ndims=2).eval()) def testDiag(self): - with self.test_session(): + with self.cached_session(): shift = np.array([-1, 0, 1], dtype=np.float32) diag = np.array([[1, 2, 3], [2, 5, 6]], dtype=np.float32) @@ -67,7 +67,7 @@ class AffineLinearOperatorTest(test.TestCase): affine.forward_log_det_jacobian(x, event_ndims=1).eval()) def testTriL(self): - with self.test_session(): + with self.cached_session(): shift = np.array([-1, 0, 1], dtype=np.float32) tril = np.array([[[3, 0, 0], [2, -1, 0], diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py index d2533620bebeb0400b6d4a6346e8315c7e37c5c6..bc6752a69dfaabb6008f1de86ca3c5242251d242 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py @@ -31,14 +31,14 @@ class AffineScalarBijectorTest(test.TestCase): """Tests correctness of the Y = scale @ x + shift transformation.""" def testProperties(self): - with self.test_session(): + with self.cached_session(): mu = -1. # scale corresponds to 1. bijector = AffineScalar(shift=mu) self.assertEqual("affine_scalar", bijector.name) def testNoBatchScalar(self): - with self.test_session() as sess: + with self.cached_session() as sess: def static_run(fun, x, **kwargs): return fun(x, **kwargs).eval() @@ -60,7 +60,7 @@ class AffineScalarBijectorTest(test.TestCase): run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) def testOneBatchScalarViaIdentityIn64BitUserProvidesShiftOnly(self): - with self.test_session() as sess: + with self.cached_session() as sess: def static_run(fun, x, **kwargs): return fun(x, **kwargs).eval() @@ -83,7 +83,7 @@ class AffineScalarBijectorTest(test.TestCase): run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) def testOneBatchScalarViaIdentityIn64BitUserProvidesScaleOnly(self): - with self.test_session() as sess: + with self.cached_session() as sess: def static_run(fun, x, **kwargs): return fun(x, **kwargs).eval() @@ -106,7 +106,7 @@ class AffineScalarBijectorTest(test.TestCase): run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) def testTwoBatchScalarIdentityViaIdentity(self): - with self.test_session() as sess: + with self.cached_session() as sess: def static_run(fun, x, **kwargs): return fun(x, **kwargs).eval() @@ -129,7 +129,7 @@ class AffineScalarBijectorTest(test.TestCase): run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) def testTwoBatchScalarIdentityViaScale(self): - with self.test_session() as sess: + with self.cached_session() as sess: def static_run(fun, x, **kwargs): return fun(x, **kwargs).eval() @@ -152,7 +152,7 @@ class AffineScalarBijectorTest(test.TestCase): run(bijector.inverse_log_det_jacobian, x, event_ndims=0)) def testScalarCongruency(self): - with self.test_session(): + with self.cached_session(): bijector = AffineScalar(shift=3.6, scale=0.42) assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py index 9e14b9a53e6c63876478d876030c476c5d77dbbb..dc18eb3df69bf5ad9c493d1bdbe882a9e48daaad 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py @@ -32,14 +32,14 @@ class AffineBijectorTest(test.TestCase): """Tests correctness of the Y = scale @ x + shift transformation.""" def testProperties(self): - with self.test_session(): + with self.cached_session(): mu = -1. # scale corresponds to 1. bijector = Affine(shift=mu) self.assertEqual("affine", bijector.name) def testNoBatchMultivariateIdentity(self): - with self.test_session() as sess: + with self.cached_session() as sess: placeholder = array_ops.placeholder(dtypes.float32, name="x") def static_run(fun, x, **kwargs): @@ -71,7 +71,7 @@ class AffineBijectorTest(test.TestCase): 0., run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testNoBatchMultivariateDiag(self): - with self.test_session() as sess: + with self.cached_session() as sess: placeholder = array_ops.placeholder(dtypes.float32, name="x") def static_run(fun, x, **kwargs): @@ -114,7 +114,7 @@ class AffineBijectorTest(test.TestCase): run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testNoBatchMultivariateFullDynamic(self): - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(dtypes.float32, name="x") mu = array_ops.placeholder(dtypes.float32, name="mu") scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag") @@ -137,7 +137,7 @@ class AffineBijectorTest(test.TestCase): feed_dict)) def testBatchMultivariateIdentity(self): - with self.test_session() as sess: + with self.cached_session() as sess: placeholder = array_ops.placeholder(dtypes.float32, name="x") def static_run(fun, x, **kwargs): @@ -161,7 +161,7 @@ class AffineBijectorTest(test.TestCase): run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testBatchMultivariateDiag(self): - with self.test_session() as sess: + with self.cached_session() as sess: placeholder = array_ops.placeholder(dtypes.float32, name="x") def static_run(fun, x, **kwargs): @@ -185,7 +185,7 @@ class AffineBijectorTest(test.TestCase): run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testBatchMultivariateFullDynamic(self): - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(dtypes.float32, name="x") mu = array_ops.placeholder(dtypes.float32, name="mu") scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag") @@ -209,7 +209,7 @@ class AffineBijectorTest(test.TestCase): x, event_ndims=1), feed_dict)) def testIdentityWithDiagUpdate(self): - with self.test_session() as sess: + with self.cached_session() as sess: placeholder = array_ops.placeholder(dtypes.float32, name="x") def static_run(fun, x, **kwargs): @@ -235,7 +235,7 @@ class AffineBijectorTest(test.TestCase): run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testIdentityWithTriL(self): - with self.test_session() as sess: + with self.cached_session() as sess: placeholder = array_ops.placeholder(dtypes.float32, name="x") def static_run(fun, x, **kwargs): @@ -261,7 +261,7 @@ class AffineBijectorTest(test.TestCase): run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testDiagWithTriL(self): - with self.test_session() as sess: + with self.cached_session() as sess: placeholder = array_ops.placeholder(dtypes.float32, name="x") def static_run(fun, x, **kwargs): @@ -285,7 +285,7 @@ class AffineBijectorTest(test.TestCase): run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testIdentityAndDiagWithTriL(self): - with self.test_session() as sess: + with self.cached_session() as sess: placeholder = array_ops.placeholder(dtypes.float32, name="x") def static_run(fun, x, **kwargs): @@ -312,7 +312,7 @@ class AffineBijectorTest(test.TestCase): run(bijector.inverse_log_det_jacobian, x, event_ndims=1)) def testIdentityWithVDVTUpdate(self): - with self.test_session() as sess: + with self.cached_session() as sess: placeholder = array_ops.placeholder(dtypes.float32, name="x") def static_run(fun, x, **kwargs): @@ -349,7 +349,7 @@ class AffineBijectorTest(test.TestCase): run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testDiagWithVDVTUpdate(self): - with self.test_session() as sess: + with self.cached_session() as sess: placeholder = array_ops.placeholder(dtypes.float32, name="x") def static_run(fun, x, **kwargs): @@ -385,7 +385,7 @@ class AffineBijectorTest(test.TestCase): run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testTriLWithVDVTUpdate(self): - with self.test_session() as sess: + with self.cached_session() as sess: placeholder = array_ops.placeholder(dtypes.float32, name="x") def static_run(fun, x, **kwargs): @@ -422,7 +422,7 @@ class AffineBijectorTest(test.TestCase): run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testTriLWithVDVTUpdateNoDiagonal(self): - with self.test_session() as sess: + with self.cached_session() as sess: placeholder = array_ops.placeholder(dtypes.float32, name="x") def static_run(fun, x, **kwargs): @@ -459,7 +459,7 @@ class AffineBijectorTest(test.TestCase): run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1)) def testNoBatchMultivariateRaisesWhenSingular(self): - with self.test_session(): + with self.cached_session(): mu = [1., -1] bijector = Affine( shift=mu, @@ -531,7 +531,7 @@ class AffineBijectorTest(test.TestCase): itertools.combinations(s, r) for r in range(len(s) + 1)) for args in _powerset(scale_params.items()): - with self.test_session(): + with self.cached_session(): args = dict(args) scale_args = dict({"x": x}, **args) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py index c832fcaa686c92f83810e4f99ca3b23ae694b723..bf61e9f2fe36f0455aadee762a8eca4894bc1806 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py @@ -69,7 +69,7 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers, ] for input_shape, event_dims, training in params: x_ = np.arange(5 * 4 * 2).astype(np.float32).reshape(input_shape) - with self.test_session() as sess: + with self.cached_session() as sess: x = constant_op.constant(x_) # When training, memorize the exact mean of the last # minibatch that it normalized (instead of moving average assignment). @@ -145,7 +145,7 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers, def testMaximumLikelihoodTraining(self): # Test Maximum Likelihood training with default bijector. - with self.test_session() as sess: + with self.cached_session() as sess: base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.]) batch_norm = BatchNormalization(training=True) dist = transformed_distribution_lib.TransformedDistribution( @@ -176,7 +176,7 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers, self.assertAllClose([1., 1.], moving_var_, atol=5e-2) def testLogProb(self): - with self.test_session() as sess: + with self.cached_session() as sess: layer = normalization.BatchNormalization(epsilon=0.) batch_norm = BatchNormalization(batchnorm_layer=layer, training=False) base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.]) @@ -196,7 +196,7 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers, def testMutuallyConsistent(self): # BatchNorm bijector is only mutually consistent when training=False. dims = 4 - with self.test_session() as sess: + with self.cached_session() as sess: layer = normalization.BatchNormalization(epsilon=0.) batch_norm = BatchNormalization(batchnorm_layer=layer, training=False) dist = transformed_distribution_lib.TransformedDistribution( @@ -215,7 +215,7 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers, def testInvertMutuallyConsistent(self): # BatchNorm bijector is only mutually consistent when training=False. dims = 4 - with self.test_session() as sess: + with self.cached_session() as sess: layer = normalization.BatchNormalization(epsilon=0.) batch_norm = Invert( BatchNormalization(batchnorm_layer=layer, training=False)) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py index dc45114b1c23b5edb78d68ad4f38f5201d265170..ada99ec9c6eccac410903ac4f1c26a89a75c842c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py @@ -46,7 +46,7 @@ class ChainBijectorTest(test.TestCase): """Tests the correctness of the Y = Chain(bij1, bij2, bij3) transformation.""" def testBijector(self): - with self.test_session(): + with self.cached_session(): chain = Chain((Exp(), Softplus())) self.assertEqual("chain_of_exp_of_softplus", chain.name) x = np.asarray([[[1., 2.], @@ -61,7 +61,7 @@ class ChainBijectorTest(test.TestCase): chain.forward_log_det_jacobian(x, event_ndims=1).eval()) def testBijectorIdentity(self): - with self.test_session(): + with self.cached_session(): chain = Chain() self.assertEqual("identity", chain.name) x = np.asarray([[[1., 2.], @@ -74,13 +74,13 @@ class ChainBijectorTest(test.TestCase): 0., chain.forward_log_det_jacobian(x, event_ndims=1).eval()) def testScalarCongruency(self): - with self.test_session(): + with self.cached_session(): chain = Chain((Exp(), Softplus())) assert_scalar_congruency( chain, lower_x=1e-3, upper_x=1.5, rtol=0.05) def testShapeGetters(self): - with self.test_session(): + with self.cached_session(): chain = Chain([ SoftmaxCentered(validate_args=True), SoftmaxCentered(validate_args=True), @@ -195,7 +195,7 @@ class ChainBijectorTest(test.TestCase): dtype=np.float32, shape=[None, 10], name="samples") ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0) self.assertTrue(ildj is not None) - with self.test_session(): + with self.cached_session(): ildj.eval({samples: np.zeros([2, 10], np.float32)}) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py index d1ce273499c8a646c0757844c91a785fa8d56ce4..9681b64cedfaedfb79ce0aedfa42e36993d557ba 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py @@ -30,7 +30,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase): """Tests the correctness of the Y = X @ X.T transformation.""" def testBijectorMatrix(self): - with self.test_session(): + with self.cached_session(): bijector = bijectors.CholeskyOuterProduct(validate_args=True) self.assertEqual("cholesky_outer_product", bijector.name) x = [[[1., 0], [2, 1]], [[np.sqrt(2.), 0], [np.sqrt(8.), 1]]] @@ -75,7 +75,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase): bijector = bijectors.CholeskyOuterProduct() x_pl = array_ops.placeholder(dtypes.float32) - with self.test_session(): + with self.cached_session(): log_det_jacobian = bijector.forward_log_det_jacobian(x_pl, event_ndims=2) # The Jacobian matrix is 2 * tf.eye(2), which has jacobian determinant 4. @@ -86,7 +86,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase): def testNoBatchStatic(self): x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y) y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T) - with self.test_session() as sess: + with self.cached_session() as sess: y_actual = bijectors.CholeskyOuterProduct().forward(x=x) x_actual = bijectors.CholeskyOuterProduct().inverse(y=y) [y_actual_, x_actual_] = sess.run([y_actual, x_actual]) @@ -98,7 +98,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase): def testNoBatchDeferred(self): x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y) y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T) - with self.test_session() as sess: + with self.cached_session() as sess: x_pl = array_ops.placeholder(dtypes.float32) y_pl = array_ops.placeholder(dtypes.float32) y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl) @@ -119,7 +119,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase): [2, 5]], [[9., 3], [3, 5]]]) # np.matmul(x, x.T) - with self.test_session() as sess: + with self.cached_session() as sess: y_actual = bijectors.CholeskyOuterProduct().forward(x=x) x_actual = bijectors.CholeskyOuterProduct().inverse(y=y) [y_actual_, x_actual_] = sess.run([y_actual, x_actual]) @@ -137,7 +137,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase): [2, 5]], [[9., 3], [3, 5]]]) # np.matmul(x, x.T) - with self.test_session() as sess: + with self.cached_session() as sess: x_pl = array_ops.placeholder(dtypes.float32) y_pl = array_ops.placeholder(dtypes.float32) y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py index 7be939cd274e6f0e33c9b01c82494755db2caa73..d2c00865e7ad609ab7b6b37e981fff4dbc151c74 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py @@ -30,7 +30,7 @@ class ExpBijectorTest(test.TestCase): """Tests correctness of the Y = g(X) = exp(X) transformation.""" def testBijector(self): - with self.test_session(): + with self.cached_session(): bijector = Exp() self.assertEqual("exp", bijector.name) x = [[[1.], [2.]]] @@ -48,13 +48,13 @@ class ExpBijectorTest(test.TestCase): x, event_ndims=1).eval()) def testScalarCongruency(self): - with self.test_session(): + with self.cached_session(): bijector = Exp() assert_scalar_congruency( bijector, lower_x=-2., upper_x=1.5, rtol=0.05) def testBijectiveAndFinite(self): - with self.test_session(): + with self.cached_session(): bijector = Exp() x = np.linspace(-10, 10, num=10).astype(np.float32) y = np.logspace(-10, 10, num=10).astype(np.float32) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py index 54e54c3296a89a4fe29a3cce971760502b65e784..b9cdbfb823d4d4a0dd6b4bb7cc2bd6a5dd6a908e 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py @@ -31,7 +31,7 @@ class GumbelBijectorTest(test.TestCase): """Tests correctness of the Gumbel bijector.""" def testBijector(self): - with self.test_session(): + with self.cached_session(): loc = 0.3 scale = 5. bijector = Gumbel(loc=loc, scale=scale, validate_args=True) @@ -52,12 +52,12 @@ class GumbelBijectorTest(test.TestCase): atol=0.) def testScalarCongruency(self): - with self.test_session(): + with self.cached_session(): assert_scalar_congruency( Gumbel(loc=0.3, scale=20.), lower_x=1., upper_x=100., rtol=0.02) def testBijectiveAndFinite(self): - with self.test_session(): + with self.cached_session(): bijector = Gumbel(loc=0., scale=3.0, validate_args=True) x = np.linspace(-10., 10., num=10).astype(np.float32) y = np.linspace(0.01, 0.99, num=10).astype(np.float32) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py index 7d3bd758cd2db307f95d2d934923ea2133dc1217..c9bccb36fcc8029ace564c6408adf6ee790e5c18 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py @@ -32,7 +32,7 @@ class InlineBijectorTest(test.TestCase): """Tests correctness of the inline constructed bijector.""" def testBijector(self): - with self.test_session(): + with self.cached_session(): exp = Exp() inline = Inline( forward_fn=math_ops.exp, @@ -55,7 +55,7 @@ class InlineBijectorTest(test.TestCase): inline.forward_log_det_jacobian(x, event_ndims=1).eval()) def testShapeGetters(self): - with self.test_session(): + with self.cached_session(): bijector = Inline( forward_event_shape_tensor_fn=lambda x: array_ops.concat((x, [1]), 0), forward_event_shape_fn=lambda x: x.as_list() + [1], diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py index 8b14c8327f08902044f50483f9f8dfe67b58cd70..7e3340aeb0e5bd1e07e2ed487446e06ae373c204 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py @@ -31,7 +31,7 @@ class InvertBijectorTest(test.TestCase): """Tests the correctness of the Y = Invert(bij) transformation.""" def testBijector(self): - with self.test_session(): + with self.cached_session(): for fwd in [ bijectors.Identity(), bijectors.Exp(), @@ -53,13 +53,13 @@ class InvertBijectorTest(test.TestCase): rev.forward_log_det_jacobian(x, event_ndims=1).eval()) def testScalarCongruency(self): - with self.test_session(): + with self.cached_session(): bijector = bijectors.Invert(bijectors.Exp()) assert_scalar_congruency( bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05) def testShapeGetters(self): - with self.test_session(): + with self.cached_session(): bijector = bijectors.Invert(bijectors.SoftmaxCentered(validate_args=True)) x = tensor_shape.TensorShape([2]) y = tensor_shape.TensorShape([1]) @@ -73,7 +73,7 @@ class InvertBijectorTest(test.TestCase): bijector.inverse_event_shape_tensor(y.as_list()).eval()) def testDocstringExample(self): - with self.test_session(): + with self.cached_session(): exp_gamma_distribution = ( transformed_distribution_lib.TransformedDistribution( distribution=gamma_lib.Gamma(concentration=1., rate=2.), diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py index a8089881f684db9f8876d6dd738e52bf2f1f7606..b3fb50005e581a33210041b5206cf1831de88ad3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py @@ -30,7 +30,7 @@ class KumaraswamyBijectorTest(test.TestCase): """Tests correctness of the Kumaraswamy bijector.""" def testBijector(self): - with self.test_session(): + with self.cached_session(): a = 2. b = 0.3 bijector = Kumaraswamy( @@ -54,13 +54,13 @@ class KumaraswamyBijectorTest(test.TestCase): atol=0.) def testScalarCongruency(self): - with self.test_session(): + with self.cached_session(): assert_scalar_congruency( Kumaraswamy(concentration1=0.5, concentration0=1.1), lower_x=0., upper_x=1., n=int(10e3), rtol=0.02) def testBijectiveAndFinite(self): - with self.test_session(): + with self.cached_session(): concentration1 = 1.2 concentration0 = 2. bijector = Kumaraswamy( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py index 5ba5a2083bf11791d7d58146dc2e6283b524d241..ad4329d42595b03747f2918317216692c1354a07 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py @@ -71,7 +71,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers, def testBijector(self): x_ = np.arange(3 * 4 * 2).astype(np.float32).reshape(3, 4, 2) - with self.test_session() as sess: + with self.cached_session() as sess: ma = MaskedAutoregressiveFlow( validate_args=True, **self._autoregressive_flow_kwargs) @@ -102,7 +102,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers, def testMutuallyConsistent(self): dims = 4 - with self.test_session() as sess: + with self.cached_session() as sess: ma = MaskedAutoregressiveFlow( validate_args=True, **self._autoregressive_flow_kwargs) @@ -121,7 +121,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers, def testInvertMutuallyConsistent(self): dims = 4 - with self.test_session() as sess: + with self.cached_session() as sess: ma = Invert(MaskedAutoregressiveFlow( validate_args=True, **self._autoregressive_flow_kwargs)) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py index 49a9afe3f6debe048369c52328fb5534946ab9e5..31ee36f024e607f0a6c37fc3a66570c0e209f328 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.platform import test +@test_util.run_all_in_graph_and_eager_modes class MatrixInverseTriLBijectorTest(test.TestCase): """Tests the correctness of the Y = inv(tril) transformation.""" @@ -40,7 +41,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase): y[idx][np.triu_indices(y[idx].shape[-1], 1)] = 0 return y - @test_util.run_in_graph_and_eager_modes def testComputesCorrectValues(self): inv = bijectors.MatrixInverseTriL(validate_args=True) self.assertEqual("matrix_inverse_tril", inv.name) @@ -62,7 +62,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertNear(expected_fldj_, fldj_, err=1e-3) self.assertNear(-expected_fldj_, ildj_, err=1e-3) - @test_util.run_in_graph_and_eager_modes def testOneByOneMatrix(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([[5.]], dtype=np.float32) @@ -81,7 +80,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertNear(expected_fldj_, fldj_, err=1e-3) self.assertNear(-expected_fldj_, ildj_, err=1e-3) - @test_util.run_in_graph_and_eager_modes def testZeroByZeroMatrix(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.eye(0, dtype=np.float32) @@ -100,7 +98,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertNear(expected_fldj_, fldj_, err=1e-3) self.assertNear(-expected_fldj_, ildj_, err=1e-3) - @test_util.run_in_graph_and_eager_modes def testBatch(self): # Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape # (2, 1). @@ -125,20 +122,18 @@ class MatrixInverseTriLBijectorTest(test.TestCase): self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3) self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3) - @test_util.run_in_graph_and_eager_modes def testErrorOnInputRankTooLow(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([0.1], dtype=np.float32) rank_error_msg = "must have rank at least 2" - with self.test_session(): - with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): - inv.forward(x_).eval() - with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): - inv.inverse(x_).eval() - with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): - inv.forward_log_det_jacobian(x_, event_ndims=2).eval() - with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): - inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + self.evaluate(inv.forward(x_)) + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + self.evaluate(inv.inverse(x_)) + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2)) + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2)) # TODO(b/80481923): Figure out why these assertions fail, and fix them. ## def testErrorOnInputNonSquare(self): @@ -146,55 +141,50 @@ class MatrixInverseTriLBijectorTest(test.TestCase): ## x_ = np.array([[1., 2., 3.], ## [4., 5., 6.]], dtype=np.float32) ## square_error_msg = "must be a square matrix" - ## with self.test_session(): - ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - ## square_error_msg): - ## inv.forward(x_).eval() - ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - ## square_error_msg): - ## inv.inverse(x_).eval() - ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - ## square_error_msg): - ## inv.forward_log_det_jacobian(x_, event_ndims=2).eval() - ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - ## square_error_msg): - ## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() - - @test_util.run_in_graph_and_eager_modes + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## self.evaluate(inv.forward(x_)) + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## self.evaluate(inv.inverse(x_)) + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2)) + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2)) + def testErrorOnInputNotLowerTriangular(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([[1., 2.], [3., 4.]], dtype=np.float32) triangular_error_msg = "must be lower triangular" - with self.test_session(): - with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - triangular_error_msg): - inv.forward(x_).eval() - with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - triangular_error_msg): - inv.inverse(x_).eval() - with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - triangular_error_msg): - inv.forward_log_det_jacobian(x_, event_ndims=2).eval() - with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - triangular_error_msg): - inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() - - @test_util.run_in_graph_and_eager_modes + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + self.evaluate(inv.forward(x_)) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + self.evaluate(inv.inverse(x_)) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2)) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2)) + def testErrorOnInputSingular(self): inv = bijectors.MatrixInverseTriL(validate_args=True) x_ = np.array([[1., 0.], [0., 0.]], dtype=np.float32) nonsingular_error_msg = "must have all diagonal entries nonzero" - with self.test_session(): - with self.assertRaisesOpError(nonsingular_error_msg): - inv.forward(x_).eval() - with self.assertRaisesOpError(nonsingular_error_msg): - inv.inverse(x_).eval() - with self.assertRaisesOpError(nonsingular_error_msg): - inv.forward_log_det_jacobian(x_, event_ndims=2).eval() - with self.assertRaisesOpError(nonsingular_error_msg): - inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + self.evaluate(inv.forward(x_)) + with self.assertRaisesOpError(nonsingular_error_msg): + self.evaluate(inv.inverse(x_)) + with self.assertRaisesOpError(nonsingular_error_msg): + self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2)) + with self.assertRaisesOpError(nonsingular_error_msg): + self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2)) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py index cb42331a21a6acdd5244c311a7def5359bb6c574..9a88f8f1bc99f80a17f64b40749ef0e5b781a242 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py @@ -38,26 +38,25 @@ class OrderedBijectorTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testBijectorVector(self): - with self.test_session(): - ordered = Ordered() - self.assertEqual("ordered", ordered.name) - x = np.asarray([[2., 3, 4], [4., 8, 13]]) - y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]] - self.assertAllClose(y, self.evaluate(ordered.forward(x))) - self.assertAllClose(x, self.evaluate(ordered.inverse(y))) - self.assertAllClose( - np.sum(np.asarray(y)[..., 1:], axis=-1), - self.evaluate(ordered.inverse_log_det_jacobian(y, event_ndims=1)), - atol=0., - rtol=1e-7) - self.assertAllClose( - self.evaluate(-ordered.inverse_log_det_jacobian(y, event_ndims=1)), - self.evaluate(ordered.forward_log_det_jacobian(x, event_ndims=1)), - atol=0., - rtol=1e-7) + ordered = Ordered() + self.assertEqual("ordered", ordered.name) + x = np.asarray([[2., 3, 4], [4., 8, 13]]) + y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]] + self.assertAllClose(y, self.evaluate(ordered.forward(x))) + self.assertAllClose(x, self.evaluate(ordered.inverse(y))) + self.assertAllClose( + np.sum(np.asarray(y)[..., 1:], axis=-1), + self.evaluate(ordered.inverse_log_det_jacobian(y, event_ndims=1)), + atol=0., + rtol=1e-7) + self.assertAllClose( + self.evaluate(-ordered.inverse_log_det_jacobian(y, event_ndims=1)), + self.evaluate(ordered.forward_log_det_jacobian(x, event_ndims=1)), + atol=0., + rtol=1e-7) def testBijectorUnknownShape(self): - with self.test_session(): + with self.cached_session(): ordered = Ordered() self.assertEqual("ordered", ordered.name) x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32) @@ -84,21 +83,20 @@ class OrderedBijectorTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testShapeGetters(self): - with self.test_session(): - x = tensor_shape.TensorShape([4]) - y = tensor_shape.TensorShape([4]) - bijector = Ordered(validate_args=True) - self.assertAllEqual(y, bijector.forward_event_shape(x)) - self.assertAllEqual(y.as_list(), - self.evaluate(bijector.forward_event_shape_tensor( - x.as_list()))) - self.assertAllEqual(x, bijector.inverse_event_shape(y)) - self.assertAllEqual(x.as_list(), - self.evaluate(bijector.inverse_event_shape_tensor( - y.as_list()))) + x = tensor_shape.TensorShape([4]) + y = tensor_shape.TensorShape([4]) + bijector = Ordered(validate_args=True) + self.assertAllEqual(y, bijector.forward_event_shape(x)) + self.assertAllEqual(y.as_list(), + self.evaluate(bijector.forward_event_shape_tensor( + x.as_list()))) + self.assertAllEqual(x, bijector.inverse_event_shape(y)) + self.assertAllEqual(x.as_list(), + self.evaluate(bijector.inverse_event_shape_tensor( + y.as_list()))) def testBijectiveAndFinite(self): - with self.test_session(): + with self.cached_session(): ordered = Ordered() x = np.sort(self._rng.randn(3, 10), axis=-1).astype(np.float32) y = (self._rng.randn(3, 10)).astype(np.float32) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py index 7eef4ab599951bbb624652f13a0091363b36b93d..e2062ed55d5e6367a7e1b1cfdbdd5541b6b1fd53 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py @@ -38,7 +38,7 @@ class PermuteBijectorTest(test.TestCase): expected_x = np.random.randn(4, 2, 3) expected_y = expected_x[..., expected_permutation] - with self.test_session() as sess: + with self.cached_session() as sess: permutation_ph = array_ops.placeholder(dtype=dtypes.int32) bijector = Permute( permutation=permutation_ph, @@ -64,7 +64,7 @@ class PermuteBijectorTest(test.TestCase): self.assertAllClose(0., ildj, rtol=1e-6, atol=0) def testRaisesOpError(self): - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesOpError("Permutation over `d` must contain"): permutation_ph = array_ops.placeholder(dtype=dtypes.int32) bijector = Permute( @@ -77,7 +77,7 @@ class PermuteBijectorTest(test.TestCase): permutation = np.int32([2, 0, 1]) x = np.random.randn(4, 2, 3) y = x[..., permutation] - with self.test_session(): + with self.cached_session(): bijector = Permute(permutation=permutation, validate_args=True) assert_bijective_and_finite( bijector, x, y, event_ndims=1, rtol=1e-6, atol=0) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py index 85d22830132816cd6c77cd0b07870f3a22ae9798..ef303ab664c1438b60c07ae2f3af83f42332b2bb 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py @@ -30,7 +30,7 @@ class PowerTransformBijectorTest(test.TestCase): """Tests correctness of the power transformation.""" def testBijector(self): - with self.test_session(): + with self.cached_session(): c = 0.2 bijector = PowerTransform(power=c, validate_args=True) self.assertEqual("power_transform", bijector.name) @@ -48,13 +48,13 @@ class PowerTransformBijectorTest(test.TestCase): atol=0.) def testScalarCongruency(self): - with self.test_session(): + with self.cached_session(): bijector = PowerTransform(power=0.2, validate_args=True) assert_scalar_congruency( bijector, lower_x=-2., upper_x=1.5, rtol=0.05) def testBijectiveAndFinite(self): - with self.test_session(): + with self.cached_session(): bijector = PowerTransform(power=0.2, validate_args=True) x = np.linspace(-4.999, 10, num=10).astype(np.float32) y = np.logspace(0.001, 10, num=10).astype(np.float32) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py index 2d52895fbe0967cdd2260d6d298a291286858d09..b3b7b8535e1387490c1f330444b8decbc4e28292 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py @@ -43,7 +43,7 @@ class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase): def testBijector(self): x_ = np.arange(3 * 4 * 2).astype(np.float32).reshape(3, 4 * 2) - with self.test_session() as sess: + with self.cached_session() as sess: nvp = RealNVP( num_masked=4, validate_args=True, @@ -78,7 +78,7 @@ class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase): def testMutuallyConsistent(self): dims = 4 - with self.test_session() as sess: + with self.cached_session() as sess: nvp = RealNVP( num_masked=3, validate_args=True, @@ -98,7 +98,7 @@ class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase): def testInvertMutuallyConsistent(self): dims = 4 - with self.test_session() as sess: + with self.cached_session() as sess: nvp = Invert(RealNVP( num_masked=3, validate_args=True, diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index d44e49b4874a5b91f7633cd9c97dbb1a7da70f27..79eadf524b5111331ecf44b56c42dc157239a461 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -50,7 +50,7 @@ class _ReshapeBijectorTest(object): expected_x = np.random.randn(4, 3, 2) expected_y = np.reshape(expected_x, [4, 6]) - with self.test_session() as sess: + with self.cached_session() as sess: shape_in, shape_out, feed_dict = self.build_shapes([3, 2], [6,]) bijector = Reshape( event_shape_out=shape_out, @@ -84,7 +84,7 @@ class _ReshapeBijectorTest(object): # using the _tensor methods, we should always get a fully-specified # result since these are evaluated at graph runtime. - with self.test_session() as sess: + with self.cached_session() as sess: (shape_out_, shape_in_) = sess.run(( bijector.forward_event_shape_tensor(shape_in), @@ -103,7 +103,7 @@ class _ReshapeBijectorTest(object): expected_y_scalar = expected_x_scalar[0] shape_in, shape_out, feed_dict = self.build_shapes([], [1,]) - with self.test_session() as sess: + with self.cached_session() as sess: bijector = Reshape( event_shape_out=shape_in, event_shape_in=shape_out, validate_args=True) @@ -124,7 +124,7 @@ class _ReshapeBijectorTest(object): def testMultipleUnspecifiedDimensionsOpError(self): - with self.test_session() as sess: + with self.cached_session() as sess: shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [4, -1, -1,]) bijector = Reshape( event_shape_out=shape_out, @@ -139,7 +139,7 @@ class _ReshapeBijectorTest(object): # pylint: disable=invalid-name def _testInvalidDimensionsOpError(self, expected_error_message): - with self.test_session() as sess: + with self.cached_session() as sess: shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 2, -2,]) bijector = Reshape( @@ -155,7 +155,7 @@ class _ReshapeBijectorTest(object): def testValidButNonMatchingInputOpError(self): x = np.random.randn(4, 3, 2) - with self.test_session() as sess: + with self.cached_session() as sess: shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 6, 1,]) bijector = Reshape( event_shape_out=shape_out, @@ -173,7 +173,7 @@ class _ReshapeBijectorTest(object): def testValidButNonMatchingInputPartiallySpecifiedOpError(self): x = np.random.randn(4, 3, 2) - with self.test_session() as sess: + with self.cached_session() as sess: shape_in, shape_out, feed_dict = self.build_shapes([2, -1], [1, 6, 1,]) bijector = Reshape( event_shape_out=shape_out, @@ -190,7 +190,7 @@ class _ReshapeBijectorTest(object): x1 = np.random.randn(4, 2, 3) x2 = np.random.randn(4, 1, 1, 5) - with self.test_session() as sess: + with self.cached_session() as sess: shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3], [1, 1, 5]) bijector = Reshape( @@ -208,7 +208,7 @@ class _ReshapeBijectorTest(object): expected_x = np.random.randn(4, 6) expected_y = np.reshape(expected_x, [4, 2, 3]) - with self.test_session() as sess: + with self.cached_session() as sess: # one of input/output shapes is partially specified shape_in, shape_out, feed_dict = self.build_shapes([-1,], [2, 3]) bijector = Reshape( @@ -227,7 +227,7 @@ class _ReshapeBijectorTest(object): def testBothShapesPartiallySpecified(self): expected_x = np.random.randn(4, 2, 3) expected_y = np.reshape(expected_x, [4, 3, 2]) - with self.test_session() as sess: + with self.cached_session() as sess: shape_in, shape_out, feed_dict = self.build_shapes([-1, 3], [-1, 2]) bijector = Reshape( event_shape_out=shape_out, @@ -245,7 +245,7 @@ class _ReshapeBijectorTest(object): def testDefaultVectorShape(self): expected_x = np.random.randn(4, 4) expected_y = np.reshape(expected_x, [4, 2, 2]) - with self.test_session() as sess: + with self.cached_session() as sess: _, shape_out, feed_dict = self.build_shapes([-1,], [-1, 2]) bijector = Reshape(shape_out, validate_args=True) @@ -292,7 +292,7 @@ class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): def testBijectiveAndFinite(self): x = np.random.randn(4, 2, 3) y = np.reshape(x, [4, 1, 2, 3]) - with self.test_session(): + with self.cached_session(): bijector = Reshape( event_shape_in=[2, 3], event_shape_out=[1, 2, 3], diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py index cea4a62c22af5d98d38ee881b29c773e6a27a4b4..a6d432753db1574c1781a236567f346b00d3c1b5 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py @@ -31,7 +31,7 @@ class SigmoidBijectorTest(test.TestCase): """Tests correctness of the Y = g(X) = (1 + exp(-X))^-1 transformation.""" def testBijector(self): - with self.test_session(): + with self.cached_session(): self.assertEqual("sigmoid", Sigmoid().name) x = np.linspace(-10., 10., 100).reshape([2, 5, 10]).astype(np.float32) y = special.expit(x) @@ -45,11 +45,11 @@ class SigmoidBijectorTest(test.TestCase): x, event_ndims=0).eval(), atol=0., rtol=1e-4) def testScalarCongruency(self): - with self.test_session(): + with self.cached_session(): assert_scalar_congruency(Sigmoid(), lower_x=-7., upper_x=7.) def testBijectiveAndFinite(self): - with self.test_session(): + with self.cached_session(): x = np.linspace(-7., 7., 100).astype(np.float32) eps = 1e-3 y = np.linspace(eps, 1. - eps, 100).astype(np.float32) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py index 795f1993ba5c31bf5a26333f31f1bc73125bff07..282619a73b24629b878b1a8b41a35af2ef572cee 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py @@ -33,7 +33,7 @@ class SinhArcsinhBijectorTest(test.TestCase): """Tests correctness of the power transformation.""" def testBijectorVersusNumpyRewriteOfBasicFunctions(self): - with self.test_session(): + with self.cached_session(): skewness = 0.2 tailweight = 2.0 bijector = SinhArcsinh( @@ -58,7 +58,7 @@ class SinhArcsinhBijectorTest(test.TestCase): atol=0.) def testLargerTailWeightPutsMoreWeightInTails(self): - with self.test_session(): + with self.cached_session(): # Will broadcast together to shape [3, 2]. x = [-1., 1.] tailweight = [[0.5], [1.0], [2.0]] @@ -75,7 +75,7 @@ class SinhArcsinhBijectorTest(test.TestCase): self.assertLess(forward_1[1], forward_1[2]) def testSkew(self): - with self.test_session(): + with self.cached_session(): # Will broadcast together to shape [3, 2]. x = [-1., 1.] skewness = [[-1.], [0.], [1.]] @@ -92,24 +92,24 @@ class SinhArcsinhBijectorTest(test.TestCase): self.assertLess(np.abs(y[2, 0]), np.abs(y[2, 1])) def testScalarCongruencySkewness1Tailweight0p5(self): - with self.test_session(): + with self.cached_session(): bijector = SinhArcsinh(skewness=1.0, tailweight=0.5, validate_args=True) assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.0, rtol=0.05) def testScalarCongruencySkewnessNeg1Tailweight1p5(self): - with self.test_session(): + with self.cached_session(): bijector = SinhArcsinh(skewness=-1.0, tailweight=1.5, validate_args=True) assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.0, rtol=0.05) def testBijectiveAndFiniteSkewnessNeg1Tailweight0p5(self): - with self.test_session(): + with self.cached_session(): bijector = SinhArcsinh(skewness=-1., tailweight=0.5, validate_args=True) x = np.concatenate((-np.logspace(-2, 10, 1000), [0], np.logspace( -2, 10, 1000))).astype(np.float32) assert_bijective_and_finite(bijector, x, x, event_ndims=0, rtol=1e-3) def testBijectiveAndFiniteSkewness1Tailweight3(self): - with self.test_session(): + with self.cached_session(): bijector = SinhArcsinh(skewness=1., tailweight=3., validate_args=True) x = np.concatenate((-np.logspace(-2, 5, 1000), [0], np.logspace( -2, 5, 1000))).astype(np.float32) @@ -117,7 +117,7 @@ class SinhArcsinhBijectorTest(test.TestCase): bijector, x, x, event_ndims=0, rtol=1e-3) def testBijectorEndpoints(self): - with self.test_session(): + with self.cached_session(): for dtype in (np.float32, np.float64): bijector = SinhArcsinh( skewness=dtype(0.), tailweight=dtype(1.), validate_args=True) @@ -129,7 +129,7 @@ class SinhArcsinhBijectorTest(test.TestCase): bijector, bounds, bounds, event_ndims=0, atol=2e-6) def testBijectorOverRange(self): - with self.test_session(): + with self.cached_session(): for dtype in (np.float32, np.float64): skewness = np.array([1.2, 5.], dtype=dtype) tailweight = np.array([2., 10.], dtype=dtype) @@ -176,12 +176,12 @@ class SinhArcsinhBijectorTest(test.TestCase): atol=0.) def testZeroTailweightRaises(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError("not positive"): SinhArcsinh(tailweight=0., validate_args=True).forward(1.0).eval() def testDefaultDtypeIsFloat32(self): - with self.test_session(): + with self.cached_session(): bijector = SinhArcsinh() self.assertEqual(bijector.tailweight.dtype, np.float32) self.assertEqual(bijector.skewness.dtype, np.float32) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py index 0f0a2fa531a0585a709df4c2c3e2631e5c275986..8d18400487d5f65a595d6d325816231c831fad78 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py @@ -35,7 +35,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase): """Tests correctness of the Y = g(X) = exp(X) / sum(exp(X)) transformation.""" def testBijectorVector(self): - with self.test_session(): + with self.cached_session(): softmax = SoftmaxCentered() self.assertEqual("softmax_centered", softmax.name) x = np.log([[2., 3, 4], [4., 8, 12]]) @@ -54,7 +54,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase): rtol=1e-7) def testBijectorUnknownShape(self): - with self.test_session(): + with self.cached_session(): softmax = SoftmaxCentered() self.assertEqual("softmax_centered", softmax.name) x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32) @@ -80,7 +80,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase): rtol=1e-7) def testShapeGetters(self): - with self.test_session(): + with self.cached_session(): x = tensor_shape.TensorShape([4]) y = tensor_shape.TensorShape([5]) bijector = SoftmaxCentered(validate_args=True) @@ -94,7 +94,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase): y.as_list()).eval()) def testBijectiveAndFinite(self): - with self.test_session(): + with self.cached_session(): softmax = SoftmaxCentered() x = np.linspace(-50, 50, num=10).reshape(5, 2).astype(np.float32) # Make y values on the simplex with a wide range. diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py index 3d8a0a32bba3539f732140e8eb7ebeb532d73ff5..e805619041d5c96ce9c4340d79834b5cc69de0c3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py @@ -42,13 +42,13 @@ class SoftplusBijectorTest(test.TestCase): return -np.log(1 - np.exp(-y)) def testHingeSoftnessZeroRaises(self): - with self.test_session(): + with self.cached_session(): bijector = Softplus(hinge_softness=0., validate_args=True) with self.assertRaisesOpError("must be non-zero"): bijector.forward([1., 1.]).eval() def testBijectorForwardInverseEventDimsZero(self): - with self.test_session(): + with self.cached_session(): bijector = Softplus() self.assertEqual("softplus", bijector.name) x = 2 * rng.randn(2, 10) @@ -58,7 +58,7 @@ class SoftplusBijectorTest(test.TestCase): self.assertAllClose(x, bijector.inverse(y).eval()) def testBijectorForwardInverseWithHingeSoftnessEventDimsZero(self): - with self.test_session(): + with self.cached_session(): bijector = Softplus(hinge_softness=1.5) x = 2 * rng.randn(2, 10) y = 1.5 * self._softplus(x / 1.5) @@ -67,7 +67,7 @@ class SoftplusBijectorTest(test.TestCase): self.assertAllClose(x, bijector.inverse(y).eval()) def testBijectorLogDetJacobianEventDimsZero(self): - with self.test_session(): + with self.cached_session(): bijector = Softplus() y = 2 * rng.rand(2, 10) # No reduction needed if event_dims = 0. @@ -77,7 +77,7 @@ class SoftplusBijectorTest(test.TestCase): y, event_ndims=0).eval()) def testBijectorForwardInverseEventDimsOne(self): - with self.test_session(): + with self.cached_session(): bijector = Softplus() self.assertEqual("softplus", bijector.name) x = 2 * rng.randn(2, 10) @@ -87,7 +87,7 @@ class SoftplusBijectorTest(test.TestCase): self.assertAllClose(x, bijector.inverse(y).eval()) def testBijectorLogDetJacobianEventDimsOne(self): - with self.test_session(): + with self.cached_session(): bijector = Softplus() y = 2 * rng.rand(2, 10) ildj_before = self._softplus_ildj_before_reduction(y) @@ -97,25 +97,25 @@ class SoftplusBijectorTest(test.TestCase): y, event_ndims=1).eval()) def testScalarCongruency(self): - with self.test_session(): + with self.cached_session(): bijector = Softplus() assert_scalar_congruency( bijector, lower_x=-2., upper_x=2.) def testScalarCongruencyWithPositiveHingeSoftness(self): - with self.test_session(): + with self.cached_session(): bijector = Softplus(hinge_softness=1.3) assert_scalar_congruency( bijector, lower_x=-2., upper_x=2.) def testScalarCongruencyWithNegativeHingeSoftness(self): - with self.test_session(): + with self.cached_session(): bijector = Softplus(hinge_softness=-1.3) assert_scalar_congruency( bijector, lower_x=-2., upper_x=2.) def testBijectiveAndFinite32bit(self): - with self.test_session(): + with self.cached_session(): bijector = Softplus() x = np.linspace(-20., 20., 100).astype(np.float32) y = np.logspace(-10, 10, 100).astype(np.float32) @@ -123,7 +123,7 @@ class SoftplusBijectorTest(test.TestCase): bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2) def testBijectiveAndFiniteWithPositiveHingeSoftness32Bit(self): - with self.test_session(): + with self.cached_session(): bijector = Softplus(hinge_softness=1.23) x = np.linspace(-20., 20., 100).astype(np.float32) y = np.logspace(-10, 10, 100).astype(np.float32) @@ -131,7 +131,7 @@ class SoftplusBijectorTest(test.TestCase): bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2) def testBijectiveAndFiniteWithNegativeHingeSoftness32Bit(self): - with self.test_session(): + with self.cached_session(): bijector = Softplus(hinge_softness=-0.7) x = np.linspace(-20., 20., 100).astype(np.float32) y = -np.logspace(-10, 10, 100).astype(np.float32) @@ -139,7 +139,7 @@ class SoftplusBijectorTest(test.TestCase): bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2) def testBijectiveAndFinite16bit(self): - with self.test_session(): + with self.cached_session(): bijector = Softplus() # softplus(-20) is zero, so we can't use such a large range as in 32bit. x = np.linspace(-10., 20., 100).astype(np.float16) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py index d0098c3c105626da1da5855710169069ebeffbd9..8dad80aa647f0c7d53685aed4025dd49ffa0f6d0 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py @@ -43,16 +43,15 @@ class SoftsignBijectorTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testBijectorBounds(self): bijector = Softsign(validate_args=True) - with self.test_session(): - with self.assertRaisesOpError("greater than -1"): - bijector.inverse(-3.).eval() - with self.assertRaisesOpError("greater than -1"): - bijector.inverse_log_det_jacobian(-3., event_ndims=0).eval() - - with self.assertRaisesOpError("less than 1"): - bijector.inverse(3.).eval() - with self.assertRaisesOpError("less than 1"): - bijector.inverse_log_det_jacobian(3., event_ndims=0).eval() + with self.assertRaisesOpError("greater than -1"): + self.evaluate(bijector.inverse(-3.)) + with self.assertRaisesOpError("greater than -1"): + self.evaluate(bijector.inverse_log_det_jacobian(-3., event_ndims=0)) + + with self.assertRaisesOpError("less than 1"): + self.evaluate(bijector.inverse(3.)) + with self.assertRaisesOpError("less than 1"): + self.evaluate(bijector.inverse_log_det_jacobian(3., event_ndims=0)) @test_util.run_in_graph_and_eager_modes def testBijectorForwardInverse(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py index 30c7a738c320b609ce90685512e6b8344dffc9dc..e5550cc83033b3bfbd336bcd3bd42306131ac909 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py @@ -29,7 +29,7 @@ class SquareBijectorTest(test.TestCase): """Tests the correctness of the Y = X ** 2 transformation.""" def testBijectorScalar(self): - with self.test_session(): + with self.cached_session(): bijector = bijectors.Square(validate_args=True) self.assertEqual("square", bijector.name) x = [[[1., 5], @@ -50,7 +50,7 @@ class SquareBijectorTest(test.TestCase): rtol=1e-7) def testScalarCongruency(self): - with self.test_session(): + with self.cached_session(): bijector = bijectors.Square(validate_args=True) assert_scalar_congruency(bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py index f57adcda898a1fdb18aacbb0804411db1bb4e4c8..424eb58fa06ef43644ac224106cc43062287ba48 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py @@ -31,7 +31,7 @@ class WeibullBijectorTest(test.TestCase): """Tests correctness of the weibull bijector.""" def testBijector(self): - with self.test_session(): + with self.cached_session(): scale = 5. concentration = 0.3 bijector = Weibull( @@ -54,13 +54,13 @@ class WeibullBijectorTest(test.TestCase): atol=0.) def testScalarCongruency(self): - with self.test_session(): + with self.cached_session(): assert_scalar_congruency( Weibull(scale=20., concentration=0.3), lower_x=1., upper_x=100., rtol=0.02) def testBijectiveAndFinite(self): - with self.test_session(): + with self.cached_session(): bijector = Weibull( scale=20., concentration=2., validate_args=True) x = np.linspace(1., 8., num=10).astype(np.float32) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index f7b2efa7bc0f76be9b9df3e74c769ed5532554dd..05f5d306664ededdfbf867a93e15aadaa3d1a80c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -542,9 +542,9 @@ class PadDynamicTest(_PadTest, test.TestCase): return False +@test_util.run_all_in_graph_and_eager_modes class TestMoveDimension(test.TestCase): - @test_util.run_in_graph_and_eager_modes def test_move_dimension_static_shape(self): x = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) @@ -561,7 +561,6 @@ class TestMoveDimension(test.TestCase): x_perm = distribution_util.move_dimension(x, 4, 2) self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1]) - @test_util.run_in_graph_and_eager_modes def test_move_dimension_dynamic_shape(self): x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index fa3f1bb7ad187993379afeedf3790c789b4538aa..84517b57c7d0af56ba7724d18e78f38041ebe773 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -14,6 +14,7 @@ py_library( ":datasets", ":metrics", ":network", + ":remote", ":saver", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", @@ -223,11 +224,24 @@ py_test( ], ) +py_library( + name = "remote", + srcs = ["remote.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:platform", + "//tensorflow/python/eager:context", + ], +) + py_test( name = "remote_test", srcs = ["remote_test.py"], srcs_version = "PY2AND3", deps = [ + ":remote", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/python:array_ops", "//tensorflow/python:client", diff --git a/tensorflow/contrib/eager/python/examples/notebooks/README.md b/tensorflow/contrib/eager/python/examples/notebooks/README.md index 0d5ed848946d1eee643a57bf8c341520268c56b1..2778b228e93b582b6235a6498cd7ca1e52d05279 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/README.md +++ b/tensorflow/contrib/eager/python/examples/notebooks/README.md @@ -1,11 +1,3 @@ -## Research and experimentation - -Eager execution provides an imperative, define-by-run interface for advanced -operations. Write custom layers, forward passes, and training loops with auto -differentiation. Start with these notebooks, then read the -[eager execution guide](https://www.tensorflow.org/guide/eager). - -1. [Eager execution basics](./eager_basics.ipynb) -2. [Automatic differentiation and gradient tapes](./automatic_differentiation.ipynb) -3. [Custom training: basics](./custom_training.ipynb) -4. [Custom layers](./custom_layers.ipynb) +The notebooks have been moved to the +[tensorflow/docs](https://github.com/tensorflow/docs/tree/master/site/en/tutorials/eager) +repository. diff --git a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb index 51b7ffc4de0cee31f7a907ae7bf90f17056f9bcf..8fae622e12864ddeee0cedd3cf99be8ea5e4bc48 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb @@ -15,12 +15,7 @@ "execution_count": 0, "metadata": { "cellView": "form", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "GCCk8_dHpuNf" }, @@ -53,308 +48,35 @@ "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "idv0bPeCp325" - }, - "source": [ - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"\u003e\n", - " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", - "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "vDJ4XzMqodTy" - }, - "source": [ - "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "GQJysDM__Qb0" - }, - "source": [ - "## Setup\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "OiMPZStlibBv" - }, - "outputs": [], - "source": [ - "import tensorflow as tf\n", - "tf.enable_eager_execution()\n", - "\n", - "tfe = tf.contrib.eager # Shorthand for some symbols" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "1CLWJl0QliB0" - }, - "source": [ - "## Derivatives of a function\n", - "\n", - "TensorFlow provides APIs for automatic differentiation - computing the derivative of a function. The way that more closely mimics the math is to encapsulate the computation in a Python function, say `f`, and use `tfe.gradients_function` to create a function that computes the derivatives of `f` with respect to its arguments. If you're familiar with [autograd](https://github.com/HIPS/autograd) for differentiating numpy functions, this will be familiar. For example: " - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "9FViq92UX7P8" - }, - "outputs": [], - "source": [ - "from math import pi\n", - "\n", - "def f(x):\n", - " return tf.square(tf.sin(x))\n", - "\n", - "assert f(pi/2).numpy() == 1.0\n", - "\n", - "\n", - "# grad_f will return a list of derivatives of f\n", - "# with respect to its arguments. Since f() has a single argument,\n", - "# grad_f will return a list with a single element.\n", - "grad_f = tfe.gradients_function(f)\n", - "assert tf.abs(grad_f(pi/2)[0]).numpy() \u003c 1e-7" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "v9fPs8RyopCf" - }, - "source": [ - "### Higher-order gradients\n", - "\n", - "The same API can be used to differentiate as many times as you like:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "3D0ZvnGYo0rW" - }, - "outputs": [], - "source": [ - "def f(x):\n", - " return tf.square(tf.sin(x))\n", - "\n", - "def grad(f):\n", - " return lambda x: tfe.gradients_function(f)(x)[0]\n", - "\n", - "x = tf.lin_space(-2*pi, 2*pi, 100) # 100 points between -2π and +2π\n", - "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "plt.plot(x, f(x), label=\"f\")\n", - "plt.plot(x, grad(f)(x), label=\"first derivative\")\n", - "plt.plot(x, grad(grad(f))(x), label=\"second derivative\")\n", - "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n", - "plt.legend()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-39gouo7mtgu" - }, - "source": [ - "## Gradient tapes\n", - "\n", - "Every differentiable TensorFlow operation has an associated gradient function. For example, the gradient function of `tf.square(x)` would be a function that returns `2.0 * x`. To compute the gradient of a user-defined function (like `f(x)` in the example above), TensorFlow first \"records\" all the operations applied to compute the output of the function. We call this record a \"tape\". It then uses that tape and the gradients functions associated with each primitive operation to compute the gradients of the user-defined function using [reverse mode differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).\n", - "\n", - "Since operations are recorded as they are executed, Python control flow (using `if`s and `while`s for example) is naturally handled:\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "MH0UfjympWf7" - }, - "outputs": [], - "source": [ - "def f(x, y):\n", - " output = 1\n", - " # Must use range(int(y)) instead of range(y) in Python 3 when\n", - " # using TensorFlow 1.10 and earlier. Can use range(y) in 1.11+\n", - " for i in range(int(y)):\n", - " output = tf.multiply(output, x)\n", - " return output\n", - "\n", - "def g(x, y):\n", - " # Return the gradient of `f` with respect to it's first parameter\n", - " return tfe.gradients_function(f)(x, y)[0]\n", - "\n", - "assert f(3.0, 2).numpy() == 9.0 # f(x, 2) is essentially x * x\n", - "assert g(3.0, 2).numpy() == 6.0 # And its gradient will be 2 * x\n", - "assert f(4.0, 3).numpy() == 64.0 # f(x, 3) is essentially x * x * x\n", - "assert g(4.0, 3).numpy() == 48.0 # And its gradient will be 3 * x * x" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "aNmR5-jhpX2t" - }, - "source": [ - "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n", - "\n", - "For example:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "bAFeIE8EuVIq" + "id": "clNGnJ3u8Rl6" }, - "outputs": [], "source": [ - "x = tf.ones((2, 2))\n", - " \n", - "# TODO(b/78880779): Remove the 'persistent=True' argument and use\n", - "# a single t.gradient() call when the bug is resolved.\n", - "with tf.GradientTape(persistent=True) as t:\n", - " # TODO(ashankar): Explain with \"watch\" argument better?\n", - " t.watch(x)\n", - " y = tf.reduce_sum(x)\n", - " z = tf.multiply(y, y)\n", - "\n", - "# Use the same tape to compute the derivative of z with respect to the\n", - "# intermediate value y.\n", - "dz_dy = t.gradient(z, y)\n", - "assert dz_dy.numpy() == 8.0\n", - "\n", - "# Derivative of z with respect to the original input tensor x\n", - "dz_dx = t.gradient(z, x)\n", - "for i in [0, 1]:\n", - " for j in [0, 1]:\n", - " assert dz_dx[i][j].numpy() == 8.0" + "This file has moved." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "DK05KXrAAld3" - }, - "source": [ - "### Higher-order gradients\n", - "\n", - "Operations inside of the `GradientTape` context manager are recorded for automatic differentiation. If gradients are computed in that context, then the gradient computation is recorded as well. As a result, the exact same API works for higher-order gradients as well. For example:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "colab_type": "code", - "id": "cPQgthZ7ugRJ" - }, - "outputs": [], - "source": [ - "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n", - "\n", - "x = tf.constant(1.0) # Convert the Python 1.0 to a Tensor object\n", - "\n", - "with tf.GradientTape() as t:\n", - " with tf.GradientTape() as t2:\n", - " t2.watch(x)\n", - " y = x * x * x\n", - " # Compute the gradient inside the 't' context manager\n", - " # which means the gradient computation is differentiable as well.\n", - " dy_dx = t2.gradient(y, x)\n", - "d2y_dx2 = t.gradient(dy_dx, x)\n", - "\n", - "assert dy_dx.numpy() == 3.0\n", - "assert d2y_dx2.numpy() == 6.0" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "4U1KKzUpNl58" + "id": "idv0bPeCp325" }, "source": [ - "## Next Steps\n", - "\n", - "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)." + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] } ], "metadata": { "colab": { "collapsed_sections": [], - "default_view": {}, "name": "automatic_differentiation.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true, - "version": "0.3.2", - "views": {} + "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb index a0bbbb612381c5eb386b04fd7bb9914eb01f4c8e..d89774c45efe115b7774517570f02fef145dc7a4 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb @@ -1,46 +1,25 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "custom_layers.ipynb", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [], - "private_outputs": true, - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, "cells": [ { + "cell_type": "markdown", "metadata": { - "id": "tDnwEv8FtJm7", - "colab_type": "text" + "colab_type": "text", + "id": "tDnwEv8FtJm7" }, - "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "JlknJBWQtKkI", + "cellView": "form", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "form" + "id": "JlknJBWQtKkI" }, - "cell_type": "code", + "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -53,347 +32,57 @@ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "60RdWsg1tETW", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Custom layers" - ] - }, - { - "metadata": { - "id": "BcJg7Enms86w", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "
\n", - "\n", - " Run in Google Colab\n", - "\n", - "View source on GitHub
" - ] - }, - { - "metadata": { - "id": "UEu3q4jmpKVT", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "We recommend using `tf.keras` as a high-level API for building neural networks. That said, most TensorFlow APIs are usable with eager execution.\n" ] }, { - "metadata": { - "id": "pwX7Fii1rwsJ", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "import tensorflow as tf\n", - "tfe = tf.contrib.eager\n", - "\n", - "tf.enable_eager_execution()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "zSFfVVjkrrsI", - "colab_type": "text" - }, "cell_type": "markdown", - "source": [ - "## Layers: common sets of useful operations\n", - "\n", - "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n", - "\n", - "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n", - "\n", - "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n" - ] - }, - { "metadata": { - "id": "8PyXlPl-4TzQ", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "colab_type": "text", + "id": "60RdWsg1tETW" }, - "cell_type": "code", - "source": [ - "# In the tf.keras.layers package, layers are objects. To construct a layer,\n", - "# simply construct the object. Most layers take as a first argument the number\n", - "# of output dimensions / channels.\n", - "layer = tf.keras.layers.Dense(100)\n", - "# The number of input dimensions is often unnecessary, as it can be inferred\n", - "# the first time the layer is used, but it can be provided if you want to \n", - "# specify it manually, which is useful in some complex models.\n", - "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "Fn69xxPO5Psr", - "colab_type": "text" - }, - "cell_type": "markdown", "source": [ - "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n", - "Conv2D, LSTM, BatchNormalization, Dropout, and many others." + "# Custom layers" ] }, { - "metadata": { - "id": "E3XKNknP5Mhb", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# To use a layer, simply call it.\n", - "layer(tf.zeros([10, 5]))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "Wt_Nsv-L5t2s", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# Layers have many useful methods. For example, you can inspect all variables\n", - "# in a layer by calling layer.variables. In this case a fully-connected layer\n", - "# will have variables for weights and biases.\n", - "layer.variables" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "6ilvKjz8_4MQ", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# The variables are also accessible through nice accessors\n", - "layer.kernel, layer.bias" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "O0kDbE54-5VS", - "colab_type": "text" - }, "cell_type": "markdown", - "source": [ - "## Implementing custom layers\n", - "The best way to implement your own layer is extending the tf.keras.Layer class and implementing:\n", - " * `__init__` , where you can do all input-independent initialization\n", - " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n", - " * `call`, where you do the forward computation\n", - "\n", - "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes required to create the variables will need to be explicitly specified." - ] - }, - { - "metadata": { - "id": "5Byl3n1k5kIy", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class MyDenseLayer(tf.keras.layers.Layer):\n", - " def __init__(self, num_outputs):\n", - " super(MyDenseLayer, self).__init__()\n", - " self.num_outputs = num_outputs\n", - " \n", - " def build(self, input_shape):\n", - " self.kernel = self.add_variable(\"kernel\", \n", - " shape=[input_shape[-1].value, \n", - " self.num_outputs])\n", - " \n", - " def call(self, input):\n", - " return tf.matmul(input, self.kernel)\n", - " \n", - "layer = MyDenseLayer(10)\n", - "print(layer(tf.zeros([10, 5])))\n", - "print(layer.variables)" - ], - "execution_count": 0, - "outputs": [] - }, - { "metadata": { - "id": "tk8E2vY0-z4Z", - "colab_type": "text" + "colab_type": "text", + "id": "9sFn_RV_8zM-" }, - "cell_type": "markdown", "source": [ - "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`.\n", - "\n", - "Overall code is easier to read and maintain if it uses standard layers whenever possible, as other readers will be familiar with the behavior of standard layers. If you want to use a layer which is not present in tf.keras.layers or tf.contrib.layers, consider filing a [github issue](http://github.com/tensorflow/tensorflow/issues/new) or, even better, sending us a pull request!" + "This file has moved." ] }, { - "metadata": { - "id": "Qhg4KlbKrs3G", - "colab_type": "text" - }, "cell_type": "markdown", - "source": [ - "## Models: composing layers\n", - "\n", - "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n", - "\n", - "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model." - ] - }, - { - "metadata": { - "id": "N30DTXiRASlb", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class ResnetIdentityBlock(tf.keras.Model):\n", - " def __init__(self, kernel_size, filters):\n", - " super(ResnetIdentityBlock, self).__init__(name='')\n", - " filters1, filters2, filters3 = filters\n", - "\n", - " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n", - " self.bn2a = tf.keras.layers.BatchNormalization()\n", - "\n", - " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n", - " self.bn2b = tf.keras.layers.BatchNormalization()\n", - "\n", - " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n", - " self.bn2c = tf.keras.layers.BatchNormalization()\n", - "\n", - " def call(self, input_tensor, training=False):\n", - " x = self.conv2a(input_tensor)\n", - " x = self.bn2a(x, training=training)\n", - " x = tf.nn.relu(x)\n", - "\n", - " x = self.conv2b(x)\n", - " x = self.bn2b(x, training=training)\n", - " x = tf.nn.relu(x)\n", - "\n", - " x = self.conv2c(x)\n", - " x = self.bn2c(x, training=training)\n", - "\n", - " x += input_tensor\n", - " return tf.nn.relu(x)\n", - "\n", - " \n", - "block = ResnetIdentityBlock(1, [1, 2, 3])\n", - "print(block(tf.zeros([1, 2, 3, 3])))\n", - "print([x.name for x in block.variables])" - ], - "execution_count": 0, - "outputs": [] - }, - { "metadata": { - "id": "wYfucVw65PMj", - "colab_type": "text" + "colab_type": "text", + "id": "BcJg7Enms86w" }, - "cell_type": "markdown", "source": [ - "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/custom_layers.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/custom_layers.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "custom_layers.ipynb", + "private_outputs": true, + "provenance": [], + "toc_visible": true, + "version": "0.3.2" }, - { - "metadata": { - "id": "L9frk7Ur4uvJ", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n", - " tf.keras.layers.BatchNormalization(),\n", - " tf.keras.layers.Conv2D(2, 1, \n", - " padding='same'),\n", - " tf.keras.layers.BatchNormalization(),\n", - " tf.keras.layers.Conv2D(3, (1, 1)),\n", - " tf.keras.layers.BatchNormalization()])\n", - "my_seq(tf.zeros([1, 2, 3, 3]))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "c5YwYcnuK-wc", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Next steps\n", - "\n", - "Now you can go back to the previous notebook and adapt the linear regression example to use layers and models to be better structured." - ] + "kernelspec": { + "display_name": "Python 3", + "name": "python3" } - ] -} \ No newline at end of file + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb index 5f1b48fa0d4aea06adab19a0e561923e1f557e50..86dca0b423d0615de48a30de7eebc17eae0aff69 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb @@ -1,46 +1,25 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Custom training: basics", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [], - "private_outputs": true, - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - } - }, "cells": [ { + "cell_type": "markdown", "metadata": { - "id": "5rmpybwysXGV", - "colab_type": "text" + "colab_type": "text", + "id": "5rmpybwysXGV" }, - "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "m8y3rGtQsYP2", + "cellView": "form", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "form" + "id": "m8y3rGtQsYP2" }, - "cell_type": "code", + "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -53,425 +32,57 @@ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "hrXv0rU9sIma", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Custom training: basics" - ] - }, - { - "metadata": { - "id": "7S0BwJ_8sLu7", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "
\n", - "\n", - " Run in Google Colab\n", - "\n", - "View source on GitHub
" - ] - }, - { - "metadata": { - "id": "k2o3TTG4TFpt", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "In the previous tutorial we covered the TensorFlow APIs for automatic differentiation, a basic building block for machine learning.\n", - "In this tutorial we will use the TensorFlow primitives introduced in the prior tutorials to do some simple machine learning.\n", - "\n", - "TensorFlow also includes a higher-level neural networks API (`tf.keras`) which provides useful abstractions to reduce boilerplate. We strongly recommend those higher level APIs for people working with neural networks. However, in this short tutorial we cover neural network training from first principles to establish a strong foundation." - ] - }, - { - "metadata": { - "id": "3LXMVuV0VhDr", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Setup" - ] - }, - { - "metadata": { - "id": "PJ64L90aVir3", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "import tensorflow as tf\n", - "\n", - "tf.enable_eager_execution()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "eMAWbDJFVmMk", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Variables\n", - "\n", - "Tensors in TensorFlow are immutable stateless objects. Machine learning models, however, need to have changing state: as your model trains, the same code to compute predictions should behave differently over time (hopefully with a lower loss!). To represent this state which needs to change over the course of your computation, you can choose to rely on the fact that Python is a stateful programming language:\n" - ] - }, - { - "metadata": { - "id": "VkJwtLS_Jbn8", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "# Using python state\n", - "x = tf.zeros([10, 10])\n", - "x += 2 # This is equivalent to x = x + 2, which does not mutate the original\n", - " # value of x\n", - "print(x)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "wfneTXy7JcUz", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "TensorFlow, however, has stateful operations built in, and these are often more pleasant to use than low-level Python representations of your state. To represent weights in a model, for example, it's often convenient and efficient to use TensorFlow variables.\n", - "\n", - "A Variable is an object which stores a value and, when used in a TensorFlow computation, will implicitly read from this stored value. There are operations (`tf.assign_sub`, `tf.scatter_update`, etc) which manipulate the value stored in a TensorFlow variable." ] }, { - "metadata": { - "id": "itxmrMil6DQi", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "v = tf.Variable(1.0)\n", - "assert v.numpy() == 1.0\n", - "\n", - "# Re-assign the value\n", - "v.assign(3.0)\n", - "assert v.numpy() == 3.0\n", - "\n", - "# Use `v` in a TensorFlow operation like tf.square() and reassign\n", - "v.assign(tf.square(v))\n", - "assert v.numpy() == 9.0" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "-paSaeq1JzwC", - "colab_type": "text" - }, "cell_type": "markdown", - "source": [ - "Computations using Variables are automatically traced when computing gradients. For Variables representing embeddings TensorFlow will do sparse updates by default, which are more computation and memory efficient.\n", - "\n", - "Using Variables is also a way to quickly let a reader of your code know that this piece of state is mutable." - ] - }, - { "metadata": { - "id": "BMiFcDzE7Qu3", - "colab_type": "text" + "colab_type": "text", + "id": "hrXv0rU9sIma" }, - "cell_type": "markdown", "source": [ - "## Example: Fitting a linear model\n", - "\n", - "Let's now put the few concepts we have so far ---`Tensor`, `GradientTape`, `Variable` --- to build and train a simple model. This typically involves a few steps:\n", - "\n", - "1. Define the model.\n", - "2. Define a loss function.\n", - "3. Obtain training data.\n", - "4. Run through the training data and use an \"optimizer\" to adjust the variables to fit the data.\n", - "\n", - "In this tutorial, we'll walk through a trivial example of a simple linear model: `f(x) = x * W + b`, which has two variables - `W` and `b`. Furthermore, we'll synthesize data such that a well trained model would have `W = 3.0` and `b = 2.0`." - ] - }, - { - "metadata": { - "id": "gFzH64Jn9PIm", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "### Define the model\n", - "\n", - "Let's define a simple class to encapsulate the variables and the computation." + "# Custom training: basics" ] }, { - "metadata": { - "id": "_WRu7Pze7wk8", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "class Model(object):\n", - " def __init__(self):\n", - " # Initialize variable to (5.0, 0.0)\n", - " # In practice, these should be initialized to random values.\n", - " self.W = tf.Variable(5.0)\n", - " self.b = tf.Variable(0.0)\n", - " \n", - " def __call__(self, x):\n", - " return self.W * x + self.b\n", - " \n", - "model = Model()\n", - "\n", - "assert model(3.0).numpy() == 15.0" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "xa6j_yXa-j79", - "colab_type": "text" - }, "cell_type": "markdown", - "source": [ - "### Define a loss function\n", - "\n", - "A loss function measures how well the output of a model for a given input matches the desired output. Let's use the standard L2 loss." - ] - }, - { - "metadata": { - "id": "Y0ysUFGY924U", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def loss(predicted_y, desired_y):\n", - " return tf.reduce_mean(tf.square(predicted_y - desired_y))" - ], - "execution_count": 0, - "outputs": [] - }, - { "metadata": { - "id": "qutT_fkl_CBc", - "colab_type": "text" + "colab_type": "text", + "id": "IGPZTmwn9IT4" }, - "cell_type": "markdown", "source": [ - "### Obtain training data\n", - "\n", - "Let's synthesize the training data with some noise." + "This file has moved." ] }, { - "metadata": { - "id": "gxPTb-kt_N5m", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "TRUE_W = 3.0\n", - "TRUE_b = 2.0\n", - "NUM_EXAMPLES = 1000\n", - "\n", - "inputs = tf.random_normal(shape=[NUM_EXAMPLES])\n", - "noise = tf.random_normal(shape=[NUM_EXAMPLES])\n", - "outputs = inputs * TRUE_W + TRUE_b + noise" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "-50nq-wPBsAW", - "colab_type": "text" - }, "cell_type": "markdown", - "source": [ - "Before we train the model let's visualize where the model stands right now. We'll plot the model's predictions in red and the training data in blue." - ] - }, - { "metadata": { - "id": "_eb83LtrB4nt", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "colab_type": "text", + "id": "7S0BwJ_8sLu7" }, - "cell_type": "code", "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "plt.scatter(inputs, outputs, c='b')\n", - "plt.scatter(inputs, model(inputs), c='r')\n", - "plt.show()\n", - "\n", - "print('Current loss: '),\n", - "print(loss(model(inputs), outputs).numpy())" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "sSDP-yeq_4jE", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "### Define a training loop\n", - "\n", - "We now have our network and our training data. Let's train it, i.e., use the training data to update the model's variables (`W` and `b`) so that the loss goes down using [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent). There are many variants of the gradient descent scheme that are captured in `tf.train.Optimizer` implementations. We'd highly recommend using those implementations, but in the spirit of building from first principles, in this particular example we will implement the basic math ourselves." + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/custom_training.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/custom_training.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "Custom training: basics", + "private_outputs": true, + "provenance": [], + "toc_visible": true, + "version": "0.3.2" }, - { - "metadata": { - "id": "MBIACgdnA55X", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def train(model, inputs, outputs, learning_rate):\n", - " with tf.GradientTape() as t:\n", - " current_loss = loss(model(inputs), outputs)\n", - " dW, db = t.gradient(current_loss, [model.W, model.b])\n", - " model.W.assign_sub(learning_rate * dW)\n", - " model.b.assign_sub(learning_rate * db)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "RwWPaJryD2aN", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "Finally, let's repeatedly run through the training data and see how `W` and `b` evolve." - ] - }, - { - "metadata": { - "id": "XdfkR223D9dW", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "model = Model()\n", - "\n", - "# Collect the history of W-values and b-values to plot later\n", - "Ws, bs = [], []\n", - "epochs = range(10)\n", - "for epoch in epochs:\n", - " Ws.append(model.W.numpy())\n", - " bs.append(model.b.numpy())\n", - " current_loss = loss(model(inputs), outputs)\n", - "\n", - " train(model, inputs, outputs, learning_rate=0.1)\n", - " print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %\n", - " (epoch, Ws[-1], bs[-1], current_loss))\n", - "\n", - "# Let's plot it all\n", - "plt.plot(epochs, Ws, 'r',\n", - " epochs, bs, 'b')\n", - "plt.plot([TRUE_W] * len(epochs), 'r--',\n", - " [TRUE_b] * len(epochs), 'b--')\n", - "plt.legend(['W', 'b', 'true W', 'true_b'])\n", - "plt.show()\n", - " " - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "vPnIVuaSJwWz", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Next Steps\n", - "\n", - "In this tutorial we covered `Variable`s and built and trained a simple linear model using the TensorFlow primitives discussed so far.\n", - "\n", - "In theory, this is pretty much all you need to use TensorFlow for your machine learning research.\n", - "In practice, particularly for neural networks, the higher level APIs like `tf.keras` will be much more convenient since it provides higher level building blocks (called \"layers\"), utilities to save and restore state, a suite of loss functions, a suite of optimization strategies etc. \n", - "\n", - "The [next tutorial](TODO) will cover these higher level APIs." - ] + "kernelspec": { + "display_name": "Python 3", + "name": "python3" } - ] -} \ No newline at end of file + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb index f1e13de5dec2fbda126caeb355494875317e3373..c6d1a566043d80741c4075a50f142b2780c78d06 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb @@ -1,46 +1,25 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "eager_basics.ipynb", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [], - "private_outputs": true, - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - } - }, "cells": [ { + "cell_type": "markdown", "metadata": { - "id": "iPpI7RaYoZuE", - "colab_type": "text" + "colab_type": "text", + "id": "iPpI7RaYoZuE" }, - "cell_type": "markdown", "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "hro2InpHobKk", + "cellView": "form", + "colab": {}, "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "form" + "id": "hro2InpHobKk" }, - "cell_type": "code", + "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -53,439 +32,47 @@ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "U9i2Dsh-ziXr", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Eager execution basics" - ] - }, - { - "metadata": { - "id": "Hndw-YcxoOJK", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "
\n", - "\n", - " Run in Google Colab\n", - "\n", - "View source on GitHub
" - ] - }, - { - "metadata": { - "id": "6sILUVbHoSgH", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "This is an introductory tutorial for using TensorFlow. It will cover:\n", - "\n", - "* Importing required packages\n", - "* Creating and using Tensors\n", - "* Using GPU acceleration\n", - "* Datasets" - ] - }, - { - "metadata": { - "id": "z1JcS5iBXMRO", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Import TensorFlow\n", - "\n", - "To get started, import the `tensorflow` module and enable eager execution.\n", - "Eager execution enables a more interactive frontend to TensorFlow, the details of which we will discuss much later." - ] - }, - { - "metadata": { - "id": "RlIWhyeLoYnG", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" - }, - "cell_type": "code", - "source": [ - "import tensorflow as tf\n", - "\n", - "tf.enable_eager_execution()" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "H9UySOPLXdaw", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Tensors\n", - "\n", - "A Tensor is a multi-dimensional array. Similar to NumPy `ndarray` objects, `Tensor` objects have a data type and a shape. Additionally, Tensors can reside in accelerator (like GPU) memory. TensorFlow offers a rich library of operations ([tf.add](https://www.tensorflow.org/api_docs/python/tf/add), [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul), [tf.linalg.inv](https://www.tensorflow.org/api_docs/python/tf/linalg/inv) etc.) that consume and produce Tensors. These operations automatically convert native Python types. For example:\n" - ] - }, - { - "metadata": { - "id": "ngUe237Wt48W", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" - }, - "cell_type": "code", - "source": [ - "print(tf.add(1, 2))\n", - "print(tf.add([1, 2], [3, 4]))\n", - "print(tf.square(5))\n", - "print(tf.reduce_sum([1, 2, 3]))\n", - "print(tf.encode_base64(\"hello world\"))\n", - "\n", - "# Operator overloading is also supported\n", - "print(tf.square(2) + tf.square(3))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "IDY4WsYRhP81", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "Each Tensor has a shape and a datatype" - ] - }, - { - "metadata": { - "id": "srYWH1MdJNG7", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "x = tf.matmul([[1]], [[2, 3]])\n", - "print(x.shape)\n", - "print(x.dtype)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "eBPw8e8vrsom", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "The most obvious differences between NumPy arrays and TensorFlow Tensors are:\n", - "\n", - "1. Tensors can be backed by accelerator memory (like GPU, TPU).\n", - "2. Tensors are immutable." - ] - }, - { - "metadata": { - "id": "Dwi1tdW3JBw6", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "### NumPy Compatibility\n", - "\n", - "Conversion between TensorFlow Tensors and NumPy ndarrays is quite simple as:\n", - "* TensorFlow operations automatically convert NumPy ndarrays to Tensors.\n", - "* NumPy operations automatically convert Tensors to NumPy ndarrays.\n", - "\n", - "Tensors can be explicitly converted to NumPy ndarrays by invoking the `.numpy()` method on them.\n", - "These conversions are typically cheap as the array and Tensor share the underlying memory representation if possible. However, sharing the underlying representation isn't always possible since the Tensor may be hosted in GPU memory while NumPy arrays are always backed by host memory, and the conversion will thus involve a copy from GPU to host memory." - ] - }, - { - "metadata": { - "id": "lCUWzso6mbqR", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "import numpy as np\n", - "\n", - "ndarray = np.ones([3, 3])\n", - "\n", - "print(\"TensorFlow operations convert numpy arrays to Tensors automatically\")\n", - "tensor = tf.multiply(ndarray, 42)\n", - "print(tensor)\n", - "\n", - "\n", - "print(\"And NumPy operations convert Tensors to numpy arrays automatically\")\n", - "print(np.add(tensor, 1))\n", - "\n", - "print(\"The .numpy() method explicitly converts a Tensor to a numpy array\")\n", - "print(tensor.numpy())" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "PBNP8yTRfu_X", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## GPU acceleration\n", - "\n", - "Many TensorFlow operations can be accelerated by using the GPU for computation. Without any annotations, TensorFlow automatically decides whether to use the GPU or CPU for an operation (and copies the tensor between CPU and GPU memory if necessary). Tensors produced by an operation are typically backed by the memory of the device on which the operation executed. For example:" - ] - }, - { - "metadata": { - "id": "3Twf_Rw-gQFM", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" - }, - "cell_type": "code", - "source": [ - "x = tf.random_uniform([3, 3])\n", - "\n", - "print(\"Is there a GPU available: \"),\n", - "print(tf.test.is_gpu_available())\n", - "\n", - "print(\"Is the Tensor on GPU #0: \"),\n", - "print(x.device.endswith('GPU:0'))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "vpgYzgVXW2Ud", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "### Device Names\n", - "\n", - "The `Tensor.device` property provides a fully qualified string name of the device hosting the contents of the Tensor. This name encodes a bunch of details, such as an identifier of the network address of the host on which this program is executing and the device within that host. This is required for distributed execution of TensorFlow programs, but we'll skip that for now. The string will end with `GPU:` if the tensor is placed on the `N`-th tensor on the host." ] }, { - "metadata": { - "id": "ZWZQCimzuqyP", - "colab_type": "text" - }, "cell_type": "markdown", - "source": [ - "\n", - "\n", - "### Explicit Device Placement\n", - "\n", - "The term \"placement\" in TensorFlow refers to how individual operations are assigned (placed on) a device for execution. As mentioned above, when there is no explicit guidance provided, TensorFlow automatically decides which device to execute an operation, and copies Tensors to that device if needed. However, TensorFlow operations can be explicitly placed on specific devices using the `tf.device` context manager. For example:" - ] - }, - { - "metadata": { - "id": "RjkNZTuauy-Q", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "def time_matmul(x):\n", - " %timeit tf.matmul(x, x)\n", - "\n", - "# Force execution on CPU\n", - "print(\"On CPU:\")\n", - "with tf.device(\"CPU:0\"):\n", - " x = tf.random_uniform([1000, 1000])\n", - " assert x.device.endswith(\"CPU:0\")\n", - " time_matmul(x)\n", - "\n", - "# Force execution on GPU #0 if available\n", - "if tf.test.is_gpu_available():\n", - " with tf.device(\"GPU:0\"): # Or GPU:1 for the 2nd GPU, GPU:2 for the 3rd etc.\n", - " x = tf.random_uniform([1000, 1000])\n", - " assert x.device.endswith(\"GPU:0\")\n", - " time_matmul(x)" - ], - "execution_count": 0, - "outputs": [] - }, - { "metadata": { - "id": "o1K4dlhhHtQj", - "colab_type": "text" + "colab_type": "text", + "id": "U9i2Dsh-ziXr" }, - "cell_type": "markdown", "source": [ - "## Datasets\n", - "\n", - "This section demonstrates the use of the [`tf.data.Dataset` API](https://www.tensorflow.org/guide/datasets) to build pipelines to feed data to your model. It covers:\n", - "\n", - "* Creating a `Dataset`.\n", - "* Iteration over a `Dataset` with eager execution enabled.\n", - "\n", - "We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n", - "\n", - "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly simpler.\n", - "You can use Python iteration over the `tf.data.Dataset` object and do not need to explicitly create an `tf.data.Iterator` object.\n", - "As a result, the discussion on iterators in the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets) is not relevant when eager execution is enabled." - ] - }, - { - "metadata": { - "id": "zI0fmOynH-Ne", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "### Create a source `Dataset`\n", - "\n", - "Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset). See the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets#reading_input_data) for more information." + "# Eager execution basics" ] }, { - "metadata": { - "id": "F04fVOHQIBiG", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "ds_tensors = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])\n", - "\n", - "# Create a CSV file\n", - "import tempfile\n", - "_, filename = tempfile.mkstemp()\n", - "\n", - "with open(filename, 'w') as f:\n", - " f.write(\"\"\"Line 1\n", - "Line 2\n", - "Line 3\n", - " \"\"\")\n", - "\n", - "ds_file = tf.data.TextLineDataset(filename)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "vbxIhC-5IPdf", - "colab_type": "text" - }, "cell_type": "markdown", - "source": [ - "### Apply transformations\n", - "\n", - "Use the transformations functions like [`map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map), [`batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle) etc. to apply transformations to the records of the dataset. See the [API documentation for `tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) for details." - ] - }, - { "metadata": { - "id": "uXSDZWE-ISsd", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "colab_type": "text", + "id": "Hndw-YcxoOJK" }, - "cell_type": "code", "source": [ - "ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)\n", - "\n", - "ds_file = ds_file.batch(2)" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "A8X1GNfoIZKJ", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "### Iterate\n", - "\n", - "When eager execution is enabled `Dataset` objects support iteration.\n", - "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that there is no need for calls to `Dataset.make_one_shot_iterator()` or `get_next()` calls." + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/eager_basics.ipynb\"\u003e\n", + " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + "\u003c/td\u003e\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/eager_basics.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "eager_basics.ipynb", + "private_outputs": true, + "provenance": [], + "toc_visible": true, + "version": "0.3.2" }, - { - "metadata": { - "id": "ws-WKRk5Ic6-", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "cell_type": "code", - "source": [ - "print('Elements of ds_tensors:')\n", - "for x in ds_tensors:\n", - " print(x)\n", - "\n", - "print('\\nElements in ds_file:')\n", - "for x in ds_file:\n", - " print(x)" - ], - "execution_count": 0, - "outputs": [] + "kernelspec": { + "display_name": "Python 3", + "name": "python3" } - ] -} \ No newline at end of file + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py index a28bc8a43d7c90737c9baf9a634d736e9de52948..3f70f573b1faeeb09e814e761f7e0f285cf328bd 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py @@ -272,8 +272,8 @@ class ResNet50(tf.keras.Model): else: self.global_pooling = None - def call(self, input_tensor, training): - x = self.conv1(input_tensor) + def call(self, inputs, training=True): + x = self.conv1(inputs) x = self.bn_conv1(x, training=training) x = tf.nn.relu(x) x = self.max_pool(x) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 5ee2176154ec7011dcb3d7b384a86213e778014f..74ebb1ec77131a560b1ebfd062c690920c35e261 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -243,8 +243,8 @@ def train_one_epoch(model, optimizer, train_data, log_interval=10): print("train/batch #%d\tloss: %.6f" % (batch, batch_model_loss())) -SOURCE_TRAIN_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv" -SOURCE_TEST_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv" +SOURCE_TRAIN_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/train.csv" +SOURCE_TEST_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/test.csv" def main(_): diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 6efafccd6b93ad58da395e0b2e1e647809af62ad..930e62b68096b468846a01b9674c669a8b8e9a53 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -336,9 +336,27 @@ class Mean(Metric): return values return values, weights - def result(self): + def result(self, write_summary=True): + """Returns the result of the Metric. + + Args: + write_summary: bool indicating whether to feed the result to the summary + before returning. + Returns: + aggregated metric as float. + Raises: + ValueError: if the optional argument is not bool + """ + # Convert the boolean to tensor for tf.cond, if it is not. + if not isinstance(write_summary, ops.Tensor): + write_summary = ops.convert_to_tensor(write_summary) t = self.numer / self.denom - summary_ops.scalar(name=self.name, tensor=t) + def write_summary_f(): + summary_ops.scalar(name=self.name, tensor=t) + return t + control_flow_ops.cond(write_summary, + write_summary_f, + lambda: t) return t diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 20d938d492bf78fab852c638ba675d7ee6ed9073..aa9961681024b84a7e465845a3502e205f209119 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -46,6 +46,18 @@ class MetricsTest(test.TestCase): self.assertEqual(dtypes.float64, m.dtype) self.assertEqual(dtypes.float64, m.result().dtype) + def testSummaryArg(self): + m = metrics.Mean() + m([1, 10, 100]) + m(1000) + m([10000.0, 100000.0]) + self.assertEqual(111111.0/6, m.result(write_summary=True).numpy()) + self.assertEqual(111111.0/6, m.result(write_summary=False).numpy()) + with self.assertRaises(ValueError): + m.result(write_summary=5) + with self.assertRaises(ValueError): + m.result(write_summary=[True]) + def testVariableCollections(self): with context.graph_mode(), ops.Graph().as_default(): m = metrics.Mean() @@ -93,6 +105,16 @@ class MetricsTest(test.TestCase): self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 37.0) + # Get result without saving the summary. + logdir = tempfile.mkdtemp() + with summary_ops.create_file_writer( + logdir, max_queue=0, + name="t0").as_default(), summary_ops.always_record_summaries(): + m.result(write_summary=False) # As a side-effect will write summaries. + # events_from_logdir(_) asserts the directory exists. + events = summary_test_util.events_from_logdir(logdir) + self.assertEqual(len(events), 1) + def testWeightedMean(self): m = metrics.Mean() m([1, 100, 100000], weights=[1, 0.2, 0.3]) diff --git a/tensorflow/contrib/eager/python/remote.py b/tensorflow/contrib/eager/python/remote.py new file mode 100644 index 0000000000000000000000000000000000000000..b74cf394f682b64327bc570ef8dbe79f5657902c --- /dev/null +++ b/tensorflow/contrib/eager/python/remote.py @@ -0,0 +1,73 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helpers to connect to remote servers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.core.protobuf.cluster_pb2 import ClusterDef +from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef +from tensorflow.python.eager import context + + +def connect_to_remote_host(remote_host=None, job_name="worker"): + """Connects to a single machine to enable remote execution on it. + + Will make devices on the remote host available to use. Note that calling this + more than once will work, but will invalidate any tensor handles on the old + remote devices. + + Using the default job_name of worker, you can schedule ops to run remotely as + follows: + ```python + # Enable eager execution, and connect to the remote host. + tf.enable_eager_execution() + tf.contrib.eager.connect_to_remote_host("exampleaddr.com:9876") + + with ops.device("job:worker/replica:0/task:1/device:CPU:0"): + # The following tensors should be resident on the remote device, and the op + # will also execute remotely. + x1 = array_ops.ones([2, 2]) + x2 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x2) + ``` + + Args: + remote_host: The addr of the remote server in host-port format. + job_name: The job name under which the new server will be accessible. + + Raises: + ValueError: if remote_host is None. + """ + if remote_host is None: + raise ValueError("Must provide an remote_host") + cluster_def = ClusterDef() + job_def = cluster_def.job.add() + job_def.name = job_name + job_def.tasks[0] = "127.0.0.1:0" + job_def.tasks[1] = remote_host + + server_def = ServerDef( + cluster=cluster_def, + job_name=job_name, + task_index=0, + protocol="grpc") + + # TODO(nareshmodi): Make this default since it works in more situations. + os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1" + context.set_server_def(server_def) diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py index 76f48eeb1cab9d1f014adeafe4827cb5d3a8c77d..13029db975bcbf8a6b31ba3c11d4c2b08edfdb6f 100644 --- a/tensorflow/contrib/eager/python/remote_test.py +++ b/tensorflow/contrib/eager/python/remote_test.py @@ -23,6 +23,7 @@ import os import numpy as np +from tensorflow.contrib.eager.python import remote from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.eager import backprop @@ -85,6 +86,7 @@ class RemoteExecutionTest(test.TestCase): self._cached_server1_target = self._cached_server1.target[len("grpc://"):] self._cached_server2_target = self._cached_server2.target[len("grpc://"):] + def setUp(self): # Start the local server. context.set_server_def( server_def=get_server_def( @@ -172,6 +174,17 @@ class RemoteExecutionTest(test.TestCase): y = math_ops.matmul(x1, x1) np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + @run_sync_and_async + def testConnectToRemoteServer(self): + """Basic server connection.""" + remote.connect_to_remote_host(self._cached_server1_target) + + with ops.device("job:worker/replica:0/task:1/device:CPU:0"): + x1 = array_ops.ones([2, 2]) + x2 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + if __name__ == "__main__": ops.enable_eager_execution() diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 4dfd0834430b2295d1454314e88c824efe4c8b13..f5b8d95e4fc7fe5cd90d658eda49590e0b330bb0 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -74,6 +74,8 @@ To use, at program startup, call `tf.enable_eager_execution()`. @@TensorSpec +@@connect_to_remote_host + @@DEVICE_PLACEMENT_EXPLICIT @@DEVICE_PLACEMENT_WARN @@DEVICE_PLACEMENT_SILENT @@ -94,6 +96,7 @@ from tensorflow.contrib.eager.python.network import Network from tensorflow.contrib.eager.python.network import Sequential from tensorflow.contrib.eager.python.network import save_network_checkpoint from tensorflow.contrib.eager.python.network import restore_network_checkpoint +from tensorflow.contrib.eager.python.remote import connect_to_remote_host from tensorflow.contrib.eager.python.saver import get_optimizer_variables from tensorflow.contrib.eager.python.saver import restore_variables_on_create from tensorflow.contrib.eager.python.saver import Saver diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index 26449b46516fe1d8c93a8e3567f93801c689a65a..e3c44bea663969b5f251275ca10676d1cd567de2 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -26,6 +26,7 @@ from tensorflow.python.estimator.export.export_output import PredictOutput from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.util import function_utils @@ -140,7 +141,7 @@ def clip_gradients_by_norm(optimizer, clip_norm): name='ClipByNorm' + optimizer.get_name()) -def forward_features(estimator, keys=None): +def forward_features(estimator, keys=None, sparse_default_values=None): """Forward features to predictions dictionary. In some cases, user wants to see some of the features in estimators prediction @@ -148,39 +149,36 @@ def forward_features(estimator, keys=None): runs inference on the users graph and returns the results. Keys are essential because there is no order guarantee on the outputs so they need to be rejoined to the inputs via keys or transclusion of the inputs in the outputs. - Example: - ```python def input_fn(): features, labels = ... features['unique_example_id'] = ... features, labels - estimator = tf.estimator.LinearClassifier(...) estimator = tf.contrib.estimator.forward_features( estimator, 'unique_example_id') estimator.train(...) assert 'unique_example_id' in estimator.predict(...) ``` - Args: estimator: A `tf.estimator.Estimator` object. - keys: a `string` or a `list` of `string`. If it is `None`, all of the + keys: A `string` or a `list` of `string`. If it is `None`, all of the `features` in `dict` is forwarded to the `predictions`. If it is a `string`, only given key is forwarded. If it is a `list` of strings, all the given `keys` are forwarded. + sparse_default_values: A dict of `str` keys mapping the name of the sparse + features to be converted to dense, to the default value to use. Only + sparse features indicated in the dictionary are converted to dense and the + provided default value is used. Returns: A new `tf.estimator.Estimator` which forwards features to predictions. - Raises: ValueError: * if `keys` is already part of `predictions`. We don't allow override. * if 'keys' does not exist in `features`. - * if feature key refers to a `SparseTensor`, since we don't support - `SparseTensor` in `predictions`. `SparseTensor` is common in `features`. TypeError: if `keys` type is not one of `string` or list/tuple of `string`. """ @@ -231,11 +229,18 @@ def forward_features(estimator, keys=None): for key in get_keys(features): feature = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( features[key]) + if sparse_default_values and (key in sparse_default_values): + if not isinstance(feature, sparse_tensor_lib.SparseTensor): + raise ValueError( + 'Feature ({}) is expected to be a `SparseTensor`.'.format(key)) + feature = sparse_ops.sparse_tensor_to_dense( + feature, default_value=sparse_default_values[key]) if not isinstance(feature, ops.Tensor): raise ValueError( - 'Forwarded feature ({}) should be a Tensor. Please use keys ' - 'argument of forward_features to filter unwanted features. Type of ' - 'features[{}] is {}.'.format(key, key, type(feature))) + 'Feature ({}) should be a Tensor. Please use `keys` ' + 'argument of forward_features to filter unwanted features, or' + 'add key to argument `sparse_default_values`.' + 'Type of features[{}] is {}.'.format(key, key, type(feature))) predictions[key] = feature spec = spec._replace(predictions=predictions) if spec.export_outputs: diff --git a/tensorflow/contrib/estimator/python/estimator/extenders_test.py b/tensorflow/contrib/estimator/python/estimator/extenders_test.py index 407af2deaf0928361a4f0b0e44e842b7750118cb..c8fdaa8791b83e54d69993cfed3205d6d343ed19 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders_test.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders_test.py @@ -14,6 +14,7 @@ # ============================================================================== """extenders tests.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -23,6 +24,7 @@ import tempfile import numpy as np from tensorflow.contrib.estimator.python.estimator import extenders +from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.predictor import from_saved_model from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator_lib @@ -170,19 +172,53 @@ class ClipGradientsByNormTest(test.TestCase): class ForwardFeaturesTest(test.TestCase): """Tests forward_features.""" - def test_forward_single_key(self): - - def input_fn(): - return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]] + def _export_estimator(self, estimator, serving_input_fn): + tmpdir = tempfile.mkdtemp() + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = estimator.export_savedmodel(export_dir_base, serving_input_fn) + self.assertTrue(gfile.Exists(export_dir)) + return export_dir, tmpdir + def make_dummy_input_fn(self): + def _input_fn(): + dataset = dataset_ops.Dataset.from_tensors({ + 'x': [[3.], [5.]], + 'id': [[101], [102]], + 'sparse_id': sparse_tensor.SparseTensor( + values=[1, 2, 3], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]), + 'labels': [[1.], [2.]] + }) + def _split(x): + labels = x.pop('labels') + return x, labels + dataset = dataset.map(_split) + return dataset + return _input_fn + + def test_forward_keys(self): + + input_fn = self.make_dummy_input_fn() estimator = linear.LinearRegressor([fc.numeric_column('x')]) estimator.train(input_fn=input_fn, steps=1) - self.assertNotIn('id', next(estimator.predict(input_fn=input_fn))) - estimator = extenders.forward_features(estimator, 'id') - predictions = next(estimator.predict(input_fn=input_fn)) - self.assertIn('id', predictions) - self.assertEqual(101, predictions['id']) + forwarded_keys = ['id', 'sparse_id'] + + for key in forwarded_keys: + self.assertNotIn(key, next(estimator.predict(input_fn=input_fn))) + + estimator = extenders.forward_features( + estimator, forwarded_keys, sparse_default_values={'sparse_id': 1}) + + expected_results = [101, 2, 102, 5] + predictions = estimator.predict(input_fn=input_fn) + for _ in range(2): + prediction = next(predictions) + for key in forwarded_keys: + self.assertIn(key, prediction) + self.assertEqual(expected_results.pop(0), sum(prediction[key])) def test_forward_in_exported(self): @@ -205,11 +241,7 @@ class ForwardFeaturesTest(test.TestCase): estimator = extenders.forward_features(estimator, 'id') # export saved model - tmpdir = tempfile.mkdtemp() - export_dir_base = os.path.join( - compat.as_bytes(tmpdir), compat.as_bytes('export')) - export_dir = estimator.export_savedmodel(export_dir_base, serving_input_fn) - self.assertTrue(gfile.Exists(export_dir)) + export_dir, tmpdir = self._export_estimator(estimator, serving_input_fn) # restore model predict_fn = from_saved_model(export_dir, signature_def_key='predict') @@ -222,6 +254,47 @@ class ForwardFeaturesTest(test.TestCase): # Clean up. gfile.DeleteRecursively(tmpdir) + def test_forward_in_exported_sparse(self): + features_columns = [fc.indicator_column( + fc.categorical_column_with_vocabulary_list('x', range(10)))] + + classifier = linear.LinearClassifier(feature_columns=features_columns) + + def train_input_fn(): + dataset = dataset_ops.Dataset.from_tensors({ + 'x': sparse_tensor.SparseTensor( + values=[1, 2, 3], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]), + 'labels': [[0], [1]] + }) + def _split(x): + labels = x.pop('labels') + return x, labels + dataset = dataset.map(_split) + return dataset + + classifier.train(train_input_fn, max_steps=1) + + classifier = extenders.forward_features( + classifier, keys=['x'], sparse_default_values={'x': 0}) + + def serving_input_fn(): + features_ph = array_ops.placeholder(dtype=dtypes.int32, name='x', + shape=[None]) + features = {'x': layers.dense_to_sparse(features_ph)} + return estimator_lib.export.ServingInputReceiver(features, + {'x': features_ph}) + export_dir, tmpdir = self._export_estimator(classifier, serving_input_fn) + prediction_fn = from_saved_model(export_dir, signature_def_key='predict') + + features = (0, 2) + prediction = prediction_fn({'x': features}) + + self.assertIn('x', prediction) + self.assertEqual(features, tuple(prediction['x'])) + gfile.DeleteRecursively(tmpdir) + def test_forward_list(self): def input_fn(): @@ -266,7 +339,6 @@ class ForwardFeaturesTest(test.TestCase): extenders.forward_features(estimator, ['x', estimator]) def test_key_should_be_in_features(self): - def input_fn(): return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]] @@ -279,27 +351,36 @@ class ForwardFeaturesTest(test.TestCase): next(estimator.predict(input_fn=input_fn)) def test_forwarded_feature_should_not_be_a_sparse_tensor(self): - def input_fn(): return { 'x': [[3.], [5.]], - 'id': - sparse_tensor.SparseTensor( - values=['1', '2'], - indices=[[0, 0], [1, 0]], - dense_shape=[2, 1]) - }, [[1.], [2.]] + 'id': sparse_tensor.SparseTensor( + values=['1', '2'], + indices=[[0, 0], [1, 0]], + dense_shape=[2, 1]) + }, [[1.], [2.]] estimator = linear.LinearRegressor([fc.numeric_column('x')]) estimator.train(input_fn=input_fn, steps=1) estimator = extenders.forward_features(estimator) with self.assertRaisesRegexp(ValueError, - 'Forwarded feature.* should be a Tensor.'): + 'Feature .* should be a Tensor.*'): next(estimator.predict(input_fn=input_fn)) - def test_predictions_should_be_dict(self): + def test_forwarded_feature_should_be_a_sparse_tensor(self): + input_fn = self.make_dummy_input_fn() + + estimator = linear.LinearRegressor([fc.numeric_column('x')]) + estimator.train(input_fn=input_fn, steps=1) + estimator = extenders.forward_features( + estimator, sparse_default_values={'id': 0, 'sparse_id': 0}) + with self.assertRaisesRegexp( + ValueError, 'Feature .* is expected to be a `SparseTensor`.'): + next(estimator.predict(input_fn=input_fn)) + + def test_predictions_should_be_dict(self): def input_fn(): return {'x': [[3.], [5.]], 'id': [[101], [102]]} diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 9866fccfba3562221ea7fe845e860ab470e238a0..9d0e6e1335d0be3477b78abce94999122672ff05 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -105,6 +105,7 @@ py_library( deps = [ ":gan_estimator", ":head", + ":stargan_estimator", "//tensorflow/python:util", ], ) @@ -533,6 +534,57 @@ py_test( ], ) +py_library( + name = "stargan_estimator", + srcs = [ + "python/estimator/python/stargan_estimator.py", + "python/estimator/python/stargan_estimator_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":namedtuples", + ":summaries", + ":train", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:framework_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_test( + name = "stargan_estimator_test", + srcs = ["python/estimator/python/stargan_estimator_test.py"], + shard_count = 1, + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":namedtuples", + ":stargan_estimator", + ":tuple_losses", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/learn", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:training_util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:estimator_py", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + "@six_archive//:six", + ], +) + py_library( name = "sliced_wasserstein", srcs = [ diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py index c9f7bc61b25230e4159cf8cbc7c9cceead0aa706..99d38011ba677f03e198a431634fbb2ce349f912 100644 --- a/tensorflow/contrib/gan/python/estimator/__init__.py +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -26,15 +26,18 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.gan.python.estimator.python import gan_estimator from tensorflow.contrib.gan.python.estimator.python import head +from tensorflow.contrib.gan.python.estimator.python import stargan_estimator from tensorflow.contrib.gan.python.estimator.python.gan_estimator import * from tensorflow.contrib.gan.python.estimator.python.head import * +from tensorflow.contrib.gan.python.estimator.python.stargan_estimator import * # pylint: enable=unused-import,wildcard-import from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'gan_estimator', + 'stargan_estimator', 'head', -] + gan_estimator.__all__ + head.__all__ +] + gan_estimator.__all__ + stargan_estimator.__all__ + head.__all__ remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/optimizer_lib.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py similarity index 70% rename from tensorflow/contrib/kfac/python/ops/optimizer_lib.py rename to tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py index 87d1866e06bb0a572033828dd5c2f04b05296039..341bdf9fbbc54893afb5d754e29c2d49754d1aec 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer_lib.py +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py @@ -12,19 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The KFAC optimizer.""" +"""`tf.Learn` components for `GANEstimator`.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.optimizer import * +from tensorflow.contrib.gan.python.estimator.python import stargan_estimator_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.estimator.python.stargan_estimator_impl import * +# pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import -_allowed_symbols = [ - "KfacOptimizer", -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) +__all__ = stargan_estimator_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..f60e16bc04662b33bc0bb22b5acc8c7fcc7a03ba --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py @@ -0,0 +1,363 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A TFGAN-backed StarGAN Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import enum + +from tensorflow.contrib.framework.python.ops import variables as variable_lib +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python import train as tfgan_train +from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import variable_scope +from tensorflow.python.util import tf_inspect as inspect + +__all__ = ['StarGANEstimator', 'SummaryType'] + + +class SummaryType(enum.IntEnum): + NONE = 0 + VARIABLES = 1 + IMAGES = 2 + IMAGE_COMPARISON = 3 + + +_summary_type_map = { + SummaryType.VARIABLES: tfgan_summaries.add_gan_model_summaries, + SummaryType.IMAGES: tfgan_summaries.add_stargan_image_summaries, +} + + +class StarGANEstimator(estimator.Estimator): + """An estimator for Generative Adversarial Networks (GANs). + + This Estimator is backed by TFGAN. The network functions follow the TFGAN API + except for one exception: if either `generator_fn` or `discriminator_fn` have + an argument called `mode`, then the tf.Estimator mode is passed in for that + argument. This helps with operations like batch normalization, which have + different train and evaluation behavior. + + Example: + + ```python + import tensorflow as tf + tfgan = tf.contrib.gan + + # See TFGAN's `train.py` for a description of the generator and + # discriminator API. + def generator_fn(generator_inputs): + ... + return generated_data + + def discriminator_fn(data, conditioning): + ... + return logits + + # Create GAN estimator. + stargan_estimator = tfgan.estimator.StarGANEstimator( + model_dir, + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + loss_fn=loss_fn, + generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5), + discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5)) + + # Train estimator. + stargan_estimator.train(train_input_fn, steps) + + # Evaluate resulting estimator. + stargan_estimator.evaluate(eval_input_fn) + + # Generate samples from generator. + stargan_estimator = np.array([ + x for x in stargan_estimator.predict(predict_input_fn)]) + ``` + """ + + def __init__(self, + model_dir=None, + generator_fn=None, + discriminator_fn=None, + loss_fn=None, + generator_optimizer=None, + discriminator_optimizer=None, + get_hooks_fn=None, + get_eval_metric_ops_fn=None, + add_summaries=None, + use_loss_summaries=True, + config=None): + """Initializes a StarGANEstimator instance. + + Args: + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + generator_fn: A python function that takes a Tensor, Tensor list, or + Tensor dictionary as inputs and returns the outputs of the GAN + generator. See `TFGAN` for more details and examples. Additionally, if + it has an argument called `mode`, the Estimator's `mode` will be passed + in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch + normalization. + discriminator_fn: A python function that takes the output of + `generator_fn` or real data in the GAN setup, and `input_data`. Outputs + a Tensor in the range [-inf, inf]. See `TFGAN` for more details and + examples. + loss_fn: The loss function on the generator. Takes a `StarGANModel` + namedtuple and return a `GANLoss` namedtuple. + generator_optimizer: The optimizer for generator updates, or a function + that takes no arguments and returns an optimizer. This function will be + called when the default graph is the `StarGANEstimator`'s graph, so + utilities like `tf.contrib.framework.get_or_create_global_step` will + work. + discriminator_optimizer: Same as `generator_optimizer`, but for the + discriminator updates. + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. These hooks are run on the generator and discriminator + train ops, and can be used to implement the GAN training scheme. + Defaults to `train.get_sequential_train_hooks()`. + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. + add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. + use_loss_summaries: If `True`, add loss summaries. If `False`, does not. + If `None`, uses defaults. + config: `RunConfig` object to configure the runtime settings. + + Raises: + ValueError: If loss functions aren't callable. + ValueError: If `use_loss_summaries` isn't boolean or `None`. + ValueError: If `get_hooks_fn` isn't callable or `None`. + """ + if not callable(loss_fn): + raise ValueError('loss_fn must be callable.') + if use_loss_summaries not in [True, False, None]: + raise ValueError('use_loss_summaries must be True, False or None.') + if get_hooks_fn is not None and not callable(get_hooks_fn): + raise TypeError('get_hooks_fn must be callable.') + + def _model_fn(features, labels, mode): + """StarGANEstimator model function.""" + if mode not in [ + model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, + model_fn_lib.ModeKeys.PREDICT + ]: + raise ValueError('Mode not recognized: %s' % mode) + + if mode == model_fn_lib.ModeKeys.PREDICT: + input_data = features[0] + input_data_domain_label = features[1] + else: + input_data = features # rename inputs for clarity + input_data_domain_label = labels # rename inputs for clarity + + # Make StarGANModel, which encapsulates the GAN model architectures. + gan_model = _get_gan_model(mode, generator_fn, discriminator_fn, + input_data, input_data_domain_label, + add_summaries) + + # Make the EstimatorSpec, which incorporates the StarGANModel, losses, + # eval, metrics, and optimizers (if required). + return _get_estimator_spec(mode, gan_model, loss_fn, + get_eval_metric_ops_fn, generator_optimizer, + discriminator_optimizer, get_hooks_fn) + + super(StarGANEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) + + +def _get_gan_model(mode, + generator_fn, + discriminator_fn, + input_data, + input_data_domain_label, + add_summaries, + generator_scope='Generator'): + """Makes the StarGANModel tuple.""" + if mode == model_fn_lib.ModeKeys.PREDICT: + gan_model = _make_prediction_gan_model(input_data, input_data_domain_label, + generator_fn, generator_scope) + else: # model_fn_lib.ModeKeys.TRAIN or model_fn_lib.ModeKeys.EVAL + gan_model = _make_gan_model(generator_fn, discriminator_fn, input_data, + input_data_domain_label, generator_scope, + add_summaries, mode) + + return gan_model + + +def _get_estimator_spec(mode, + gan_model, + loss_fn, + get_eval_metric_ops_fn, + generator_optimizer, + discriminator_optimizer, + get_hooks_fn=None): + """Get the EstimatorSpec for the current mode.""" + if mode == model_fn_lib.ModeKeys.PREDICT: + estimator_spec = model_fn_lib.EstimatorSpec( + mode=mode, predictions=gan_model.generated_data) + else: + gan_loss = loss_fn(gan_model) + if mode == model_fn_lib.ModeKeys.EVAL: + estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss, + get_eval_metric_ops_fn) + else: # model_fn_lib.ModeKeys.TRAIN: + gopt = ( + generator_optimizer() + if callable(generator_optimizer) else generator_optimizer) + dopt = ( + discriminator_optimizer() + if callable(discriminator_optimizer) else discriminator_optimizer) + get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() + estimator_spec = _get_train_estimator_spec(gan_model, gan_loss, gopt, + dopt, get_hooks_fn) + + return estimator_spec + + +def _make_gan_model(generator_fn, discriminator_fn, input_data, + input_data_domain_label, generator_scope, add_summaries, + mode): + """Construct a `StarGANModel`, and optionally pass in `mode`.""" + # If network functions have an argument `mode`, pass mode to it. + if 'mode' in inspect.getargspec(generator_fn).args: + generator_fn = functools.partial(generator_fn, mode=mode) + if 'mode' in inspect.getargspec(discriminator_fn).args: + discriminator_fn = functools.partial(discriminator_fn, mode=mode) + gan_model = tfgan_train.stargan_model( + generator_fn, + discriminator_fn, + input_data, + input_data_domain_label, + generator_scope=generator_scope) + if add_summaries: + if not isinstance(add_summaries, (tuple, list)): + add_summaries = [add_summaries] + with ops.name_scope(None): + for summary_type in add_summaries: + _summary_type_map[summary_type](gan_model) + + return gan_model + + +def _make_prediction_gan_model(input_data, input_data_domain_label, + generator_fn, generator_scope): + """Make a `StarGANModel` from just the generator.""" + # If `generator_fn` has an argument `mode`, pass mode to it. + if 'mode' in inspect.getargspec(generator_fn).args: + generator_fn = functools.partial( + generator_fn, mode=model_fn_lib.ModeKeys.PREDICT) + with variable_scope.variable_scope(generator_scope) as gen_scope: + # pylint:disable=protected-access + input_data = tfgan_train._convert_tensor_or_l_or_d(input_data) + input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d( + input_data_domain_label) + # pylint:enable=protected-access + generated_data = generator_fn(input_data, input_data_domain_label) + generator_variables = variable_lib.get_trainable_variables(gen_scope) + + return tfgan_tuples.StarGANModel( + input_data=input_data, + input_data_domain_label=None, + generated_data=generated_data, + generated_data_domain_target=input_data_domain_label, + reconstructed_data=None, + discriminator_input_data_source_predication=None, + discriminator_generated_data_source_predication=None, + discriminator_input_data_domain_predication=None, + discriminator_generated_data_domain_predication=None, + generator_variables=generator_variables, + generator_scope=generator_scope, + generator_fn=generator_fn, + discriminator_variables=None, + discriminator_scope=None, + discriminator_fn=None) + + +def _get_eval_estimator_spec(gan_model, + gan_loss, + get_eval_metric_ops_fn=None, + name=None): + """Return an EstimatorSpec for the eval case.""" + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + with ops.name_scope(None, 'metrics', + [gan_loss.generator_loss, gan_loss.discriminator_loss]): + + def _summary_key(head_name, val): + return '%s/%s' % (val, head_name) if head_name else val + + eval_metric_ops = { + _summary_key(name, 'generator_loss'): + metrics_lib.mean(gan_loss.generator_loss), + _summary_key(name, 'discriminator_loss'): + metrics_lib.mean(gan_loss.discriminator_loss) + } + if get_eval_metric_ops_fn is not None: + custom_eval_metric_ops = get_eval_metric_ops_fn(gan_model) + if not isinstance(custom_eval_metric_ops, dict): + raise TypeError('get_eval_metric_ops_fn must return a dict, ' + 'received: {}'.format(custom_eval_metric_ops)) + eval_metric_ops.update(custom_eval_metric_ops) + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, + predictions=gan_model.generated_data, + loss=scalar_loss, + eval_metric_ops=eval_metric_ops) + + +def _get_train_estimator_spec(gan_model, + gan_loss, + generator_optimizer, + discriminator_optimizer, + get_hooks_fn, + train_op_fn=tfgan_train.gan_train_ops): + """Return an EstimatorSpec for the train case.""" + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer, + discriminator_optimizer) + training_hooks = get_hooks_fn(train_ops) + return model_fn_lib.EstimatorSpec( + loss=scalar_loss, + mode=model_fn_lib.ModeKeys.TRAIN, + train_op=train_ops.global_step_inc_op, + training_hooks=training_hooks) + + +def stargan_prediction_input_fn_wrapper(fn): + """StarGAN Estimator prediction input_fn wrapper. + + Since estimator will disregard the "label" variable pass to the model, we will + use a wrapper to pack the (feature, label) tuple as feature passed to the + model. + + Args: + fn: input_fn for the prediction. + + Returns: + A tuple ((feature, label), None) where the second element is the dummy label + to be disregarded and the first element is the true input to the estimator. + """ + + def new_fn(): + return fn(), None + + return new_fn diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2ec7938c7c4051842c7e982b54c1213b6e841b79 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py @@ -0,0 +1,306 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TFGAN's stargan_estimator.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +from absl.testing import parameterized +import numpy as np +import six + +from tensorflow.contrib import layers +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python.estimator.python import stargan_estimator_impl as estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import learning_rate_decay +from tensorflow.python.training import training +from tensorflow.python.training import training_util + + +def dummy_generator_fn(input_data, input_data_domain_label, mode): + del input_data_domain_label, mode + + return variable_scope.get_variable('dummy_g', initializer=0.5) * input_data + + +def dummy_discriminator_fn(input_data, num_domains, mode): + del mode + + hidden = layers.flatten(input_data) + output_src = math_ops.reduce_mean(hidden, axis=1) + output_cls = layers.fully_connected( + inputs=hidden, num_outputs=num_domains, scope='debug') + + return output_src, output_cls + + +class StarGetGANModelTest(test.TestCase, parameterized.TestCase): + """Tests that `StarGetGANModel` produces the correct model.""" + + @parameterized.named_parameters(('train', model_fn_lib.ModeKeys.TRAIN), + ('eval', model_fn_lib.ModeKeys.EVAL), + ('predict', model_fn_lib.ModeKeys.PREDICT)) + def test_get_gan_model(self, mode): + with ops.Graph().as_default(): + input_data = array_ops.ones([6, 4, 4, 3]) + input_data_domain_label = array_ops.one_hot([0] * 6, 5) + gan_model = estimator._get_gan_model( + mode, + dummy_generator_fn, + dummy_discriminator_fn, + input_data, + input_data_domain_label, + add_summaries=False) + + self.assertEqual(input_data, gan_model.input_data) + self.assertIsNotNone(gan_model.generated_data) + self.assertIsNotNone(gan_model.generated_data_domain_target) + self.assertEqual(1, len(gan_model.generator_variables)) + self.assertIsNotNone(gan_model.generator_scope) + self.assertIsNotNone(gan_model.generator_fn) + if mode == model_fn_lib.ModeKeys.PREDICT: + self.assertIsNone(gan_model.input_data_domain_label) + self.assertEqual(input_data_domain_label, + gan_model.generated_data_domain_target) + self.assertIsNone(gan_model.reconstructed_data) + self.assertIsNone(gan_model.discriminator_input_data_source_predication) + self.assertIsNone( + gan_model.discriminator_generated_data_source_predication) + self.assertIsNone(gan_model.discriminator_input_data_domain_predication) + self.assertIsNone( + gan_model.discriminator_generated_data_domain_predication) + self.assertIsNone(gan_model.discriminator_variables) + self.assertIsNone(gan_model.discriminator_scope) + self.assertIsNone(gan_model.discriminator_fn) + else: + self.assertEqual(input_data_domain_label, + gan_model.input_data_domain_label) + self.assertIsNotNone(gan_model.reconstructed_data.shape) + self.assertIsNotNone( + gan_model.discriminator_input_data_source_predication) + self.assertIsNotNone( + gan_model.discriminator_generated_data_source_predication) + self.assertIsNotNone( + gan_model.discriminator_input_data_domain_predication) + self.assertIsNotNone( + gan_model.discriminator_generated_data_domain_predication) + self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer + self.assertIsNotNone(gan_model.discriminator_scope) + self.assertIsNotNone(gan_model.discriminator_fn) + + +def get_dummy_gan_model(): + """Similar to get_gan_model().""" + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) + with variable_scope.variable_scope('discriminator') as dis_scope: + dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) + return tfgan_tuples.StarGANModel( + input_data=array_ops.ones([1, 2, 2, 3]), + input_data_domain_label=array_ops.ones([1, 2]), + generated_data=array_ops.ones([1, 2, 2, 3]), + generated_data_domain_target=array_ops.ones([1, 2]), + reconstructed_data=array_ops.ones([1, 2, 2, 3]), + discriminator_input_data_source_predication=array_ops.ones([1]) * dis_var, + discriminator_generated_data_source_predication=array_ops.ones( + [1]) * gen_var * dis_var, + discriminator_input_data_domain_predication=array_ops.ones([1, 2 + ]) * dis_var, + discriminator_generated_data_domain_predication=array_ops.ones([1, 2]) * + gen_var * dis_var, + generator_variables=[gen_var], + generator_scope=gen_scope, + generator_fn=None, + discriminator_variables=[dis_var], + discriminator_scope=dis_scope, + discriminator_fn=None) + + +def dummy_loss_fn(gan_model): + loss = math_ops.reduce_sum( + gan_model.discriminator_input_data_domain_predication - + gan_model.discriminator_generated_data_domain_predication) + loss += math_ops.reduce_sum(gan_model.input_data - gan_model.generated_data) + return tfgan_tuples.GANLoss(loss, loss) + + +def get_metrics(gan_model): + return { + 'mse_custom_metric': + metrics_lib.mean_squared_error(gan_model.input_data, + gan_model.generated_data) + } + + +class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): + """Tests that the EstimatorSpec is constructed appropriately.""" + + @classmethod + def setUpClass(cls): + cls._generator_optimizer = training.GradientDescentOptimizer(1.0) + cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0) + + @parameterized.named_parameters(('train', model_fn_lib.ModeKeys.TRAIN), + ('eval', model_fn_lib.ModeKeys.EVAL), + ('predict', model_fn_lib.ModeKeys.PREDICT)) + def test_get_estimator_spec(self, mode): + with ops.Graph().as_default(): + self._gan_model = get_dummy_gan_model() + spec = estimator._get_estimator_spec( + mode, + self._gan_model, + loss_fn=dummy_loss_fn, + get_eval_metric_ops_fn=get_metrics, + generator_optimizer=self._generator_optimizer, + discriminator_optimizer=self._discriminator_optimizer) + + self.assertEqual(mode, spec.mode) + if mode == model_fn_lib.ModeKeys.PREDICT: + self.assertEqual(self._gan_model.generated_data, spec.predictions) + elif mode == model_fn_lib.ModeKeys.TRAIN: + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.train_op) + self.assertIsNotNone(spec.training_hooks) + elif mode == model_fn_lib.ModeKeys.EVAL: + self.assertEqual(self._gan_model.generated_data, spec.predictions) + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.eval_metric_ops) + + +# TODO(joelshor): Add pandas test. +class StarGANEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow(self, + train_input_fn, + eval_input_fn, + predict_input_fn, + prediction_size, + lr_decay=False): + + def make_opt(): + gstep = training_util.get_or_create_global_step() + lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) + return training.GradientDescentOptimizer(lr) + + gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + est = estimator.StarGANEstimator( + generator_fn=dummy_generator_fn, + discriminator_fn=dummy_discriminator_fn, + loss_fn=dummy_loss_fn, + generator_optimizer=gopt, + discriminator_optimizer=dopt, + get_eval_metric_ops_fn=get_metrics, + model_dir=self._model_dir) + + # TRAIN + num_steps = 10 + est.train(train_input_fn, steps=num_steps) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], + scores['loss']) + self.assertIn('mse_custom_metric', six.iterkeys(scores)) + + # PREDICT + predictions = np.array([x for x in est.predict(predict_input_fn)]) + + self.assertAllEqual(prediction_size, predictions.shape) + + @staticmethod + def _numpy_input_fn_wrapper(numpy_input_fn, batch_size, label_size): + """Wrapper to remove the dictionary in numpy_input_fn. + + NOTE: + We create the domain_label here because the model expect a fully define + batch_size from the input. + + Args: + numpy_input_fn: input_fn created from numpy_io + batch_size: (int) number of items for each batch + label_size: (int) number of domains + + Returns: + a new input_fn + """ + + def new_input_fn(): + features = numpy_input_fn() + return features['x'], array_ops.one_hot([0] * batch_size, label_size) + + return new_input_fn + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + batch_size = 5 + img_size = 8 + channel_size = 3 + label_size = 3 + image_data = np.zeros( + [batch_size, img_size, img_size, channel_size], dtype=np.float32) + train_input_fn = numpy_io.numpy_input_fn( + x={'x': image_data}, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': image_data}, batch_size=batch_size, shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': image_data}, shuffle=False) + + train_input_fn = self._numpy_input_fn_wrapper(train_input_fn, batch_size, + label_size) + eval_input_fn = self._numpy_input_fn_wrapper(eval_input_fn, batch_size, + label_size) + predict_input_fn = self._numpy_input_fn_wrapper(predict_input_fn, + batch_size, label_size) + + predict_input_fn = estimator.stargan_prediction_input_fn_wrapper( + predict_input_fn) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + prediction_size=[batch_size, img_size, img_size, channel_size]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index 7e6a0f14f6f5e467801fef39ebb597565b3d7e98..726f74c7b7addbd6c048d0b05f5695a77deb53b2 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -186,22 +186,22 @@ class GdrMemoryManager : public RemoteMemoryManager { // TODO(byronyi): remove this class and its registration when the default // cpu_allocator() returns visitable allocator, or cpu_allocator() is no // longer in use. -class BFCRdmaAllocator : public BFCAllocator { +class BFCGdrAllocator : public BFCAllocator { public: - BFCRdmaAllocator() + BFCGdrAllocator() : BFCAllocator(new BasicCPUAllocator(port::kNUMANoAffinity), 1LL << 36, - true, "cpu_rdma_bfc") {} + true, "cpu_gdr_bfc") {} }; -class BFCRdmaAllocatorFactory : public AllocatorFactory { +class BFCGdrAllocatorFactory : public AllocatorFactory { public: - Allocator* CreateAllocator() override { return new BFCRdmaAllocator; } + Allocator* CreateAllocator() override { return new BFCGdrAllocator; } virtual SubAllocator* CreateSubAllocator(int numa_node) { return new BasicCPUAllocator(numa_node); } }; -REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocatorFactory); +REGISTER_MEM_ALLOCATOR("BFCGdrAllocator", 102, BFCGdrAllocatorFactory); GdrMemoryManager::GdrMemoryManager(const string& host, const string& port) : host_(host), diff --git a/tensorflow/contrib/kfac/BUILD b/tensorflow/contrib/kfac/BUILD deleted file mode 100644 index b719046b37ac761d56e8d5aa34772103be691cd6..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -# Description: -# Contains KfacOptimizer, an implementation of the K-FAC optimization -# algorithm in TensorFlow. -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -py_library( - name = "kfac", - srcs = ["__init__.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:curvature_matrix_vector_products_lib", - "//tensorflow/contrib/kfac/python/ops:fisher_blocks_lib", - "//tensorflow/contrib/kfac/python/ops:fisher_estimator_lib", - "//tensorflow/contrib/kfac/python/ops:fisher_factors_lib", - "//tensorflow/contrib/kfac/python/ops:kfac_optimizer_lib", - "//tensorflow/contrib/kfac/python/ops:layer_collection_lib", - "//tensorflow/contrib/kfac/python/ops:loss_functions_lib", - "//tensorflow/contrib/kfac/python/ops:op_queue_lib", - "//tensorflow/contrib/kfac/python/ops:utils_lib", - "//tensorflow/python:util", - ], -) diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md index 102626925db560e47cdc73eb1e25e08836cb4fba..42b91d031375b8edb7e4f364ac91ffb74ef1f54b 100644 --- a/tensorflow/contrib/kfac/README.md +++ b/tensorflow/contrib/kfac/README.md @@ -1,94 +1,3 @@ # K-FAC: Kronecker-Factored Approximate Curvature -# WARNING: -# ==third_party/tensorflow/contrib/kfac is deprecated. This will be== -# ==removed on 15-07-2018. Please import third_party/tensorflow_kfac.== -# ==== - -**K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an -approximate second-order optimization method, in TensorFlow. When applied to -feedforward and convolutional neural networks, K-FAC can converge `>3.5x` -faster in `>14x` fewer iterations than SGD with Momentum. - -[kfac-paper]: https://arxiv.org/abs/1503.05671 - -## What is K-FAC? - -K-FAC, short for "Kronecker-factored Approximate Curvature", is an approximation -to the [Natural Gradient][natural_gradient] algorithm designed specifically for -neural networks. It maintains a block-diagonal approximation to the [Fisher -Information matrix][fisher_information], whose inverse preconditions the -gradient. - -K-FAC can be used in place of SGD, Adam, and other `Optimizer` implementations. -Experimentally, K-FAC converges `>3.5x` faster than well-tuned SGD. - -Unlike most optimizers, K-FAC exploits structure in the model itself (e.g. "What -are the weights for layer i?"). As such, you must add some additional code while -constructing your model to use K-FAC. - -[natural_gradient]: http://www.mitpressjournals.org/doi/abs/10.1162/089976698300017746 -[fisher_information]: https://en.wikipedia.org/wiki/Fisher_information#Matrix_form - -## Why should I use K-FAC? - -K-FAC can take advantage of the curvature of the optimization problem, resulting -in **faster training**. For an 8-layer Autoencoder, K-FAC converges to the same -loss as SGD with Momentum in 3.8x fewer seconds and 14.7x fewer updates. See how -training loss changes as a function of number of epochs, steps, and seconds: - -![autoencoder](g3doc/autoencoder.png) - -## Is K-FAC for me? - -If you have a feedforward or convolutional model for classification that is -converging too slowly, K-FAC is for you. K-FAC can be used in your model if: - -* Your model defines a posterior distribution. -* Your model uses only fully-connected or convolutional layers (residual - connections OK). -* You are training on CPU or GPU. -* You can modify model code to register layers with K-FAC. - -## How do I use K-FAC? - -Using K-FAC requires three steps: - -1. Registering layer inputs, weights, and pre-activations with a - `LayerCollection`. -1. Minimizing the loss with a `KfacOptimizer`. -1. Keeping K-FAC's preconditioner updated. - -```python -# Build model. -w = tf.get_variable("w", ...) -b = tf.get_variable("b", ...) -logits = tf.matmul(x, w) + b -loss = tf.reduce_mean( - tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits)) - -# Register layers. -layer_collection = LayerCollection() -layer_collection.register_fully_connected((w, b), x, logits) -layer_collection.register_categorical_predictive_distribution(logits) - -# Construct training ops. -optimizer = KfacOptimizer(..., layer_collection=layer_collection) -train_op = optimizer.minimize(loss) - -# Minimize loss. -with tf.Session() as sess: - ... - sess.run([train_op, optimizer.cov_update_op, optimizer.inv_update_op]) -``` - -See [`examples/`](https://www.tensorflow.org/code/tensorflow/contrib/kfac/examples/) for runnable, end-to-end illustrations. - -## Authors - -- Alok Aggarwal -- Daniel Duckworth -- James Martens -- Matthew Johnson -- Olga Wichrowska -- Roger Grosse +## KFAC moved to third_party/tensorflow_kfac. diff --git a/tensorflow/contrib/kfac/__init__.py b/tensorflow/contrib/kfac/__init__.py deleted file mode 100644 index 1ea354e6cdf3e78eaca1f3e5dff174ed489c752e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/__init__.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Kronecker-factored Approximate Curvature Optimizer.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long -from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products_lib as curvature_matrix_vector_products -from tensorflow.contrib.kfac.python.ops import estimator_lib as estimator -from tensorflow.contrib.kfac.python.ops import fisher_blocks_lib as fisher_blocks -from tensorflow.contrib.kfac.python.ops import fisher_factors_lib as fisher_factors -from tensorflow.contrib.kfac.python.ops import layer_collection_lib as layer_collection -from tensorflow.contrib.kfac.python.ops import loss_functions_lib as loss_functions -from tensorflow.contrib.kfac.python.ops import op_queue_lib as op_queue -from tensorflow.contrib.kfac.python.ops import optimizer_lib as optimizer -from tensorflow.contrib.kfac.python.ops import utils_lib as utils -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long - -_allowed_symbols = [ - "curvature_matrix_vector_products", - "estimator", - "fisher_blocks", - "fisher_factors", - "layer_collection", - "loss_functions", - "op_queue", - "optimizer", - "utils", -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD deleted file mode 100644 index 8186fa1c62cb952f86614a96c3965bcddae1686e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/examples/BUILD +++ /dev/null @@ -1,80 +0,0 @@ -package(default_visibility = [ - "//learning/brain/contrib/kfac/examples:__subpackages__", - "//tensorflow/contrib/kfac/examples:__subpackages__", -]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -py_binary( - name = "mlp_mnist_main", - srcs = ["mlp_mnist_main.py"], - srcs_version = "PY2AND3", - deps = [ - ":mlp", - "//tensorflow:tensorflow_py", - ], -) - -py_library( - name = "mlp", - srcs = ["mlp.py"], - srcs_version = "PY2AND3", - deps = [ - ":mnist", - "//tensorflow:tensorflow_py", - ], -) - -py_binary( - name = "convnet_mnist_single_main", - srcs = ["convnet_mnist_single_main.py"], - srcs_version = "PY2AND3", - deps = [ - ":convnet", - "//tensorflow:tensorflow_py", - ], -) - -py_binary( - name = "convnet_mnist_multi_tower_main", - srcs = ["convnet_mnist_multi_tower_main.py"], - srcs_version = "PY2AND3", - deps = [ - ":convnet", - "//tensorflow:tensorflow_py", - ], -) - -py_binary( - name = "convnet_mnist_distributed_main", - srcs = ["convnet_mnist_distributed_main.py"], - srcs_version = "PY2AND3", - deps = [ - ":convnet", - "//tensorflow:tensorflow_py", - ], -) - -py_library( - name = "convnet", - srcs = ["convnet.py"], - srcs_version = "PY2AND3", - deps = [ - ":mlp", - ":mnist", - "//tensorflow:tensorflow_py", - "//third_party/py/numpy", - ], -) - -py_library( - name = "mnist", - srcs = ["mnist.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py", - "//third_party/py/numpy", - ], -) diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py deleted file mode 100644 index 44e01e1aebf80e83fa0f84d9cd8ed9e9ea2526f5..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/examples/convnet.py +++ /dev/null @@ -1,667 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -r"""Train a ConvNet on MNIST using K-FAC. - -This library fits a 5-layer ConvNet on MNIST using K-FAC. The model has the -following structure, - -- Conv Layer: 5x5 kernel, 16 output channels. -- Max Pool: 3x3 kernel, stride 2. -- Conv Layer: 5x5 kernel, 16 output channels. -- Max Pool: 3x3 kernel, stride 2. -- Linear: 10 output dims. - -After 3k~6k steps, this should reach perfect accuracy on the training set. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -import numpy as np -import tensorflow as tf - -from tensorflow.contrib.kfac.examples import mlp -from tensorflow.contrib.kfac.examples import mnist -from tensorflow.contrib.kfac.python.ops import optimizer as opt - - -lc = tf.contrib.kfac.layer_collection -oq = tf.contrib.kfac.op_queue -opt = tf.contrib.kfac.optimizer - -__all__ = [ - "conv_layer", - "max_pool_layer", - "linear_layer", - "build_model", - "minimize_loss_single_machine", - "distributed_grads_only_and_ops_chief_worker", - "distributed_grads_and_ops_dedicated_workers", - "train_mnist_single_machine", - "train_mnist_distributed_sync_replicas", - "train_mnist_multitower" -] - - -# Inverse update ops will be run every _INVERT_EVRY iterations. -_INVERT_EVERY = 10 - - -def conv_layer(layer_id, inputs, kernel_size, out_channels): - """Builds a convolutional layer with ReLU non-linearity. - - Args: - layer_id: int. Integer ID for this layer's variables. - inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row - corresponds to a single example. - kernel_size: int. Width and height of the convolution kernel. The kernel is - assumed to be square. - out_channels: int. Number of output features per pixel. - - Returns: - preactivations: Tensor of shape [num_examples, width, height, out_channels]. - Values of the layer immediately before the activation function. - activations: Tensor of shape [num_examples, width, height, out_channels]. - Values of the layer immediately after the activation function. - params: Tuple of (kernel, bias), parameters for this layer. - """ - # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. - layer = tf.layers.Conv2D( - out_channels, - kernel_size=[kernel_size, kernel_size], - kernel_initializer=tf.random_normal_initializer(stddev=0.01), - padding="SAME", - name="conv_%d" % layer_id) - preactivations = layer(inputs) - activations = tf.nn.relu(preactivations) - - # layer.weights is a list. This converts it a (hashable) tuple. - return preactivations, activations, (layer.kernel, layer.bias) - - -def max_pool_layer(layer_id, inputs, kernel_size, stride): - """Build a max-pooling layer. - - Args: - layer_id: int. Integer ID for this layer's variables. - inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row - corresponds to a single example. - kernel_size: int. Width and height to pool over per input channel. The - kernel is assumed to be square. - stride: int. Step size between pooling operations. - - Returns: - Tensor of shape [num_examples, width/stride, height/stride, out_channels]. - Result of applying max pooling to 'inputs'. - """ - # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. - with tf.variable_scope("pool_%d" % layer_id): - return tf.nn.max_pool( - inputs, [1, kernel_size, kernel_size, 1], [1, stride, stride, 1], - padding="SAME", - name="pool") - - -def linear_layer(layer_id, inputs, output_size): - """Builds the final linear layer for an MNIST classification problem. - - Args: - layer_id: int. Integer ID for this layer's variables. - inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row - corresponds to a single example. - output_size: int. Number of output dims per example. - - Returns: - activations: Tensor of shape [num_examples, output_size]. Values of the - layer immediately after the activation function. - params: Tuple of (weights, bias), parameters for this layer. - """ - # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. - pre, _, params = mlp.fc_layer(layer_id, inputs, output_size) - return pre, params - - -def build_model(examples, labels, num_labels, layer_collection): - """Builds a ConvNet classification model. - - Args: - examples: Tensor of shape [num_examples, num_features]. Represents inputs of - model. - labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted - by softmax for each example. - num_labels: int. Number of distinct values 'labels' can take on. - layer_collection: LayerCollection instance. Layers will be registered here. - - Returns: - loss: 0-D Tensor representing loss to be minimized. - accuracy: 0-D Tensor representing model's accuracy. - """ - # Build a ConvNet. For each layer with parameters, we'll keep track of the - # preactivations, activations, weights, and bias. - tf.logging.info("Building model.") - pre0, act0, params0 = conv_layer( - layer_id=0, inputs=examples, kernel_size=5, out_channels=16) - act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2) - pre2, act2, params2 = conv_layer( - layer_id=2, inputs=act1, kernel_size=5, out_channels=16) - act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2) - flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))]) - logits, params4 = linear_layer( - layer_id=4, inputs=flat_act3, output_size=num_labels) - loss = tf.reduce_mean( - tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=labels, logits=logits)) - accuracy = tf.reduce_mean( - tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32)) - - with tf.device("/cpu:0"): - tf.summary.scalar("loss", loss) - tf.summary.scalar("accuracy", accuracy) - - # Register parameters. K-FAC needs to know about the inputs, outputs, and - # parameters of each conv/fully connected layer and the logits powering the - # posterior probability over classes. - tf.logging.info("Building LayerCollection.") - layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples, - pre0) - layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2) - layer_collection.register_fully_connected(params4, flat_act3, logits) - layer_collection.register_categorical_predictive_distribution( - logits, name="logits") - - return loss, accuracy - - -def minimize_loss_single_machine(loss, - accuracy, - layer_collection, - device="/gpu:0", - session_config=None): - """Minimize loss with K-FAC on a single machine. - - A single Session is responsible for running all of K-FAC's ops. The covariance - and inverse update ops are placed on `device`. All model variables are on CPU. - - Args: - loss: 0-D Tensor. Loss to be minimized. - accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. - layer_collection: LayerCollection instance describing model architecture. - Used by K-FAC to construct preconditioner. - device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse - update ops are run on this device. - session_config: None or tf.ConfigProto. Configuration for tf.Session(). - - Returns: - final value for 'accuracy'. - """ - # Train with K-FAC. - g_step = tf.train.get_or_create_global_step() - optimizer = opt.KfacOptimizer( - learning_rate=0.0001, - cov_ema_decay=0.95, - damping=0.001, - layer_collection=layer_collection, - placement_strategy="round_robin", - cov_devices=[device], - inv_devices=[device], - momentum=0.9) - (cov_update_thunks, - inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - - def make_update_op(update_thunks): - update_ops = [thunk() for thunk in update_thunks] - return tf.group(*update_ops) - - cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([cov_update_op]): - inverse_op = tf.cond( - tf.equal(tf.mod(g_step, _INVERT_EVERY), 0), - lambda: make_update_op(inv_update_thunks), tf.no_op) - with tf.control_dependencies([inverse_op]): - with tf.device(device): - train_op = optimizer.minimize(loss, global_step=g_step) - - tf.logging.info("Starting training.") - with tf.train.MonitoredTrainingSession(config=session_config) as sess: - while not sess.should_stop(): - global_step_, loss_, accuracy_, _ = sess.run( - [g_step, loss, accuracy, train_op]) - - if global_step_ % _INVERT_EVERY == 0: - tf.logging.info("global_step: %d | loss: %f | accuracy: %s", - global_step_, loss_, accuracy_) - - return accuracy_ - - -def _is_gradient_task(task_id, num_tasks): - """Returns True if this task should update the weights.""" - if num_tasks < 3: - return True - return 0 <= task_id < 0.6 * num_tasks - - -def _is_cov_update_task(task_id, num_tasks): - """Returns True if this task should update K-FAC's covariance matrices.""" - if num_tasks < 3: - return False - return 0.6 * num_tasks <= task_id < num_tasks - 1 - - -def _is_inv_update_task(task_id, num_tasks): - """Returns True if this task should update K-FAC's preconditioner.""" - if num_tasks < 3: - return False - return task_id == num_tasks - 1 - - -def _num_gradient_tasks(num_tasks): - """Number of tasks that will update weights.""" - if num_tasks < 3: - return num_tasks - return int(np.ceil(0.6 * num_tasks)) - - -def _make_distributed_train_op( - task_id, - num_worker_tasks, - num_ps_tasks, - layer_collection -): - """Creates optimizer and distributed training op. - - Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes - the train op. - - Args: - task_id: int. Integer in [0, num_worker_tasks). ID for this worker. - num_worker_tasks: int. Number of workers in this distributed training setup. - num_ps_tasks: int. Number of parameter servers holding variables. If 0, - parameter servers are not used. - layer_collection: LayerCollection instance describing model architecture. - Used by K-FAC to construct preconditioner. - - Returns: - sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC - optimizer. - optimizer: Instance of `opt.KfacOptimizer`. - global_step: `tensor`, Global step. - """ - tf.logging.info("Task id : %d", task_id) - with tf.device(tf.train.replica_device_setter(num_ps_tasks)): - global_step = tf.train.get_or_create_global_step() - optimizer = opt.KfacOptimizer( - learning_rate=0.0001, - cov_ema_decay=0.95, - damping=0.001, - layer_collection=layer_collection, - momentum=0.9) - sync_optimizer = tf.train.SyncReplicasOptimizer( - opt=optimizer, - replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks), - total_num_replicas=num_worker_tasks) - return sync_optimizer, optimizer, global_step - - -def distributed_grads_only_and_ops_chief_worker( - task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, - loss, accuracy, layer_collection, invert_every=10): - """Minimize loss with a synchronous implementation of K-FAC. - - All workers perform gradient computation. Chief worker applies gradient after - averaging the gradients obtained from all the workers. All workers block - execution until the update is applied. Chief worker runs covariance and - inverse update ops. Covariance and inverse matrices are placed on parameter - servers in a round robin manner. For further details on synchronous - distributed optimization check `tf.train.SyncReplicasOptimizer`. - - Args: - task_id: int. Integer in [0, num_worker_tasks). ID for this worker. - is_chief: `boolean`, `True` if the worker is chief worker. - num_worker_tasks: int. Number of workers in this distributed training setup. - num_ps_tasks: int. Number of parameter servers holding variables. If 0, - parameter servers are not used. - master: string. IP and port of TensorFlow runtime process. Set to empty - string to run locally. - checkpoint_dir: string or None. Path to store checkpoints under. - loss: 0-D Tensor. Loss to be minimized. - accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to - run with each step. - layer_collection: LayerCollection instance describing model architecture. - Used by K-FAC to construct preconditioner. - invert_every: `int`, Number of steps between update the inverse. - - Returns: - final value for 'accuracy'. - - Raises: - ValueError: if task_id >= num_worker_tasks. - """ - - sync_optimizer, optimizer, global_step = _make_distributed_train_op( - task_id, num_worker_tasks, num_ps_tasks, layer_collection) - (cov_update_thunks, - inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - - tf.logging.info("Starting training.") - hooks = [sync_optimizer.make_session_run_hook(is_chief)] - - def make_update_op(update_thunks): - update_ops = [thunk() for thunk in update_thunks] - return tf.group(*update_ops) - - if is_chief: - cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([cov_update_op]): - inverse_op = tf.cond( - tf.equal(tf.mod(global_step, invert_every), 0), - lambda: make_update_op(inv_update_thunks), - tf.no_op) - with tf.control_dependencies([inverse_op]): - train_op = sync_optimizer.minimize(loss, global_step=global_step) - else: - train_op = sync_optimizer.minimize(loss, global_step=global_step) - - with tf.train.MonitoredTrainingSession( - master=master, - is_chief=is_chief, - checkpoint_dir=checkpoint_dir, - hooks=hooks, - stop_grace_period_secs=0) as sess: - while not sess.should_stop(): - global_step_, loss_, accuracy_, _ = sess.run( - [global_step, loss, accuracy, train_op]) - tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, - loss_, accuracy_) - return accuracy_ - - -def distributed_grads_and_ops_dedicated_workers( - task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir, - loss, accuracy, layer_collection): - """Minimize loss with a synchronous implementation of K-FAC. - - Different workers are responsible for different parts of K-FAC's Ops. The - first 60% of tasks compute gradients; the next 20% accumulate covariance - statistics; the last 20% invert the matrices used to precondition gradients. - The chief worker applies the gradient . - - Args: - task_id: int. Integer in [0, num_worker_tasks). ID for this worker. - is_chief: `boolean`, `True` if the worker is chief worker. - num_worker_tasks: int. Number of workers in this distributed training setup. - num_ps_tasks: int. Number of parameter servers holding variables. If 0, - parameter servers are not used. - master: string. IP and port of TensorFlow runtime process. Set to empty - string to run locally. - checkpoint_dir: string or None. Path to store checkpoints under. - loss: 0-D Tensor. Loss to be minimized. - accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to - run with each step. - layer_collection: LayerCollection instance describing model architecture. - Used by K-FAC to construct preconditioner. - - Returns: - final value for 'accuracy'. - - Raises: - ValueError: if task_id >= num_worker_tasks. - """ - sync_optimizer, optimizer, global_step = _make_distributed_train_op( - task_id, num_worker_tasks, num_ps_tasks, layer_collection) - _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars() - train_op = sync_optimizer.minimize(loss, global_step=global_step) - inv_update_queue = oq.OpQueue(inv_update_ops) - - tf.logging.info("Starting training.") - is_chief = (task_id == 0) - hooks = [sync_optimizer.make_session_run_hook(is_chief)] - with tf.train.MonitoredTrainingSession( - master=master, - is_chief=is_chief, - checkpoint_dir=checkpoint_dir, - hooks=hooks, - stop_grace_period_secs=0) as sess: - while not sess.should_stop(): - # Choose which op this task is responsible for running. - if _is_gradient_task(task_id, num_worker_tasks): - learning_op = train_op - elif _is_cov_update_task(task_id, num_worker_tasks): - learning_op = cov_update_op - elif _is_inv_update_task(task_id, num_worker_tasks): - # TODO(duckworthd): Running this op before cov_update_op has been run a - # few times can result in "InvalidArgumentError: Cholesky decomposition - # was not successful." Delay running this op until cov_update_op has - # been run a few times. - learning_op = inv_update_queue.next_op(sess) - else: - raise ValueError("Which op should task %d do?" % task_id) - - global_step_, loss_, accuracy_, _ = sess.run( - [global_step, loss, accuracy, learning_op]) - tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, - loss_, accuracy_) - - return accuracy_ - - -def train_mnist_single_machine(data_dir, - num_epochs, - use_fake_data=False, - device="/gpu:0"): - """Train a ConvNet on MNIST. - - Args: - data_dir: string. Directory to read MNIST examples from. - num_epochs: int. Number of passes to make over the training set. - use_fake_data: bool. If True, generate a synthetic dataset. - device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse - update ops are run on this device. - - Returns: - accuracy of model on the final minibatch of training data. - """ - # Load a dataset. - tf.logging.info("Loading MNIST into memory.") - examples, labels = mnist.load_mnist( - data_dir, - num_epochs=num_epochs, - batch_size=128, - use_fake_data=use_fake_data, - flatten_images=False) - - # Build a ConvNet. - layer_collection = lc.LayerCollection() - loss, accuracy = build_model( - examples, labels, num_labels=10, layer_collection=layer_collection) - - # Fit model. - return minimize_loss_single_machine( - loss, accuracy, layer_collection, device=device) - - -def train_mnist_multitower(data_dir, num_epochs, num_towers, - use_fake_data=True, devices=None): - """Train a ConvNet on MNIST. - - Training data is split equally among the towers. Each tower computes loss on - its own batch of data and the loss is aggregated on the CPU. The model - variables are placed on first tower. The covariance and inverse update ops - and variables are placed on GPUs in a round robin manner. - - Args: - data_dir: string. Directory to read MNIST examples from. - num_epochs: int. Number of passes to make over the training set. - num_towers: int. Number of CPUs to split inference across. - use_fake_data: bool. If True, generate a synthetic dataset. - devices: string, Either list of CPU or GPU. The covariance and inverse - update ops are run on this device. - - Returns: - accuracy of model on the final minibatch of training data. - """ - if devices: - device_count = {"GPU": num_towers} - else: - device_count = {"CPU": num_towers} - - devices = devices or [ - "/cpu:{}".format(tower_id) for tower_id in range(num_towers) - ] - # Load a dataset. - tf.logging.info("Loading MNIST into memory.") - tower_batch_size = 128 - batch_size = tower_batch_size * num_towers - tf.logging.info( - ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d " - "tower batch size.") % (batch_size, num_towers, tower_batch_size)) - examples, labels = mnist.load_mnist( - data_dir, - num_epochs=num_epochs, - batch_size=batch_size, - use_fake_data=use_fake_data, - flatten_images=False) - - # Split minibatch across towers. - examples = tf.split(examples, num_towers) - labels = tf.split(labels, num_towers) - - # Build an MLP. Each tower's layers will be added to the LayerCollection. - layer_collection = lc.LayerCollection() - tower_results = [] - for tower_id in range(num_towers): - with tf.device(devices[tower_id]): - with tf.name_scope("tower%d" % tower_id): - with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): - tf.logging.info("Building tower %d." % tower_id) - tower_results.append( - build_model(examples[tower_id], labels[tower_id], 10, - layer_collection)) - losses, accuracies = zip(*tower_results) - - # Average across towers. - loss = tf.reduce_mean(losses) - accuracy = tf.reduce_mean(accuracies) - - # Fit model. - - session_config = tf.ConfigProto( - allow_soft_placement=False, - device_count=device_count, - ) - - g_step = tf.train.get_or_create_global_step() - optimizer = opt.KfacOptimizer( - learning_rate=0.0001, - cov_ema_decay=0.95, - damping=0.001, - layer_collection=layer_collection, - placement_strategy="round_robin", - cov_devices=devices, - inv_devices=devices, - momentum=0.9) - (cov_update_thunks, - inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - - def make_update_op(update_thunks): - update_ops = [thunk() for thunk in update_thunks] - return tf.group(*update_ops) - - cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([cov_update_op]): - inverse_op = tf.cond( - tf.equal(tf.mod(g_step, _INVERT_EVERY), 0), - lambda: make_update_op(inv_update_thunks), tf.no_op) - with tf.control_dependencies([inverse_op]): - train_op = optimizer.minimize(loss, global_step=g_step) - - tf.logging.info("Starting training.") - with tf.train.MonitoredTrainingSession(config=session_config) as sess: - while not sess.should_stop(): - global_step_, loss_, accuracy_, _ = sess.run( - [g_step, loss, accuracy, train_op]) - - if global_step_ % _INVERT_EVERY == 0: - tf.logging.info("global_step: %d | loss: %f | accuracy: %s", - global_step_, loss_, accuracy_) - - -def train_mnist_distributed_sync_replicas(task_id, - is_chief, - num_worker_tasks, - num_ps_tasks, - master, - data_dir, - num_epochs, - op_strategy, - use_fake_data=False): - """Train a ConvNet on MNIST using Sync replicas optimizer. - - Args: - task_id: int. Integer in [0, num_worker_tasks). ID for this worker. - is_chief: `boolean`, `True` if the worker is chief worker. - num_worker_tasks: int. Number of workers in this distributed training setup. - num_ps_tasks: int. Number of parameter servers holding variables. - master: string. IP and port of TensorFlow runtime process. - data_dir: string. Directory to read MNIST examples from. - num_epochs: int. Number of passes to make over the training set. - op_strategy: `string`, Strategy to run the covariance and inverse - ops. If op_strategy == `chief_worker` then covariance and inverse - update ops are run on chief worker otherwise they are run on dedicated - workers. - - use_fake_data: bool. If True, generate a synthetic dataset. - - Returns: - accuracy of model on the final minibatch of training data. - - Raises: - ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"]. - """ - # Load a dataset. - tf.logging.info("Loading MNIST into memory.") - examples, labels = mnist.load_mnist( - data_dir, - num_epochs=num_epochs, - batch_size=128, - use_fake_data=use_fake_data, - flatten_images=False) - - # Build a ConvNet. - layer_collection = lc.LayerCollection() - with tf.device(tf.train.replica_device_setter(num_ps_tasks)): - loss, accuracy = build_model( - examples, labels, num_labels=10, layer_collection=layer_collection) - - # Fit model. - checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac") - if op_strategy == "chief_worker": - return distributed_grads_only_and_ops_chief_worker( - task_id, is_chief, num_worker_tasks, num_ps_tasks, master, - checkpoint_dir, loss, accuracy, layer_collection) - elif op_strategy == "dedicated_workers": - return distributed_grads_and_ops_dedicated_workers( - task_id, is_chief, num_worker_tasks, num_ps_tasks, master, - checkpoint_dir, loss, accuracy, layer_collection) - else: - raise ValueError("Only supported op strategies are : {}, {}".format( - "chief_worker", "dedicated_workers")) - - -if __name__ == "__main__": - tf.app.run() diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py deleted file mode 100644 index b4c2d4a9e9bfcc4bfb55a25d2f23e66afe5b1375..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -r"""Train a ConvNet on MNIST using K-FAC. - -Distributed training with sync replicas optimizer. See -`convnet.train_mnist_distributed_sync_replicas` for details. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -from absl import flags -import tensorflow as tf - -from tensorflow.contrib.kfac.examples import convnet - -FLAGS = flags.FLAGS -flags.DEFINE_integer("task", -1, "Task identifier") -flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir") -flags.DEFINE_string( - "cov_inv_op_strategy", "chief_worker", - "In dist training mode run the cov, inv ops on chief or dedicated workers." -) -flags.DEFINE_string("master", "local", "Session master.") -flags.DEFINE_integer("ps_tasks", 2, - "Number of tasks in the parameter server job.") -flags.DEFINE_integer("replicas_to_aggregate", 5, - "Number of replicas to aggregate.") -flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.") -flags.DEFINE_integer("num_epochs", None, "Number of epochs.") - - -def _is_chief(): - """Determines whether a job is the chief worker.""" - if "chief_worker" in FLAGS.brain_jobs: - return FLAGS.brain_job_name == "chief_worker" - else: - return FLAGS.task == 0 - - -def main(unused_argv): - _ = unused_argv - convnet.train_mnist_distributed_sync_replicas( - FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks, - FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy) - -if __name__ == "__main__": - tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py deleted file mode 100644 index 4249bf8a8d9d3a5beb87d4140a55b0ee6eadbc64..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -r"""Train a ConvNet on MNIST using K-FAC. - -Multi tower training mode. See `convnet.train_mnist_multitower` for details. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -from absl import flags -import tensorflow as tf - -from tensorflow.contrib.kfac.examples import convnet - -FLAGS = flags.FLAGS -flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir") -flags.DEFINE_integer("num_towers", 2, - "Number of towers for multi tower training.") - - -def main(unused_argv): - _ = unused_argv - assert FLAGS.num_towers > 1 - devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)] - convnet.train_mnist_multitower( - FLAGS.data_dir, - num_epochs=200, - num_towers=FLAGS.num_towers, - devices=devices) - - -if __name__ == "__main__": - tf.app.run(main=main) diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py deleted file mode 100644 index ea2b252a05702d5adcdc5f70d713277ba604f691..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/examples/mlp.py +++ /dev/null @@ -1,354 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -r"""Train an MLP on MNIST using K-FAC. - -This library fits a 3-layer, tanh-activated MLP on MNIST using K-FAC. After -~25k steps, this should reach perfect accuracy on the training set. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -from tensorflow.contrib.kfac.examples import mnist - -lc = tf.contrib.kfac.layer_collection -opt = tf.contrib.kfac.optimizer - -__all__ = [ - "fc_layer", - "train_mnist", - "train_mnist_multitower", -] - - -def fc_layer(layer_id, inputs, output_size): - """Builds a fully connected layer. - - Args: - layer_id: int. Integer ID for this layer's variables. - inputs: Tensor of shape [num_examples, input_size]. Each row corresponds - to a single example. - output_size: int. Number of output dimensions after fully connected layer. - - Returns: - preactivations: Tensor of shape [num_examples, output_size]. Values of the - layer immediately before the activation function. - activations: Tensor of shape [num_examples, output_size]. Values of the - layer immediately after the activation function. - params: Tuple of (weights, bias), parameters for this layer. - """ - # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. - layer = tf.layers.Dense( - output_size, - kernel_initializer=tf.random_normal_initializer(), - name="fc_%d" % layer_id) - preactivations = layer(inputs) - activations = tf.nn.tanh(preactivations) - - # layer.weights is a list. This converts it a (hashable) tuple. - return preactivations, activations, (layer.kernel, layer.bias) - - -def build_model(examples, labels, num_labels, layer_collection): - """Builds an MLP classification model. - - Args: - examples: Tensor of shape [num_examples, num_features]. Represents inputs of - model. - labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted - by softmax for each example. - num_labels: int. Number of distinct values 'labels' can take on. - layer_collection: LayerCollection instance describing model architecture. - - Returns: - loss: 0-D Tensor representing loss to be minimized. - accuracy: 0-D Tensor representing model's accuracy. - """ - # Build an MLP. For each layer, we'll keep track of the preactivations, - # activations, weights, and bias. - pre0, act0, params0 = fc_layer(layer_id=0, inputs=examples, output_size=128) - pre1, act1, params1 = fc_layer(layer_id=1, inputs=act0, output_size=64) - pre2, act2, params2 = fc_layer(layer_id=2, inputs=act1, output_size=32) - logits, _, params3 = fc_layer(layer_id=3, inputs=act2, output_size=num_labels) - loss = tf.reduce_mean( - tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=labels, logits=logits)) - accuracy = tf.reduce_mean( - tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32)) - - # Register parameters. K-FAC needs to know about the inputs, outputs, and - # parameters of each layer and the logits powering the posterior probability - # over classes. - tf.logging.info("Building LayerCollection.") - layer_collection.register_fully_connected(params0, examples, pre0) - layer_collection.register_fully_connected(params1, act0, pre1) - layer_collection.register_fully_connected(params2, act1, pre2) - layer_collection.register_fully_connected(params3, act2, logits) - layer_collection.register_categorical_predictive_distribution( - logits, name="logits") - - return loss, accuracy - - -def minimize(loss, accuracy, layer_collection, num_towers, session_config=None): - """Minimize 'loss' with KfacOptimizer. - - Args: - loss: 0-D Tensor. Loss to be minimized. - accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. - layer_collection: LayerCollection instance. Describes layers in model. - num_towers: int. Number of CPUs to split minibatch across. - session_config: tf.ConfigProto. Configuration for tf.Session(). - - Returns: - accuracy of classifier on final minibatch. - """ - devices = tuple("/cpu:%d" % tower_id for tower_id in range(num_towers)) - - # Train with K-FAC. We'll use a decreasing learning rate that's cut in 1/2 - # every 10k iterations. - tf.logging.info("Building KFAC Optimizer.") - global_step = tf.train.get_or_create_global_step() - optimizer = opt.KfacOptimizer( - learning_rate=tf.train.exponential_decay( - 0.00002, global_step, 10000, 0.5, staircase=True), - cov_ema_decay=0.95, - damping=0.0005, - layer_collection=layer_collection, - momentum=0.99, - placement_strategy="round_robin", - cov_devices=devices, - inv_devices=devices) - - (cov_update_thunks, - inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - - def make_update_op(update_thunks): - update_ops = [thunk() for thunk in update_thunks] - return tf.group(*update_ops) - - # TODO(b/78537047): change (some) examples to use PeriodicInvCovUpdateKfacOpt - # once that gets moved over? Could still leave more advanced examples as they - # are (e.g. train_mnist_estimator in this file) - - cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([cov_update_op]): - # We update the inverses only every 20 iterations. - inverse_op = tf.cond( - tf.equal(tf.mod(global_step, 100), 0), - lambda: make_update_op(inv_update_thunks), tf.no_op) - with tf.control_dependencies([inverse_op]): - train_op = optimizer.minimize(loss, global_step=global_step) - - tf.logging.info("Starting training.") - with tf.train.MonitoredTrainingSession(config=session_config) as sess: - while not sess.should_stop(): - global_step_, loss_, accuracy_, _ = sess.run( - [global_step, loss, accuracy, train_op]) - - if global_step_ % 100 == 0: - tf.logging.info("global_step: %d | loss: %f | accuracy: %f", - global_step_, loss_, accuracy_) - - return accuracy_ - - -def train_mnist(data_dir, num_epochs, use_fake_data=False): - """Train an MLP on MNIST. - - Args: - data_dir: string. Directory to read MNIST examples from. - num_epochs: int. Number of passes to make over the training set. - use_fake_data: bool. If True, generate a synthetic dataset. - - Returns: - accuracy of model on the final minibatch of training data. - """ - # Load a dataset. - tf.logging.info("Loading MNIST into memory.") - examples, labels = mnist.load_mnist( - data_dir, - num_epochs=num_epochs, - batch_size=64, - flatten_images=True, - use_fake_data=use_fake_data) - - # Build an MLP. The model's layers will be added to the LayerCollection. - tf.logging.info("Building model.") - layer_collection = lc.LayerCollection() - loss, accuracy = build_model(examples, labels, 10, layer_collection) - - # Fit model. - minimize(loss, accuracy, layer_collection, 1) - - -def train_mnist_multitower(data_dir, - num_epochs, - num_towers, - use_fake_data=False): - """Train an MLP on MNIST, splitting the minibatch across multiple towers. - - Args: - data_dir: string. Directory to read MNIST examples from. - num_epochs: int. Number of passes to make over the training set. - num_towers: int. Number of CPUs to split minibatch across. - use_fake_data: bool. If True, generate a synthetic dataset. - - Returns: - accuracy of model on the final minibatch of training data. - """ - # Load a dataset. - tower_batch_size = 64 - batch_size = tower_batch_size * num_towers - tf.logging.info( - ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d " - "tower batch size.") % (batch_size, num_towers, tower_batch_size)) - examples, labels = mnist.load_mnist( - data_dir, - num_epochs=num_epochs, - batch_size=batch_size, - flatten_images=True, - use_fake_data=use_fake_data) - - # Split minibatch across towers. - examples = tf.split(examples, num_towers) - labels = tf.split(labels, num_towers) - - # Build an MLP. Each tower's layers will be added to the LayerCollection. - layer_collection = lc.LayerCollection() - tower_results = [] - for tower_id in range(num_towers): - with tf.device("/cpu:%d" % tower_id): - with tf.name_scope("tower%d" % tower_id): - with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): - tf.logging.info("Building tower %d." % tower_id) - tower_results.append( - build_model(examples[tower_id], labels[tower_id], 10, - layer_collection)) - losses, accuracies = zip(*tower_results) - - # Average across towers. - loss = tf.reduce_mean(losses) - accuracy = tf.reduce_mean(accuracies) - - # Fit model. - session_config = tf.ConfigProto( - allow_soft_placement=False, device_count={ - "CPU": num_towers - }) - return minimize( - loss, accuracy, layer_collection, num_towers, - session_config=session_config) - - -def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False): - """Train an MLP on MNIST using tf.estimator. - - Args: - data_dir: string. Directory to read MNIST examples from. - num_epochs: int. Number of passes to make over the training set. - use_fake_data: bool. If True, generate a synthetic dataset. - - Returns: - accuracy of model on the final minibatch of training data. - """ - - # Load a dataset. - def input_fn(): - tf.logging.info("Loading MNIST into memory.") - return mnist.load_mnist( - data_dir, - num_epochs=num_epochs, - batch_size=64, - flatten_images=True, - use_fake_data=use_fake_data) - - def model_fn(features, labels, mode, params): - """Model function for MLP trained with K-FAC. - - Args: - features: Tensor of shape [batch_size, input_size]. Input features. - labels: Tensor of shape [batch_size]. Target labels for training. - mode: tf.estimator.ModeKey. Must be TRAIN. - params: ignored. - - Returns: - EstimatorSpec for training. - - Raises: - ValueError: If 'mode' is anything other than TRAIN. - """ - del params - - if mode != tf.estimator.ModeKeys.TRAIN: - raise ValueError("Only training is supposed with this API.") - - # Build a ConvNet. - layer_collection = lc.LayerCollection() - loss, accuracy = build_model( - features, labels, num_labels=10, layer_collection=layer_collection) - - # Train with K-FAC. - global_step = tf.train.get_or_create_global_step() - optimizer = opt.KfacOptimizer( - learning_rate=tf.train.exponential_decay( - 0.00002, global_step, 10000, 0.5, staircase=True), - cov_ema_decay=0.95, - damping=0.0001, - layer_collection=layer_collection, - momentum=0.99) - - (cov_update_thunks, - inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - - def make_update_op(update_thunks): - update_ops = [thunk() for thunk in update_thunks] - return tf.group(*update_ops) - - def make_batch_executed_op(update_thunks, batch_size=1): - return tf.group(*tf.contrib.kfac.utils.batch_execute( - global_step, update_thunks, batch_size=batch_size)) - - # Run cov_update_op every step. Run 1 inv_update_ops per step. - cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([cov_update_op]): - # But make sure to execute all the inverse ops on the first step - inverse_op = tf.cond(tf.equal(global_step, 0), - lambda: make_update_op(inv_update_thunks), - lambda: make_batch_executed_op(inv_update_thunks)) - with tf.control_dependencies([inverse_op]): - train_op = optimizer.minimize(loss, global_step=global_step) - - # Print metrics every 5 sec. - hooks = [ - tf.train.LoggingTensorHook( - { - "loss": loss, - "accuracy": accuracy - }, every_n_secs=5), - ] - return tf.estimator.EstimatorSpec( - mode=mode, loss=loss, train_op=train_op, training_hooks=hooks) - - run_config = tf.estimator.RunConfig( - model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100) - - # Train until input_fn() is empty with Estimator. This is a prerequisite for - # TPU compatibility. - estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) - estimator.train(input_fn=input_fn) diff --git a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py deleted file mode 100644 index 9c34ade1d2018135b3636fddb9dcc65839cd59de..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -r"""Train an MLP on MNIST using K-FAC. - -See mlp.py for details. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import sys - -import tensorflow as tf - -from tensorflow.contrib.kfac.examples import mlp - -FLAGS = None - - -def main(argv): - _ = argv - if FLAGS.use_estimator: - if FLAGS.num_towers != 1: - raise ValueError("Only 1 device supported in tf.estimator example.") - mlp.train_mnist_estimator(FLAGS.data_dir, num_epochs=200) - elif FLAGS.num_towers > 1: - mlp.train_mnist_multitower( - FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers) - else: - mlp.train_mnist(FLAGS.data_dir, num_epochs=200) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--data_dir", - type=str, - default="/tmp/mnist", - help="Directory to store dataset in.") - parser.add_argument( - "--num_towers", - type=int, - default=1, - help="Number of CPUs to split minibatch across.") - parser.add_argument( - "--use_estimator", - action="store_true", - help="Use tf.estimator API to train.") - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/kfac/examples/mnist.py b/tensorflow/contrib/kfac/examples/mnist.py deleted file mode 100644 index 547c4ab25d589192f2a5b65987be3b05128fe298..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/examples/mnist.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities for loading MNIST into TensorFlow.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf - -__all__ = [ - 'load_mnist', -] - - -def load_mnist(data_dir, - num_epochs, - batch_size, - flatten_images=True, - use_fake_data=False): - """Loads MNIST dataset into memory. - - Args: - data_dir: string. Directory to read MNIST examples from. - num_epochs: int. Number of passes to make over the dataset. - batch_size: int. Number of examples per minibatch. - flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into - [784]-shaped vectors. - use_fake_data: bool. If True, generate a synthetic dataset rather than - reading MNIST in. - - Returns: - examples: Tensor of shape [batch_size, 784] if 'flatten_images' is - True, else [batch_size, 28, 28, 1]. Each row is one example. - Values in [0, 1]. - labels: Tensor of shape [batch_size]. Indices of integer corresponding to - each example. Values in {0...9}. - """ - if use_fake_data: - rng = np.random.RandomState(42) - num_examples = batch_size * 4 - images = rng.rand(num_examples, 28 * 28) - if not flatten_images: - images = np.reshape(images, [num_examples, 28, 28, 1]) - labels = rng.randint(10, size=num_examples) - else: - mnist_data = tf.contrib.learn.datasets.mnist.read_data_sets( - data_dir, reshape=flatten_images) - num_examples = len(mnist_data.train.labels) - images = mnist_data.train.images - labels = mnist_data.train.labels - - dataset = tf.data.Dataset.from_tensor_slices((np.asarray( - images, dtype=np.float32), np.asarray(labels, dtype=np.int64))) - return (dataset.repeat(num_epochs).shuffle(num_examples).batch(batch_size) - .make_one_shot_iterator().get_next()) diff --git a/tensorflow/contrib/kfac/examples/tests/BUILD b/tensorflow/contrib/kfac/examples/tests/BUILD deleted file mode 100644 index ede7f183fe24f26bd86e232e831dea5f8ea1fdc4..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/examples/tests/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -package(default_visibility = ["//visibility:private"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -load("//tensorflow:tensorflow.bzl", "py_test") - -py_test( - name = "mlp_test", - size = "large", - srcs = ["mlp_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "notsan", - ], - deps = [ - "//tensorflow:tensorflow_py", - "//tensorflow/contrib/kfac/examples:mlp", - "//third_party/py/numpy", - ], -) - -py_test( - name = "convnet_test", - size = "large", - srcs = ["convnet_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "notsan", - ], - deps = [ - "//tensorflow:tensorflow_py", - "//tensorflow/contrib/kfac", - "//tensorflow/contrib/kfac/examples:convnet", - "//third_party/py/numpy", - ], -) - -py_test( - name = "mnist_test", - srcs = ["mnist_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ - "//tensorflow:tensorflow_py", - "//tensorflow/contrib/kfac/examples:mnist", - "//third_party/py/numpy", - ], -) diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py deleted file mode 100644 index adecda71666ee74bc577859589060fa65baf5166..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for convnet.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf - -from tensorflow.contrib.kfac import layer_collection as lc -from tensorflow.contrib.kfac.examples import convnet - - -class ConvNetTest(tf.test.TestCase): - - def testConvLayer(self): - with tf.Graph().as_default(): - pre, act, (w, b) = convnet.conv_layer( - layer_id=1, - inputs=tf.zeros([5, 3, 3, 2]), - kernel_size=3, - out_channels=5) - self.assertShapeEqual(np.zeros([5, 3, 3, 5]), pre) - self.assertShapeEqual(np.zeros([5, 3, 3, 5]), act) - self.assertShapeEqual(np.zeros([3, 3, 2, 5]), tf.convert_to_tensor(w)) - self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b)) - self.assertIsInstance(w, tf.Variable) - self.assertIsInstance(b, tf.Variable) - self.assertIn("conv_1", w.op.name) - self.assertIn("conv_1", b.op.name) - - def testMaxPoolLayer(self): - with tf.Graph().as_default(): - act = convnet.max_pool_layer( - layer_id=1, inputs=tf.zeros([5, 6, 6, 2]), kernel_size=5, stride=3) - self.assertShapeEqual(np.zeros([5, 2, 2, 2]), act) - self.assertEqual(act.op.name, "pool_1/pool") - - def testLinearLayer(self): - with tf.Graph().as_default(): - act, (w, b) = convnet.linear_layer( - layer_id=1, inputs=tf.zeros([5, 20]), output_size=5) - self.assertShapeEqual(np.zeros([5, 5]), act) - self.assertShapeEqual(np.zeros([20, 5]), tf.convert_to_tensor(w)) - self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b)) - self.assertIsInstance(w, tf.Variable) - self.assertIsInstance(b, tf.Variable) - self.assertIn("fc_1", w.op.name) - self.assertIn("fc_1", b.op.name) - - def testBuildModel(self): - with tf.Graph().as_default(): - x = tf.placeholder(tf.float32, [None, 6, 6, 3]) - y = tf.placeholder(tf.int64, [None]) - layer_collection = lc.LayerCollection() - loss, accuracy = convnet.build_model( - x, y, num_labels=5, layer_collection=layer_collection) - - # Ensure layers and logits were registered. - self.assertEqual(len(layer_collection.fisher_blocks), 3) - self.assertEqual(len(layer_collection.losses), 1) - - # Ensure inference doesn't crash. - with self.test_session() as sess: - sess.run(tf.global_variables_initializer()) - feed_dict = { - x: np.random.randn(10, 6, 6, 3).astype(np.float32), - y: np.random.randint(5, size=10).astype(np.int64), - } - sess.run([loss, accuracy], feed_dict=feed_dict) - - def _build_toy_problem(self): - """Construct a toy linear regression problem. - - Initial loss should be, - 2.5 = 0.5 * (1^2 + 2^2) - - Returns: - loss: 0-D Tensor representing loss to be minimized. - accuracy: 0-D Tensors representing model accuracy. - layer_collection: LayerCollection instance describing model architecture. - """ - x = np.asarray([[1.], [2.]]).astype(np.float32) - y = np.asarray([1., 2.]).astype(np.float32) - x, y = (tf.data.Dataset.from_tensor_slices((x, y)) - .repeat(100).batch(2).make_one_shot_iterator().get_next()) - w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer()) - y_hat = tf.matmul(x, w) - loss = tf.reduce_mean(0.5 * tf.square(y_hat - y)) - accuracy = loss - - layer_collection = lc.LayerCollection() - layer_collection.register_fully_connected(params=w, inputs=x, outputs=y_hat) - layer_collection.register_normal_predictive_distribution(y_hat) - - return loss, accuracy, layer_collection - - def testMinimizeLossSingleMachine(self): - with tf.Graph().as_default(): - loss, accuracy, layer_collection = self._build_toy_problem() - accuracy_ = convnet.minimize_loss_single_machine( - loss, accuracy, layer_collection, device="/cpu:0") - self.assertLess(accuracy_, 2.0) - - def testMinimizeLossDistributed(self): - with tf.Graph().as_default(): - loss, accuracy, layer_collection = self._build_toy_problem() - accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker( - task_id=0, - is_chief=True, - num_worker_tasks=1, - num_ps_tasks=0, - master="", - checkpoint_dir=None, - loss=loss, - accuracy=accuracy, - layer_collection=layer_collection) - self.assertLess(accuracy_, 2.0) - - def testTrainMnistSingleMachine(self): - with tf.Graph().as_default(): - # Ensure model training doesn't crash. - # - # Ideally, we should check that accuracy increases as the model converges, - # but there are too few parameters for the model to effectively memorize - # the training set the way an MLP can. - convnet.train_mnist_single_machine( - data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0") - - def testTrainMnistMultitower(self): - with tf.Graph().as_default(): - # Ensure model training doesn't crash. - convnet.train_mnist_multitower( - data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True) - - def testTrainMnistDistributed(self): - with tf.Graph().as_default(): - # Ensure model training doesn't crash. - convnet.train_mnist_distributed_sync_replicas( - task_id=0, - is_chief=True, - num_worker_tasks=1, - num_ps_tasks=0, - master="", - data_dir=None, - num_epochs=2, - op_strategy="chief_worker", - use_fake_data=True) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow/contrib/kfac/examples/tests/mlp_test.py b/tensorflow/contrib/kfac/examples/tests/mlp_test.py deleted file mode 100644 index 22da6c29f1b364d94432315988d844db9b95ec28..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/examples/tests/mlp_test.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for mlp.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf - -from tensorflow.contrib.kfac.examples import mlp - - -class MlpTest(tf.test.TestCase): - - def testFcLayer(self): - with tf.Graph().as_default(): - pre, act, (w, b) = mlp.fc_layer( - layer_id=1, inputs=tf.zeros([5, 3]), output_size=10) - self.assertShapeEqual(np.zeros([5, 10]), pre) - self.assertShapeEqual(np.zeros([5, 10]), act) - self.assertShapeEqual(np.zeros([3, 10]), tf.convert_to_tensor(w)) - self.assertShapeEqual(np.zeros([10]), tf.convert_to_tensor(b)) - self.assertIsInstance(w, tf.Variable) - self.assertIsInstance(b, tf.Variable) - self.assertIn("fc_1/", w.op.name) - self.assertIn("fc_1/", b.op.name) - - def testTrainMnist(self): - with tf.Graph().as_default(): - # Ensure model training doesn't crash. - # - # Ideally, we should check that accuracy increases as the model converges, - # but that takes a non-trivial amount of compute. - mlp.train_mnist(data_dir=None, num_epochs=1, use_fake_data=True) - - def testTrainMnistMultitower(self): - with tf.Graph().as_default(): - # Ensure model training doesn't crash. - mlp.train_mnist_multitower( - data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True) - - def testTrainMnistEstimator(self): - with tf.Graph().as_default(): - # Ensure model training doesn't crash. - mlp.train_mnist_estimator(data_dir=None, num_epochs=1, use_fake_data=True) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow/contrib/kfac/examples/tests/mnist_test.py b/tensorflow/contrib/kfac/examples/tests/mnist_test.py deleted file mode 100644 index 92f84623573d3ad3af26b500fccfe533280d0199..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/examples/tests/mnist_test.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for mnist.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf - -from tensorflow.contrib.kfac.examples import mnist - - -class MnistTest(tf.test.TestCase): - - def testValues(self): - """Ensure values are in their expected range.""" - with tf.Graph().as_default(): - examples, labels = mnist.load_mnist( - data_dir=None, num_epochs=1, batch_size=64, use_fake_data=True) - - with self.test_session() as sess: - examples_, labels_ = sess.run([examples, labels]) - self.assertTrue(np.all((0 <= examples_) & (examples_ < 1))) - self.assertTrue(np.all((0 <= labels_) & (labels_ < 10))) - - def testFlattenedShapes(self): - """Ensure images are flattened into their appropriate shape.""" - with tf.Graph().as_default(): - examples, labels = mnist.load_mnist( - data_dir=None, - num_epochs=1, - batch_size=64, - flatten_images=True, - use_fake_data=True) - - with self.test_session() as sess: - examples_, labels_ = sess.run([examples, labels]) - self.assertEqual(examples_.shape, (64, 784)) - self.assertEqual(labels_.shape, (64,)) - - def testNotFlattenedShapes(self): - """Ensure non-flattened images are their appropriate shape.""" - with tf.Graph().as_default(): - examples, labels = mnist.load_mnist( - data_dir=None, - num_epochs=1, - batch_size=64, - flatten_images=False, - use_fake_data=True) - - with self.test_session() as sess: - examples_, labels_ = sess.run([examples, labels]) - self.assertEqual(examples_.shape, (64, 28, 28, 1)) - self.assertEqual(labels_.shape, (64,)) - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/contrib/kfac/g3doc/autoencoder.png b/tensorflow/contrib/kfac/g3doc/autoencoder.png deleted file mode 100644 index 20f93c77034f3355653a6a260cccdad29c080eaf..0000000000000000000000000000000000000000 Binary files a/tensorflow/contrib/kfac/g3doc/autoencoder.png and /dev/null differ diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD deleted file mode 100644 index 6e4a8d71baa85d05d514e4683016c2f4d299ec8e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ /dev/null @@ -1,160 +0,0 @@ -package(default_visibility = ["//visibility:private"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -load("//tensorflow:tensorflow.bzl", "py_test") - -py_test( - name = "estimator_test", - srcs = ["estimator_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:fisher_estimator", - "//tensorflow/contrib/kfac/python/ops:layer_collection", - "//tensorflow/contrib/kfac/python/ops:utils", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - -py_test( - name = "fisher_factors_test", - srcs = ["fisher_factors_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:fisher_blocks", - "//tensorflow/contrib/kfac/python/ops:fisher_factors", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_seed", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - -py_test( - name = "fisher_blocks_test", - srcs = ["fisher_blocks_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:fisher_blocks", - "//tensorflow/contrib/kfac/python/ops:layer_collection", - "//tensorflow/contrib/kfac/python/ops:linear_operator", - "//tensorflow/contrib/kfac/python/ops:utils", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:random_seed", - "//tensorflow/python:state_ops", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - -py_test( - name = "layer_collection_test", - srcs = ["layer_collection_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:fisher_blocks", - "//tensorflow/contrib/kfac/python/ops:fisher_factors", - "//tensorflow/contrib/kfac/python/ops:layer_collection", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:random_seed", - "//tensorflow/python:variable_scope", - ], -) - -py_test( - name = "optimizer_test", - srcs = ["optimizer_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:fisher_factors", - "//tensorflow/contrib/kfac/python/ops:kfac_optimizer", - "//tensorflow/contrib/kfac/python/ops:layer_collection", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - -py_test( - name = "utils_test", - srcs = ["utils_test.py"], - srcs_version = "PY2AND3", - tags = ["no_windows"], # TODO: needs investigation on Windows - deps = [ - "//tensorflow/contrib/kfac/python/ops:utils", - "//tensorflow/contrib/tpu", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:random_seed", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - -py_test( - name = "op_queue_test", - srcs = ["op_queue_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:op_queue", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - ], -) - -py_test( - name = "loss_functions_test", - srcs = ["loss_functions_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/kfac/python/ops:loss_functions", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:framework_ops", - "//tensorflow/python:random_ops", - "//third_party/py/numpy", - ], -) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py deleted file mode 100644 index 76b31a5730ad7c298711b1533a991f4adfd68cc1..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ /dev/null @@ -1,310 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.estimator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.kfac.python.ops import estimator -from tensorflow.contrib.kfac.python.ops import layer_collection as lc -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import test -from tensorflow.python.training import training_util - -_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] - - -class EstimatorTest(test.TestCase): - - def setUp(self): - self._graph = ops.Graph() - with self._graph.as_default(): - self.layer_collection = lc.LayerCollection() - - self.inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32) - self.weights = variable_scope.get_variable( - "w", shape=(2, 2), dtype=dtypes.float32) - self.bias = variable_scope.get_variable( - "b", initializer=init_ops.zeros_initializer(), shape=(2, 1)) - self.output = math_ops.matmul(self.inputs, self.weights) + self.bias - - # Only register the weights. - self.layer_collection.register_fully_connected( - params=(self.weights,), inputs=self.inputs, outputs=self.output) - - self.outputs = math_ops.tanh(self.output) - self.targets = array_ops.zeros_like(self.outputs) - self.layer_collection.register_categorical_predictive_distribution( - logits=self.outputs, targets=self.targets) - - def testEstimatorInitManualRegistration(self): - with self._graph.as_default(): - # We should be able to build an estimator for only the registered vars. - estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection - ) - - # Check that we throw an error if we try to build an estimator for vars - # that were not manually registered. - with self.assertRaises(ValueError): - est = estimator.FisherEstimatorRoundRobin( - variables=[self.weights, self.bias], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection - ) - est.make_vars_and_create_op_thunks() - - # Check that we throw an error if we don't include registered variables, - # i.e. self.weights - with self.assertRaises(ValueError): - est = estimator.FisherEstimatorRoundRobin( - variables=[], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection) - est.make_vars_and_create_op_thunks() - - @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) - def testVariableWrongNumberOfUses(self, mock_uses): - with self.assertRaises(ValueError): - est = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection) - est.make_vars_and_create_op_thunks() - - def testInvalidEstimationMode(self): - with self.assertRaises(ValueError): - est = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection, - estimation_mode="not_a_real_mode") - est.make_vars_and_create_op_thunks() - - def testGradientsModeBuild(self): - with self._graph.as_default(): - est = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection, - estimation_mode="gradients") - est.make_vars_and_create_op_thunks() - - def testEmpiricalModeBuild(self): - with self._graph.as_default(): - est = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection, - estimation_mode="empirical") - est.make_vars_and_create_op_thunks() - - def testCurvaturePropModeBuild(self): - with self._graph.as_default(): - est = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection, - estimation_mode="curvature_prop") - est.make_vars_and_create_op_thunks() - - def testExactModeBuild(self): - with self._graph.as_default(): - est = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - cov_ema_decay=0.1, - damping=0.2, - layer_collection=self.layer_collection, - estimation_mode="exact") - est.make_vars_and_create_op_thunks() - - def test_cov_update_thunks(self): - """Ensures covariance update ops run once per global_step.""" - with self._graph.as_default(), self.cached_session() as sess: - fisher_estimator = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - layer_collection=self.layer_collection, - damping=0.2, - cov_ema_decay=0.0) - - # Construct an op that executes one covariance update per step. - global_step = training_util.get_or_create_global_step() - (cov_variable_thunks, cov_update_op_thunks, _, - _) = fisher_estimator.create_ops_and_vars_thunks() - for thunk in cov_variable_thunks: - thunk() - cov_matrices = [ - fisher_factor.get_cov() - for fisher_factor in self.layer_collection.get_factors() - ] - cov_update_op = control_flow_ops.case( - [(math_ops.equal(global_step, i), thunk) - for i, thunk in enumerate(cov_update_op_thunks)]) - increment_global_step = global_step.assign_add(1) - - sess.run(variables.global_variables_initializer()) - initial_cov_values = sess.run(cov_matrices) - - # Ensure there's one update per covariance matrix. - self.assertEqual(len(cov_matrices), len(cov_update_op_thunks)) - - # Test is no-op if only 1 covariance matrix. - assert len(cov_matrices) > 1 - - for i in range(len(cov_matrices)): - # Compare new and old covariance values - new_cov_values = sess.run(cov_matrices) - is_cov_equal = [ - np.allclose(initial_cov_value, new_cov_value) - for (initial_cov_value, - new_cov_value) in zip(initial_cov_values, new_cov_values) - ] - num_cov_equal = sum(is_cov_equal) - - # Ensure exactly one covariance matrix changes per step. - self.assertEqual(num_cov_equal, len(cov_matrices) - i) - - # Run all covariance update ops. - sess.run(cov_update_op) - sess.run(increment_global_step) - - def test_round_robin_placement(self): - """Check if the ops and variables are placed on devices correctly.""" - with self._graph.as_default(): - fisher_estimator = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - layer_collection=self.layer_collection, - damping=0.2, - cov_ema_decay=0.0, - cov_devices=["/cpu:{}".format(i) for i in range(2)], - inv_devices=["/cpu:{}".format(i) for i in range(2)]) - - # Construct an op that executes one covariance update per step. - (cov_update_thunks, - inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks( - scope="test") - cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) - inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) - self.assertEqual(cov_update_ops[0].device, "/device:CPU:0") - self.assertEqual(cov_update_ops[1].device, "/device:CPU:1") - self.assertEqual(inv_update_ops[0].device, "/device:CPU:0") - self.assertEqual(inv_update_ops[1].device, "/device:CPU:1") - cov_matrices = [ - fisher_factor.get_cov() - for fisher_factor in self.layer_collection.get_factors() - ] - inv_matrices = [ - matrix - for fisher_factor in self.layer_collection.get_factors() - for matrix in fisher_factor._matpower_by_exp_and_damping.values() - ] - self.assertEqual(cov_matrices[0].device, "/device:CPU:0") - self.assertEqual(cov_matrices[1].device, "/device:CPU:1") - # Inverse matrices need to be explicitly placed. - self.assertEqual(inv_matrices[0].device, "") - self.assertEqual(inv_matrices[1].device, "") - - def test_inv_update_thunks(self): - """Ensures inverse update ops run once per global_step.""" - with self._graph.as_default(), self.cached_session() as sess: - fisher_estimator = estimator.FisherEstimatorRoundRobin( - variables=[self.weights], - layer_collection=self.layer_collection, - damping=0.2, - cov_ema_decay=0.0) - - # Construct op that updates one inverse per global step. - global_step = training_util.get_or_create_global_step() - (cov_variable_thunks, _, inv_variable_thunks, - inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks() - for thunk in cov_variable_thunks: - thunk() - for thunk in inv_variable_thunks: - thunk() - inv_matrices = [ - matrix - for fisher_factor in self.layer_collection.get_factors() - for matrix in fisher_factor._matpower_by_exp_and_damping.values() - ] - inv_update_op = control_flow_ops.case( - [(math_ops.equal(global_step, i), thunk) - for i, thunk in enumerate(inv_update_op_thunks)]) - increment_global_step = global_step.assign_add(1) - - sess.run(variables.global_variables_initializer()) - initial_inv_values = sess.run(inv_matrices) - - # Ensure there's one update per inverse matrix. This is true as long as - # there's no fan-in/fan-out or parameter re-use. - self.assertEqual(len(inv_matrices), len(inv_update_op_thunks)) - - # Test is no-op if only 1 invariance matrix. - assert len(inv_matrices) > 1 - - # Assign each covariance matrix a value other than the identity. This - # ensures that the inverse matrices are updated to something different as - # well. - cov_matrices = [ - fisher_factor.get_cov() - for fisher_factor in self.layer_collection.get_factors() - ] - sess.run([ - cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0]))) - for cov_matrix in cov_matrices - ]) - - for i in range(len(inv_matrices)): - # Compare new and old inverse values - new_inv_values = sess.run(inv_matrices) - is_inv_equal = [ - np.allclose(initial_inv_value, new_inv_value) - for (initial_inv_value, - new_inv_value) in zip(initial_inv_values, new_inv_values) - ] - num_inv_equal = sum(is_inv_equal) - - # Ensure exactly one inverse matrix changes per step. - self.assertEqual(num_inv_equal, len(inv_matrices) - i) - - # Run all inverse update ops. - sess.run(inv_update_op) - sess.run(increment_global_step) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py deleted file mode 100644 index f845def5074ea7510e13d00140ecd12ab3da53a0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ /dev/null @@ -1,1018 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.fisher_blocks.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb -from tensorflow.contrib.kfac.python.ops import fisher_factors as ff -from tensorflow.contrib.kfac.python.ops import layer_collection as lc -from tensorflow.contrib.kfac.python.ops import linear_operator as lo -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variables as tf_variables -from tensorflow.python.platform import test - - -# We need to set these constants since the numerical values used in the tests -# were chosen when these used to be the defaults. -ff.set_global_constants(init_covariances_at_zero=False, - zero_debias=False, - init_inverses_at_zero=False) - -# TODO(b/78538100): As far as I can tell, all the tests that say "Make sure our -# inverse is something other than the identity" are actually broken. They never -# run the covariance update ops and so the inverse actually is the identity -# (possible plus the damping term, which would still make it a multiple of the -# identity). - - -def _make_psd(dim): - """Constructs a PSD matrix of the given dimension.""" - mat = np.ones((dim, dim), dtype=np.float32) - mat[np.arange(dim), np.arange(dim)] = 2. + np.arange(dim) - return array_ops.constant(mat) - - -class UtilsTest(test.TestCase): - - def testComputePiTracenorm(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - diag = ops.convert_to_tensor([1., 2., 0., 1.]) - left_factor = lo.LinearOperatorDiag(diag) - right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2])) - - # pi is the sqrt of the left trace norm divided by the right trace norm - pi = fb.compute_pi_tracenorm(left_factor, right_factor) - - pi_val = sess.run(pi) - self.assertEqual(1., pi_val) - - -class FullFBTest(test.TestCase): - - def testFullFBInitSingleTensor(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - - self.assertAllEqual(params, block.tensors_to_compute_grads()) - - def testFullFBInitTensorTuple(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - - self.assertAllEqual(params, block.tensors_to_compute_grads()) - - def testInstantiateFactors(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - - grads = (params[0]**2, math_ops.sqrt(params[1])) - block.instantiate_factors(grads, 0.5) - - def testMultiplyInverseTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - grads = (params[0]**2, math_ops.sqrt(params[1])) - block.instantiate_factors((grads,), 0.5) - block._factor.instantiate_cov_variables() - block.register_inverse() - block._factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._factor.make_inverse_update_ops()) - - vector = array_ops.ones(3,) * 2 - output = block.multiply_inverse(vector) - - self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) - - def testMultiplyInverseNotTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = array_ops.constant([[1.], [2.]]) - block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - grads = params**2 - block.instantiate_factors((grads,), 0.5) - block._factor.instantiate_cov_variables() - block.register_inverse() - block._factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._factor.make_inverse_update_ops()) - - vector = array_ops.ones(2,) * 2 - output = block.multiply_inverse(vector) - - self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) - - def testMultiplyInverseAgainstExplicit(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - grads = (array_ops.constant([2., 3.]), array_ops.constant(4.)) - damping = 0.5 - block.instantiate_factors((grads,), damping) - block._factor.instantiate_cov_variables() - block.register_inverse() - block._factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(state_ops.assign(block._factor._cov, _make_psd(3))) - sess.run(block._factor.make_inverse_update_ops()) - - v_flat = np.array([4., 5., 6.], dtype=np.float32) - vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) - output = block.multiply_inverse(vector) - output_flat = sess.run(utils.tensors_to_column(output)).ravel() - - full = sess.run(block.full_fisher_block()) - explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) - - self.assertAllClose(output_flat, explicit) - - -class NaiveDiagonalFBTest(test.TestCase): - - def testNaiveDiagonalFBInitSingleTensor(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - - self.assertAllEqual(params, block.tensors_to_compute_grads()) - - def testNaiveDiagonalFBInitTensorTuple(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - - self.assertAllEqual(params, block.tensors_to_compute_grads()) - - def testInstantiateFactors(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - - grads = (params[0]**2, math_ops.sqrt(params[1])) - block.instantiate_factors(grads, 0.5) - - def testMultiplyInverseTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - grads = (params[0]**2, math_ops.sqrt(params[1])) - block.instantiate_factors((grads,), 0.5) - block._factor.instantiate_cov_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._factor.make_inverse_update_ops()) - - vector = array_ops.ones(3,) * 2 - output = block.multiply_inverse(vector) - - self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) - - def testMultiplyInverseNotTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = array_ops.constant([[1.], [2.]]) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - grads = params**2 - block.instantiate_factors((grads,), 0.5) - block._factor.instantiate_cov_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._factor.make_inverse_update_ops()) - vector = array_ops.ones(2,) * 2 - output = block.multiply_inverse(vector) - - self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) - - def testMultiplyInverseAgainstExplicit(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) - block.register_additional_tower(32) - grads = (params[0]**2, math_ops.sqrt(params[1])) - damping = 0.5 - block.instantiate_factors((grads,), damping) - block._factor.instantiate_cov_variables() - - cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1]) - sess.run(state_ops.assign(block._factor._cov, cov)) - sess.run(block._factor.make_inverse_update_ops()) - - v_flat = np.array([4., 5., 6.], dtype=np.float32) - vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) - output = block.multiply_inverse(vector) - output_flat = sess.run(utils.tensors_to_column(output)).ravel() - - full = sess.run(block.full_fisher_block()) - explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) - self.assertAllClose(output_flat, explicit) - - -class FullyConnectedDiagonalFBTest(test.TestCase): - - def setUp(self): - super(FullyConnectedDiagonalFBTest, self).setUp() - - self.batch_size = 4 - self.input_size = 6 - self.output_size = 3 - - self.inputs = np.random.randn(self.batch_size, self.input_size).astype( - np.float32) - self.outputs = np.zeros([self.batch_size, self.output_size]).astype( - np.float32) - self.output_grads = np.random.randn(self.batch_size, - self.output_size).astype(np.float32) - self.w = np.random.randn(self.input_size, self.output_size).astype( - np.float32) - self.b = np.random.randn(self.output_size).astype(np.float32) - - def fisherApprox(self, has_bias=False): - """Fisher approximation using default inputs.""" - if has_bias: - inputs = np.concatenate( - [self.inputs, np.ones([self.batch_size, 1])], axis=1) - else: - inputs = self.inputs - return self.buildDiagonalFisherApproximation(inputs, self.output_grads) - - def buildDiagonalFisherApproximation(self, inputs, output_grads): - """Builds explicit diagonal Fisher approximation. - - Fisher's diagonal is (d loss / d w)'s elements squared for - d/dw = E[outer(input, output_grad)] - - where the expectation is taken over examples. - - Args: - inputs: np.array of shape [batch_size, input_size]. - output_grads: np.array of shape [batch_size, output_size]. - - Returns: - Diagonal np.array of shape [num_params, num_params] for num_params = - input_size * output_size. - """ - batch_size = inputs.shape[0] - assert output_grads.shape[0] == batch_size - input_size = inputs.shape[1] - output_size = output_grads.shape[1] - fisher_diag = np.zeros((input_size, output_size)) - for i in range(batch_size): - fisher_diag += np.square(np.outer(inputs[i], output_grads[i])) - return np.diag(fisher_diag.flatten()) / batch_size - - def testMultiply(self): - result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], - [self.output_grads]) - - # Construct Fisher-vector product. - expected_result = self.fisherApprox().dot(self.w.flatten()) - expected_result = expected_result.reshape( - [self.input_size, self.output_size]) - - self.assertAllClose(expected_result, result) - - def testMultiplyInverse(self): - _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], - [self.output_grads]) - - # Construct inverse Fisher-vector product. - expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten()) - expected_result = expected_result.reshape( - [self.input_size, self.output_size]) - - self.assertAllClose(expected_result, result) - - def testRegisterAdditionalTower(self): - """Ensure 1 big tower and 2 small towers are equivalent.""" - multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( - self.w, [self.inputs], [self.outputs], [self.output_grads]) - multiply_result_small, multiply_inverse_result_small = ( - self.runFisherBlockOps(self.w, np.split(self.inputs, 2), - np.split(self.outputs, 2), - np.split(self.output_grads, 2))) - - self.assertAllClose(multiply_result_big, multiply_result_small) - self.assertAllClose(multiply_inverse_result_big, - multiply_inverse_result_small) - - def testMultiplyHasBias(self): - result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs], - [self.outputs], [self.output_grads]) - expected_result = self.fisherApprox(True).dot( - np.concatenate([self.w.flatten(), self.b.flatten()])) - expected_result = expected_result.reshape( - [self.input_size + 1, self.output_size]) - expected_result = (expected_result[:-1], expected_result[-1]) - - self.assertEqual(len(result), 2) - self.assertAllClose(expected_result[0], result[0]) - self.assertAllClose(expected_result[1], result[1]) - - def runFisherBlockOps(self, params, inputs, outputs, output_grads): - """Run Ops guaranteed by FisherBlock interface. - - Args: - params: Tensor or 2-tuple of Tensors. Represents weights or weights and - bias of this layer. - inputs: list of Tensors of shape [batch_size, input_size]. Inputs to - layer. - outputs: list of Tensors of shape [batch_size, output_size]. - Preactivations produced by layer. - output_grads: list of Tensors of shape [batch_size, output_size]. - Gradient of loss with respect to 'outputs'. - - Returns: - multiply_result: Result of FisherBlock.multiply(params) - multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) - """ - with ops.Graph().as_default(), self.cached_session() as sess: - inputs = as_tensors(inputs) - outputs = as_tensors(outputs) - output_grads = as_tensors(output_grads) - params = as_tensors(params) - - block = fb.FullyConnectedDiagonalFB( - lc.LayerCollection(), has_bias=isinstance(params, (tuple, list))) - for (i, o) in zip(inputs, outputs): - block.register_additional_tower(i, o) - - block.instantiate_factors((output_grads,), damping=0.0) - block._factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._factor.make_covariance_update_op(0.0)) - multiply_result = sess.run(block.multiply(params)) - multiply_inverse_result = sess.run(block.multiply_inverse(params)) - - return multiply_result, multiply_inverse_result - - -class EmbeddingKFACFBTest(test.TestCase): - - def testInstantiateFactors(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - - # Create a Fisher Block. - vocab_size = 5 - block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) - - # Add some examples. - inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) - outputs = array_ops.constant([[0.], [1.], [2.]]) - block.register_additional_tower(inputs, outputs) - - # Instantiate factor's variables. Ensure it doesn't fail. - grads = outputs**2. - damping = array_ops.constant(0.) - block.instantiate_factors(((grads,),), damping) - - def testMultiplyInverse(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - - # Create a Fisher Block. - vocab_size = 5 - block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) - - # Add some examples. - inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) - outputs = array_ops.constant([[0.], [1.], [2.]]) - block.register_additional_tower(inputs, outputs) - - # Instantiate factor's variables. Ensure it doesn't fail. - grads = outputs**2. - damping = array_ops.constant(0.) - block.instantiate_factors(((grads,),), damping) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - # Create a sparse update. - indices = array_ops.constant([1, 3, 4]) - values = array_ops.constant([[1.], [1.], [1.]]) - sparse_vector = ops.IndexedSlices( - values, indices, dense_shape=[vocab_size, 1]) - dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1]) - - # Compare Fisher-vector product against explicit result. - result = block.multiply_inverse(sparse_vector) - expected_result = linalg_ops.matrix_solve(block.full_fisher_block(), - dense_vector) - - sess.run(tf_variables.global_variables_initializer()) - self.assertAlmostEqual( - sess.run(expected_result[1]), sess.run(result.values[0])) - self.assertAlmostEqual( - sess.run(expected_result[3]), sess.run(result.values[1])) - self.assertAlmostEqual( - sess.run(expected_result[4]), sess.run(result.values[2])) - - -class FullyConnectedKFACBasicFBTest(test.TestCase): - - def testFullyConnectedKFACBasicFBInit(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - inputs = array_ops.constant([1., 2.]) - outputs = array_ops.constant([3., 4.]) - block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection()) - block.register_additional_tower(inputs, outputs) - - self.assertAllEqual([outputs], block.tensors_to_compute_grads()) - - def testInstantiateFactorsHasBias(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - inputs = array_ops.constant([[1., 2.], [3., 4.]]) - outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True) - block.register_additional_tower(inputs, outputs) - - grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) - - def testInstantiateFactorsNoBias(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - inputs = array_ops.constant([[1., 2.], [3., 4.]]) - outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_tower(inputs, outputs) - - grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) - - def testMultiplyInverseTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]) - outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_tower(inputs, outputs) - grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) - - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._input_factor.make_inverse_update_ops()) - sess.run(block._output_factor.make_inverse_update_ops()) - - vector = ( - np.arange(2, 6).reshape(2, 2).astype(np.float32), # - np.arange(1, 3).reshape(2, 1).astype(np.float32)) - output = block.multiply_inverse((array_ops.constant(vector[0]), - array_ops.constant(vector[1]))) - - output = sess.run(output) - self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]], - output[0]) - self.assertAllClose([0.343146, 0.686291], output[1]) - - def testMultiplyInverseNotTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - inputs = array_ops.constant([[1., 2.], [3., 4.]]) - outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_tower(inputs, outputs) - grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._input_factor.make_inverse_update_ops()) - sess.run(block._output_factor.make_inverse_update_ops()) - - vector = np.arange(2, 6).reshape(2, 2).astype(np.float32) - output = block.multiply_inverse(array_ops.constant(vector)) - - self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]], - sess.run(output)) - - def testMultiplyInverseAgainstExplicit(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - input_dim, output_dim = 3, 2 - inputs = array_ops.zeros([32, input_dim]) - outputs = array_ops.zeros([32, output_dim]) - params = array_ops.zeros([input_dim, output_dim]) - block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) - block.register_additional_tower(inputs, outputs) - grads = outputs**2 - damping = 0. # This test is only valid without damping. - block.instantiate_factors(((grads,),), damping) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - - sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3))) - sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) - - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - sess.run(block._input_factor.make_inverse_update_ops()) - sess.run(block._output_factor.make_inverse_update_ops()) - - v_flat = np.arange(6, dtype=np.float32) - vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) - output = block.multiply_inverse(vector) - output_flat = sess.run(utils.tensors_to_column(output)).ravel() - - full = sess.run(block.full_fisher_block()) - explicit = np.dot(np.linalg.inv(full + damping * np.eye(6)), v_flat) - - self.assertAllClose(output_flat, explicit) - - -class ConvDiagonalFBTest(test.TestCase): - - def setUp(self): - super(ConvDiagonalFBTest, self).setUp() - - self.batch_size = 2 - self.height = 8 - self.width = 4 - self.input_channels = 6 - self.output_channels = 3 - self.kernel_size = 1 - - self.inputs = np.random.randn(self.batch_size, self.height, self.width, - self.input_channels).astype(np.float32) - self.outputs = np.zeros( - [self.batch_size, self.height, self.width, - self.output_channels]).astype(np.float32) - self.output_grads = np.random.randn( - self.batch_size, self.height, self.width, self.output_channels).astype( - np.float32) - self.w = np.random.randn(self.kernel_size, self.kernel_size, - self.input_channels, self.output_channels).astype( - np.float32) - self.b = np.random.randn(self.output_channels).astype(np.float32) - - def fisherApprox(self, has_bias=False): - """Fisher approximation using default inputs.""" - if has_bias: - inputs = np.concatenate( - [self.inputs, - np.ones([self.batch_size, self.height, self.width, 1])], - axis=-1) - else: - inputs = self.inputs - return self.buildDiagonalFisherApproximation(inputs, self.output_grads, - self.kernel_size) - - def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size): - r"""Builds explicit diagonal Fisher approximation. - - Fisher's diagonal is (d loss / d w)'s elements squared for - d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})] - - where the expectation is taken over examples and the sum over (x, y) - locations upon which the convolution is applied. - - Args: - inputs: np.array of shape [batch_size, height, width, input_channels]. - output_grads: np.array of shape [batch_size, height, width, - output_channels]. - kernel_size: int. height and width of kernel. - - Returns: - Diagonal np.array of shape [num_params, num_params] for num_params = - kernel_size^2 * input_channels * output_channels. - """ - batch_size, height, width, input_channels = inputs.shape - assert output_grads.shape[0] == batch_size - assert output_grads.shape[1] == height - assert output_grads.shape[2] == width - output_channels = output_grads.shape[3] - - # If kernel_size == 1, then we don't need to worry about capturing context - # around the pixel upon which a convolution is applied. This makes testing - # easier. - assert kernel_size == 1, "kernel_size != 1 isn't supported." - num_locations = height * width - inputs = np.reshape(inputs, [batch_size, num_locations, input_channels]) - output_grads = np.reshape(output_grads, - [batch_size, num_locations, output_channels]) - - fisher_diag = np.zeros((input_channels, output_channels)) - for i in range(batch_size): - # Each example's approximation is a square(sum-of-outer-products). - example_fisher_diag = np.zeros((input_channels, output_channels)) - for j in range(num_locations): - example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j]) - fisher_diag += np.square(example_fisher_diag) - - # Normalize by batch_size (not num_locations). - return np.diag(fisher_diag.flatten()) / batch_size - - def testMultiply(self): - result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], - [self.output_grads]) - - # Construct Fisher-vector product. - expected_result = self.fisherApprox().dot(self.w.flatten()) - expected_result = expected_result.reshape([ - self.kernel_size, self.kernel_size, self.input_channels, - self.output_channels - ]) - - self.assertAllClose(expected_result, result) - - def testMultiplyInverse(self): - _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], - [self.output_grads]) - - # Construct inverse Fisher-vector product. - expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten()) - expected_result = expected_result.reshape([ - self.kernel_size, self.kernel_size, self.input_channels, - self.output_channels - ]) - - self.assertAllClose(expected_result, result, atol=1e-3) - - def testRegisterAdditionalTower(self): - """Ensure 1 big tower and 2 small towers are equivalent.""" - multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( - self.w, [self.inputs], [self.outputs], [self.output_grads]) - multiply_result_small, multiply_inverse_result_small = ( - self.runFisherBlockOps(self.w, np.split(self.inputs, 2), - np.split(self.outputs, 2), - np.split(self.output_grads, 2))) - - self.assertAllClose(multiply_result_big, multiply_result_small) - self.assertAllClose(multiply_inverse_result_big, - multiply_inverse_result_small) - - def testMultiplyHasBias(self): - result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs], - [self.outputs], [self.output_grads]) - # Clone 'b' along 'input_channels' dimension. - b_filter = np.tile( - np.reshape(self.b, [1, 1, 1, self.output_channels]), - [self.kernel_size, self.kernel_size, 1, 1]) - params = np.concatenate([self.w, b_filter], axis=2) - expected_result = self.fisherApprox(True).dot(params.flatten()) - - # Extract 'b' from concatenated parameters. - expected_result = expected_result.reshape([ - self.kernel_size, self.kernel_size, self.input_channels + 1, - self.output_channels - ]) - expected_result = (expected_result[:, :, 0:-1, :], - np.reshape(expected_result[:, :, -1, :], - [self.output_channels])) - - self.assertEqual(len(result), 2) - self.assertAllClose(expected_result[0], result[0]) - self.assertAllClose(expected_result[1], result[1]) - - def runFisherBlockOps(self, params, inputs, outputs, output_grads): - """Run Ops guaranteed by FisherBlock interface. - - Args: - params: Tensor or 2-tuple of Tensors. Represents weights or weights and - bias of this layer. - inputs: list of Tensors of shape [batch_size, input_size]. Inputs to - layer. - outputs: list of Tensors of shape [batch_size, output_size]. - Preactivations produced by layer. - output_grads: list of Tensors of shape [batch_size, output_size]. - Gradient of loss with respect to 'outputs'. - - Returns: - multiply_result: Result of FisherBlock.multiply(params) - multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) - """ - with ops.Graph().as_default(), self.cached_session() as sess: - inputs = as_tensors(inputs) - outputs = as_tensors(outputs) - output_grads = as_tensors(output_grads) - params = as_tensors(params) - - block = fb.ConvDiagonalFB( - lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME') - for (i, o) in zip(inputs, outputs): - block.register_additional_tower(i, o) - - block.instantiate_factors((output_grads,), damping=0.0) - block._factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._factor.make_covariance_update_op(0.0)) - multiply_result = sess.run(block.multiply(params)) - multiply_inverse_result = sess.run(block.multiply_inverse(params)) - - return multiply_result, multiply_inverse_result - - -class DepthwiseConvKFCBasicFBTest(test.TestCase): - - def testInstantiateFactors(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - params = random_ops.random_normal((3, 3, 8, 2)) - inputs = random_ops.random_normal((32, 5, 5, 8)) - outputs = random_ops.random_normal((32, 5, 5, 16)) - layer_collection = lc.LayerCollection() - block = fb.DepthwiseConvKFCBasicFB( - layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') - block.register_additional_tower(inputs, outputs) - grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) - - def testMultiplyInverse(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = random_ops.random_normal((3, 3, 8, 2)) - inputs = random_ops.random_normal((32, 5, 5, 8)) - outputs = random_ops.random_normal((32, 5, 5, 16)) - layer_collection = lc.LayerCollection() - block = fb.DepthwiseConvKFCBasicFB( - layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME') - block.register_additional_tower(inputs, outputs) - grads = outputs**2 - block.instantiate_factors(([grads],), 0.5) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - # Ensure inverse update op doesn't crash. - sess.run(tf_variables.global_variables_initializer()) - sess.run([ - factor.make_inverse_update_ops() - for factor in layer_collection.get_factors() - ]) - - # Ensure inverse-vector multiply doesn't crash. - output = block.multiply_inverse(params) - sess.run(output) - - # Ensure same shape. - self.assertAllEqual(output.shape, params.shape) - - -class ConvKFCBasicFBTest(test.TestCase): - - def _testConvKFCBasicFBInitParams(self, params): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - if isinstance(params, (list, tuple)): - params = [array_ops.constant(param) for param in params] - else: - params = array_ops.constant(params) - inputs = random_ops.random_normal((2, 2, 2)) - outputs = random_ops.random_normal((2, 2, 2)) - block = fb.ConvKFCBasicFB( - lc.LayerCollection(), params=params, padding='SAME') - block.register_additional_tower(inputs, outputs) - - self.assertAllEqual([outputs], block.tensors_to_compute_grads()) - - def testConvKFCBasicFBInitParamsParamsTuple(self): - self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])]) - - def testConvKFCBasicFBInitParamsParamsSingle(self): - self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])]) - - def testMultiplyInverseTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = random_ops.random_normal((2, 2, 2, 2)) - inputs = random_ops.random_normal((2, 2, 2, 2)) - outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB( - lc.LayerCollection(), params=params, padding='SAME') - block.register_additional_tower(inputs, outputs) - grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._input_factor.make_inverse_update_ops()) - sess.run(block._output_factor.make_inverse_update_ops()) - - vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), - np.arange(2, 4).reshape(2, 1).astype(np.float32)) - output = block.multiply_inverse((array_ops.constant(vector[0]), - array_ops.constant(vector[1]))) - - output = sess.run(output) - self.assertAllClose([0.136455, 0.27291], output[0][0]) - self.assertAllClose([0.27291, 0.409365], output[1]) - - def testMultiplyInverseNotTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = random_ops.random_normal((2, 2, 2, 2)) - inputs = random_ops.random_normal((2, 2, 2, 2)) - outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB( - lc.LayerCollection(), params=params, padding='SAME') - block.register_additional_tower(inputs, outputs) - self.assertFalse(block._has_bias) - grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._input_factor.make_inverse_update_ops()) - sess.run(block._output_factor.make_inverse_update_ops()) - - vector = np.arange(1, 17).reshape(8, 2).astype(np.float32) - output = block.multiply_inverse(array_ops.constant(vector)) - - self.assertAllClose([0.136455, 0.27291], sess.run(output)[0]) - - def testMultiplyInverseNotTupleWithBias(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = [random_ops.random_normal((2, 2, 2, 2))] - inputs = random_ops.random_normal((2, 2, 2, 2)) - outputs = random_ops.random_normal((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB( - lc.LayerCollection(), params=params, padding='SAME') - block.register_additional_tower(inputs, outputs) - self.assertTrue(block._has_bias) - grads = outputs**2 - block.instantiate_factors(((grads,),), 0.5) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - # Make sure our inverse is something other than the identity. - sess.run(tf_variables.global_variables_initializer()) - sess.run(block._input_factor.make_inverse_update_ops()) - sess.run(block._output_factor.make_inverse_update_ops()) - - vector = np.arange(1, 19).reshape(9, 2).astype(np.float32) - output = block.multiply_inverse(array_ops.constant(vector)) - - self.assertAllClose([0.136455, 0.27291], sess.run(output)[0]) - - def testMultiplyInverseAgainstExplicit(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - params = array_ops.zeros((2, 2, 2, 2)) - inputs = array_ops.zeros((2, 2, 2, 2)) - outputs = array_ops.zeros((2, 2, 2, 2)) - block = fb.ConvKFCBasicFB( - lc.LayerCollection(), params=params, padding='SAME') - block.register_additional_tower(inputs, outputs) - grads = outputs**2 - damping = 0. # This test is only valid without damping. - block.instantiate_factors(((grads,),), damping) - block._input_factor.instantiate_cov_variables() - block._output_factor.instantiate_cov_variables() - block.register_inverse() - block._input_factor.instantiate_inv_variables() - block._output_factor.instantiate_inv_variables() - - sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8))) - sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) - sess.run(block._input_factor.make_inverse_update_ops()) - sess.run(block._output_factor.make_inverse_update_ops()) - - v_flat = np.arange(16, dtype=np.float32) - vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) - output = block.multiply_inverse(vector) - output_flat = sess.run(utils.tensors_to_column(output)).ravel() - - full = sess.run(block.full_fisher_block()) - explicit = np.dot(np.linalg.inv(full + damping * np.eye(16)), v_flat) - - self.assertAllClose(output_flat, explicit) - - -class FullyConnectedSeriesFBTest(test.TestCase): - - def testFullyConnectedSeriesFBInit(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - inputs = array_ops.constant([1., 2.]) - outputs = array_ops.constant([3., 4.]) - block = fb.FullyConnectedSeriesFB(lc.LayerCollection()) - block.register_additional_tower([inputs], [outputs]) - self.assertAllEqual([[outputs]], block.tensors_to_compute_grads()) - - def testInstantiateFactorsHasBias(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - inputs = array_ops.constant([[1., 2.], [3., 4.]]) - outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedSeriesFB( - lc.LayerCollection(), - has_bias=True) - block.register_additional_tower([inputs], [outputs]) - grads = outputs**2 - block.instantiate_factors((((grads,),),), 0.5) - - def testInstantiateFactorsNoBias(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - inputs = array_ops.constant([[1., 2.], [3., 4.]]) - outputs = array_ops.constant([[3., 4.], [5., 6.]]) - block = fb.FullyConnectedSeriesFB( - lc.LayerCollection(), - has_bias=False) - block.register_additional_tower([inputs], [outputs]) - grads = outputs**2 - block.instantiate_factors((((grads,),),), 0.5) - - -def as_tensors(tensor_or_tuple): - """Converts a potentially nested tuple of np.array to Tensors.""" - if isinstance(tensor_or_tuple, (tuple, list)): - return tuple(as_tensors(t) for t in tensor_or_tuple) - return ops.convert_to_tensor(tensor_or_tuple) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py deleted file mode 100644 index a396ca3f8581f9738f09bda30ee9a6cb6ae3fbab..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ /dev/null @@ -1,955 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.fisher_factors.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import numpy.random as npr - -from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb -from tensorflow.contrib.kfac.python.ops import fisher_factors as ff -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variables as tf_variables -from tensorflow.python.platform import test - - -# We need to set these constants since the numerical values used in the tests -# were chosen when these used to be the defaults. -ff.set_global_constants(init_covariances_at_zero=False, - zero_debias=False, - init_inverses_at_zero=False) - - -def make_damping_func(damping): - return fb._package_func(lambda: damping, damping) - - -class FisherFactorTestingDummy(ff.FisherFactor): - """Dummy class to test the non-abstract methods on ff.FisherFactor.""" - - @property - def _var_scope(self): - return 'dummy/a_b_c' - - @property - def _cov_shape(self): - raise NotImplementedError - - @property - def _num_sources(self): - return 1 - - @property - def _dtype(self): - return dtypes.float32 - - def _compute_new_cov(self): - raise NotImplementedError - - def instantiate_covariance(self): - pass - - def make_inverse_update_ops(self): - return [] - - def get_cov(self): - return NotImplementedError - - def instantiate_inv_variables(self): - return NotImplementedError - - def _num_towers(self): - raise NotImplementedError - - def _get_data_device(self): - raise NotImplementedError - - def register_matpower(self, exp, damping_func): - raise NotImplementedError - - def register_cholesky(self, damping_func): - raise NotImplementedError - - def register_cholesky_inverse(self, damping_func): - raise NotImplementedError - - def get_matpower(self, exp, damping_func): - raise NotImplementedError - - def get_cholesky(self, damping_func): - raise NotImplementedError - - def get_cholesky_inverse(self, damping_func): - raise NotImplementedError - - def get_cov_as_linear_operator(self): - raise NotImplementedError - - -class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor): - """Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor. - """ - - def __init__(self, shape): - self._shape = shape - super(DenseSquareMatrixFactorTestingDummy, self).__init__() - - @property - def _var_scope(self): - return 'dummy/a_b_c' - - @property - def _cov_shape(self): - return self._shape - - @property - def _num_sources(self): - return 1 - - @property - def _dtype(self): - return dtypes.float32 - - def _compute_new_cov(self): - raise NotImplementedError - - def instantiate_covariance(self): - pass - - def _num_towers(self): - raise NotImplementedError - - def _get_data_device(self): - raise NotImplementedError - - -class NumericalUtilsTest(test.TestCase): - - def testComputeCovAgainstNumpy(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - npr.seed(0) - random_seed.set_random_seed(200) - - x = npr.randn(100, 3) - cov = ff.compute_cov(array_ops.constant(x)) - np_cov = np.dot(x.T, x) / x.shape[0] - - self.assertAllClose(sess.run(cov), np_cov) - - def testComputeCovAgainstNumpyWithAlternativeNormalizer(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - npr.seed(0) - random_seed.set_random_seed(200) - - normalizer = 10. - x = npr.randn(100, 3) - cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer) - np_cov = np.dot(x.T, x) / normalizer - - self.assertAllClose(sess.run(cov), np_cov) - - def testAppendHomog(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - npr.seed(0) - - m, n = 3, 4 - a = npr.randn(m, n) - a_homog = ff.append_homog(array_ops.constant(a)) - np_result = np.hstack([a, np.ones((m, 1))]) - - self.assertAllClose(sess.run(a_homog), np_result) - - -class NameStringUtilFunctionTest(test.TestCase): - - def _make_tensor(self): - x = array_ops.placeholder(dtypes.float64, (3, 1)) - w = array_ops.constant(npr.RandomState(0).randn(3, 3)) - y = math_ops.matmul(w, x) - g = gradients_impl.gradients(y, x)[0] - return g - - def testScopeStringFromParamsSingleTensor(self): - with tf_ops.Graph().as_default(): - g = self._make_tensor() - scope_string = ff.scope_string_from_params(g) - self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string) - - def testScopeStringFromParamsMultipleTensors(self): - with tf_ops.Graph().as_default(): - x = array_ops.constant(1,) - y = array_ops.constant(2,) - scope_string = ff.scope_string_from_params((x, y)) - self.assertEqual('Const_Const_1', scope_string) - - def testScopeStringFromParamsMultipleTypes(self): - with tf_ops.Graph().as_default(): - x = array_ops.constant(1,) - y = array_ops.constant(2,) - scope_string = ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, - (x, y)]) - self.assertEqual('1-2-3_foo_True_4_Const__Const_1', scope_string) - - def testScopeStringFromParamsUnsupportedType(self): - with tf_ops.Graph().as_default(): - x = array_ops.constant(1,) - y = array_ops.constant(2,) - unsupported = 1.2 # Floats are not supported. - with self.assertRaises(ValueError): - ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, (x, y), - unsupported]) - - def testScopeStringFromName(self): - with tf_ops.Graph().as_default(): - g = self._make_tensor() - scope_string = ff.scope_string_from_name(g) - self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string) - - def testScalarOrTensorToString(self): - with tf_ops.Graph().as_default(): - self.assertEqual(ff.scalar_or_tensor_to_string(5.), repr(5.)) - - g = self._make_tensor() - scope_string = ff.scope_string_from_name(g) - self.assertEqual(ff.scalar_or_tensor_to_string(g), scope_string) - - -class FisherFactorTest(test.TestCase): - - def testMakeInverseUpdateOps(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - factor = FisherFactorTestingDummy() - - self.assertEqual(0, len(factor.make_inverse_update_ops())) - - -class DenseSquareMatrixFactorTest(test.TestCase): - - def testRegisterDampedInverse(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - shape = [2, 2] - factor = DenseSquareMatrixFactorTestingDummy(shape) - factor_var_scope = 'dummy/a_b_c' - - damping_funcs = [make_damping_func(0.1), - make_damping_func(0.1), - make_damping_func(1e-5), - make_damping_func(1e-5)] - for damping_func in damping_funcs: - factor.register_inverse(damping_func) - - factor.instantiate_inv_variables() - - inv = factor.get_inverse(damping_funcs[0]).to_dense() - self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense()) - self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense()) - self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(), - factor.get_inverse(damping_funcs[3]).to_dense()) - factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, - factor_var_scope) - factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) - - self.assertEqual(set([inv, - factor.get_inverse(damping_funcs[2]).to_dense()]), - set(factor_tensors)) - self.assertEqual(shape, inv.get_shape()) - - def testRegisterMatpower(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - shape = [3, 3] - factor = DenseSquareMatrixFactorTestingDummy(shape) - factor_var_scope = 'dummy/a_b_c' - - # TODO(b/74201126): Change to using the same func for both once - # Topohash is in place. - damping_func_1 = make_damping_func(0.5) - damping_func_2 = make_damping_func(0.5) - - factor.register_matpower(-0.5, damping_func_1) - factor.register_matpower(2, damping_func_2) - - factor.instantiate_inv_variables() - - factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, - factor_var_scope) - - factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) - - matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense() - matpower2 = factor.get_matpower(2, damping_func_2).to_dense() - - self.assertEqual(set([matpower1, matpower2]), set(factor_tensors)) - - self.assertEqual(shape, matpower1.get_shape()) - self.assertEqual(shape, matpower2.get_shape()) - - def testMakeInverseUpdateOps(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - factor = FisherFactorTestingDummy() - - self.assertEqual(0, len(factor.make_inverse_update_ops())) - - def testMakeInverseUpdateOpsManyInversesEigenDecomp(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - cov = np.array([[1., 2.], [3., 4.]]) - factor = DenseSquareMatrixFactorTestingDummy(cov.shape) - factor._cov = array_ops.constant(cov, dtype=dtypes.float32) - - damping_funcs = [] - for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): - damping_funcs.append(make_damping_func(1./i)) - - for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): - factor.register_inverse(damping_funcs[i]) - - factor.instantiate_inv_variables() - ops = factor.make_inverse_update_ops() - self.assertEqual(1, len(ops)) - - sess.run(tf_variables.global_variables_initializer()) - new_invs = [] - sess.run(ops) - for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): - # The inverse op will assign the damped inverse of cov to the inv var. - new_invs.append( - sess.run(factor.get_inverse(damping_funcs[i]).to_dense())) - - # We want to see that the new invs are all different from each other. - for i in range(len(new_invs)): - for j in range(i + 1, len(new_invs)): - # Just check the first element. - self.assertNotEqual(new_invs[i][0][0], new_invs[j][0][0]) - - def testMakeInverseUpdateOpsMatPowerEigenDecomp(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - cov = np.array([[6., 2.], [2., 4.]]) - factor = DenseSquareMatrixFactorTestingDummy(cov.shape) - factor._cov = array_ops.constant(cov, dtype=dtypes.float32) - exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power - damping = 0.5 - damping_func = make_damping_func(damping) - - factor.register_matpower(exp, damping_func) - factor.instantiate_inv_variables() - ops = factor.make_inverse_update_ops() - self.assertEqual(1, len(ops)) - - sess.run(tf_variables.global_variables_initializer()) - sess.run(ops[0]) - matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense()) - matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp) - self.assertAllClose(matpower, matpower_np) - - def testMakeInverseUpdateOpsNoEigenDecomp(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric - factor = DenseSquareMatrixFactorTestingDummy(cov.shape) - factor._cov = array_ops.constant(cov, dtype=dtypes.float32) - - damping_func = make_damping_func(0) - - factor.register_inverse(damping_func) - factor.instantiate_inv_variables() - ops = factor.make_inverse_update_ops() - self.assertEqual(1, len(ops)) - - sess.run(tf_variables.global_variables_initializer()) - # The inverse op will assign the damped inverse of cov to the inv var. - old_inv = sess.run(factor.get_inverse(damping_func).to_dense()) - self.assertAllClose( - sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv) - - sess.run(ops) - new_inv = sess.run(factor.get_inverse(damping_func).to_dense()) - self.assertAllClose(new_inv, np.linalg.inv(cov)) - - -class FullFactorTest(test.TestCase): - - def testFullFactorInit(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') - factor = ff.FullFactor((tensor,), 32) - factor.instantiate_cov_variables() - self.assertEqual([6, 6], factor.get_cov().get_shape().as_list()) - - def testFullFactorInitFloat64(self): - with tf_ops.Graph().as_default(): - dtype = dtypes.float64_ref - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - factor = ff.FullFactor((tensor,), 32) - factor.instantiate_cov_variables() - cov = factor.get_cov() - self.assertEqual(cov.dtype, dtype) - self.assertEqual([6, 6], cov.get_shape().as_list()) - - def testMakeCovarianceUpdateOp(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - tensor = array_ops.constant([1., 2.], name='a/b/c') - factor = ff.FullFactor((tensor,), 2) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[0.75, 0.5], [0.5, 1.5]], new_cov) - - -class NaiveDiagonalFactorTest(test.TestCase): - - def testNaiveDiagonalFactorInit(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') - factor = ff.NaiveDiagonalFactor((tensor,), 32) - factor.instantiate_cov_variables() - self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) - - def testNaiveDiagonalFactorInitFloat64(self): - with tf_ops.Graph().as_default(): - dtype = dtypes.float64_ref - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - factor = ff.NaiveDiagonalFactor((tensor,), 32) - factor.instantiate_cov_variables() - cov = factor.get_cov() - self.assertEqual(cov.dtype, dtype) - self.assertEqual([6, 1], cov.get_shape().as_list()) - - def testMakeCovarianceUpdateOp(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - tensor = array_ops.constant([1., 2.], name='a/b/c') - factor = ff.NaiveDiagonalFactor((tensor,), 2) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[0.75], [1.5]], new_cov) - - -class EmbeddingInputKroneckerFactorTest(test.TestCase): - - def testInitialization(self): - with tf_ops.Graph().as_default(): - input_ids = array_ops.constant([[0], [1], [4]]) - vocab_size = 5 - factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) - factor.instantiate_cov_variables() - cov = factor.get_cov() - self.assertEqual(cov.shape.as_list(), [vocab_size]) - - def testCovarianceUpdateOp(self): - with tf_ops.Graph().as_default(): - input_ids = array_ops.constant([[0], [1], [4]]) - vocab_size = 5 - factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) - factor.instantiate_cov_variables() - cov_update_op = factor.make_covariance_update_op(0.0) - - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(cov_update_op) - self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov) - - -class ConvDiagonalFactorTest(test.TestCase): - - def setUp(self): - self.batch_size = 10 - self.height = self.width = 32 - self.in_channels = 3 - self.out_channels = 1 - self.kernel_height = self.kernel_width = 3 - self.strides = [1, 2, 2, 1] - self.data_format = 'NHWC' - self.padding = 'SAME' - self.kernel_shape = [ - self.kernel_height, self.kernel_width, self.in_channels, - self.out_channels - ] - - def testInit(self): - with tf_ops.Graph().as_default(): - inputs = random_ops.random_uniform( - [self.batch_size, self.height, self.width, self.in_channels]) - outputs_grads = [ - random_ops.random_uniform([ - self.batch_size, self.height // self.strides[1], - self.width // self.strides[2], self.out_channels - ]) for _ in range(3) - ] - - factor = ff.ConvDiagonalFactor( - (inputs,), - (outputs_grads,), - self.kernel_shape, - self.strides, - self.padding, - data_format=self.data_format) - factor.instantiate_cov_variables() - - # Ensure covariance matrix's shape makes sense. - self.assertEqual([ - self.kernel_height * self.kernel_width * self.in_channels, - self.out_channels - ], - factor.get_cov().shape.as_list()) - - def testMakeCovarianceUpdateOp(self): - with tf_ops.Graph().as_default(): - # Construct all arguments such that convolution kernel is applied in - # exactly one spatial location. - inputs = np.random.randn( - 1, # batch_size - self.kernel_height, - self.kernel_width, - self.in_channels) # in_channels - outputs_grad = np.random.randn( - 1, # batch_size - 1, # output_height - 1, # output_width - self.out_channels) - - factor = ff.ConvDiagonalFactor( - (constant_op.constant(inputs),), - ((constant_op.constant(outputs_grad),),), - self.kernel_shape, - strides=[1, 1, 1, 1], - padding='VALID') - factor.instantiate_cov_variables() - - # Completely forget initial value on first update. - cov_update_op = factor.make_covariance_update_op(0.0) - - # Ensure new covariance value is same as outer-product of inputs/outputs - # vectorized, squared. - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - cov = sess.run(cov_update_op) - expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2 - self.assertAllClose(expected_cov, cov) - - def testHasBias(self): - with tf_ops.Graph().as_default(): - inputs = random_ops.random_uniform( - [self.batch_size, self.height, self.width, self.in_channels]) - outputs_grads = [ - random_ops.random_uniform([ - self.batch_size, self.height // self.strides[1], - self.width // self.strides[2], self.out_channels - ]) for _ in range(3) - ] - - factor = ff.ConvDiagonalFactor( - (inputs,), - (outputs_grads,), - self.kernel_shape, - self.strides, - self.padding, - data_format=self.data_format, - has_bias=True) - factor.instantiate_cov_variables() - - # Ensure shape accounts for bias. - self.assertEqual([ - self.kernel_height * self.kernel_width * self.in_channels + 1, - self.out_channels - ], - factor.get_cov().shape.as_list()) - - # Ensure update op doesn't crash. - cov_update_op = factor.make_covariance_update_op(0.0) - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - sess.run(cov_update_op) - - -class FullyConnectedKroneckerFactorTest(test.TestCase): - - def _testFullyConnectedKroneckerFactorInit(self, - has_bias, - final_shape, - dtype=dtypes.float32_ref): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias) - factor.instantiate_cov_variables() - cov = factor.get_cov() - self.assertEqual(cov.dtype, dtype) - self.assertEqual(final_shape, cov.get_shape().as_list()) - - def testFullyConnectedKroneckerFactorInitNoBias(self): - for dtype in (dtypes.float32_ref, dtypes.float64_ref): - self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype) - - def testFullyConnectedKroneckerFactorInitWithBias(self): - for dtype in (dtypes.float32_ref, dtypes.float64_ref): - self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype) - - def testMakeCovarianceUpdateOpWithBias(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov) - - def testMakeCovarianceUpdateOpNoBias(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedKroneckerFactor(((tensor,),)) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) - - -class ConvFactorTestCase(test.TestCase): - - def assertMatrixRank(self, rank, matrix, atol=1e-5): - assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.' - eigvals = np.linalg.eigvals(matrix) - nnz_eigvals = np.sum(eigvals > atol) - self.assertEqual( - rank, - nnz_eigvals, - msg=('Found %d of %d expected non-zero eigenvalues: %s.' % - (nnz_eigvals, rank, eigvals))) - - -class ConvInputKroneckerFactorTest(ConvFactorTestCase): - - def test3DConvolution(self): - with tf_ops.Graph().as_default(): - batch_size = 1 - width = 3 - in_channels = 3**3 - out_channels = 4 - - factor = ff.ConvInputKroneckerFactor( - inputs=(random_ops.random_uniform( - (batch_size, width, width, width, in_channels), seed=0),), - filter_shape=(width, width, width, in_channels, out_channels), - padding='SAME', - strides=(2, 2, 2), - extract_patches_fn='extract_convolution_patches', - has_bias=False) - factor.instantiate_cov_variables() - - # Ensure shape of covariance matches input size of filter. - input_size = in_channels * (width**3) - self.assertEqual([input_size, input_size], - factor.get_cov().shape.as_list()) - - # Ensure cov_update_op doesn't crash. - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov()) - - # Cov should be rank-8, as the filter will be applied at each corner of - # the 4-D cube. - self.assertMatrixRank(8, cov) - - def testPointwiseConv2d(self): - with tf_ops.Graph().as_default(): - batch_size = 1 - width = 3 - in_channels = 3**2 - out_channels = 4 - - factor = ff.ConvInputKroneckerFactor( - inputs=(random_ops.random_uniform( - (batch_size, width, width, in_channels), seed=0),), - filter_shape=(1, 1, in_channels, out_channels), - padding='SAME', - strides=(1, 1, 1, 1), - extract_patches_fn='extract_pointwise_conv2d_patches', - has_bias=False) - factor.instantiate_cov_variables() - - # Ensure shape of covariance matches input size of filter. - self.assertEqual([in_channels, in_channels], - factor.get_cov().shape.as_list()) - - # Ensure cov_update_op doesn't crash. - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov()) - - # Cov should be rank-9, as the filter will be applied at each location. - self.assertMatrixRank(9, cov) - - def testStrides(self): - with tf_ops.Graph().as_default(): - batch_size = 1 - width = 3 - in_channels = 3**2 - out_channels = 4 - - factor = ff.ConvInputKroneckerFactor( - inputs=(random_ops.random_uniform( - (batch_size, width, width, in_channels), seed=0),), - filter_shape=(1, 1, in_channels, out_channels), - padding='SAME', - strides=(1, 2, 1, 1), - extract_patches_fn='extract_image_patches', - has_bias=False) - factor.instantiate_cov_variables() - - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov()) - - # Cov should be the sum of 3 * 2 = 6 outer products. - self.assertMatrixRank(6, cov) - - def testDilationRate(self): - with tf_ops.Graph().as_default(): - batch_size = 1 - width = 3 - in_channels = 2 - out_channels = 4 - - factor = ff.ConvInputKroneckerFactor( - inputs=(random_ops.random_uniform( - (batch_size, width, width, in_channels), seed=0),), - filter_shape=(3, 3, in_channels, out_channels), - padding='SAME', - extract_patches_fn='extract_image_patches', - strides=(1, 1, 1, 1), - dilation_rate=(1, width, width, 1), - has_bias=False) - factor.instantiate_cov_variables() - - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov()) - - # Cov should be rank = in_channels, as only the center of the filter - # receives non-zero input for each input channel. - self.assertMatrixRank(in_channels, cov) - - def testConvInputKroneckerFactorInitNoBias(self): - with tf_ops.Graph().as_default(): - tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') - factor = ff.ConvInputKroneckerFactor( - inputs=(tensor,), - filter_shape=(1, 2, 3, 4), - padding='SAME', - has_bias=False) - factor.instantiate_cov_variables() - self.assertEqual([1 * 2 * 3, 1 * 2 * 3], - factor.get_cov().get_shape().as_list()) - - def testConvInputKroneckerFactorInit(self): - with tf_ops.Graph().as_default(): - tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c') - factor = ff.ConvInputKroneckerFactor( - (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) - factor.instantiate_cov_variables() - self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], - factor.get_cov().get_shape().as_list()) - - def testConvInputKroneckerFactorInitFloat64(self): - with tf_ops.Graph().as_default(): - dtype = dtypes.float64_ref - tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64) - factor = ff.ConvInputKroneckerFactor( - (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True) - factor.instantiate_cov_variables() - cov = factor.get_cov() - self.assertEqual(cov.dtype, dtype) - self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], - cov.get_shape().as_list()) - - def testMakeCovarianceUpdateOpWithBias(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - input_shape = (2, 1, 1, 1) - tensor = array_ops.constant( - np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( - np.float32)) - factor = ff.ConvInputKroneckerFactor( - (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(0.)) - self.assertAllClose( - [ - [(1. + 4.) / 2., (1. + 2.) / 2.], # - [(1. + 2.) / 2., (1. + 1.) / 2.] - ], # - new_cov) - - def testMakeCovarianceUpdateOpNoBias(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - input_shape = (2, 1, 1, 1) - tensor = array_ops.constant( - np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype( - np.float32)) - factor = ff.ConvInputKroneckerFactor( - (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME') - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(0.)) - self.assertAllClose([[(1. + 4.) / 2.]], new_cov) - - def testSubSample(self): - with tf_ops.Graph().as_default(): - patches_1 = array_ops.constant(1, shape=(10, 2)) - patches_2 = array_ops.constant(1, shape=(10, 8)) - patches_3 = array_ops.constant(1, shape=(3, 3)) - patches_1_sub = ff._subsample_for_cov_computation(patches_1) - patches_2_sub = ff._subsample_for_cov_computation(patches_2) - patches_3_sub = ff._subsample_for_cov_computation(patches_3) - patches_1_sub_batch_size = patches_1_sub.shape.as_list()[0] - patches_2_sub_batch_size = patches_2_sub.shape.as_list()[0] - patches_3_sub_batch_size = patches_3_sub.shape.as_list()[0] - self.assertEqual(2, patches_1_sub_batch_size) - self.assertEqual(8, patches_2_sub_batch_size) - self.assertEqual(3, patches_3_sub_batch_size) - - -class ConvOutputKroneckerFactorTest(ConvFactorTestCase): - - def test3DConvolution(self): - with tf_ops.Graph().as_default(): - batch_size = 1 - width = 3 - out_channels = width**3 - - factor = ff.ConvOutputKroneckerFactor(outputs_grads=([ - random_ops.random_uniform( - (batch_size, width, width, width, out_channels), seed=0) - ],)) - factor.instantiate_cov_variables() - - with self.cached_session() as sess: - sess.run(tf_variables.global_variables_initializer()) - sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov()) - - # Cov should be rank 3^3, as each spatial position donates a rank-1 - # update. - self.assertMatrixRank(width**3, cov) - - def testConvOutputKroneckerFactorInit(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c') - factor = ff.ConvOutputKroneckerFactor(((tensor,),)) - factor.instantiate_cov_variables() - self.assertEqual([5, 5], factor.get_cov().get_shape().as_list()) - - def testConvOutputKroneckerFactorInitFloat64(self): - with tf_ops.Graph().as_default(): - dtype = dtypes.float64_ref - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c') - factor = ff.ConvOutputKroneckerFactor(((tensor,),)) - factor.instantiate_cov_variables() - cov = factor.get_cov() - self.assertEqual(cov.dtype, dtype) - self.assertEqual([5, 5], cov.get_shape().as_list()) - - def testMakeCovarianceUpdateOp(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32) - factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),)) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov) - - -class FullyConnectedMultiKFTest(test.TestCase): - - def testFullyConnectedMultiKFInit(self): - with tf_ops.Graph().as_default(): - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') - factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) - factor.instantiate_cov_variables() - self.assertEqual([3, 3], factor.get_cov().get_shape().as_list()) - - def testFullyConnectedMultiKFInitFloat64(self): - with tf_ops.Graph().as_default(): - dtype = dtypes.float64_ref - random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') - factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False) - factor.instantiate_cov_variables() - cov = factor.get_cov() - self.assertEqual(cov.dtype, dtype) - self.assertEqual([3, 3], cov.get_shape().as_list()) - - def testMakeCovarianceUpdateOpWithBias(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov) - - def testMakeCovarianceUpdateOpNoBias(self): - with tf_ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') - factor = ff.FullyConnectedMultiKF(((tensor,),)) - factor.instantiate_cov_variables() - - sess.run(tf_variables.global_variables_initializer()) - new_cov = sess.run(factor.make_covariance_update_op(.5)) - self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py deleted file mode 100644 index 586fcd4c3cf364227bae7bd1546250fd7921eb7a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ /dev/null @@ -1,597 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.layer_collection.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.kfac.python.ops import fisher_blocks -from tensorflow.contrib.kfac.python.ops import fisher_factors -from tensorflow.contrib.kfac.python.ops import layer_collection -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import test - - -class MockFisherBlock(object): - """A fake FisherBlock.""" - - num_registered_towers = 2 - - def __init__(self, name='MockFisherBlock'): - self.name = name - - def __eq__(self, other): - return isinstance(other, MockFisherBlock) and other.name == self.name - - def __hash__(self): - return hash(self.name) - - -class LayerParametersDictTest(test.TestCase): - - def testSetItem(self): - """Ensure insertion, contains, retrieval works for supported key types.""" - with ops.Graph().as_default(): - lp_dict = layer_collection.LayerParametersDict() - - x = array_ops.constant(0) - y0 = array_ops.constant(0) - y1 = array_ops.constant(0) - z0 = array_ops.constant(0) - z1 = array_ops.constant(0) - keys = [x, (y0, y1), [z0, z1]] - for key in keys: - lp_dict[key] = key - - for key in keys: - self.assertTrue(key in lp_dict) - self.assertEqual(lp_dict[key], key) - - def testSetItemOverlap(self): - """Ensure insertion fails if key overlaps with existing key.""" - with ops.Graph().as_default(): - lp_dict = layer_collection.LayerParametersDict() - - x = array_ops.constant(0) - y = array_ops.constant(0) - lp_dict[x] = 'value' - - with self.assertRaises(ValueError): - lp_dict[(x, y)] = 'value' - - # Ensure 'y' wasn't inserted. - self.assertTrue(x in lp_dict) - self.assertFalse(y in lp_dict) - - -class LayerCollectionTest(test.TestCase): - - def testLayerCollectionInit(self): - lc = layer_collection.LayerCollection() - self.assertEqual(0, len(lc.get_blocks())) - self.assertEqual(0, len(lc.get_factors())) - self.assertFalse(lc.losses) - - def testRegisterBlocks(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - lc = layer_collection.LayerCollection() - lc.register_fully_connected( - array_ops.constant(1), array_ops.constant(2), array_ops.constant(3)) - lc.register_fully_connected( - array_ops.constant(1), - array_ops.constant(2), - array_ops.constant(3), - approx=layer_collection.APPROX_DIAGONAL_NAME) - lc.register_conv2d( - params=array_ops.ones((2, 3, 4, 5)), - strides=[1, 1, 1, 1], - padding='SAME', - inputs=array_ops.ones((1, 2, 3, 4)), - outputs=array_ops.ones((1, 1, 1, 5))) - lc.register_conv2d( - params=array_ops.ones((2, 3, 4, 5)), - strides=[1, 1, 1, 1], - padding='SAME', - inputs=array_ops.ones((1, 2, 3, 4)), - outputs=array_ops.ones((1, 1, 1, 5)), - approx=layer_collection.APPROX_DIAGONAL_NAME) - lc.register_separable_conv2d( - depthwise_params=array_ops.ones((3, 3, 1, 2)), - pointwise_params=array_ops.ones((1, 1, 2, 4)), - inputs=array_ops.ones((32, 5, 5, 1)), - depthwise_outputs=array_ops.ones((32, 5, 5, 2)), - pointwise_outputs=array_ops.ones((32, 5, 5, 4)), - strides=[1, 1, 1, 1], - padding='SAME') - lc.register_convolution( - params=array_ops.ones((3, 3, 1, 8)), - inputs=array_ops.ones((32, 5, 5, 1)), - outputs=array_ops.ones((32, 5, 5, 8)), - padding='SAME') - lc.register_generic( - array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) - lc.register_generic( - array_ops.constant(6), - 16, - approx=layer_collection.APPROX_DIAGONAL_NAME) - lc.register_fully_connected_multi( - array_ops.constant(1), - (array_ops.constant(2), array_ops.constant(3)), - (array_ops.constant(4), array_ops.constant(5))) - lc.register_conv2d_multi( - params=array_ops.ones((2, 3, 4, 5)), - strides=[1, 1, 1, 1], - padding='SAME', - inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))), - outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10)))) - lc.register_embedding_multi( - array_ops.constant((1,)), - (array_ops.constant(2), array_ops.constant(3)), - (array_ops.constant(4), array_ops.constant(5))) - - self.assertEqual(12, len(lc.get_blocks())) - - def testRegisterBlocksMultipleRegistrations(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - lc = layer_collection.LayerCollection() - key = array_ops.constant(1) - lc.register_fully_connected(key, array_ops.constant(2), - array_ops.constant(3)) - with self.assertRaises(ValueError) as cm: - lc.register_generic(key, 16) - self.assertIn('already in LayerCollection', str(cm.exception)) - - def testRegisterSingleParamNotRegistered(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = { - variable_scope.get_variable('y', initializer=array_ops.constant(1,)): - '1' - } - lc.register_block(x, 'foo') - - def testShouldRegisterSingleParamRegistered(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = {x: '1'} - with self.assertRaises(ValueError) as cm: - lc.register_block(x, 'foo') - self.assertIn('already in LayerCollection', str(cm.exception)) - - def testRegisterSingleParamRegisteredInTuple(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = {(x, y): '1'} - with self.assertRaises(ValueError) as cm: - lc.register_block(x, 'foo') - self.assertIn('was already registered', str(cm.exception)) - - def testRegisterTupleParamNotRegistered(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = { - variable_scope.get_variable('z', initializer=array_ops.constant(1,)): - '1' - } - - lc.register_block((x, y), 'foo') - self.assertEqual(set(['1', 'foo']), set(lc.get_blocks())) - - def testRegisterTupleParamRegistered(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = {(x, y): '1'} - - with self.assertRaises(ValueError) as cm: - lc.register_block((x, y), 'foo') - self.assertIn('already in LayerCollection', str(cm.exception)) - - def testRegisterTupleParamRegisteredInSuperset(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) - z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = {(x, y, z): '1'} - - with self.assertRaises(ValueError) as cm: - lc.register_block((x, y), 'foo') - self.assertIn('was already registered', str(cm.exception)) - - def testRegisterTupleParamSomeRegistered(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) - z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')} - - with self.assertRaises(ValueError) as cm: - lc.register_block((x, y), MockFisherBlock('foo')) - self.assertIn('was already registered', str(cm.exception)) - - def testRegisterTupleVarSomeRegisteredInOtherTuples(self): - x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) - y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) - z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) - w = variable_scope.get_variable('w', initializer=array_ops.constant(1,)) - lc = layer_collection.LayerCollection() - lc.fisher_blocks = {(x, z): '1', (z, w): '2'} - - with self.assertRaises(ValueError) as cm: - lc.register_block((x, y), 'foo') - self.assertIn('was already registered', str(cm.exception)) - - def testRegisterCategoricalPredictiveDistribution(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - logits = linalg_ops.eye(2) - - lc = layer_collection.LayerCollection() - lc.register_categorical_predictive_distribution(logits, seed=200) - single_loss = sess.run(lc.total_sampled_loss()) - - lc2 = layer_collection.LayerCollection() - lc2.register_categorical_predictive_distribution(logits, seed=200) - lc2.register_categorical_predictive_distribution(logits, seed=200) - double_loss = sess.run(lc2.total_sampled_loss()) - self.assertAlmostEqual(2 * single_loss, double_loss) - - def testLossFunctionByName(self): - """Ensure loss functions can be identified by name.""" - with ops.Graph().as_default(): - logits = linalg_ops.eye(2) - lc = layer_collection.LayerCollection() - - # Create a new loss function by name. - lc.register_categorical_predictive_distribution(logits, name='loss1') - self.assertEqual(1, len(lc.towers_by_loss)) - - # Add logits to same loss function. - lc.register_categorical_predictive_distribution( - logits, name='loss1', reuse=True) - self.assertEqual(1, len(lc.towers_by_loss)) - - # Add another new loss function. - lc.register_categorical_predictive_distribution(logits, name='loss2') - self.assertEqual(2, len(lc.towers_by_loss)) - - def testLossFunctionWithoutName(self): - """Ensure loss functions get unique names if 'name' not specified.""" - with ops.Graph().as_default(): - logits = linalg_ops.eye(2) - lc = layer_collection.LayerCollection() - - # Create a new loss function with default names. - lc.register_categorical_predictive_distribution(logits) - lc.register_categorical_predictive_distribution(logits) - self.assertEqual(2, len(lc.losses)) - - def testCategoricalPredictiveDistributionMultipleMinibatches(self): - """Ensure multiple minibatches are registered.""" - with ops.Graph().as_default(): - batch_size = 3 - output_size = 2 - logits = array_ops.zeros([batch_size, output_size]) - targets = array_ops.ones([batch_size], dtype=dtypes.int32) - lc = layer_collection.LayerCollection() - - # Create a new loss function. - lc.register_categorical_predictive_distribution( - logits, targets=targets, name='loss1') - - # Can add when reuse=True - lc.register_categorical_predictive_distribution( - logits, targets=targets, name='loss1', reuse=True) - - # Can add when reuse=VARIABLE_SCOPE and reuse=True there. - with variable_scope.variable_scope( - variable_scope.get_variable_scope(), reuse=True): - lc.register_categorical_predictive_distribution( - logits, - targets=targets, - name='loss1', - reuse=layer_collection.VARIABLE_SCOPE) - - # Can't add when reuse=False - with self.assertRaises(KeyError): - lc.register_categorical_predictive_distribution( - logits, targets=targets, name='loss1', reuse=False) - - # Can't add when reuse=VARIABLE_SCOPE and reuse=False there. - with self.assertRaises(KeyError): - lc.register_categorical_predictive_distribution( - logits, - targets=targets, - name='loss1', - reuse=layer_collection.VARIABLE_SCOPE) - - self.assertEqual(len(lc.towers_by_loss), 1) - # Three successful registrations. - self.assertEqual(len(lc.towers_by_loss[0]), 3) - - def testRegisterCategoricalPredictiveDistributionBatchSize1(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - logits = random_ops.random_normal((1, 2)) - lc = layer_collection.LayerCollection() - - lc.register_categorical_predictive_distribution(logits, seed=200) - - def testRegisterCategoricalPredictiveDistributionSpecifiedTargets(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - logits = array_ops.constant([[1., 2.], [3., 4.]], dtype=dtypes.float32) - lc = layer_collection.LayerCollection() - targets = array_ops.constant([0, 1], dtype=dtypes.int32) - - lc.register_categorical_predictive_distribution(logits, targets=targets) - single_loss = sess.run(lc.total_loss()) - self.assertAlmostEqual(1.6265233, single_loss) - - def testRegisterNormalPredictiveDistribution(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - predictions = array_ops.constant( - [[1., 2.], [3., 4]], dtype=dtypes.float32) - - lc = layer_collection.LayerCollection() - lc.register_normal_predictive_distribution(predictions, 1., seed=200) - single_loss = sess.run(lc.total_sampled_loss()) - - lc2 = layer_collection.LayerCollection() - lc2.register_normal_predictive_distribution(predictions, 1., seed=200) - lc2.register_normal_predictive_distribution(predictions, 1., seed=200) - double_loss = sess.run(lc2.total_sampled_loss()) - - self.assertAlmostEqual(2 * single_loss, double_loss) - - def testRegisterNormalPredictiveDistributionSpecifiedTargets(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - predictions = array_ops.constant( - [[1., 2.], [3., 4.]], dtype=dtypes.float32) - lc = layer_collection.LayerCollection() - targets = array_ops.constant([[3., 1.], [4., 2.]], dtype=dtypes.float32) - - lc.register_normal_predictive_distribution( - predictions, 2.**2, targets=targets) - single_loss = sess.run(lc.total_loss()) - self.assertAlmostEqual(7.6983433, single_loss) - - def ensureLayerReuseWorks(self, register_fn): - """Ensure the 'reuse' keyword argument function as intended. - - Args: - register_fn: function for registering a layer. Arguments are - layer_collection, reuse, and approx. - """ - # Fails on second if reuse=False. - lc = layer_collection.LayerCollection() - register_fn(lc) - with self.assertRaises(ValueError): - register_fn(lc, reuse=False) - - # Succeeds on second if reuse=True. - lc = layer_collection.LayerCollection() - register_fn(lc) - register_fn(lc, reuse=True) - - # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse. - lc = layer_collection.LayerCollection() - register_fn(lc) - with self.assertRaises(ValueError): - register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE) - - # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse. - lc = layer_collection.LayerCollection() - register_fn(lc) - with variable_scope.variable_scope( - variable_scope.get_variable_scope(), reuse=True): - register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE) - - # Fails if block type changes. - lc = layer_collection.LayerCollection() - register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME) - with self.assertRaises(ValueError): - register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True) - - # Fails if reuse requested but no FisherBlock exists. - lc = layer_collection.LayerCollection() - with self.assertRaises(KeyError): - register_fn(lc, reuse=True) - - def testRegisterFullyConnectedReuse(self): - """Ensure the 'reuse' works with register_fully_connected.""" - with ops.Graph().as_default(): - inputs = array_ops.ones([2, 10]) - outputs = array_ops.zeros([2, 5]) - params = ( - variable_scope.get_variable('w', [10, 5]), # - variable_scope.get_variable('b', [5])) - - def register_fn(lc, **kwargs): - lc.register_fully_connected( - params=params, inputs=inputs, outputs=outputs, **kwargs) - - self.ensureLayerReuseWorks(register_fn) - - def testRegisterConv2dReuse(self): - """Ensure the 'reuse' works with register_conv2d.""" - with ops.Graph().as_default(): - inputs = array_ops.ones([2, 5, 5, 10]) - outputs = array_ops.zeros([2, 5, 5, 3]) - params = ( - variable_scope.get_variable('w', [1, 1, 10, 3]), # - variable_scope.get_variable('b', [3])) - - def register_fn(lc, **kwargs): - lc.register_conv2d( - params=params, - strides=[1, 1, 1, 1], - padding='SAME', - inputs=inputs, - outputs=outputs, - **kwargs) - - self.ensureLayerReuseWorks(register_fn) - - def testReuseWithInvalidRegistration(self): - """Invalid registrations shouldn't overwrite existing blocks.""" - with ops.Graph().as_default(): - inputs = array_ops.ones([2, 5, 5, 10]) - outputs = array_ops.zeros([2, 5, 5, 3]) - w = variable_scope.get_variable('w', [1, 1, 10, 3]) - b = variable_scope.get_variable('b', [3]) - lc = layer_collection.LayerCollection() - lc.register_fully_connected(w, inputs, outputs) - self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) - with self.assertRaises(KeyError): - lc.register_fully_connected((w, b), inputs, outputs, reuse=True) - self.assertNotIn((w, b), lc.fisher_blocks) - self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1) - lc.register_fully_connected(w, inputs, outputs, reuse=True) - self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2) - - def testMakeOrGetFactor(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - lc = layer_collection.LayerCollection() - key = array_ops.constant(1) - lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) - lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) - lc.make_or_get_factor(fisher_factors.FullFactor, - ((array_ops.constant(2),), 16)) - - self.assertEqual(2, len(lc.get_factors())) - variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertTrue( - all([var.name.startswith('LayerCollection') for var in variables])) - - def testMakeOrGetFactorCustomScope(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - scope = 'Foo' - lc = layer_collection.LayerCollection(name=scope) - key = array_ops.constant(1) - lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) - lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) - lc.make_or_get_factor(fisher_factors.FullFactor, - ((array_ops.constant(2),), 16)) - - self.assertEqual(2, len(lc.get_factors())) - variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - self.assertTrue(all([var.name.startswith(scope) for var in variables])) - - def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self): - x = variable_scope.get_variable('x', shape=()) - y = variable_scope.get_variable('y', shape=()) - z = variable_scope.get_variable('z', shape=()) - lc = layer_collection.LayerCollection() - lc.define_linked_parameters((x, y)) - - with self.assertRaises(ValueError): - lc.define_linked_parameters((x, z)) - - def testIdentifySubsetPreviouslyRegisteredTensor(self): - x = variable_scope.get_variable('x', shape=()) - y = variable_scope.get_variable('y', shape=()) - lc = layer_collection.LayerCollection() - lc.define_linked_parameters((x, y)) - - with self.assertRaises(ValueError): - lc.define_linked_parameters(x) - - def testSpecifyApproximation(self): - w_0 = variable_scope.get_variable('w_0', [10, 10]) - w_1 = variable_scope.get_variable('w_1', [10, 10]) - - b_0 = variable_scope.get_variable('b_0', [10]) - b_1 = variable_scope.get_variable('b_1', [10]) - - x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) - x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) - - pre_bias_0 = math_ops.matmul(x_0, w_0) - pre_bias_1 = math_ops.matmul(x_1, w_1) - - # Build the fully connected layers in the graph. - pre_bias_0 + b_0 # pylint: disable=pointless-statement - pre_bias_1 + b_1 # pylint: disable=pointless-statement - - lc = layer_collection.LayerCollection() - lc.define_linked_parameters( - w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME) - lc.define_linked_parameters( - w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME) - lc.define_linked_parameters( - b_0, approximation=layer_collection.APPROX_FULL_NAME) - lc.define_linked_parameters( - b_1, approximation=layer_collection.APPROX_FULL_NAME) - - lc.register_fully_connected(w_0, x_0, pre_bias_0) - lc.register_fully_connected( - w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME) - self.assertIsInstance(lc.fisher_blocks[w_0], - fisher_blocks.FullyConnectedDiagonalFB) - self.assertIsInstance(lc.fisher_blocks[w_1], - fisher_blocks.FullyConnectedKFACBasicFB) - - lc.register_generic(b_0, batch_size=1) - lc.register_generic( - b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME) - self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB) - self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB) - - def testDefaultLayerCollection(self): - with ops.Graph().as_default(): - # Can't get default if there isn't one set. - with self.assertRaises(ValueError): - layer_collection.get_default_layer_collection() - - # Can't set default twice. - lc = layer_collection.LayerCollection() - layer_collection.set_default_layer_collection(lc) - with self.assertRaises(ValueError): - layer_collection.set_default_layer_collection(lc) - - # Same as one set. - self.assertTrue(lc is layer_collection.get_default_layer_collection()) - - # Can set to None. - layer_collection.set_default_layer_collection(None) - with self.assertRaises(ValueError): - layer_collection.get_default_layer_collection() - - # as_default() is the same as setting/clearing. - with lc.as_default(): - self.assertTrue(lc is layer_collection.get_default_layer_collection()) - with self.assertRaises(ValueError): - layer_collection.get_default_layer_collection() - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py deleted file mode 100644 index f424e02360da7e3399b630be5391591d4bdd07d5..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.loss_functions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.kfac.python.ops import loss_functions -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import test - - -class InsertSliceInZerosTest(test.TestCase): - - def testBadShape(self): - bad_shaped_ones = array_ops.ones(shape=[1, 3]) # n.b. shape[1] != 1 - with self.assertRaises(ValueError): - loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17) - - def test3d(self): - input_tensor = constant_op.constant([[[1, 2]], [[3, 4]]]) - expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]] - op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0) - with self.cached_session() as sess: - actual_output_array = sess.run(op) - self.assertAllEqual(expected_output_array, actual_output_array) - - -class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): - - def testSample(self): - """Ensure samples can be drawn.""" - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.asarray([ - [0., 0., 0.], # - [1., -1., 0.] - ]).astype(np.float32) - loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( - array_ops.constant(logits)) - sample = loss.sample(42) - sample = sess.run(sample) - self.assertEqual(sample.shape, (2,)) - - def testEvaluateOnTargets(self): - """Ensure log probability can be evaluated correctly.""" - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.asarray([ - [0., 0., 0.], # - [1., -1., 0.] - ]).astype(np.float32) - targets = np.asarray([2, 1]).astype(np.int32) - loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( - array_ops.constant(logits), targets=array_ops.constant(targets)) - neg_log_prob = loss.evaluate() - neg_log_prob = sess.run(neg_log_prob) - - # Calculate explicit log probability of targets. - probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) - log_probs = np.log([ - probs[0, targets[0]], # - probs[1, targets[1]] - ]) - expected_log_prob = np.sum(log_probs) - - self.assertAllClose(neg_log_prob, -expected_log_prob) - - def testEvaluateOnSample(self): - """Ensure log probability of a sample can be drawn.""" - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.asarray([ - [0., 0., 0.], # - [1., -1., 0.] - ]).astype(np.float32) - loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( - array_ops.constant(logits)) - neg_log_prob = loss.evaluate_on_sample(42) - - # Simply ensure this doesn't crash. As the output is random, it's - # difficult to say if the output is correct or not... - neg_log_prob = sess.run(neg_log_prob) - - def testMultiplyFisherSingleVector(self): - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.array([1., 2., 3.]) - loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) - - # the LossFunction.multiply_fisher docstring only says it supports the - # case where the vector is the same shape as the input natural parameters - # (i.e. the logits here), but here we also test leading dimensions - vector = np.array([1., 2., 3.]) - vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)] - - probs = np.exp(logits - np.logaddexp.reduce(logits)) - fisher = np.diag(probs) - np.outer(probs, probs) - - for vector in vectors: - result = loss.multiply_fisher(vector) - expected_result = np.dot(vector, fisher) - self.assertAllClose(expected_result, sess.run(result)) - - def testMultiplyFisherBatch(self): - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.array([[1., 2., 3.], [4., 6., 8.]]) - loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) - - vector = np.array([[1., 2., 3.], [5., 3., 1.]]) - - na = np.newaxis - probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1, - keepdims=True)) - fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :] - - result = loss.multiply_fisher(vector) - expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :] - self.assertEqual(sess.run(result).shape, logits.shape) - self.assertAllClose(expected_result, sess.run(result)) - - -class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase): - - def testSample(self): - """Ensure samples can be drawn.""" - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.asarray([ - [0., 0., 0.], # - [1., -1., 0.] - ]).astype(np.float32) - loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( - array_ops.constant(logits)) - sample = loss.sample(42) - sample = sess.run(sample) - self.assertEqual(sample.shape, (2, 3)) - - def testEvaluateOnTargets(self): - """Ensure log probability can be evaluated correctly.""" - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.asarray([ - [0., 0., 0.], # - [1., -1., 0.] - ]).astype(np.float32) - targets = np.asarray([2, 1]).astype(np.int32) - loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( - array_ops.constant(logits), targets=array_ops.one_hot(targets, 3)) - neg_log_prob = loss.evaluate() - neg_log_prob = sess.run(neg_log_prob) - - # Calculate explicit log probability of targets. - probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) - log_probs = np.log([ - probs[0, targets[0]], # - probs[1, targets[1]] - ]) - expected_log_prob = np.sum(log_probs) - - self.assertAllClose(neg_log_prob, -expected_log_prob) - - def testEvaluateOnSample(self): - """Ensure log probability of a sample can be drawn.""" - with ops.Graph().as_default(), self.cached_session() as sess: - logits = np.asarray([ - [0., 0., 0.], # - [1., -1., 0.] - ]).astype(np.float32) - loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( - array_ops.constant(logits)) - neg_log_prob = loss.evaluate_on_sample(42) - - # Simply ensure this doesn't crash. As the output is random, it's - # difficult to say if the output is correct or not... - neg_log_prob = sess.run(neg_log_prob) - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py b/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py deleted file mode 100644 index 4fae4374e17e0b14d5e9c8ebbb626f348865f475..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.op_queue.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.kfac.python.ops import op_queue -from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import test - - -class OpQueueTest(test.TestCase): - - def testNextOp(self): - """Ensures all ops get selected eventually.""" - with tf_ops.Graph().as_default(): - ops = [ - math_ops.add(1, 2), - math_ops.subtract(1, 2), - math_ops.reduce_mean([1, 2]), - ] - queue = op_queue.OpQueue(ops, seed=0) - - with self.cached_session() as sess: - # Ensure every inv update op gets selected. - selected_ops = set([queue.next_op(sess) for _ in ops]) - self.assertEqual(set(ops), set(selected_ops)) - - # Ensure additional calls don't create any new ops. - selected_ops.add(queue.next_op(sess)) - self.assertEqual(set(ops), set(selected_ops)) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py deleted file mode 100644 index 0b0de12ce6a651455614a20bf48e1e9a951d046c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.optimizer.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.kfac.python.ops import fisher_factors as ff -from tensorflow.contrib.kfac.python.ops import layer_collection as lc -from tensorflow.contrib.kfac.python.ops import optimizer -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as tf_variables -from tensorflow.python.platform import test - - -# We need to set these constants since the numerical values used in the tests -# were chosen when these used to be the defaults. -ff.set_global_constants(init_covariances_at_zero=False, - zero_debias=False, - init_inverses_at_zero=False) - - -def dummy_layer_collection(): - lcoll = lc.LayerCollection() - dummy = array_ops.constant([1., 2.]) - lcoll.register_categorical_predictive_distribution(logits=dummy) - return lcoll - - -class OptimizerTest(test.TestCase): - - def testOptimizerInitInvalidMomentumRegistration(self): - with self.assertRaises(ValueError): - optimizer.KfacOptimizer( - 0.1, 0.2, 0.3, lc.LayerCollection(), momentum_type='foo') - - def testOptimizerInit(self): - with ops.Graph().as_default(): - layer_collection = lc.LayerCollection() - - inputs = array_ops.ones((2, 1)) * 2 - weights_val = np.ones((1, 1), dtype=np.float32) * 3. - weights = variable_scope.get_variable( - 'w', initializer=array_ops.constant(weights_val)) - bias = variable_scope.get_variable( - 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1)) - output = math_ops.matmul(inputs, weights) + bias - - layer_collection.register_fully_connected((weights, bias), inputs, output) - - logits = math_ops.tanh(output) - targets = array_ops.constant([[0.], [1.]]) - output = math_ops.reduce_mean( - nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets)) - - layer_collection.register_categorical_predictive_distribution(logits) - - optimizer.KfacOptimizer( - 0.1, - 0.2, - 0.3, - layer_collection, - momentum=0.5, - momentum_type='regular') - - def testSquaredFisherNorm(self): - with ops.Graph().as_default(), self.cached_session() as sess: - grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None), - (array_ops.constant([[2., 3.], [4., 5.]]), None)] - pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None), - (array_ops.constant([[7., 8.], [9., 10.]]), None)] - opt = optimizer.KfacOptimizer(0.1, 0.2, 0.3, dummy_layer_collection()) - sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars) - self.assertAlmostEqual(174., sess.run(sq_norm), places=5) - - def testUpdateClipCoeff(self): - with ops.Graph().as_default(), self.cached_session() as sess: - grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None), - (array_ops.constant([[2., 3.], [4., 5.]]), None)] - pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None), - (array_ops.constant([[7., 8.], [9., 10.]]), None)] - lrate = 0.1 - - # Note: without rescaling, the squared Fisher norm of the update - # is 1.74 - - # If the update already satisfies the norm constraint, there should - # be no rescaling. - opt = optimizer.KfacOptimizer( - lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=10.) - coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars) - self.assertAlmostEqual(1., sess.run(coeff), places=5) - - # If the update violates the constraint, it should be rescaled to - # be on the constraint boundary. - opt = optimizer.KfacOptimizer( - lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=0.5) - coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars) - sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars) - sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad - self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5) - - def testComputeUpdateStepsRegular(self): - # TODO(olganw): implement this. - pass - - def testComputeUpdateStepsAdam(self): - # TODO(olganw): implement this. - pass - - def testUpdateVelocities(self): - with ops.Graph().as_default(), self.cached_session() as sess: - layers = lc.LayerCollection() - layers.register_categorical_predictive_distribution( - array_ops.constant([1.0])) - opt = optimizer.KfacOptimizer( - 0.1, 0.2, 0.3, layers, momentum=0.5, momentum_type='regular') - x = variable_scope.get_variable('x', initializer=array_ops.ones((2, 2))) - y = variable_scope.get_variable( - 'y', initializer=array_ops.ones((2, 2)) * 2) - vec1 = array_ops.ones((2, 2)) * 3 - vec2 = array_ops.ones((2, 2)) * 4 - - model_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5) - opt_vars = [ - v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) - if v not in model_vars - ] - - sess.run(tf_variables.global_variables_initializer()) - old_opt_vars = sess.run(opt_vars) - - # Optimizer vars start out at 0. - for opt_var in old_opt_vars: - self.assertAllEqual(sess.run(array_ops.zeros_like(opt_var)), opt_var) - - sess.run(update_op) - new_opt_vars = sess.run(opt_vars) - # After one update, the velocities are equal to the vectors. - for vec, opt_var in zip([vec1, vec2], new_opt_vars): - self.assertAllEqual(sess.run(vec), opt_var) - - sess.run(update_op) - final_opt_vars = sess.run(opt_vars) - for first, second in zip(new_opt_vars, final_opt_vars): - self.assertFalse(np.equal(first, second).all()) - - def testApplyGradients(self): - with ops.Graph().as_default(), self.cached_session() as sess: - layer_collection = lc.LayerCollection() - - inputs = array_ops.ones((2, 1)) * 2 - weights_val = np.ones((1, 1), dtype=np.float32) * 3. - weights = variable_scope.get_variable( - 'w', initializer=array_ops.constant(weights_val)) - bias = variable_scope.get_variable( - 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1)) - output = math_ops.matmul(inputs, weights) + bias - - layer_collection.register_fully_connected((weights, bias), inputs, output) - - logits = math_ops.tanh(output) - targets = array_ops.constant([[0.], [1.]]) - output = math_ops.reduce_mean( - nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets)) - - layer_collection.register_categorical_predictive_distribution(logits) - - opt = optimizer.KfacOptimizer( - 0.1, - 0.2, - 0.3, - layer_collection, - momentum=0.5, - momentum_type='regular') - (cov_update_thunks, - inv_update_thunks) = opt.make_vars_and_create_op_thunks() - cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) - inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) - - grads_and_vars = opt.compute_gradients(output, [weights, bias]) - all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars] - - op = opt.apply_gradients(grads_and_vars) - - sess.run(tf_variables.global_variables_initializer()) - old_vars = sess.run(all_vars) - sess.run(cov_update_ops) - sess.run(inv_update_ops) - sess.run(op) - new_vars = sess.run(all_vars) - - for old_var, new_var in zip(old_vars, new_vars): - self.assertNotEqual(old_var, new_var) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py deleted file mode 100644 index 7df79a3c7fe4fbebe2106e255a1945c0ef36b37b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py +++ /dev/null @@ -1,410 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.kfac.utils.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import numpy.random as npr - -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.contrib.tpu.python.tpu import tpu_function -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -class SequenceDictTest(test.TestCase): - - def testSequenceDictInit(self): - seq_dict = utils.SequenceDict() - self.assertFalse(seq_dict._dict) - - def testSequenceDictInitWithIterable(self): - reg_dict = {'a': 'foo', 'b': 'bar'} - itr = zip(reg_dict.keys(), reg_dict.values()) - seq_dict = utils.SequenceDict(itr) - self.assertEqual(reg_dict, seq_dict._dict) - - def testGetItemSingleKey(self): - seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'}) - self.assertEqual('foo', seq_dict['a']) - - def testGetItemMultipleKeys(self): - seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'}) - self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')]) - - def testSetItemSingleKey(self): - seq_dict = utils.SequenceDict() - seq_dict['a'] = 'foo' - self.assertEqual([('a', 'foo')], seq_dict.items()) - - def testSetItemMultipleKeys(self): - seq_dict = utils.SequenceDict() - keys = ('a', 'b', 'c') - values = ('foo', 'bar', 'baz') - seq_dict[keys] = values - self.assertItemsEqual(list(zip(keys, values)), seq_dict.items()) - - -class SubGraphTest(test.TestCase): - - def testBasicGraph(self): - a = array_ops.constant([[1., 2.], [3., 4.]]) - b = array_ops.constant([[5., 6.], [7., 8.]]) - c = a + b - d = a * b - sub_graph = utils.SubGraph((c,)) - self.assertTrue(sub_graph.is_member(a)) - self.assertTrue(sub_graph.is_member(b)) - self.assertTrue(sub_graph.is_member(c)) - self.assertFalse(sub_graph.is_member(d)) - - def testRepeatedAdds(self): - a = array_ops.constant([[1., 2.], [3., 4.]]) - b = array_ops.constant([[5., 6.], [7., 8.]]) - c = a + b + a # note that a appears twice in this graph - sub_graph = utils.SubGraph((c,)) - self.assertTrue(sub_graph.is_member(a)) - self.assertTrue(sub_graph.is_member(b)) - self.assertTrue(sub_graph.is_member(c)) - - def testFilterList(self): - a = array_ops.constant([[1., 2.], [3., 4.]]) - b = array_ops.constant([[5., 6.], [7., 8.]]) - c = a + b - d = a * b - sub_graph = utils.SubGraph((c,)) - input_list = [b, d] - filtered_list = sub_graph.filter_list(input_list) - self.assertEqual(filtered_list, [b]) - - def testVariableUses(self): - with ops.Graph().as_default(): - var = variable_scope.get_variable('var', shape=[10, 10]) - resource_var = variable_scope.get_variable( - 'resource_var', shape=[10, 10], use_resource=True) - x = array_ops.zeros([3, 10]) - z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var) - z1 = math_ops.matmul(x, resource_var) - sub_graph = utils.SubGraph((z0, z1)) - self.assertEqual(2, sub_graph.variable_uses(var)) - self.assertEqual(1, sub_graph.variable_uses(resource_var)) - - -class UtilsTest(test.TestCase): - - def _fully_connected_layer_params(self): - weights_part = array_ops.constant([[1., 2.], [4., 3.]]) - bias_part = array_ops.constant([1., 2.]) - return (weights_part, bias_part) - - def _conv_layer_params(self): - weights_shape = 2, 2, 3, 4 - biases_shape = weights_shape[-1:] - weights = array_ops.constant(npr.RandomState(0).randn(*weights_shape)) - biases = array_ops.constant(npr.RandomState(1).randn(*biases_shape)) - return (weights, biases) - - def testFullyConnectedLayerParamsTupleToMat2d(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - layer_params = self._fully_connected_layer_params() - output = utils.layer_params_to_mat2d(layer_params) - self.assertListEqual([3, 2], output.get_shape().as_list()) - self.assertAllClose( - sess.run(output), np.array([[1., 2.], [4., 3.], [1., 2.]])) - - def testFullyConnectedLayerParamsTensorToMat2d(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - layer_params = self._fully_connected_layer_params() - output = utils.layer_params_to_mat2d(layer_params[0]) - self.assertListEqual([2, 2], output.get_shape().as_list()) - self.assertAllClose(sess.run(output), np.array([[1., 2.], [4., 3.]])) - - def testConvLayerParamsTupleToMat2d(self): - with ops.Graph().as_default(): - random_seed.set_random_seed(200) - layer_params = self._conv_layer_params() - output = utils.layer_params_to_mat2d(layer_params) - self.assertListEqual([2 * 2 * 3 + 1, 4], output.get_shape().as_list()) - - def testKron(self): - with ops.Graph().as_default(), self.cached_session() as sess: - mat1 = np.array([[1., 2.], [3., 4.]]) - mat2 = np.array([[5., 6.], [7., 8.]]) - mat1_tf = array_ops.constant(mat1) - mat2_tf = array_ops.constant(mat2) - ans_tf = sess.run(utils.kronecker_product(mat1_tf, mat2_tf)) - ans_np = np.kron(mat1, mat2) - self.assertAllClose(ans_tf, ans_np) - - def testMat2dToFullyConnectedLayerParamsTuple(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - vector_template = self._fully_connected_layer_params() - mat2d = array_ops.constant([[5., 4.], [3., 2.], [1., 0.]]) - - output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d)) - - self.assertIsInstance(output, tuple) - self.assertEqual(len(output), 2) - a, b = output - self.assertAllClose(a, np.array([[5., 4.], [3., 2.]])) - self.assertAllClose(b, np.array([1., 0.])) - - def testMat2dToFullyConnectedLayerParamsTensor(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - vector_template = self._fully_connected_layer_params()[0] - mat2d = array_ops.constant([[5., 4.], [3., 2.]]) - - output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d)) - - self.assertAllClose(output, np.array([[5., 4.], [3., 2.]])) - - def testTensorsToColumn(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - - vector = array_ops.constant(np.array([[0., 1.], [2., 3.]])) - output = utils.tensors_to_column(vector) - self.assertListEqual([4, 1], output.get_shape().as_list()) - self.assertAllClose(sess.run(output), np.array([0., 1., 2., 3.])[:, None]) - - vector = self._fully_connected_layer_params() - output = utils.tensors_to_column(vector) - self.assertListEqual([6, 1], output.get_shape().as_list()) - self.assertAllClose( - sess.run(output), np.array([1., 2., 4., 3., 1., 2.])[:, None]) - - vector = list(vector) - vector.append(array_ops.constant([[6.], [7.], [8.], [9.]])) - - output = utils.tensors_to_column(vector) - self.assertListEqual([10, 1], output.get_shape().as_list()) - self.assertAllClose( - sess.run(output), - np.array([1., 2., 4., 3., 1., 2., 6., 7., 8., 9.])[:, None]) - - def testColumnToTensors(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - - vector_template = array_ops.constant(np.array([[0., 1.], [2., 3.]])) - colvec = array_ops.constant(np.arange(4.)[:, None]) - output = sess.run(utils.column_to_tensors(vector_template, colvec)) - self.assertAllClose(output, np.array([[0., 1.], [2., 3.]])) - - vector_template = self._fully_connected_layer_params() - colvec = array_ops.constant(np.arange(6.)[:, None]) - output = sess.run(utils.column_to_tensors(vector_template, colvec)) - - self.assertIsInstance(output, tuple) - self.assertEqual(len(output), 2) - a, b = output - self.assertAllClose(a, np.array([[0., 1.], [2., 3.]])) - self.assertAllClose(b, np.array([4., 5.])) - - vector_template = list(vector_template) - vector_template.append(array_ops.constant([[6.], [7.], [8.], [9.]])) - colvec = array_ops.constant(np.arange(10.)[:, None]) - output = sess.run(utils.column_to_tensors(vector_template, colvec)) - self.assertIsInstance(output, tuple) - self.assertEqual(len(output), 3) - a, b, c = output - self.assertAllClose(a, np.array([[0., 1.], [2., 3.]])) - self.assertAllClose(b, np.array([4., 5.])) - self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]])) - - def testPosDefInvCholesky(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - npr.seed(0) - square = lambda x: np.dot(x, x.T) - - size = 3 - x = square(npr.randn(size, size)) - damp = 0.1 - identity = linalg_ops.eye(size, dtype=dtypes.float64) - - tf_inv = utils.posdef_inv_cholesky(array_ops.constant(x), identity, damp) - np_inv = np.linalg.inv(x + damp * np.eye(size)) - self.assertAllClose(sess.run(tf_inv), np_inv) - - def testPosDefInvMatrixInverse(self): - with ops.Graph().as_default(), self.cached_session() as sess: - random_seed.set_random_seed(200) - npr.seed(0) - square = lambda x: np.dot(x, x.T) - - size = 3 - x = square(npr.randn(size, size)) - damp = 0.1 - identity = linalg_ops.eye(size, dtype=dtypes.float64) - - tf_inv = utils.posdef_inv_matrix_inverse( - array_ops.constant(x), identity, damp) - np_inv = np.linalg.inv(x + damp * np.eye(size)) - self.assertAllClose(sess.run(tf_inv), np_inv) - - def testCrossReplicaMean(self): - """Ensures that cross_replica_mean() executes only when num_shards > 1.""" - with ops.Graph().as_default(): - with tpu_function.tpu_shard_context(4): - tensor = array_ops.zeros([], dtype=dtypes.float32) - mean = utils.cross_replica_mean(tensor) - self.assertNotEqual(mean, tensor) - - with ops.Graph().as_default(): - with tpu_function.tpu_shard_context(1): - tensor = array_ops.zeros([], dtype=dtypes.float32) - mean = utils.cross_replica_mean(tensor) - self.assertEqual(mean, tensor) - - with ops.Graph().as_default(): - with self.assertRaises(ValueError): # Outside of TPU context. - tensor = array_ops.zeros([], dtype=dtypes.float32) - mean = utils.cross_replica_mean(tensor) - - def testBatchExecute(self): - """Ensure batch_execute runs in a round-robin fashion.""" - - def increment_var(var): - return lambda: var.assign_add(1) - - with ops.Graph().as_default(), self.cached_session() as sess: - i = variable_scope.get_variable('i', initializer=0) - accumulators = [ - variable_scope.get_variable('var%d' % j, initializer=0) - for j in range(3) - ] - thunks = [increment_var(var) for var in accumulators] - increment_accumulators = utils.batch_execute(i, thunks, 2) - increment_i = i.assign_add(1) - - sess.run(variables.global_variables_initializer()) - - # Ensure one op per thunk. - self.assertEqual(3, len(increment_accumulators)) - - # Ensure round-robin execution. - values = [] - for _ in range(5): - sess.run(increment_accumulators) - sess.run(increment_i) - values.append(sess.run(accumulators)) - self.assertAllClose( - [ - [1, 1, 0], # - [2, 1, 1], # - [2, 2, 2], # - [3, 3, 2], # - [4, 3, 3] - ], - values) - - def testExtractConvolutionPatches(self): - with ops.Graph().as_default(), self.cached_session() as sess: - batch_size = 10 - image_spatial_shape = [9, 10, 11] - in_channels = out_channels = 32 - kernel_spatial_shape = [5, 3, 3] - spatial_strides = [1, 2, 1] - spatial_dilation = [1, 1, 1] - padding = 'SAME' - - images = random_ops.random_uniform( - [batch_size] + image_spatial_shape + [in_channels], seed=0) - kernel_shape = kernel_spatial_shape + [in_channels, out_channels] - kernel = random_ops.random_uniform(kernel_shape, seed=1) - - # Ensure shape matches expectation. - patches = utils.extract_convolution_patches( - images, - kernel_shape, - padding, - strides=spatial_strides, - dilation_rate=spatial_dilation) - result_spatial_shape = ( - patches.shape.as_list()[1:1 + len(image_spatial_shape)]) - self.assertEqual(patches.shape.as_list(), - [batch_size] + result_spatial_shape + - kernel_spatial_shape + [in_channels]) - - # Ensure extract...patches() + matmul() and convolution() implementation - # give the same answer. - outputs = nn_ops.convolution( - images, - kernel, - padding, - strides=spatial_strides, - dilation_rate=spatial_dilation) - - patches_flat = array_ops.reshape( - patches, [-1, np.prod(kernel_spatial_shape) * in_channels]) - kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) - outputs_flat = math_ops.matmul(patches_flat, kernel_flat) - - outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) - self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) - - def testExtractPointwiseConv2dPatches(self): - with ops.Graph().as_default(), self.cached_session() as sess: - batch_size = 10 - image_height = image_width = 8 - in_channels = out_channels = 3 - kernel_height = kernel_width = 1 - strides = [1, 1, 1, 1] - padding = 'VALID' - - images = random_ops.random_uniform( - [batch_size, image_height, image_width, in_channels], seed=0) - kernel_shape = [kernel_height, kernel_width, in_channels, out_channels] - kernel = random_ops.random_uniform(kernel_shape, seed=1) - - # Ensure shape matches expectation. - patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape) - self.assertEqual(patches.shape.as_list(), [ - batch_size, image_height, image_width, kernel_height, kernel_width, - in_channels - ]) - - # Ensure extract...patches() + matmul() and conv2d() implementation - # give the same answer. - outputs = nn_ops.conv2d(images, kernel, strides, padding) - - patches_flat = array_ops.reshape( - patches, [-1, kernel_height * kernel_width * in_channels]) - kernel_flat = array_ops.reshape(kernel, [-1, out_channels]) - outputs_flat = math_ops.matmul(patches_flat, kernel_flat) - - outputs_, outputs_flat_ = sess.run([outputs, outputs_flat]) - self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten()) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD deleted file mode 100644 index 3c01eb65e7a687d6c477b858b8d91ea7f309dc64..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ /dev/null @@ -1,263 +0,0 @@ -package(default_visibility = [ - "//tensorflow/contrib/kfac:__pkg__", - "//tensorflow/contrib/kfac/python/kernel_tests:__pkg__", -]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -py_library( - name = "fisher_blocks", - srcs = ["fisher_blocks.py"], - srcs_version = "PY2AND3", - deps = [ - ":fisher_factors", - ":utils", - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", - "@six_archive//:six", - ], -) - -py_library( - name = "fisher_blocks_lib", - srcs = ["fisher_blocks_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":fisher_blocks", - "//tensorflow/python:util", - ], -) - -py_library( - name = "fisher_factors", - srcs = ["fisher_factors.py"], - srcs_version = "PY2AND3", - deps = [ - ":linear_operator", - ":utils", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:special_math_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - -py_library( - name = "fisher_factors_lib", - srcs = ["fisher_factors_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":fisher_factors", - "//tensorflow/python:util", - ], -) - -py_library( - name = "linear_operator", - srcs = ["linear_operator.py"], - srcs_version = "PY2AND3", - deps = [ - ":utils", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python/ops/linalg", - "@six_archive//:six", - ], -) - -py_library( - name = "loss_functions", - srcs = ["loss_functions.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/ops/distributions", - "@six_archive//:six", - ], -) - -py_library( - name = "loss_functions_lib", - srcs = ["loss_functions_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":loss_functions", - "//tensorflow/python:util", - ], -) - -py_library( - name = "curvature_matrix_vector_products", - srcs = ["curvature_matrix_vector_products.py"], - srcs_version = "PY2AND3", - deps = [ - ":utils", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:util", - ], -) - -py_library( - name = "curvature_matrix_vector_products_lib", - srcs = ["curvature_matrix_vector_products_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":curvature_matrix_vector_products", - "//tensorflow/python:util", - ], -) - -py_library( - name = "layer_collection", - srcs = ["layer_collection.py"], - srcs_version = "PY2AND3", - deps = [ - ":fisher_blocks", - ":loss_functions", - ":utils", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "@six_archive//:six", - ], -) - -py_library( - name = "layer_collection_lib", - srcs = ["layer_collection_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":layer_collection", - "//tensorflow/python:util", - ], -) - -py_library( - name = "kfac_optimizer", - srcs = [ - "optimizer.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":curvature_matrix_vector_products", - ":fisher_estimator", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], -) - -py_library( - name = "kfac_optimizer_lib", - srcs = [ - "optimizer_lib.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":kfac_optimizer", - "//tensorflow/python:util", - ], -) - -py_library( - name = "fisher_estimator", - srcs = [ - "estimator.py", - "placement.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":utils", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:util", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - -py_library( - name = "fisher_estimator_lib", - srcs = [ - "estimator_lib.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":fisher_estimator", - "//tensorflow/python:util", - ], -) - -py_library( - name = "utils", - srcs = ["utils.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/tpu", - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//third_party/py/numpy", - ], -) - -py_library( - name = "utils_lib", - srcs = ["utils_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":utils", - "//tensorflow/python:util", - ], -) - -py_library( - name = "op_queue", - srcs = ["op_queue.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/python:framework_ops", - ], -) - -py_library( - name = "op_queue_lib", - srcs = ["op_queue_lib.py"], - srcs_version = "PY2AND3", - deps = [ - ":op_queue", - "//tensorflow/python:util", - ], -) diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py deleted file mode 100644 index 21b5cde9b931a95110c9a5fd7930a3a4ee74b207..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Curvature matrix-vector multiplication.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import math_ops -from tensorflow.python.util import nest - - -class CurvatureMatrixVectorProductComputer(object): - """Class for computing matrix-vector products for Fishers, GGNs and Hessians. - - In other words we compute M*v where M is the matrix, v is the vector, and - * refers to standard matrix/vector multiplication (not element-wise - multiplication). - - The matrices are defined in terms of some differential quantity of the total - loss function with respect to a provided list of tensors ("wrt_tensors"). - For example, the Fisher associated with a log-prob loss w.r.t. the - parameters. - - The 'vecs' argument to each method are lists of tensors that must be the - size as the corresponding ones from "wrt_tensors". They represent - the vector being multiplied. - - "factors" of the matrix M are defined as matrices B such that B*B^T = M. - Methods that multiply by the factor B take a 'loss_inner_vecs' argument - instead of 'vecs', which must be a list of tensors with shapes given by the - corresponding XXX_inner_shapes property. - - Note that matrix-vector products are not normalized by the batch size, nor - are any damping terms added to the results. These things can be easily - applied externally, if desired. - - See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf - and https://arxiv.org/abs/1412.1193 for more information about the - generalized Gauss-Newton, Fisher, etc., and how to compute matrix-vector - products. - """ - - def __init__(self, losses, wrt_tensors): - """Create a CurvatureMatrixVectorProductComputer object. - - Args: - losses: A list of LossFunction instances whose sum defines the total loss. - wrt_tensors: A list of Tensors to compute the differential quantities - (defining the matrices) with respect to. See class description for more - info. - """ - self._losses = losses - self._inputs_to_losses = list(loss.inputs for loss in losses) - self._inputs_to_losses_flat = nest.flatten(self._inputs_to_losses) - self._wrt_tensors = wrt_tensors - - @property - def _total_loss(self): - return math_ops.add_n(tuple(loss.evaluate() for loss in self._losses)) - - # Jacobian multiplication functions: - def _multiply_jacobian(self, vecs): - """Multiply vecs by the Jacobian of losses.""" - # We stop gradients at wrt_tensors to produce partial derivatives (which is - # what we want for Jacobians). - jacobian_vecs_flat = utils.fwd_gradients( - self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs, - stop_gradients=self._wrt_tensors) - return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat) - - def _multiply_jacobian_transpose(self, loss_vecs): - """Multiply vecs by the transpose Jacobian of losses.""" - loss_vecs_flat = nest.flatten(loss_vecs) - # We stop gradients at wrt_tensors to produce partial derivatives (which is - # what we want for Jacobians). - return gradients_impl.gradients( - self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat, - stop_gradients=self._wrt_tensors) - - # Losses Fisher/Hessian multiplication functions: - def _multiply_loss_fisher(self, loss_vecs): - """Multiply loss_vecs by Fisher of total loss.""" - return tuple( - loss.multiply_fisher(loss_vec) - for loss, loss_vec in zip(self._losses, loss_vecs)) - - def _multiply_loss_fisher_factor(self, loss_inner_vecs): - """Multiply loss_inner_vecs by factor of Fisher of total loss.""" - return tuple( - loss.multiply_fisher_factor(loss_vec) - for loss, loss_vec in zip(self._losses, loss_inner_vecs)) - - def _multiply_loss_fisher_factor_transpose(self, loss_vecs): - """Multiply loss_vecs by transpose factor of Fisher of total loss.""" - return tuple( - loss.multiply_fisher_factor_transpose(loss_vec) - for loss, loss_vec in zip(self._losses, loss_vecs)) - - def _multiply_loss_hessian(self, loss_vecs): - """Multiply loss_vecs by Hessian of total loss.""" - return tuple( - loss.multiply_hessian(loss_vec) - for loss, loss_vec in zip(self._losses, loss_vecs)) - - def _multiply_loss_hessian_factor(self, loss_inner_vecs): - """Multiply loss_inner_vecs by factor of Hessian of total loss.""" - return tuple( - loss.multiply_hessian_factor(loss_vec) - for loss, loss_vec in zip(self._losses, loss_inner_vecs)) - - def _multiply_loss_hessian_factor_transpose(self, loss_vecs): - """Multiply loss_vecs by transpose factor of Hessian of total loss.""" - return tuple( - loss.multiply_hessian_factor_transpose(loss_vec) - for loss, loss_vec in zip(self._losses, loss_vecs)) - - # Matrix-vector product functions: - def multiply_fisher(self, vecs): - """Multiply vecs by Fisher of total loss.""" - jacobian_vecs = self._multiply_jacobian(vecs) - loss_fisher_jacobian_vecs = self._multiply_loss_fisher(jacobian_vecs) - return self._multiply_jacobian_transpose(loss_fisher_jacobian_vecs) - - def multiply_fisher_factor_transpose(self, vecs): - """Multiply vecs by transpose of factor of Fisher of total loss.""" - jacobian_vecs = self._multiply_jacobian(vecs) - return self._multiply_loss_fisher_factor_transpose(jacobian_vecs) - - def multiply_fisher_factor(self, loss_inner_vecs): - """Multiply loss_inner_vecs by factor of Fisher of total loss.""" - fisher_factor_transpose_vecs = self._multiply_loss_fisher_factor_transpose( - loss_inner_vecs) - return self._multiply_jacobian_transpose(fisher_factor_transpose_vecs) - - def multiply_hessian(self, vecs): - """Multiply vecs by Hessian of total loss.""" - return gradients_impl.gradients( - gradients_impl.gradients(self._total_loss, self._wrt_tensors), - self._wrt_tensors, - grad_ys=vecs) - - def multiply_generalized_gauss_newton(self, vecs): - """Multiply vecs by generalized Gauss-Newton of total loss.""" - jacobian_vecs = self._multiply_jacobian(vecs) - loss_hessian_jacobian_vecs = self._multiply_loss_hessian(jacobian_vecs) - return self._multiply_jacobian_transpose(loss_hessian_jacobian_vecs) - - def multiply_generalized_gauss_newton_factor_transpose(self, vecs): - """Multiply vecs by transpose of factor of GGN of total loss.""" - jacobian_vecs = self._multiply_jacobian(vecs) - return self._multiply_loss_hessian_factor_transpose(jacobian_vecs) - - def multiply_generalized_gauss_newton_factor(self, loss_inner_vecs): - """Multiply loss_inner_vecs by factor of GGN of total loss.""" - hessian_factor_transpose_vecs = ( - self._multiply_loss_hessian_factor_transpose(loss_inner_vecs)) - return self._multiply_jacobian_transpose(hessian_factor_transpose_vecs) - - # Shape properties for multiply_XXX_factor methods: - @property - def fisher_factor_inner_shapes(self): - """Shapes required by multiply_fisher_factor.""" - return tuple(loss.fisher_factor_inner_shape for loss in self._losses) - - @property - def generalized_gauss_newton_factor_inner_shapes(self): - """Shapes required by multiply_generalized_gauss_newton_factor.""" - return tuple(loss.hessian_factor_inner_shape for loss in self._losses) diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py deleted file mode 100644 index 6e8c6404dcba0970785a2c8358cb4e2356e45b0e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Curvature matrix-vector multiplication.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.curvature_matrix_vector_products import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - 'CurvatureMatrixVectorProductComputer', -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py deleted file mode 100644 index 323234c40316757b8bc33564ba8a13b07c8858e0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ /dev/null @@ -1,516 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Defines the high-level Fisher estimator class.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import abc -import numpy as np -import six - -from tensorflow.contrib.kfac.python.ops import placement -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import variable_scope -from tensorflow.python.util import nest - - -# The linter is confused. -# pylint: disable=abstract-class-instantiated -def make_fisher_estimator(placement_strategy=None, **kwargs): - """Creates Fisher estimator instances based on the placement strategy. - - For example if the `placement_strategy` is 'round_robin' then - `FisherEstimatorRoundRobin` instance is returned. - - Args: - placement_strategy: `string`, Strategy to be used for placing covariance - variables, covariance ops and inverse ops. Check - `placement.FisherEstimatorRoundRobin` for a concrete example. - **kwargs: Arguments to be passed into `FisherEstimator` class initializer. - - Returns: - An instance of class which inherits from `FisherEstimator` and the mixin - which implements specific placement strategy. See, - `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and - `RoundRobinPlacementMixin`. - - Raises: - ValueError: If the `placement_strategy` is not equal to 'round_robin'. - """ - if placement_strategy in [None, "round_robin"]: - return FisherEstimatorRoundRobin(**kwargs) - else: - raise ValueError("Unimplemented vars and ops " - "placement strategy : {}".format(placement_strategy)) -# pylint: enable=abstract-class-instantiated - - -@six.add_metaclass(abc.ABCMeta) -class FisherEstimator(object): - """Fisher estimator class supporting various approximations of the Fisher. - - This is an abstract base class which does not implement a strategy for - placing covariance variables, covariance update ops and inverse update ops. - The placement strategies are implemented in `placement.py`. See - `FisherEstimatorRoundRobin` for example of a concrete subclass with - a round-robin placement strategy. - """ - - def __init__(self, - variables, - cov_ema_decay, - damping, - layer_collection, - exps=(-1,), - estimation_mode="gradients", - colocate_gradients_with_ops=True, - name="FisherEstimator", - compute_cholesky=False, - compute_cholesky_inverse=False): - """Create a FisherEstimator object. - - Args: - variables: A `list` of variables or `callable` which returns the variables - for which to estimate the Fisher. This must match the variables - registered in layer_collection (if it is not None). - cov_ema_decay: The decay factor used when calculating the covariance - estimate moving averages. - damping: float. The damping factor used to stabilize training due to - errors in the local approximation with the Fisher information matrix, - and to regularize the update direction by making it closer to the - gradient. (Higher damping means the update looks more like a standard - gradient update - see Tikhonov regularization.) - layer_collection: The layer collection object, which holds the Fisher - blocks, Kronecker factors, and losses associated with the - graph. - exps: List of floats or ints. These represent the different matrix - powers of the approximate Fisher that the FisherEstimator will be able - to multiply vectors by. If the user asks for a matrix power other - one of these (or 1, which is always supported), there will be a - failure. (Default: (-1,)) - estimation_mode: The type of estimator to use for the Fishers. Can be - 'gradients', 'empirical', 'curvature_prop', or 'exact'. - (Default: 'gradients'). 'gradients' is the basic estimation approach - from the original K-FAC paper. 'empirical' computes the 'empirical' - Fisher information matrix (which uses the data's distribution for the - targets, as opposed to the true Fisher which uses the model's - distribution) and requires that each registered loss have specified - targets. 'curvature_propagation' is a method which estimates the - Fisher using self-products of random 1/-1 vectors times "half-factors" - of the Fisher, as described here: https://arxiv.org/abs/1206.6464 . - Finally, 'exact' is the obvious generalization of Curvature - Propagation to compute the exact Fisher (modulo any additional - diagonal or Kronecker approximations) by looping over one-hot vectors - for each coordinate of the output instead of using 1/-1 vectors. It - is more expensive to compute than the other three options by a factor - equal to the output dimension, roughly speaking. - colocate_gradients_with_ops: Whether we should request gradients be - colocated with their respective ops. (Default: True) - name: A string. A name given to this estimator, which is added to the - variable scope when constructing variables and ops. - (Default: "FisherEstimator") - compute_cholesky: Bool. Whether or not the FisherEstimator will be - able to multiply vectors by the Cholesky factor. - (Default: False) - compute_cholesky_inverse: Bool. Whether or not the FisherEstimator - will be able to multiply vectors by the Cholesky factor inverse. - (Default: False) - Raises: - ValueError: If no losses have been registered with layer_collection. - """ - self._variables = variables - self._cov_ema_decay = cov_ema_decay - self._damping = damping - self._estimation_mode = estimation_mode - self._layers = layer_collection - self._gradient_fns = { - "gradients": self._get_grads_lists_gradients, - "empirical": self._get_grads_lists_empirical, - "curvature_prop": self._get_grads_lists_curvature_prop, - "exact": self._get_grads_lists_exact - } - self._colocate_gradients_with_ops = colocate_gradients_with_ops - - self._made_vars = False - self._exps = exps - self._compute_cholesky = compute_cholesky - self._compute_cholesky_inverse = compute_cholesky_inverse - - self._name = name - - @property - def variables(self): - if callable(self._variables): - return self._variables() - else: - return self._variables - - @property - def damping(self): - return self._damping - - @property - def blocks(self): - """All registered FisherBlocks.""" - return self._layers.get_blocks() - - @property - def factors(self): - """All registered FisherFactors.""" - return self._layers.get_factors() - - @property - def name(self): - return self._name - - @abc.abstractmethod - def make_vars_and_create_op_thunks(self, scope=None): - """Make vars and create op thunks with a specific placement strategy. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. A new device is chosen - for each factor by cycling through list of devices in the cov_devices - argument. If cov_devices is None then no explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the inv_devices argument. - - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All variables will be created, - and all thunks will execute, inside of a variable scope of the given - name. (Default: None) - - Returns: - cov_update_thunks: List of cov update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - inv_update_thunks: List of inv update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - """ - pass - - def _apply_transformation(self, vecs_and_vars, transform): - """Applies an block-wise transformation to the corresponding vectors. - - Args: - vecs_and_vars: List of (vector, variable) pairs. - transform: A function of the form f(fb, vec), where vec is the vector - to transform and fb is its corresponding block in the matrix, that - returns the transformed vector. - - Returns: - A list of (transformed vector, var) pairs in the same order as - vecs_and_vars. - """ - - vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars) - - trans_vecs = utils.SequenceDict() - - for params, fb in self._layers.fisher_blocks.items(): - trans_vecs[params] = transform(fb, vecs[params]) - - return [(trans_vecs[var], var) for _, var in vecs_and_vars] - - def multiply_inverse(self, vecs_and_vars): - """Multiplies the vecs by the corresponding (damped) inverses of the blocks. - - Args: - vecs_and_vars: List of (vector, variable) pairs. - - Returns: - A list of (transformed vector, var) pairs in the same order as - vecs_and_vars. - """ - return self.multiply_matpower(-1, vecs_and_vars) - - def multiply(self, vecs_and_vars): - """Multiplies the vectors by the corresponding (damped) blocks. - - Args: - vecs_and_vars: List of (vector, variable) pairs. - - Returns: - A list of (transformed vector, var) pairs in the same order as - vecs_and_vars. - """ - return self.multiply_matpower(1, vecs_and_vars) - - def multiply_matpower(self, exp, vecs_and_vars): - """Multiplies the vecs by the corresponding matrix powers of the blocks. - - Args: - exp: A float representing the power to raise the blocks by before - multiplying it by the vector. - vecs_and_vars: List of (vector, variable) pairs. - - Returns: - A list of (transformed vector, var) pairs in the same order as - vecs_and_vars. - """ - assert exp in self._exps - - fcn = lambda fb, vec: fb.multiply_matpower(vec, exp) - return self._apply_transformation(vecs_and_vars, fcn) - - def multiply_cholesky(self, vecs_and_vars, transpose=False): - """Multiplies the vecs by the corresponding Cholesky factors. - - Args: - vecs_and_vars: List of (vector, variable) pairs. - transpose: Bool. If true the Cholesky factors are transposed before - multiplying the vecs. (Default: False) - - Returns: - A list of (transformed vector, var) pairs in the same order as - vecs_and_vars. - """ - assert self._compute_cholesky - - fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose) - return self._apply_transformation(vecs_and_vars, fcn) - - def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False): - """Mults the vecs by the inverses of the corresponding Cholesky factors. - - Note: if you are using Cholesky inverse multiplication to sample from - a matrix-variate Gaussian you will want to multiply by the transpose. - Let L be the Cholesky factor of F and observe that - - L^-T * L^-1 = (L * L^T)^-1 = F^-1 . - - Thus we want to multiply by L^-T in order to sample from Gaussian with - covariance F^-1. - - Args: - vecs_and_vars: List of (vector, variable) pairs. - transpose: Bool. If true the Cholesky factor inverses are transposed - before multiplying the vecs. (Default: False) - - Returns: - A list of (transformed vector, var) pairs in the same order as - vecs_and_vars. - """ - assert self._compute_cholesky_inverse - - fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose) - return self._apply_transformation(vecs_and_vars, fcn) - - def _instantiate_factors(self): - """Instantiates FisherFactors' variables. - - Raises: - ValueError: If estimation_mode was improperly specified at construction. - """ - blocks = self.blocks - tensors_to_compute_grads = [ - block.tensors_to_compute_grads() for block in blocks - ] - - try: - grads_lists = self._gradient_fns[self._estimation_mode]( - tensors_to_compute_grads) - except KeyError: - raise ValueError("Unrecognized value {} for estimation_mode.".format( - self._estimation_mode)) - - for grads_list, block in zip(grads_lists, blocks): - block.instantiate_factors(grads_list, self.damping) - - def _check_vars_unmade_and_set_made_flag(self): - if self._made_vars: - raise Exception("Already made variables.") - self._made_vars = True - - def made_vars(self): - return self._made_vars - - def _register_matrix_functions(self): - for block in self.blocks: - for exp in self._exps: - block.register_matpower(exp) - if self._compute_cholesky: - block.register_cholesky() - if self._compute_cholesky_inverse: - block.register_cholesky_inverse() - - def _finalize_layer_collection(self): - self._layers.create_subgraph() - self._layers.check_registration(self.variables) - self._instantiate_factors() - self._register_matrix_functions() - - def create_ops_and_vars_thunks(self, scope=None): - """Create thunks that make the ops and vars on demand. - - This function returns 4 lists of thunks: cov_variable_thunks, - cov_update_thunks, inv_variable_thunks, and inv_update_thunks. - - The length of each list is the number of factors and the i-th element of - each list corresponds to the i-th factor (given by the "factors" property). - - Note that the execution of these thunks must happen in a certain - partial order. The i-th element of cov_variable_thunks must execute - before the i-th element of cov_update_thunks (and also the i-th element - of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks - must execute before the i-th element of inv_update_thunks. - - TL;DR (oversimplified): Execute the thunks according to the order that - they are returned. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All thunks will execute inside - of a variable scope of the given name. (Default: None) - Returns: - cov_variable_thunks: A list of thunks that make the cov variables. - cov_update_thunks: A list of thunks that make the cov update ops. - inv_variable_thunks: A list of thunks that make the inv variables. - inv_update_thunks: A list of thunks that make the inv update ops. - """ - self._check_vars_unmade_and_set_made_flag() - - self._finalize_layer_collection() - - scope = self.name if scope is None else scope - - cov_variable_thunks = [ - self._create_cov_variable_thunk(factor, scope) - for factor in self.factors - ] - cov_update_thunks = [ - self._create_cov_update_thunk(factor, scope) for factor in self.factors - ] - inv_variable_thunks = [ - self._create_inv_variable_thunk(factor, scope) - for factor in self.factors - ] - inv_update_thunks = [ - self._create_inv_update_thunk(factor, scope) for factor in self.factors - ] - - return (cov_variable_thunks, cov_update_thunks, - inv_variable_thunks, inv_update_thunks) - - def _create_cov_variable_thunk(self, factor, scope): - """Constructs a covariance variable thunk for a single FisherFactor.""" - - def thunk(): - with variable_scope.variable_scope(scope): - return factor.instantiate_cov_variables() - - return thunk - - def _create_cov_update_thunk(self, factor, scope): - """Constructs a covariance update thunk for a single FisherFactor.""" - - def thunk(): - with variable_scope.variable_scope(scope): - return factor.make_covariance_update_op(self._cov_ema_decay) - - return thunk - - def _create_inv_variable_thunk(self, factor, scope): - """Constructs a inverse variable thunk for a single FisherFactor.""" - - def thunk(): - with variable_scope.variable_scope(scope): - return factor.instantiate_inv_variables() - - return thunk - - def _create_inv_update_thunk(self, factor, scope): - """Constructs an inverse update thunk for a single FisherFactor.""" - - def thunk(): - with variable_scope.variable_scope(scope): - return control_flow_ops.group(factor.make_inverse_update_ops()) - - return thunk - - def _get_grads_lists_gradients(self, tensors): - # Passing in a list of loss values is better than passing in the sum as - # the latter creates unnessesary ops on the default device - grads_flat = gradients_impl.gradients( - self._layers.eval_losses_on_samples(), - nest.flatten(tensors), - colocate_gradients_with_ops=self._colocate_gradients_with_ops) - grads_all = nest.pack_sequence_as(tensors, grads_flat) - return tuple((grad,) for grad in grads_all) - - def _get_grads_lists_empirical(self, tensors): - # Passing in a list of loss values is better than passing in the sum as - # the latter creates unnecessary ops on the default device - grads_flat = gradients_impl.gradients( - self._layers.eval_losses(), - nest.flatten(tensors), - colocate_gradients_with_ops=self._colocate_gradients_with_ops) - grads_all = nest.pack_sequence_as(tensors, grads_flat) - return tuple((grad,) for grad in grads_all) - - def _get_transformed_random_signs(self): - transformed_random_signs = [] - for loss in self._layers.losses: - with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]): - transformed_random_signs.append( - loss.multiply_fisher_factor( - utils.generate_random_signs(loss.fisher_factor_inner_shape))) - return transformed_random_signs - - def _get_grads_lists_curvature_prop(self, tensors): - loss_inputs = list(loss.inputs for loss in self._layers.losses) - transformed_random_signs = self._get_transformed_random_signs() - grads_flat = gradients_impl.gradients( - nest.flatten(loss_inputs), - nest.flatten(tensors), - grad_ys=nest.flatten(transformed_random_signs), - colocate_gradients_with_ops=self._colocate_gradients_with_ops) - grads_all = nest.pack_sequence_as(tensors, grads_flat) - return tuple((grad,) for grad in grads_all) - - def _get_grads_lists_exact(self, tensors): - """No docstring required.""" - # Loop over all coordinates of all losses. - grads_all = [] - for loss in self._layers.losses: - with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]): - for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): - transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( - index) - grads_flat = gradients_impl.gradients( - loss.inputs, - nest.flatten(tensors), - grad_ys=transformed_one_hot, - colocate_gradients_with_ops=self._colocate_gradients_with_ops) - grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) - return zip(*grads_all) - - -class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin, - FisherEstimator): - """Fisher estimator which provides round robin device placement strategy.""" - pass diff --git a/tensorflow/contrib/kfac/python/ops/estimator_lib.py b/tensorflow/contrib/kfac/python/ops/estimator_lib.py deleted file mode 100644 index 9c9fef471f8033bec53ceb1e4f073dd921cbe3c7..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/estimator_lib.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Defines the high-level Fisher estimator class.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.estimator import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - 'FisherEstimator', - 'make_fisher_estimator', -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py deleted file mode 100644 index 9fa6eb7dcd12d7c6474d176198c1e47f1ec6fd4c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ /dev/null @@ -1,1752 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""FisherBlock definitions. - -This library contains classes for estimating blocks in a model's Fisher -Information matrix. Suppose one has a model that parameterizes a posterior -distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its -Fisher Information matrix is given by, - - $$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$ - -where, - - $$v(x, y, params) = (d / d params) log p(y | x, params)$$ - -and the expectation is taken with respect to the data's distribution for 'x' and -the model's posterior distribution for 'y', - - x ~ p(x) - y ~ p(y | x, params) - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import abc -import enum # pylint: disable=g-bad-import-order - -import numpy as np -import six - -from tensorflow.contrib.kfac.python.ops import fisher_factors -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.util import nest - -# For blocks corresponding to convolutional layers, or any type of block where -# the parameters can be thought of as being replicated in time or space, -# we want to adjust the scale of the damping by -# damping /= num_replications ** NORMALIZE_DAMPING_POWER -NORMALIZE_DAMPING_POWER = 1.0 - -# Methods for adjusting damping for FisherBlocks. See -# compute_pi_adjusted_damping() for details. -PI_OFF_NAME = "off" -PI_TRACENORM_NAME = "tracenorm" -PI_TYPE = PI_TRACENORM_NAME - - -def set_global_constants(normalize_damping_power=None, pi_type=None): - """Sets various global constants used by the classes in this module.""" - global NORMALIZE_DAMPING_POWER - global PI_TYPE - - if normalize_damping_power is not None: - NORMALIZE_DAMPING_POWER = normalize_damping_power - - if pi_type is not None: - PI_TYPE = pi_type - - -def normalize_damping(damping, num_replications): - """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER.""" - if NORMALIZE_DAMPING_POWER: - return damping / (num_replications ** NORMALIZE_DAMPING_POWER) - return damping - - -def compute_pi_tracenorm(left_cov, right_cov): - r"""Computes the scalar constant pi for Tikhonov regularization/damping. - - $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$ - See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. - - Args: - left_cov: A LinearOperator object. The left Kronecker factor "covariance". - right_cov: A LinearOperator object. The right Kronecker factor "covariance". - - Returns: - The computed scalar constant pi for these Kronecker Factors (as a Tensor). - """ - # Instead of dividing by the dim of the norm, we multiply by the dim of the - # other norm. This works out the same in the ratio. - left_norm = left_cov.trace() * int(right_cov.domain_dimension) - right_norm = right_cov.trace() * int(left_cov.domain_dimension) - return math_ops.sqrt(left_norm / right_norm) - - -def compute_pi_adjusted_damping(left_cov, right_cov, damping): - - if PI_TYPE == PI_TRACENORM_NAME: - pi = compute_pi_tracenorm(left_cov, right_cov) - return (damping * pi, damping / pi) - - elif PI_TYPE == PI_OFF_NAME: - return (damping, damping) - - -class PackagedFunc(object): - """A Python thunk with a stable ID. - - Enables stable names for lambdas. - """ - - def __init__(self, func, func_id): - """Initializes PackagedFunc. - - Args: - func: a zero-arg Python function. - func_id: a hashable, function that produces a hashable, or a list/tuple - thereof. - """ - self._func = func - func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,) - self._func_id = func_id - - def __call__(self): - return self._func() - - @property - def func_id(self): - """A hashable identifier for this function.""" - return tuple(elt() if callable(elt) else elt for elt in self._func_id) - - -def _package_func(func, func_id): - return PackagedFunc(func, func_id) - - -@six.add_metaclass(abc.ABCMeta) -class FisherBlock(object): - """Abstract base class for objects modeling approximate Fisher matrix blocks. - - Subclasses must implement register_matpower, multiply_matpower, - instantiate_factors, tensors_to_compute_grads, and num_registered_towers - methods. - """ - - def __init__(self, layer_collection): - self._layer_collection = layer_collection - - @abc.abstractmethod - def instantiate_factors(self, grads_list, damping): - """Creates and registers the component factors of this Fisher block. - - Args: - grads_list: A list gradients (each a Tensor or tuple of Tensors) with - respect to the tensors returned by tensors_to_compute_grads() that - are to be used to estimate the block. - damping: The damping factor (float or Tensor). - """ - pass - - @abc.abstractmethod - def register_matpower(self, exp): - """Registers a matrix power to be computed by the block. - - Args: - exp: A float representing the power to raise the block by. - """ - pass - - @abc.abstractmethod - def register_cholesky(self): - """Registers a Cholesky factor to be computed by the block.""" - pass - - @abc.abstractmethod - def register_cholesky_inverse(self): - """Registers an inverse Cholesky factor to be computed by the block.""" - pass - - def register_inverse(self): - """Registers a matrix inverse to be computed by the block.""" - self.register_matpower(-1) - - @abc.abstractmethod - def multiply_matpower(self, vector, exp): - """Multiplies the vector by the (damped) matrix-power of the block. - - Args: - vector: The vector (a Tensor or tuple of Tensors) to be multiplied. - exp: A float representing the power to raise the block by before - multiplying it by the vector. - - Returns: - The vector left-multiplied by the (damped) matrix-power of the block. - """ - pass - - def multiply_inverse(self, vector): - """Multiplies the vector by the (damped) inverse of the block. - - Args: - vector: The vector (a Tensor or tuple of Tensors) to be multiplied. - - Returns: - The vector left-multiplied by the (damped) inverse of the block. - """ - return self.multiply_matpower(vector, -1) - - def multiply(self, vector): - """Multiplies the vector by the (damped) block. - - Args: - vector: The vector (a Tensor or tuple of Tensors) to be multiplied. - - Returns: - The vector left-multiplied by the (damped) block. - """ - return self.multiply_matpower(vector, 1) - - @abc.abstractmethod - def multiply_cholesky(self, vector, transpose=False): - """Multiplies the vector by the (damped) Cholesky-factor of the block. - - Args: - vector: The vector (a Tensor or tuple of Tensors) to be multiplied. - transpose: Bool. If true the Cholesky factor is transposed before - multiplying the vector. (Default: False) - - Returns: - The vector left-multiplied by the (damped) Cholesky-factor of the block. - """ - pass - - @abc.abstractmethod - def multiply_cholesky_inverse(self, vector, transpose=False): - """Multiplies vector by the (damped) inverse Cholesky-factor of the block. - - Args: - vector: The vector (a Tensor or tuple of Tensors) to be multiplied. - transpose: Bool. If true the Cholesky factor inverse is transposed - before multiplying the vector. (Default: False) - Returns: - Vector left-multiplied by (damped) inverse Cholesky-factor of the block. - """ - pass - - @abc.abstractmethod - def tensors_to_compute_grads(self): - """Returns the Tensor(s) with respect to which this FisherBlock needs grads. - """ - pass - - @abc.abstractproperty - def num_registered_towers(self): - """Number of towers registered for this FisherBlock. - - Typically equal to the number of towers in a multi-tower setup. - """ - pass - - -class FullFB(FisherBlock): - """FisherBlock using a full matrix estimate (no approximations). - - FullFB uses a full matrix estimate (no approximations), and should only ever - be used for very low dimensional parameters. - - Note that this uses the naive "square the sum estimator", and so is applicable - to any type of parameter in principle, but has very high variance. - """ - - def __init__(self, layer_collection, params): - """Creates a FullFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: The parameters of this layer (Tensor or tuple of Tensors). - """ - self._batch_sizes = [] - self._params = params - - super(FullFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - self._damping_func = _package_func(lambda: damping, (damping,)) - - self._factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullFactor, (grads_list, self._batch_size)) - - def register_matpower(self, exp): - self._factor.register_matpower(exp, self._damping_func) - - def register_cholesky(self): - self._factor.register_cholesky(self._damping_func) - - def register_cholesky_inverse(self): - self._factor.register_cholesky_inverse(self._damping_func) - - def _multiply_matrix(self, matrix, vector, transpose=False): - vector_flat = utils.tensors_to_column(vector) - out_flat = matrix.matmul(vector_flat, adjoint=transpose) - return utils.column_to_tensors(vector, out_flat) - - def multiply_matpower(self, vector, exp): - matrix = self._factor.get_matpower(exp, self._damping_func) - return self._multiply_matrix(matrix, vector) - - def multiply_cholesky(self, vector, transpose=False): - matrix = self._factor.get_cholesky(self._damping_func) - return self._multiply_matrix(matrix, vector, transpose=transpose) - - def multiply_cholesky_inverse(self, vector, transpose=False): - matrix = self._factor.get_cholesky_inverse(self._damping_func) - return self._multiply_matrix(matrix, vector, transpose=transpose) - - def full_fisher_block(self): - """Explicitly constructs the full Fisher block.""" - return self._factor.get_cov_as_linear_operator().to_dense() - - def tensors_to_compute_grads(self): - return self._params - - def register_additional_tower(self, batch_size): - """Register an additional tower. - - Args: - batch_size: The batch size, used in the covariance estimator. - """ - self._batch_sizes.append(batch_size) - - @property - def num_registered_towers(self): - return len(self._batch_sizes) - - @property - def _batch_size(self): - return math_ops.reduce_sum(self._batch_sizes) - - -@six.add_metaclass(abc.ABCMeta) -class DiagonalFB(FisherBlock): - """A base class for FisherBlocks that use diagonal approximations.""" - - def register_matpower(self, exp): - # Not needed for this. Matrix powers are computed on demand in the - # diagonal case - pass - - def register_cholesky(self): - # Not needed for this. Cholesky's are computed on demand in the - # diagonal case - pass - - def register_cholesky_inverse(self): - # Not needed for this. Cholesky inverses's are computed on demand in the - # diagonal case - pass - - def _multiply_matrix(self, matrix, vector): - vector_flat = utils.tensors_to_column(vector) - out_flat = matrix.matmul(vector_flat) - return utils.column_to_tensors(vector, out_flat) - - def multiply_matpower(self, vector, exp): - matrix = self._factor.get_matpower(exp, self._damping_func) - return self._multiply_matrix(matrix, vector) - - def multiply_cholesky(self, vector, transpose=False): - matrix = self._factor.get_cholesky(self._damping_func) - return self._multiply_matrix(matrix, vector) - - def multiply_cholesky_inverse(self, vector, transpose=False): - matrix = self._factor.get_cholesky_inverse(self._damping_func) - return self._multiply_matrix(matrix, vector) - - def full_fisher_block(self): - return self._factor.get_cov_as_linear_operator().to_dense() - - -class NaiveDiagonalFB(DiagonalFB): - """FisherBlock using a diagonal matrix approximation. - - This type of approximation is generically applicable but quite primitive. - - Note that this uses the naive "square the sum estimator", and so is applicable - to any type of parameter in principle, but has very high variance. - """ - - def __init__(self, layer_collection, params): - """Creates a NaiveDiagonalFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: The parameters of this layer (Tensor or tuple of Tensors). - """ - self._params = params - self._batch_sizes = [] - - super(NaiveDiagonalFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - self._damping_func = _package_func(lambda: damping, (damping,)) - - self._factor = self._layer_collection.make_or_get_factor( - fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size)) - - def tensors_to_compute_grads(self): - return self._params - - def register_additional_tower(self, batch_size): - """Register an additional tower. - - Args: - batch_size: The batch size, used in the covariance estimator. - """ - self._batch_sizes.append(batch_size) - - @property - def num_registered_towers(self): - return len(self._batch_sizes) - - @property - def _batch_size(self): - return math_ops.reduce_sum(self._batch_sizes) - - -class InputOutputMultiTower(object): - """Mix-in class for blocks with inputs & outputs and multiple mini-batches.""" - - def __init__(self, *args, **kwargs): - self.__inputs = [] - self.__outputs = [] - super(InputOutputMultiTower, self).__init__(*args, **kwargs) - - def _process_data(self, grads_list): - """Process data into the format used by the factors. - - This function takes inputs and grads_lists data and processes it into - one of the formats expected by the FisherFactor classes (depending on - the value of the global configuration variable TOWER_STRATEGY). - - The initial format of self._inputs is expected to be a list of Tensors - over towers. Similarly grads_lists is expected to be a list over sources - of such lists. - - If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single - tensor (represented as a PartitionedTensor object) equal to the - concatenation (across towers) of all of the elements of self._inputs. And - similarly grads_list is formatted into a tuple (over sources) of such - tensors (also represented as PartitionedTensors). - - If TOWER_STRATEGY is "separate", formatting of inputs and grads_list - remains unchanged from the initial format (although possibly converting - from lists into tuples). - - Args: - grads_list: grads_list in its initial format (see above). - - Returns: - inputs: self._inputs transformed into the appropriate format (see - above). - grads_list: grads_list transformed into the appropriate format (see - above). - - Raises: - ValueError: if TOWER_STRATEGY is not one of "separate" or "concat". - """ - inputs = self._inputs - # inputs is a list over towers of Tensors - # grads_list is a list of list with the first index being sources and the - # second being towers. - if fisher_factors.TOWER_STRATEGY == "concat": - # Merge towers together into a PartitionedTensor. We package it in - # a singleton tuple since the factors will expect a list over towers - inputs = (utils.PartitionedTensor(inputs),) - # Do the same for grads_list but preserve leading sources dimension - grads_list = tuple((utils.PartitionedTensor(grads),) - for grads in grads_list) - elif fisher_factors.TOWER_STRATEGY == "separate": - inputs = tuple(inputs) - grads_list = tuple(grads_list) - - else: - raise ValueError("Global config variable TOWER_STRATEGY must be one of " - "'concat' or 'separate'.") - - return inputs, grads_list - - def tensors_to_compute_grads(self): - """Tensors to compute derivative of loss with respect to.""" - return tuple(self._outputs) - - def register_additional_tower(self, inputs, outputs): - self._inputs.append(inputs) - self._outputs.append(outputs) - - @property - def num_registered_towers(self): - result = len(self._inputs) - assert result == len(self._outputs) - return result - - @property - def _inputs(self): - return self.__inputs - - @property - def _outputs(self): - return self.__outputs - - -class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB): - """FisherBlock for fully-connected (dense) layers using a diagonal approx. - - Estimates the Fisher Information matrix's diagonal entries for a fully - connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of - squares" estimator. - - Let 'params' be a vector parameterizing a model and 'i' an arbitrary index - into it. We are interested in Fisher(params)[i, i]. This is, - - $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] - = E[ v(x, y, params)[i] ^ 2 ]$$ - - Consider fully connected layer in this model with (unshared) weight matrix - 'w'. For an example 'x' that produces layer inputs 'a' and output - preactivations 's', - - $$v(x, y, w) = vec( a (d loss / d s)^T )$$ - - This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding - to the layer's parameters 'w'. - """ - - def __init__(self, layer_collection, has_bias=False): - """Creates a FullyConnectedDiagonalFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - has_bias: Whether the component Kronecker factors have an additive bias. - (Default: False) - """ - self._has_bias = has_bias - - super(FullyConnectedDiagonalFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - self._factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedDiagonalFactor, - (inputs, grads_list, self._has_bias)) - - self._damping_func = _package_func(lambda: damping, (damping,)) - - -class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB): - """FisherBlock for 2-D convolutional layers using a diagonal approx. - - Estimates the Fisher Information matrix's diagonal entries for a convolutional - layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" - estimator. - - Let 'params' be a vector parameterizing a model and 'i' an arbitrary index - into it. We are interested in Fisher(params)[i, i]. This is, - - $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] - = E[ v(x, y, params)[i] ^ 2 ]$$ - - Consider a convoluational layer in this model with (unshared) filter matrix - 'w'. For an example image 'x' that produces layer inputs 'a' and output - preactivations 's', - - $$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$ - - where 'loc' is a single (x, y) location in an image. - - This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding - to the layer's parameters 'w'. - """ - - def __init__(self, - layer_collection, - params, - strides, - padding, - data_format=None, - dilations=None): - """Creates a ConvDiagonalFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: The parameters (Tensor or tuple of Tensors) of this layer. If - kernel alone, a Tensor of shape [kernel_height, kernel_width, - in_channels, out_channels]. If kernel and bias, a tuple of 2 elements - containing the previous and a Tensor of shape [out_channels]. - strides: The stride size in this layer (1-D Tensor of length 4). - padding: The padding in this layer (e.g. "SAME"). - data_format: str or None. Format of input data. - dilations: List of 4 ints or None. Rate for dilation along all dimensions. - - Raises: - ValueError: if strides is not length-4. - ValueError: if dilations is not length-4. - ValueError: if channel is not last dimension. - """ - if len(strides) != 4: - raise ValueError("strides must contain 4 numbers.") - - if dilations is None: - dilations = [1, 1, 1, 1] - - if len(dilations) != 4: - raise ValueError("dilations must contain 4 numbers.") - - if not utils.is_data_format_channel_last(data_format): - raise ValueError("data_format must be channels-last.") - - self._strides = maybe_tuple(strides) - self._padding = padding - self._data_format = data_format - self._dilations = maybe_tuple(dilations) - self._has_bias = isinstance(params, (tuple, list)) - - fltr = params[0] if self._has_bias else params - self._filter_shape = tuple(fltr.shape.as_list()) - - if len(self._filter_shape) != 4: - raise ValueError( - "Convolution filter must be of shape" - " [filter_height, filter_width, in_channels, out_channels].") - - super(ConvDiagonalFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs[0].shape.as_list(), - self._strides) - - self._factor = self._layer_collection.make_or_get_factor( - fisher_factors.ConvDiagonalFactor, - (inputs, grads_list, self._filter_shape, self._strides, self._padding, - self._data_format, self._dilations, self._has_bias)) - - def damping_func(): - return self._num_locations * normalize_damping(damping, - self._num_locations) - - damping_id = (self._num_locations, "mult", "normalize_damping", damping, - self._num_locations) - self._damping_func = _package_func(damping_func, damping_id) - - -class KroneckerProductFB(FisherBlock): - """A base class for blocks with separate input and output Kronecker factors. - - The Fisher block is approximated as a Kronecker product of the input and - output factors. - """ - - def _setup_damping(self, damping, normalization=None): - """Makes functions that compute the damping values for both factors.""" - def compute_damping(): - if normalization is not None: - maybe_normalized_damping = normalize_damping(damping, normalization) - else: - maybe_normalized_damping = damping - - return compute_pi_adjusted_damping( - self._input_factor.get_cov_as_linear_operator(), - self._output_factor.get_cov_as_linear_operator(), - maybe_normalized_damping**0.5) - - if normalization is not None: - damping_id = ("compute_pi_adjusted_damping", - "cov", self._input_factor.name, - "cov", self._output_factor.name, - "normalize_damping", damping, normalization, "power", 0.5) - else: - damping_id = ("compute_pi_adjusted_damping", - "cov", self._input_factor.name, - "cov", self._output_factor.name, - damping, "power", 0.5) - - self._input_damping_func = _package_func(lambda: compute_damping()[0], - damping_id + ("ref", 0)) - self._output_damping_func = _package_func(lambda: compute_damping()[1], - damping_id + ("ref", 1)) - - def register_matpower(self, exp): - self._input_factor.register_matpower(exp, self._input_damping_func) - self._output_factor.register_matpower(exp, self._output_damping_func) - - def register_cholesky(self): - self._input_factor.register_cholesky(self._input_damping_func) - self._output_factor.register_cholesky(self._output_damping_func) - - def register_cholesky_inverse(self): - self._input_factor.register_cholesky_inverse(self._input_damping_func) - self._output_factor.register_cholesky_inverse(self._output_damping_func) - - @property - def _renorm_coeff(self): - """Kronecker factor multiplier coefficient. - - If this FisherBlock is represented as 'FB = c * kron(left, right)', then - this is 'c'. - - Returns: - 0-D Tensor. - """ - return 1.0 - - def _multiply_factored_matrix(self, left_factor, right_factor, vector, - extra_scale=1.0, transpose_left=False, - transpose_right=False): - reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = right_factor.matmul_right(reshaped_vector, - adjoint=transpose_right) - reshaped_out = left_factor.matmul(reshaped_out, - adjoint=transpose_left) - if extra_scale != 1.0: - reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype) - return utils.mat2d_to_layer_params(vector, reshaped_out) - - def multiply_matpower(self, vector, exp): - left_factor = self._input_factor.get_matpower( - exp, self._input_damping_func) - right_factor = self._output_factor.get_matpower( - exp, self._output_damping_func) - extra_scale = float(self._renorm_coeff)**exp - return self._multiply_factored_matrix(left_factor, right_factor, vector, - extra_scale=extra_scale) - - def multiply_cholesky(self, vector, transpose=False): - left_factor = self._input_factor.get_cholesky(self._input_damping_func) - right_factor = self._output_factor.get_cholesky(self._output_damping_func) - extra_scale = float(self._renorm_coeff)**0.5 - return self._multiply_factored_matrix(left_factor, right_factor, vector, - extra_scale=extra_scale, - transpose_left=transpose, - transpose_right=not transpose) - - def multiply_cholesky_inverse(self, vector, transpose=False): - left_factor = self._input_factor.get_cholesky_inverse( - self._input_damping_func) - right_factor = self._output_factor.get_cholesky_inverse( - self._output_damping_func) - extra_scale = float(self._renorm_coeff)**-0.5 - return self._multiply_factored_matrix(left_factor, right_factor, vector, - extra_scale=extra_scale, - transpose_left=transpose, - transpose_right=not transpose) - - def full_fisher_block(self): - """Explicitly constructs the full Fisher block. - - Used for testing purposes. (In general, the result may be very large.) - - Returns: - The full Fisher block. - """ - left_factor = self._input_factor.get_cov_as_linear_operator().to_dense() - right_factor = self._output_factor.get_cov_as_linear_operator().to_dense() - return self._renorm_coeff * utils.kronecker_product(left_factor, - right_factor) - - -class EmbeddingKFACFB(InputOutputMultiTower, KroneckerProductFB): - """K-FAC FisherBlock for embedding layers. - - This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its - input factor is approximated by a diagonal matrix. In the case that each - example references exactly one embedding, this approximation is exact. - - Does not support bias parameters. - """ - - def __init__(self, layer_collection, vocab_size): - """Creates a EmbeddingKFACFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - vocab_size: int. Size of vocabulary for this embedding layer. - """ - self._vocab_size = vocab_size - - super(EmbeddingKFACFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - """Instantiate Kronecker Factors for this FisherBlock. - - Args: - grads_list: List of list of Tensors. grads_list[i][j] is the - gradient of the loss with respect to 'outputs' from source 'i' and - tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. - damping: 0-D Tensor or float. 'damping' * identity is approximately added - to this FisherBlock's Fisher approximation. - """ - inputs, grads_list = self._process_data(grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.EmbeddingInputKroneckerFactor, - (inputs, self._vocab_size)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedKroneckerFactor, (grads_list,)) - self._setup_damping(damping) - - -class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB): - """K-FAC FisherBlock for fully-connected (dense) layers. - - This uses the Kronecker-factorized approximation from the original - K-FAC paper (https://arxiv.org/abs/1503.05671) - """ - - def __init__(self, layer_collection, has_bias=False): - """Creates a FullyConnectedKFACBasicFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - has_bias: Whether the component Kronecker factors have an additive bias. - (Default: False) - """ - self._has_bias = has_bias - - super(FullyConnectedKFACBasicFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - """Instantiate Kronecker Factors for this FisherBlock. - - Args: - grads_list: List of list of Tensors. grads_list[i][j] is the - gradient of the loss with respect to 'outputs' from source 'i' and - tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. - damping: 0-D Tensor or float. 'damping' * identity is approximately added - to this FisherBlock's Fisher approximation. - """ - inputs, grads_list = self._process_data(grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedKroneckerFactor, - ((inputs,), self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedKroneckerFactor, - (grads_list,)) - self._setup_damping(damping) - - -class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): - r"""FisherBlock for convolutional layers using the basic KFC approx. - - Estimates the Fisher Information matrix's blog for a convolutional - layer. - - Consider a convolutional layer in this model with (unshared) filter matrix - 'w'. For a minibatch that produces inputs 'a' and output preactivations 's', - this FisherBlock estimates, - - $$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T], - E[flat(ds) flat(ds)^T])$$ - - where - - $$ds = (d / ds) log p(y | x, w)$$ - #locations = number of (x, y) locations where 'w' is applied. - - where the expectation is taken over all examples and locations and flat() - concatenates an array's leading dimensions. - - See equation 23 in https://arxiv.org/abs/1602.01407 for details. - """ - - def __init__(self, - layer_collection, - params, - padding, - strides=None, - dilation_rate=None, - data_format=None, - extract_patches_fn=None): - """Creates a ConvKFCBasicFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: The parameters (Tensor or tuple of Tensors) of this layer. If - kernel alone, a Tensor of shape [..spatial_filter_shape.., - in_channels, out_channels]. If kernel and bias, a tuple of 2 elements - containing the previous and a Tensor of shape [out_channels]. - padding: str. Padding method. - strides: List of ints or None. Contains [..spatial_filter_strides..] if - 'extract_patches_fn' is compatible with tf.nn.convolution(), else - [1, ..spatial_filter_strides, 1]. - dilation_rate: List of ints or None. Rate for dilation along each spatial - dimension if 'extract_patches_fn' is compatible with - tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. - data_format: str or None. Format of input data. - extract_patches_fn: str or None. Name of function that extracts image - patches. One of "extract_convolution_patches", "extract_image_patches", - "extract_pointwise_conv2d_patches". - """ - self._padding = padding - self._strides = maybe_tuple(strides) - self._dilation_rate = maybe_tuple(dilation_rate) - self._data_format = data_format - self._extract_patches_fn = extract_patches_fn - self._has_bias = isinstance(params, (tuple, list)) - - fltr = params[0] if self._has_bias else params - self._filter_shape = tuple(fltr.shape.as_list()) - - super(ConvKFCBasicFB, self).__init__(layer_collection) - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs[0].shape.as_list(), - self._strides) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.ConvInputKroneckerFactor, - (inputs, self._filter_shape, self._padding, self._strides, - self._dilation_rate, self._data_format, self._extract_patches_fn, - self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) - - self._setup_damping(damping, normalization=self._num_locations) - - @property - def _renorm_coeff(self): - return self._num_locations - - -class DepthwiseConvDiagonalFB(ConvDiagonalFB): - """FisherBlock for depthwise_conv2d(). - - Equivalent to ConvDiagonalFB applied to each input channel in isolation. - """ - - def __init__(self, - layer_collection, - params, - strides, - padding, - rate=None, - data_format=None): - """Creates a DepthwiseConvKFCBasicFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: Tensor of shape [filter_height, filter_width, in_channels, - channel_multiplier]. - strides: List of 4 ints. Strides along all dimensions. - padding: str. Padding method. - rate: List of 4 ints or None. Rate for dilation along all dimensions. - data_format: str or None. Format of input data. - - Raises: - NotImplementedError: If parameters contains bias. - ValueError: If filter is not 4-D. - ValueError: If strides is not length-4. - ValueError: If rates is not length-2. - ValueError: If channels are not last dimension. - """ - if isinstance(params, (tuple, list)): - raise NotImplementedError("Bias not yet supported.") - - if params.shape.ndims != 4: - raise ValueError("Filter must be 4-D.") - - if len(strides) != 4: - raise ValueError("strides must account for 4 dimensions.") - - if rate is not None: - if len(rate) != 2: - raise ValueError("rate must only account for spatial dimensions.") - rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. - - if not utils.is_data_format_channel_last(data_format): - raise ValueError("data_format must be channels-last.") - - super(DepthwiseConvDiagonalFB, self).__init__( - layer_collection=layer_collection, - params=params, - strides=strides, - padding=padding, - dilations=rate, - data_format=data_format) - - # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). - filter_height, filter_width, in_channels, channel_multiplier = ( - params.shape.as_list()) - self._filter_shape = (filter_height, filter_width, in_channels, - in_channels * channel_multiplier) - - def _multiply_matrix(self, matrix, vector): - conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) - conv2d_result = super( - DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector) - return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) - - -class DepthwiseConvKFCBasicFB(ConvKFCBasicFB): - """FisherBlock for depthwise_conv2d(). - - Equivalent to ConvKFCBasicFB applied to each input channel in isolation. - """ - - def __init__(self, - layer_collection, - params, - strides, - padding, - rate=None, - data_format=None): - """Creates a DepthwiseConvKFCBasicFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: Tensor of shape [filter_height, filter_width, in_channels, - channel_multiplier]. - strides: List of 4 ints. Strides along all dimensions. - padding: str. Padding method. - rate: List of 4 ints or None. Rate for dilation along all dimensions. - data_format: str or None. Format of input data. - - Raises: - NotImplementedError: If parameters contains bias. - ValueError: If filter is not 4-D. - ValueError: If strides is not length-4. - ValueError: If rates is not length-2. - ValueError: If channels are not last dimension. - """ - if isinstance(params, (tuple, list)): - raise NotImplementedError("Bias not yet supported.") - - if params.shape.ndims != 4: - raise ValueError("Filter must be 4-D.") - - if len(strides) != 4: - raise ValueError("strides must account for 4 dimensions.") - - if rate is not None: - if len(rate) != 2: - raise ValueError("rate must only account for spatial dimensions.") - rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate. - - if not utils.is_data_format_channel_last(data_format): - raise ValueError("data_format must be channels-last.") - - super(DepthwiseConvKFCBasicFB, self).__init__( - layer_collection=layer_collection, - params=params, - padding=padding, - strides=strides, - dilation_rate=rate, - data_format=data_format, - extract_patches_fn="extract_image_patches") - - # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__(). - filter_height, filter_width, in_channels, channel_multiplier = ( - params.shape.as_list()) - self._filter_shape = (filter_height, filter_width, in_channels, - in_channels * channel_multiplier) - - def _multiply_factored_matrix(self, left_factor, right_factor, vector, - extra_scale=1.0, transpose_left=False, - transpose_right=False): - conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) - conv2d_result = super( - DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix( - left_factor, right_factor, conv2d_vector, extra_scale=extra_scale, - transpose_left=transpose_left, transpose_right=transpose_right) - return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) - - -def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin - """Converts a convolution filter for use with conv2d. - - Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's - compatible with tf.nn.conv2d(). - - Args: - filter: Tensor of shape [height, width, in_channels, channel_multiplier]. - name: None or str. Name of Op. - - Returns: - Tensor of shape [height, width, in_channels, out_channels]. - - """ - with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter", - [filter]): - filter = ops.convert_to_tensor(filter) - filter_height, filter_width, in_channels, channel_multiplier = ( - filter.shape.as_list()) - - results = [] - for i in range(in_channels): - # Slice out one in_channel's filter. Insert zeros around it to force it - # to affect that channel and that channel alone. - elements = [] - if i > 0: - elements.append( - array_ops.zeros( - [filter_height, filter_width, i, channel_multiplier])) - elements.append(filter[:, :, i:(i + 1), :]) - if i + 1 < in_channels: - elements.append( - array_ops.zeros([ - filter_height, filter_width, in_channels - (i + 1), - channel_multiplier - ])) - - # Concat along in_channel. - results.append( - array_ops.concat(elements, axis=-2, name="in_channel_%d" % i)) - - # Concat along out_channel. - return array_ops.concat(results, axis=-1, name="out_channel") - - -def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin - """Converts a convolution filter for use with depthwise_conv2d. - - Transforms a filter for use with tf.nn.conv2d() to one that's - compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along - the diagonal. - - Args: - filter: Tensor of shape [height, width, in_channels, out_channels]. - name: None or str. Name of Op. - - Returns: - Tensor of shape, - [height, width, in_channels, channel_multiplier] - - Raises: - ValueError: if out_channels is not evenly divisible by in_channels. - """ - with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter", - [filter]): - filter = ops.convert_to_tensor(filter) - filter_height, filter_width, in_channels, out_channels = ( - filter.shape.as_list()) - - if out_channels % in_channels != 0: - raise ValueError("out_channels must be evenly divisible by in_channels.") - channel_multiplier = out_channels // in_channels - - results = [] - filter = array_ops.reshape(filter, [ - filter_height, filter_width, in_channels, in_channels, - channel_multiplier - ]) - for i in range(in_channels): - # Slice out output corresponding to the correct filter. - filter_slice = array_ops.reshape( - filter[:, :, i, i, :], - [filter_height, filter_width, 1, channel_multiplier]) - results.append(filter_slice) - - # Concat along out_channel. - return array_ops.concat(results, axis=-2, name="in_channels") - - -def maybe_tuple(obj): - if not isinstance(obj, list): - return obj - return tuple(obj) - - -def num_conv_locations(input_shape, strides): - """Returns the number of spatial locations a 2D Conv kernel is applied to. - - Args: - input_shape: List of ints representing shape of inputs to - tf.nn.convolution(). - strides: List of ints representing strides along spatial dimensions as - passed in to tf.nn.convolution(). - - Returns: - A scalar |T| denoting the number of spatial locations for the Conv layer. - """ - spatial_input_locations = np.prod(input_shape[1:-1]) - - if strides is None: - spatial_strides_divisor = 1 - else: - spatial_strides_divisor = np.prod(strides) - - return spatial_input_locations // spatial_strides_divisor - - -class InputOutputMultiTowerMultiUse(InputOutputMultiTower): - """Adds methods for multi-use/time-step case to InputOutputMultiTower.""" - - def __init__(self, num_uses=None, *args, **kwargs): - self._num_uses = num_uses - super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs) - - def _process_data(self, grads_list): - """Process temporal/multi-use data into the format used by the factors. - - This function takes inputs and grads_lists data and processes it into - one of the formats expected by the FisherFactor classes (depending on - the value of the global configuration variable TOWER_STRATEGY). - - It accepts the data in one of two initial formats. The first possible - format is where self._inputs is a list of list of Tensors. The first index - is tower, the second is use/time-step. grads_list, meanwhile, is a list - over sources of such lists of lists. - - The second possible data format is where self._inputs is a Tensor with - uses/times-steps folded into the batch dimension. i.e. it is a Tensor - of shape [num_uses * size_batch, ...] which represents a reshape of a - Tensor of shape [num_uses, size_batch, ...]. And similarly grads_list is - a list over sources of such Tensors. - - There are two possible formats which inputs and grads_list are transformed - into. - - If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing - a single tensor (represented as a PartitionedTensor object) with all of - the data from the towers, as well as the uses/time-steps, concatenated - together. In this tensor the leading dimension is the batch and - use/time-step dimensions folded together (with 'use' being the major of - these two, so that the tensors can be thought of as reshapes of ones of - shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a - tuple over sources of such tensors. - - If TOWER_STRATEGY is "separate" the inputs are formatted into lists of - tensors over towers. Each of these tensors has a similar format to - the tensor produced by the "concat" option, except that each contains - only the data from a single tower. grads_list is similarly formatted - into a tuple over sources of such tuples. - - Args: - grads_list: grads_list in its initial format (see above). - - Returns: - inputs: self._inputs transformed into the appropriate format (see - above). - grads_list: grads_list transformed into the appropriate format (see - above). - - Raises: - ValueError: If TOWER_STRATEGY is not one of "separate" or "concat". - ValueError: If the given/initial format of self._inputs and grads_list - isn't recognized, or doesn't agree with self._num_uses. - """ - - inputs = self._inputs - - if isinstance(inputs[0], (list, tuple)): - num_uses = len(inputs[0]) - if self._num_uses is not None and self._num_uses != num_uses: - raise ValueError("num_uses argument doesn't match length of inputs.") - else: - self._num_uses = num_uses - - # Check that all mini-batches/towers have the same number of uses - if not all(len(input_) == num_uses for input_ in inputs): - raise ValueError("Length of inputs argument is inconsistent across " - "towers.") - - if fisher_factors.TOWER_STRATEGY == "concat": - # Reverse the tower and use/time-step indices, so that use is now first, - # and towers is second - inputs = tuple(zip(*inputs)) - - # Flatten the two dimensions - inputs = nest.flatten(inputs) - - # Merge everything together into a PartitionedTensor. We package it in - # a singleton tuple since the factors will expect a list over towers - inputs = (utils.PartitionedTensor(inputs),) - - elif fisher_factors.TOWER_STRATEGY == "separate": - # Merge together the uses/time-step dimension into PartitionedTensors, - # but keep the leading dimension (towers) intact for the factors to - # process individually. - inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs) - - else: - raise ValueError("Global config variable TOWER_STRATEGY must be one of " - "'concat' or 'separate'.") - else: - inputs = tuple(inputs) - - # Now we perform the analogous processing for grads_list - if isinstance(grads_list[0][0], (list, tuple)): - num_uses = len(grads_list[0][0]) - if self._num_uses is not None and self._num_uses != num_uses: - raise ValueError("num_uses argument doesn't match length of outputs, " - "or length of outputs is inconsistent with length of " - "inputs.") - else: - self._num_uses = num_uses - - if not all(len(grad) == num_uses for grads in grads_list - for grad in grads): - raise ValueError("Length of outputs argument is inconsistent across " - "towers.") - - if fisher_factors.TOWER_STRATEGY == "concat": - # Reverse the tower and use/time-step indices, so that use is now first, - # and towers is second - grads_list = tuple(tuple(zip(*grads)) for grads in grads_list) - - # Flatten the two dimensions, leaving the leading dimension (source) - # intact - grads_list = tuple(nest.flatten(grads) for grads in grads_list) - - # Merge inner dimensions together into PartitionedTensors. We package - # them in a singleton tuple since the factors will expect a list over - # towers - grads_list = tuple((utils.PartitionedTensor(grads),) - for grads in grads_list) - - elif fisher_factors.TOWER_STRATEGY == "separate": - # Merge together the uses/time-step dimension into PartitionedTensors, - # but keep the leading dimension (towers) intact for the factors to - # process individually. - grads_list = tuple(tuple(utils.PartitionedTensor(grad) - for grad in grads) - for grads in grads_list) - - else: - raise ValueError("Global config variable TOWER_STRATEGY must be one of " - "'concat' or 'separate'.") - else: - grads_list = tuple(tuple(grads) for grads in grads_list) - - if self._num_uses is None: - raise ValueError("You must supply a value for the num_uses argument if " - "the number of uses cannot be inferred from inputs or " - "outputs arguments (e.g. if they are both given in the " - "single Tensor format, instead of as lists of Tensors.") - - return inputs, grads_list - - -class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse, - KroneckerProductFB): - """FisherBlock for fully-connected layers that share parameters. - - This class implements the "independence across time" approximation from the - following paper: - https://openreview.net/pdf?id=HyMTkQZAb - """ - - def __init__(self, layer_collection, has_bias=False, num_uses=None): - """Creates a FullyConnectedMultiIndepFB block. - - Args: - layer_collection: LayerCollection instance. - has_bias: bool. If True, estimates Fisher with respect to a bias - parameter as well as the layer's parameters. - num_uses: int or None. Number of uses of the layer in the model's graph. - Only required if the data is formatted with uses/time folded into the - batch dimension (instead of uses/time being a list dimension). - (Default: None) - """ - self._has_bias = has_bias - - super(FullyConnectedMultiIndepFB, self).__init__( - layer_collection=layer_collection, - num_uses=num_uses) - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, - ((inputs,), self._num_uses, self._has_bias)) - - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) - - self._setup_damping(damping, normalization=self._num_uses) - - @property - def _renorm_coeff(self): - return float(self._num_uses) - - -class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse, - KroneckerProductFB): - """FisherBlock for 2D convolutional layers using the basic KFC approx. - - Similar to ConvKFCBasicFB except that this version supports multiple - uses/time-steps via a standard independence approximation. Similar to the - "independence across time" used in FullyConnectedMultiIndepFB but generalized - in the obvious way to conv layers. - """ - - def __init__(self, - layer_collection, - params, - padding, - strides=None, - dilation_rate=None, - data_format=None, - extract_patches_fn=None, - num_uses=None): - """Creates a ConvKFCBasicMultiIndepFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - params: The parameters (Tensor or tuple of Tensors) of this layer. If - kernel alone, a Tensor of shape [..spatial_filter_shape.., - in_channels, out_channels]. If kernel and bias, a tuple of 2 elements - containing the previous and a Tensor of shape [out_channels]. - padding: str. Padding method. - strides: List of ints or None. Contains [..spatial_filter_strides..] if - 'extract_patches_fn' is compatible with tf.nn.convolution(), else - [1, ..spatial_filter_strides, 1]. - dilation_rate: List of ints or None. Rate for dilation along each spatial - dimension if 'extract_patches_fn' is compatible with - tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. - data_format: str or None. Format of input data. - extract_patches_fn: str or None. Name of function that extracts image - patches. One of "extract_convolution_patches", "extract_image_patches", - "extract_pointwise_conv2d_patches". - num_uses: int or None. Number of uses of the layer in the model's graph. - Only required if the data is formatted with uses/time folded into the - batch dimension (instead of uses/time being a list dimension). - (Default: None) - """ - self._padding = padding - self._strides = maybe_tuple(strides) - self._dilation_rate = maybe_tuple(dilation_rate) - self._data_format = data_format - self._extract_patches_fn = extract_patches_fn - self._has_bias = isinstance(params, (tuple, list)) - - fltr = params[0] if self._has_bias else params - self._filter_shape = tuple(fltr.shape.as_list()) - - super(ConvKFCBasicMultiIndepFB, self).__init__( - layer_collection=layer_collection, - num_uses=num_uses) - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs[0].shape.as_list(), - self._strides) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.ConvInputKroneckerFactor, - (inputs, self._filter_shape, self._padding, self._strides, - self._dilation_rate, self._data_format, self._extract_patches_fn, - self._has_bias)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) - - self._setup_damping(damping, normalization= - (self._num_locations * self._num_uses)) - - @property - def _renorm_coeff(self): - return self._num_locations * self._num_uses - - -class EmbeddingKFACMultiIndepFB(InputOutputMultiTowerMultiUse, - KroneckerProductFB): - """K-FAC FisherBlock for embedding layers used multiple times in the graph. - - Similar to EmbeddingKFACFB except that this version supports multiple uses - of the parameter within a single model. These uses could correspond to time - steps in an RNN architecture, but they don't have to. - - Does not support bias parameters. - """ - - def __init__(self, layer_collection, vocab_size, num_uses=None): - """Creates a EmbeddingKFACMultiIndepFB block. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - vocab_size: int. Size of vocabulary for this embedding layer. - num_uses: int or None. Number of uses of the layer in the model's graph. - Only required if the data is formatted with time folded into the batch - dimension (instead of time being a list dimension). (Default: None) - """ - self._vocab_size = vocab_size - - super(EmbeddingKFACMultiIndepFB, self).__init__( - layer_collection=layer_collection, - num_uses=num_uses) - - def instantiate_factors(self, grads_list, damping): - """Instantiate Kronecker Factors for this FisherBlock. - - Args: - grads_list: List of list of list of Tensors. grads_list[i][j][k] is the - gradient of the loss with respect to 'outputs' from source 'i', - tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape - [tower_minibatch_size, output_size]. - damping: 0-D Tensor or float. 'damping' * identity is approximately added - to this FisherBlock's Fisher approximation. - """ - inputs, grads_list = self._process_data(grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.EmbeddingInputKroneckerFactor, - (inputs, self._vocab_size)) - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) - self._setup_damping(damping, normalization=self._num_uses) - - @property - def _renorm_coeff(self): - return float(self._num_uses) - - -class SeriesFBApproximation(enum.IntEnum): - """See FullyConnectedSeriesFB.__init__ for description and usage.""" - option1 = 1 - option2 = 2 - - -class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, - KroneckerProductFB): - """FisherBlock for fully-connected layers that share parameters across time. - - This class implements the "Option 1" and "Option 2" approximation from the - following paper: - https://openreview.net/pdf?id=HyMTkQZAb - - See the end of the appendix of the paper for a pseudo-code of the - algorithm being implemented by multiply_matpower here. Note that we are - using pre-computed versions of certain matrix-matrix products to speed - things up. This is explicitly explained wherever it is done. - """ - - def __init__(self, - layer_collection, - has_bias=False, - num_uses=None, - option=SeriesFBApproximation.option2): - """Constructs a new `FullyConnectedSeriesFB`. - - Args: - layer_collection: The collection of all layers in the K-FAC approximate - Fisher information matrix to which this FisherBlock belongs. - has_bias: Whether the layer includes a bias parameter. - num_uses: int or None. Number of time-steps over which the layer - is used. Only required if the data is formatted with time folded into - the batch dimension (instead of time being a list dimension). - (Default: None) - option: A `SeriesFBApproximation` specifying the simplifying assumption - to be used in this block. `option1` approximates the cross-covariance - over time as a symmetric matrix, while `option2` makes - the assumption that training sequences are infinitely long. See section - 3.5 of the paper for more details. - """ - - self._has_bias = has_bias - self._option = option - - super(FullyConnectedSeriesFB, self).__init__( - layer_collection=layer_collection, - num_uses=num_uses) - - @property - def _num_timesteps(self): - return self._num_uses - - @property - def _renorm_coeff(self): - # This should no longer be used since the multiply_X functions from the base - # class have been overridden - assert False - - def instantiate_factors(self, grads_list, damping): - inputs, grads_list = self._process_data(grads_list) - - self._input_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, - ((inputs,), self._num_uses, self._has_bias)) - self._input_factor.register_cov_dt1() - - self._output_factor = self._layer_collection.make_or_get_factor( - fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses)) - self._output_factor.register_cov_dt1() - - self._setup_damping(damping, normalization=self._num_uses) - - def register_matpower(self, exp): - if exp != -1: - raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" - "multiplications.") - - if self._option == SeriesFBApproximation.option1: - self._input_factor.register_option1quants(self._input_damping_func) - self._output_factor.register_option1quants(self._output_damping_func) - elif self._option == SeriesFBApproximation.option2: - self._input_factor.register_option2quants(self._input_damping_func) - self._output_factor.register_option2quants(self._output_damping_func) - else: - raise ValueError( - "Unrecognized FullyConnectedSeriesFB approximation: {}".format( - self._option)) - - def multiply_matpower(self, vector, exp): - if exp != -1: - raise NotImplementedError("FullyConnectedSeriesFB only supports inverse" - "multiplications.") - - # pylint: disable=invalid-name - - Z = utils.layer_params_to_mat2d(vector) - - # Derivations were done for "batch_dim==1" case so we need to convert to - # that orientation: - Z = array_ops.transpose(Z) - - if self._option == SeriesFBApproximation.option1: - - # Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\) - L_A, psi_A = self._input_factor.get_option1quants( - self._input_damping_func) - L_G, psi_G = self._output_factor.get_option1quants( - self._output_damping_func) - - def gamma(x): - # We are assuming that each case has the same number of time-steps. - # If this stops being the case one shouldn't simply replace this T - # with its average value. Instead, one needs to go back to the - # definition of the gamma function from the paper. - T = self._num_timesteps - return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T)) - - # \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise) - # Even though Y is Z-independent we are recomputing it from the psi's - # each since Y depends on both A and G quantities, and it is relatively - # cheap to compute. - Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A) - - # \\(Z = L_G^T * Z * L_A\\) - # This is equivalent to the following computation from the original - # pseudo-code: - # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) - # \\(Z = U_G^T * Z * U_A\\) - Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True) - - # \\(Z = Z .* Y\\) - Z *= Y - - # \\(Z = L_G * Z * L_A^T\\) - # This is equivalent to the following computation from the original - # pseudo-code: - # \\(Z = U_G * Z * U_A^T\\) - # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) - Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True)) - - elif self._option == SeriesFBApproximation.option2: - - # Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\), - # and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\) - P_A, K_A, mu_A = self._input_factor.get_option2quants( - self._input_damping_func) - P_G, K_G, mu_G = self._output_factor.get_option2quants( - self._output_damping_func) - - # Our approach differs superficially from the pseudo-code in the paper - # in order to reduce the total number of matrix-matrix multiplies. - # In particular, the first three computations in the pseudo code are - # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\) - # \\(Z = Z - hPsi_G^T * Z * hPsi_A\\) - # \\(Z = E_G^T * Z * E_A\\) - # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that - # \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\) - # the entire computation can be written as - # \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) - # \\( - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\) - # \\( = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\) - # \\( - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\) - # \\( = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\) - # \\( - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\) - # \\( = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A\\) - # This final expression is computed by the following two lines: - # \\(Z = Z - P_G * Z * P_A^T\\) - Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True)) - # \\(Z = K_G^T * Z * K_A\\) - Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True) - - # \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\) - # Be careful with the outer product. We don't want to accidentally - # make it an inner-product instead. - tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A - # Prevent some numerical issues by setting any 0.0 eigs to 1.0 - tmp += 1.0 * math_ops.cast(math_ops.equal(tmp, 0.0), dtype=tmp.dtype) - Z /= tmp - - # We now perform the transpose/reverse version of the operations - # derived above, whose derivation from the original pseudo-code is - # analgous. - # \\(Z = K_G * Z * K_A^T\\) - Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True)) - - # \\(Z = Z - P_G^T * Z * P_A\\) - Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True) - - # \\(Z = normalize (1/E[T]) * Z\\) - # Note that this normalization is done because we compute the statistics - # by averaging, not summing, over time. (And the gradient is presumably - # summed over time, not averaged, and thus their scales are different.) - Z /= math_ops.cast(self._num_timesteps, Z.dtype) - - # Convert back to the "batch_dim==0" orientation. - Z = array_ops.transpose(Z) - - return utils.mat2d_to_layer_params(vector, Z) - - # pylint: enable=invalid-name - - def multiply_cholesky(self, vector): - raise NotImplementedError("FullyConnectedSeriesFB does not support " - "Cholesky computations.") - - def multiply_cholesky_inverse(self, vector): - raise NotImplementedError("FullyConnectedSeriesFB does not support " - "Cholesky computations.") - diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py deleted file mode 100644 index c04cf727fa958160d61c7a3638ec65f6c93c2f24..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""FisherBlock definitions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.fisher_blocks import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - 'FisherBlock', - 'FullFB', - 'NaiveDiagonalFB', - 'FullyConnectedDiagonalFB', - 'KroneckerProductFB', - 'EmbeddingKFACFB', - 'FullyConnectedKFACBasicFB', - 'ConvKFCBasicFB', - 'ConvDiagonalFB', - 'set_global_constants', - 'compute_pi_tracenorm', - 'compute_pi_adjusted_damping', - 'num_conv_locations', - 'normalize_damping', - 'LEFT_MULTIPLY', - 'RIGHT_MULTIPLY', -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py deleted file mode 100644 index afa2fd1ca72d703e42a9beaac2c86964e22de3e3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ /dev/null @@ -1,1830 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""FisherFactor definitions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import abc -import contextlib - -import numpy as np -import six - -from tensorflow.contrib.kfac.python.ops import linear_operator as lo -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import special_math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.training import moving_averages -from tensorflow.python.util import nest - - -# Whether to initialize covariance estimators at a zero matrix (or the identity -# matrix). -INIT_COVARIANCES_AT_ZERO = True - -# Whether to zero-debias the moving averages. -ZERO_DEBIAS = True - -# Whether to initialize inverse (and other such matrices computed from the cov -# matrices) to the zero matrix (or the identity matrix). -INIT_INVERSES_AT_ZERO = True - -# When the number of inverses requested from a FisherFactor exceeds this value, -# the inverses are computed using an eigenvalue decomposition. -EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 - -# Numerical eigenvalues computed from covariance matrix estimates are clipped to -# be at least as large as this value before they are used to compute inverses or -# matrix powers. Must be nonnegative. -EIGENVALUE_CLIPPING_THRESHOLD = 0.0 - -# Used to subsample the flattened extracted image patches. The number of -# outer products per row of the covariance matrix should not exceed this -# value. This parameter is used only if `_SUB_SAMPLE_OUTER_PRODUCTS` is True. -_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = 1 - -# Used to subsample the inputs passed to the extract image patches. The batch -# size of number of inputs to extract image patches is multiplied by this -# factor. This parameter is used only if `_SUB_SAMPLE_INPUTS` is True. -_INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.5 - -# If True, then subsamples the tensor passed to compute the covariance matrix. -_SUB_SAMPLE_OUTER_PRODUCTS = False - -# If True, then subsamples the tensor passed to compute the covariance matrix. -_SUB_SAMPLE_INPUTS = False - -# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data -# passed to the factors from the blocks will be concatenated across towers -# (lazily via PartitionedTensor objects). Otherwise a tuple of tensors over -# towers will be passed in, and the factors will iterate over this and do the -# cov computations separately for each one, averaging the results together. -TOWER_STRATEGY = "concat" - - -def set_global_constants(init_covariances_at_zero=None, - zero_debias=None, - init_inverses_at_zero=None, - eigenvalue_decomposition_threshold=None, - eigenvalue_clipping_threshold=None, - max_num_outer_products_per_cov_row=None, - sub_sample_outer_products=None, - inputs_to_extract_patches_factor=None, - sub_sample_inputs=None, - tower_strategy=None): - """Sets various global constants used by the classes in this module.""" - global INIT_COVARIANCES_AT_ZERO - global ZERO_DEBIAS - global INIT_INVERSES_AT_ZERO - global EIGENVALUE_DECOMPOSITION_THRESHOLD - global EIGENVALUE_CLIPPING_THRESHOLD - global _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW - global _SUB_SAMPLE_OUTER_PRODUCTS - global _INPUTS_TO_EXTRACT_PATCHES_FACTOR - global _SUB_SAMPLE_INPUTS - global TOWER_STRATEGY - - if init_covariances_at_zero is not None: - INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero - if zero_debias is not None: - ZERO_DEBIAS = zero_debias - if init_inverses_at_zero is not None: - INIT_INVERSES_AT_ZERO = init_inverses_at_zero - if eigenvalue_decomposition_threshold is not None: - EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold - if eigenvalue_clipping_threshold is not None: - EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold - if max_num_outer_products_per_cov_row is not None: - _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = max_num_outer_products_per_cov_row - if sub_sample_outer_products is not None: - _SUB_SAMPLE_OUTER_PRODUCTS = sub_sample_outer_products - if inputs_to_extract_patches_factor is not None: - _INPUTS_TO_EXTRACT_PATCHES_FACTOR = inputs_to_extract_patches_factor - if sub_sample_inputs is not None: - _SUB_SAMPLE_INPUTS = sub_sample_inputs - if tower_strategy is not None: - TOWER_STRATEGY = tower_strategy - - -def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument - if INIT_INVERSES_AT_ZERO: - return array_ops.zeros(shape, dtype=dtype) - return linalg_ops.eye(num_rows=shape[0], dtype=dtype) - - -def covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument - if INIT_COVARIANCES_AT_ZERO: - return array_ops.zeros(shape, dtype=dtype) - return linalg_ops.eye(num_rows=shape[0], dtype=dtype) - - -def diagonal_covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument - if INIT_COVARIANCES_AT_ZERO: - return array_ops.zeros(shape, dtype=dtype) - return array_ops.ones(shape, dtype=dtype) - - -@contextlib.contextmanager -def place_on_device(device): - if device is not None and len(device): - with tf_ops.device(device): - yield - else: - yield - - -def compute_cov(tensor, tensor_right=None, normalizer=None): - """Compute the empirical second moment of the rows of a 2D Tensor. - - This function is meant to be applied to random matrices for which the true row - mean is zero, so that the true second moment equals the true covariance. - - Args: - tensor: A 2D Tensor. - tensor_right: An optional 2D Tensor. If provided, this function computes - the matrix product tensor^T * tensor_right instead of tensor^T * tensor. - normalizer: optional scalar for the estimator (by default, the normalizer is - the number of rows of tensor). - - Returns: - A square 2D Tensor with as many rows/cols as the number of input columns. - """ - if normalizer is None: - normalizer = array_ops.shape(tensor)[0] - if tensor_right is None: - cov = ( - math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast( - normalizer, tensor.dtype)) - return (cov + array_ops.transpose(cov)) / math_ops.cast(2.0, cov.dtype) - else: - return (math_ops.matmul(tensor, tensor_right, transpose_a=True) / - math_ops.cast(normalizer, tensor.dtype)) - - -def append_homog(tensor): - """Appends a homogeneous coordinate to the last dimension of a Tensor. - - Args: - tensor: A Tensor. - - Returns: - A Tensor identical to the input but one larger in the last dimension. The - new entries are filled with ones. - """ - rank = len(tensor.shape.as_list()) - shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0) - ones = array_ops.ones(shape, dtype=tensor.dtype) - return array_ops.concat([tensor, ones], axis=rank - 1) - - -def scope_string_from_params(params): - """Builds a variable scope string name from the given parameters. - - Supported parameters are: - * tensors - * booleans - * ints - * strings - * depth-1 tuples/lists of ints - * any depth tuples/lists of tensors - Other parameter types will throw an error. - - Args: - params: A parameter or list of parameters. - - Returns: - A string to use for the variable scope. - - Raises: - ValueError: if params includes an unsupported type. - """ - params = params if isinstance(params, (tuple, list)) else (params,) - - name_parts = [] - for param in params: - if param is None: - name_parts.append("None") - elif isinstance(param, (tuple, list)): - if all([isinstance(p, int) for p in param]): - name_parts.append("-".join([str(p) for p in param])) - else: - name_parts.append(scope_string_from_name(param)) - elif isinstance(param, (str, int, bool)): - name_parts.append(str(param)) - elif isinstance(param, (tf_ops.Tensor, variables.Variable)): - name_parts.append(scope_string_from_name(param)) - elif isinstance(param, utils.PartitionedTensor): - name_parts.append(scope_string_from_name(param.tensors)) - else: - raise ValueError("Encountered an unsupported param type {}".format( - type(param))) - return "_".join(name_parts) - - -def scope_string_from_name(tensor): - if isinstance(tensor, (tuple, list)): - return "__".join([scope_string_from_name(t) for t in tensor]) - # "gradients/add_4_grad/Reshape:0" -> "gradients_add_4_grad_Reshape" - return tensor.name.split(":")[0].replace("/", "_") - - -def scalar_or_tensor_to_string(val): - return repr(val) if np.isscalar(val) else scope_string_from_name(val) - - -def list_to_string(lst): - return "_".join(val if isinstance(val, six.string_types) - else scalar_or_tensor_to_string(val) for val in lst) - - -def graph_func_to_id(func): - """Returns a hashable object that represents func's computation.""" - # TODO(b/74201126): replace with Topohash of func's output - return func.func_id - - -def graph_func_to_string(func): - # TODO(b/74201126): replace with Topohash of func's output - return list_to_string(func.func_id) - - -def _subsample_for_cov_computation(array, name=None): - """Subsamples the first dimension of the array. - - `array`(A) is a tensor of shape `[batch_size, dim_2]`. Then the covariance - matrix(A^TA) is of shape `dim_2 ** 2`. Subsample only if the number of outer - products per row of the covariance matrix is greater than - `_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW`. - - Args: - array: Tensor, of shape `[batch_size, dim_2]`. - name: `string`, Default(None) - - Returns: - A tensor of shape `[max_samples, dim_2]`. - - Raises: - ValueError: If array's is not matrix-shaped. - ValueError: If array's batch_size cannot be inferred. - - """ - with tf_ops.name_scope(name, "subsample", [array]): - array = tf_ops.convert_to_tensor(array) - if len(array.shape) != 2: - raise ValueError("Input param array must be a matrix.") - - batch_size = array.shape.as_list()[0] - if batch_size is None: - raise ValueError("Unable to get batch_size from input param array.") - - num_cov_rows = array.shape.as_list()[-1] - max_batch_size = int(_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW * num_cov_rows) - if batch_size <= max_batch_size: - return array - - return _random_tensor_gather(array, max_batch_size) - - -def _random_tensor_gather(array, max_size): - """Generates a random set of indices and gathers the value at the indices. - - Args: - array: Tensor, of shape `[batch_size, dim_2]`. - max_size: int, Number of indices to sample. - - Returns: - A tensor of shape `[max_size, ...]`. - """ - batch_size = array.shape.as_list()[0] - indices = random_ops.random_shuffle(math_ops.range(0, batch_size))[:max_size] - return array_ops.gather(array, indices) - - -@six.add_metaclass(abc.ABCMeta) -class FisherFactor(object): - """Base class for objects modeling factors of approximate Fisher blocks. - - A FisherFactor represents part of an approximate Fisher Information matrix. - For example, one approximation to the Fisher uses the Kronecker product of two - FisherFactors A and B, F = kron(A, B). FisherFactors are composed with - FisherBlocks to construct a block-diagonal approximation to the full Fisher. - - FisherFactors are backed by a single, non-trainable variable that is updated - by running FisherFactor.make_covariance_update_op(). The shape and type of - this variable is implementation specific. - - Note that for blocks that aren't based on approximations, a 'factor' can - be the entire block itself, as is the case for the diagonal and full - representations. - """ - - def __init__(self): - self._cov = None - - @abc.abstractproperty - def _var_scope(self): - """Variable scope for this FisherFactor instance. - - Returns: - string that unique identifies this FisherFactor instance. - """ - pass - - @property - def name(self): - return self._var_scope - - @abc.abstractproperty - def _cov_shape(self): - """The shape of the variable backing this FisherFactor.""" - pass - - @abc.abstractproperty - def _num_sources(self): - """The number of things to sum over when updating covariance variable. - - The default make_covariance_update_op function will call _compute_new_cov - with indices ranging from 0 to _num_sources-1. The typical situation is - where the factor wants to sum the statistics it computes over multiple - backpropped "gradients" (typically passed in via "tensors" or - "outputs_grads" arguments). - """ - pass - - @abc.abstractproperty - def _num_towers(self): - pass - - @abc.abstractproperty - def _dtype(self): - """dtype for variable backing this factor.""" - pass - - @property - def _cov_initializer(self): - """Function for initializing covariance variable.""" - return covariance_initializer - - def instantiate_cov_variables(self): - """Makes the internal cov variable(s).""" - assert self._cov is None - with variable_scope.variable_scope(self._var_scope): - self._cov = variable_scope.get_variable( - "cov", - initializer=self._cov_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - - @abc.abstractmethod - def _compute_new_cov(self, source, tower): - """Computes minibatch-estimated covariance for a single source. - - Args: - source: int in [0, self._num_sources). Which source to use when computing - the cov update. - tower: int in [0, self._num_towers). Which tower to use when computing - the cov update. - - Returns: - Tensor of same shape as self.get_cov(). - """ - pass - - def make_covariance_update_op(self, ema_decay): - """Constructs and returns the covariance update Op. - - Args: - ema_decay: The exponential moving average decay (float or Tensor). - Returns: - An Op for updating the covariance Variable referenced by _cov. - """ - new_cov_contribs = [] - for source in range(self._num_sources): - for tower in range(self._num_towers): - device = (self._get_data_device(tower) - if TOWER_STRATEGY == "separate" else None) - with place_on_device(device): - new_cov_contribs.append(self._compute_new_cov(source, tower)) - - new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers) - - # Compute average of 'new_cov' across all TPU cores. On a TPU, each - # instance of 'new_cov' will be based on a different minibatch. This ensures - # that by the end of assign_moving_average(), all TPU cores see the same - # value for self._cov. - # - # Other implementations of make_covariance_update_op() that accumulate - # statistics in other variables should mimic this behavior. - if utils.on_tpu(): - new_cov = utils.cross_replica_mean(new_cov) - - return moving_averages.assign_moving_average( - self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) - - @abc.abstractmethod - def _get_data_device(self, tower): - pass - - @abc.abstractmethod - def instantiate_inv_variables(self): - """Makes the internal "inverse" variable(s).""" - pass - - @abc.abstractmethod - def make_inverse_update_ops(self): - """Create and return update ops corresponding to registered computations.""" - pass - - def get_cov(self): - return self._cov - - @abc.abstractmethod - def get_cov_as_linear_operator(self): - pass - - @abc.abstractmethod - def register_matpower(self, exp, damping_func): - pass - - @abc.abstractmethod - def register_cholesky(self, damping_func): - pass - - @abc.abstractmethod - def register_cholesky_inverse(self, damping_func): - pass - - @abc.abstractmethod - def get_matpower(self, exp, damping_func): - pass - - @abc.abstractmethod - def get_cholesky(self, damping_func): - pass - - @abc.abstractmethod - def get_cholesky_inverse(self, damping_func): - pass - - -class DenseSquareMatrixFactor(FisherFactor): - """Base class for FisherFactors that are stored as dense square matrices. - - This class explicitly calculates and stores inverses of their `cov` matrices, - which must be square dense matrices. - - Subclasses must implement the _compute_new_cov method, and the _var_scope and - _cov_shape properties. - """ - - # TODO(b/69108481): This class (and its subclasses) should be refactored to - # serve the matrix quantities it computes as both (potentially stale) - # variables, updated by the inverse update ops, and fresh values stored in - # tensors that recomputed once every session.run() call. Currently matpower - # and damp_inverse have the former behavior, while eigendecomposition has - # the latter. - - def __init__(self): - self._matpower_by_exp_and_damping = {} # { (float, hashable): variable } - self._matpower_registrations = set() # { (float, hashable) } - self._eigendecomp = None - self._damping_funcs_by_id = {} # {hashable: lambda} - - self._cholesky_registrations = set() # { hashable } - self._cholesky_inverse_registrations = set() # { hashable } - - self._cholesky_by_damping = {} # { hashable: variable } - self._cholesky_inverse_by_damping = {} # { hashable: variable } - - super(DenseSquareMatrixFactor, self).__init__() - - def get_cov_as_linear_operator(self): - assert self.get_cov().shape.ndims == 2 - return lo.LinearOperatorFullMatrix(self.get_cov(), - is_self_adjoint=True, - is_square=True) - - def _register_damping(self, damping_func): - damping_id = graph_func_to_id(damping_func) - if damping_id not in self._damping_funcs_by_id: - self._damping_funcs_by_id[damping_id] = damping_func - return damping_id - - def register_inverse(self, damping_func): - # Just for backwards compatibility of some old code and tests - self.register_matpower(-1, damping_func) - - def register_matpower(self, exp, damping_func): - """Registers a matrix power to be maintained and served on demand. - - This creates a variable and signals make_inverse_update_ops to make the - corresponding update op. The variable can be read via the method - get_matpower. - - Args: - exp: float. The exponent to use in the matrix power. - damping_func: A function that computes a 0-D Tensor or a float which will - be the damping value used. i.e. damping = damping_func(). - """ - if exp == 1.0: - return - - damping_id = self._register_damping(damping_func) - - if (exp, damping_id) not in self._matpower_registrations: - self._matpower_registrations.add((exp, damping_id)) - - def register_cholesky(self, damping_func): - """Registers a Cholesky factor to be maintained and served on demand. - - This creates a variable and signals make_inverse_update_ops to make the - corresponding update op. The variable can be read via the method - get_cholesky. - - Args: - damping_func: A function that computes a 0-D Tensor or a float which will - be the damping value used. i.e. damping = damping_func(). - """ - damping_id = self._register_damping(damping_func) - - if damping_id not in self._cholesky_registrations: - self._cholesky_registrations.add(damping_id) - - def register_cholesky_inverse(self, damping_func): - """Registers an inverse Cholesky factor to be maintained/served on demand. - - This creates a variable and signals make_inverse_update_ops to make the - corresponding update op. The variable can be read via the method - get_cholesky_inverse. - - Args: - damping_func: A function that computes a 0-D Tensor or a float which will - be the damping value used. i.e. damping = damping_func(). - """ - damping_id = self._register_damping(damping_func) - - if damping_id not in self._cholesky_inverse_registrations: - self._cholesky_inverse_registrations.add(damping_id) - - def instantiate_inv_variables(self): - """Makes the internal "inverse" variable(s).""" - - for (exp, damping_id) in self._matpower_registrations: - exp_string = scalar_or_tensor_to_string(exp) - damping_func = self._damping_funcs_by_id[damping_id] - damping_string = graph_func_to_string(damping_func) - with variable_scope.variable_scope(self._var_scope): - matpower = variable_scope.get_variable( - "matpower_exp{}_damp{}".format(exp_string, damping_string), - initializer=inverse_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - assert (exp, damping_id) not in self._matpower_by_exp_and_damping - self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower - - for damping_id in self._cholesky_registrations: - damping_func = self._damping_funcs_by_id[damping_id] - damping_string = graph_func_to_string(damping_func) - with variable_scope.variable_scope(self._var_scope): - chol = variable_scope.get_variable( - "cholesky_damp{}".format(damping_string), - initializer=inverse_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - assert damping_id not in self._cholesky_by_damping - self._cholesky_by_damping[damping_id] = chol - - for damping_id in self._cholesky_inverse_registrations: - damping_func = self._damping_funcs_by_id[damping_id] - damping_string = graph_func_to_string(damping_func) - with variable_scope.variable_scope(self._var_scope): - cholinv = variable_scope.get_variable( - "cholesky_inverse_damp{}".format(damping_string), - initializer=inverse_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - assert damping_id not in self._cholesky_inverse_by_damping - self._cholesky_inverse_by_damping[damping_id] = cholinv - - def make_inverse_update_ops(self): - """Create and return update ops corresponding to registered computations.""" - ops = [] - - num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping - if exp == -1) - - num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses - - other_matrix_power_registered = num_other_matpower >= 1 - - use_eig = ( - self._eigendecomp or other_matrix_power_registered or - num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) - - # We precompute these so we don't need to evaluate them multiple times (for - # each matrix power that uses them) - damping_value_by_id = {damping_id: math_ops.cast( - self._damping_funcs_by_id[damping_id](), self._dtype) - for damping_id in self._damping_funcs_by_id} - - if use_eig: - eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence - - for (exp, damping_id), matpower in ( - self._matpower_by_exp_and_damping.items()): - damping = damping_value_by_id[damping_id] - ops.append( - matpower.assign( - math_ops.matmul(eigenvectors * - (eigenvalues + damping)**exp, - array_ops.transpose(eigenvectors)))) - # These ops share computation and should be run on a single device. - ops = [control_flow_ops.group(*ops)] - else: - for (exp, damping_id), matpower in ( - self._matpower_by_exp_and_damping.items()): - assert exp == -1 - damping = damping_value_by_id[damping_id] - ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping))) - - # TODO(b/77902055): If inverses are being computed with Cholesky's - # we can share the work. Instead this code currently just computes the - # Cholesky a second time. It does at least share work between requests for - # Cholesky's and Cholesky inverses with the same damping id. - for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items(): - cholesky_ops = [] - - damping = damping_value_by_id[damping_id] - cholesky_value = utils.cholesky(self.get_cov(), damping) - - if damping_id in self._cholesky_by_damping: - cholesky = self._cholesky_by_damping[damping_id] - cholesky_ops.append(cholesky.assign(cholesky_value)) - - identity = linalg_ops.eye(cholesky_value.shape.as_list()[0], - dtype=cholesky_value.dtype) - cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value, - identity) - cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value)) - - ops.append(control_flow_ops.group(*cholesky_ops)) - - for damping_id, cholesky in self._cholesky_by_damping.items(): - if damping_id not in self._cholesky_inverse_by_damping: - damping = damping_value_by_id[damping_id] - cholesky_value = utils.cholesky(self.get_cov(), damping) - ops.append(cholesky.assign(cholesky_value)) - - self._eigendecomp = False - return ops - - def get_inverse(self, damping_func): - # Just for backwards compatibility of some old code and tests - return self.get_matpower(-1, damping_func) - - def get_matpower(self, exp, damping_func): - # Note that this function returns a variable which gets updated by the - # inverse ops. It may be stale / inconsistent with the latest value of - # get_cov(). - if exp != 1: - damping_id = graph_func_to_id(damping_func) - matpower = self._matpower_by_exp_and_damping[(exp, damping_id)] - else: - matpower = self.get_cov() - identity = linalg_ops.eye(matpower.shape.as_list()[0], - dtype=matpower.dtype) - matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity - - assert matpower.shape.ndims == 2 - return lo.LinearOperatorFullMatrix(matpower, - is_non_singular=True, - is_self_adjoint=True, - is_positive_definite=True, - is_square=True) - - def get_cholesky(self, damping_func): - # Note that this function returns a variable which gets updated by the - # inverse ops. It may be stale / inconsistent with the latest value of - # get_cov(). - damping_id = graph_func_to_id(damping_func) - cholesky = self._cholesky_by_damping[damping_id] - assert cholesky.shape.ndims == 2 - return lo.LinearOperatorFullMatrix(cholesky, - is_non_singular=True, - is_square=True) - - def get_cholesky_inverse(self, damping_func): - # Note that this function returns a variable which gets updated by the - # inverse ops. It may be stale / inconsistent with the latest value of - # get_cov(). - damping_id = graph_func_to_id(damping_func) - cholesky_inv = self._cholesky_inverse_by_damping[damping_id] - assert cholesky_inv.shape.ndims == 2 - return lo.LinearOperatorFullMatrix(cholesky_inv, - is_non_singular=True, - is_square=True) - - def get_eigendecomp(self): - """Creates or retrieves eigendecomposition of self._cov.""" - # Unlike get_matpower this doesn't retrieve a stored variable, but instead - # always computes a fresh version from the current value of get_cov(). - if not self._eigendecomp: - eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov()) - - # The matrix self._cov is positive semidefinite by construction, but the - # numerical eigenvalues could be negative due to numerical errors, so here - # we clip them to be at least FLAGS.eigenvalue_clipping_threshold - clipped_eigenvalues = math_ops.maximum(eigenvalues, - EIGENVALUE_CLIPPING_THRESHOLD) - self._eigendecomp = (clipped_eigenvalues, eigenvectors) - - return self._eigendecomp - - -class FullFactor(DenseSquareMatrixFactor): - """FisherFactor for a full matrix representation of the Fisher of a parameter. - - Note that this uses the naive "square the sum estimator", and so is applicable - to any type of parameter in principle, but has very high variance. - """ - - def __init__(self, - params_grads, - batch_size): - self._batch_size = batch_size - self._params_grads = tuple(utils.ensure_sequence(params_grad) - for params_grad in params_grads) - super(FullFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_full_" + scope_string_from_params( - [self._params_grads, self._batch_size]) - - @property - def _cov_shape(self): - size = sum(param_grad.shape.num_elements() - for param_grad in self._params_grads[0]) - return (size, size) - - @property - def _num_sources(self): - return len(self._params_grads) - - @property - def _num_towers(self): - return 1 - - @property - def _dtype(self): - return self._params_grads[0][0].dtype - - def _compute_new_cov(self, source, tower): - assert tower == 0 - - # This will be a very basic rank 1 estimate - params_grads_flat = utils.tensors_to_column(self._params_grads[source]) - return ((params_grads_flat * array_ops.transpose( - params_grads_flat)) / math_ops.cast(self._batch_size, - params_grads_flat.dtype)) - - def _get_data_device(self, tower): - return None - - -class DiagonalFactor(FisherFactor): - """A base class for FisherFactors that use diagonal approximations. - - A DiagonalFactor's covariance variable can be of any shape, but must contain - exactly one entry per parameter. - """ - - def __init__(self): - super(DiagonalFactor, self).__init__() - - def get_cov_as_linear_operator(self): - assert self._matrix_diagonal.shape.ndims == 1 - return lo.LinearOperatorDiag(self._matrix_diagonal, - is_self_adjoint=True, - is_square=True) - - @property - def _cov_initializer(self): - return diagonal_covariance_initializer - - @property - def _matrix_diagonal(self): - return array_ops.reshape(self.get_cov(), [-1]) - - def make_inverse_update_ops(self): - return [] - - def instantiate_inv_variables(self): - pass - - def register_matpower(self, exp, damping_func): - pass - - def register_cholesky(self, damping_func): - pass - - def register_cholesky_inverse(self, damping_func): - pass - - def get_matpower(self, exp, damping_func): - matpower_diagonal = (self._matrix_diagonal - + math_ops.cast(damping_func(), self._dtype))**exp - return lo.LinearOperatorDiag(matpower_diagonal, - is_non_singular=True, - is_self_adjoint=True, - is_positive_definite=True, - is_square=True) - - def get_cholesky(self, damping_func): - return self.get_matpower(0.5, damping_func) - - def get_cholesky_inverse(self, damping_func): - return self.get_matpower(-0.5, damping_func) - - -class NaiveDiagonalFactor(DiagonalFactor): - """FisherFactor for a diagonal approximation of any type of param's Fisher. - - Note that this uses the naive "square the sum estimator", and so is applicable - to any type of parameter in principle, but has very high variance. - """ - - def __init__(self, - params_grads, - batch_size): - """Initializes NaiveDiagonalFactor instance. - - Args: - params_grads: Sequence of Tensors, each with same shape as parameters this - FisherFactor corresponds to. For example, the gradient of the loss with - respect to parameters. - batch_size: int or 0-D Tensor. Size - """ - self._params_grads = tuple(utils.ensure_sequence(params_grad) - for params_grad in params_grads) - self._batch_size = batch_size - super(NaiveDiagonalFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_naivediag_" + scope_string_from_params( - [self._params_grads, self._batch_size]) - - @property - def _cov_shape(self): - size = sum(param_grad.shape.num_elements() - for param_grad in self._params_grads[0]) - return [size, 1] - - @property - def _num_sources(self): - return len(self._params_grads) - - @property - def _num_towers(self): - return 1 - - @property - def _dtype(self): - return self._params_grads[0][0].dtype - - def _compute_new_cov(self, source, tower): - assert tower == 0 - - params_grads_flat = utils.tensors_to_column(self._params_grads[source]) - return (math_ops.square(params_grads_flat) / math_ops.cast( - self._batch_size, params_grads_flat.dtype)) - - def _get_data_device(self, tower): - return None - - -class EmbeddingInputKroneckerFactor(DiagonalFactor): - r"""FisherFactor for input to an embedding layer. - - Given input_ids = [batch_size, input_size] representing indices into an - [vocab_size, embedding_size] embedding matrix, approximate input covariance by - a diagonal matrix, - - Cov(input_ids, input_ids) = - (1/batch_size) sum_{i} diag(n_hot(input[i]) ** 2). - - where n_hot() constructs an n-hot binary vector and diag() constructs a - diagonal matrix of size [vocab_size, vocab_size]. - """ - - def __init__(self, input_ids, vocab_size, dtype=None): - """Instantiate EmbeddingInputKroneckerFactor. - - Args: - input_ids: List of Tensors of shape [batch_size, input_size] and dtype - int32. Indices into embedding matrix. List index is tower. - vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'. - dtype: dtype for covariance statistics. Must be a floating point type. - Defaults to float32. - """ - self._input_ids = input_ids - self._vocab_size = vocab_size - self._cov_dtype = dtype or dtypes.float32 - - super(EmbeddingInputKroneckerFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_diag_embedding_" + scope_string_from_params(self._input_ids) - - @property - def _cov_shape(self): - return [self._vocab_size] - - @property - def _num_sources(self): - return 1 - - @property - def _num_towers(self): - return len(self._input_ids) - - @property - def _dtype(self): - return self._cov_dtype - - def _compute_new_cov(self, source, tower): - assert source == 0 - - input_ids = self._input_ids[tower] - - if len(input_ids.shape) > 2: - raise ValueError( - "Input to embeddings must have rank <= 2. Found rank %d." % len( - input_ids.shape)) - - batch_size = array_ops.shape(input_ids)[0] - - # Transform indices into one-hot vectors. - # - # TODO(b/72714822): There must be a faster way to construct the diagonal - # covariance matrix! This operation is O(batch_size * vocab_size), where - # it should be O(batch_size * input_size). - flat_input_ids = array_ops.reshape(input_ids, [-1]) - one_hots = array_ops.one_hot(flat_input_ids, - self._vocab_size) # [?, vocab_size] - - # Take average across examples. Note that, because all entries have - # magnitude zero or one, there's no need to square the entries. - # - # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation - # within an example such as average. - # - # TODO(b/72714822): Support for partitioned embeddings. - new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size] - new_cov /= math_ops.cast(batch_size, new_cov.dtype) - - return new_cov - - def _get_data_device(self, tower): - return self._input_ids[tower].device - - -class FullyConnectedDiagonalFactor(DiagonalFactor): - r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher. - - Given in = [batch_size, input_size] and out_grad = [batch_size, output_size], - approximates the covariance as, - - Cov(in, out) = (1/batch_size) sum_{i} outer(in[i], out_grad[i]) ** 2.0 - - where the square is taken element-wise. - """ - - def __init__(self, - inputs, - outputs_grads, - has_bias=False): - """Instantiate FullyConnectedDiagonalFactor. - - Args: - inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this - layer. List index is towers. - outputs_grads: List of Tensors, each of shape [batch_size, output_size], - which are the gradients of the loss with respect to the layer's - outputs. First index is source, second is tower. - - has_bias: bool. If True, append '1' to each input. - """ - self._inputs = inputs - self._has_bias = has_bias - self._outputs_grads = outputs_grads - self._squared_inputs = None - - super(FullyConnectedDiagonalFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_diagfc_" + scope_string_from_params( - tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) - - @property - def _cov_shape(self): - input_size = self._inputs[0].shape[1] + self._has_bias - output_size = self._outputs_grads[0][0].shape[1] - return [input_size, output_size] - - @property - def _num_sources(self): - return len(self._outputs_grads) - - @property - def _num_towers(self): - return len(self._inputs) - - @property - def _dtype(self): - return self._outputs_grads[0][0].dtype - - def make_covariance_update_op(self, ema_decay): - - self._squared_inputs = [] - for tower in range(self._num_towers): - inputs = self._inputs[tower] - - with place_on_device(self._get_data_device(tower)): - if self._has_bias: - inputs = append_homog(inputs) - self._squared_inputs.append(math_ops.square(inputs)) - - return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op( - ema_decay) - - def _compute_new_cov(self, source, tower): - batch_size = array_ops.shape(self._squared_inputs[tower])[0] - outputs_grad = self._outputs_grads[source][tower] - - # The well-known special formula that uses the fact that the entry-wise - # square of an outer product is the outer-product of the entry-wise squares. - # The gradient is the outer product of the input and the output gradients, - # so we just square both and then take their outer-product. - new_cov = math_ops.matmul( - self._squared_inputs[tower], - math_ops.square(outputs_grad), - transpose_a=True) - new_cov /= math_ops.cast(batch_size, new_cov.dtype) - return new_cov - - def _get_data_device(self, tower): - return self._inputs[tower].device - - -class ConvDiagonalFactor(DiagonalFactor): - """FisherFactor for a diagonal approx of a convolutional layer's Fisher.""" - - def __init__(self, - inputs, - outputs_grads, - filter_shape, - strides, - padding, - data_format=None, - dilations=None, - has_bias=False): - """Creates a ConvDiagonalFactor object. - - Args: - inputs: List of Tensors of shape [batch_size, height, width, in_channels]. - Input activations to this layer. List index is towers. - outputs_grads: List of Tensors, each of shape [batch_size, - height, width, out_channels], which are the gradients of the loss - with respect to the layer's outputs. First index is source, second - index is tower. - filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels, - out_channels). Represents shape of kernel used in this layer. - strides: The stride size in this layer (1-D Tensor of length 4). - padding: The padding in this layer (1-D of Tensor length 4). - data_format: None or str. Format of conv2d inputs. - dilations: None or tuple of 4 ints. - has_bias: Python bool. If True, the layer is assumed to have a bias - parameter in addition to its filter parameter. - - Raises: - ValueError: If inputs, output_grads, and filter_shape do not agree on - in_channels or out_channels. - ValueError: If strides, dilations are not length-4 lists of ints. - ValueError: If data_format does not put channel last. - """ - if not utils.is_data_format_channel_last(data_format): - raise ValueError("Channel must be last.") - if any(input_.shape.ndims != 4 for input_ in inputs): - raise ValueError("inputs must be a list of 4-D Tensors.") - if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs): - raise ValueError("inputs and filter_shape must agree on in_channels.") - for i, outputs_grad in enumerate(outputs_grads): - if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad): - raise ValueError("outputs[%d] must be 4-D Tensor." % i) - if any(output_grad.shape.as_list()[-1] != filter_shape[-1] - for output_grad in outputs_grad): - raise ValueError( - "outputs[%d] and filter_shape must agree on out_channels." % i) - if len(strides) != 4: - raise ValueError("strides must be length-4 list of ints.") - if dilations is not None and len(dilations) != 4: - raise ValueError("dilations must be length-4 list of ints.") - - self._inputs = inputs - self._outputs_grads = outputs_grads - self._filter_shape = filter_shape - self._strides = strides - self._padding = padding - self._data_format = data_format - self._dilations = dilations - self._has_bias = has_bias - self._patches = None - - super(ConvDiagonalFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_convdiag_" + scope_string_from_params( - tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) - - @property - def _cov_shape(self): - filter_height, filter_width, in_channels, out_channels = self._filter_shape - return [ - filter_height * filter_width * in_channels + self._has_bias, - out_channels - ] - - @property - def _num_sources(self): - return len(self._outputs_grads) - - @property - def _num_towers(self): - return len(self._inputs) - - @property - def _dtype(self): - return self._inputs[0].dtype - - def make_covariance_update_op(self, ema_decay): - filter_height, filter_width, _, _ = self._filter_shape - - # TODO(b/64144716): there is potential here for a big savings in terms - # of memory use. - if self._dilations is None: - rates = (1, 1, 1, 1) - else: - rates = tuple(self._dilations) - - self._patches = [] - for tower in range(self._num_towers): - with place_on_device(self._get_data_device(tower)): - patches = array_ops.extract_image_patches( - self._inputs[tower], - ksizes=[1, filter_height, filter_width, 1], - strides=self._strides, - rates=rates, - padding=self._padding) - - if self._has_bias: - patches = append_homog(patches) - - self._patches.append(patches) - - return super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay) - - def _compute_new_cov(self, source, tower): - patches = self._patches[tower] - batch_size = array_ops.shape(patches)[0] - outputs_grad = self._outputs_grads[source][tower] - - new_cov = self._convdiag_sum_of_squares(patches, outputs_grad) - new_cov /= math_ops.cast(batch_size, new_cov.dtype) - - return new_cov - - def _convdiag_sum_of_squares(self, patches, outputs_grad): - # This computes the sum of the squares of the per-training-case "gradients". - # It does this simply by computing a giant tensor containing all of these, - # doing an entry-wise square, and them summing along the batch dimension. - case_wise_gradients = special_math_ops.einsum("bijk,bijl->bkl", patches, - outputs_grad) - return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0) - - def _get_data_device(self, tower): - return self._inputs[tower].device - - -class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor): - """Kronecker factor for the input or output side of a fully-connected layer. - """ - - def __init__(self, - tensors, - has_bias=False): - """Instantiate FullyConnectedKroneckerFactor. - - Args: - tensors: List of list of Tensors, each of shape [batch_size, n]. The - Tensors are typically either a layer's inputs or its output's gradients. - The first list index is source, the second is tower. - has_bias: bool. If True, append '1' to each row. - """ - # The tensor argument is either a tensor of input activations or a tensor of - # output pre-activation gradients. - self._has_bias = has_bias - self._tensors = tensors - super(FullyConnectedKroneckerFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_fckron_" + scope_string_from_params( - tuple(nest.flatten(self._tensors)) + (self._has_bias,)) - - @property - def _cov_shape(self): - size = self._tensors[0][0].shape[1] + self._has_bias - return [size, size] - - @property - def _num_sources(self): - return len(self._tensors) - - @property - def _num_towers(self): - return len(self._tensors[0]) - - @property - def _dtype(self): - return self._tensors[0][0].dtype - - def _compute_new_cov(self, source, tower): - tensor = self._tensors[source][tower] - if self._has_bias: - tensor = append_homog(tensor) - return compute_cov(tensor) - - def _get_data_device(self, tower): - return self._tensors[0][tower].device - - -class ConvInputKroneckerFactor(DenseSquareMatrixFactor): - r"""Kronecker factor for the input side of a convolutional layer. - - Estimates E[ a a^T ] where a is the inputs to a convolutional layer given - example x. Expectation is taken over all examples and locations. - - Equivalent to Omega in https://arxiv.org/abs/1602.01407 for details. See - Section 3.1 Estimating the factors. - """ - - def __init__(self, - inputs, - filter_shape, - padding, - strides=None, - dilation_rate=None, - data_format=None, - extract_patches_fn=None, - has_bias=False, - sub_sample_inputs=None, - sub_sample_patches=None): - """Initializes ConvInputKroneckerFactor. - - Args: - inputs: List of Tensors of shape [batch_size, ..spatial_input_size.., - in_channels]. Inputs to layer. List index is tower. - filter_shape: List of ints. Contains [..spatial_filter_size.., - in_channels, out_channels]. Shape of convolution kernel. - padding: str. Padding method for layer. "SAME" or "VALID". - strides: List of ints or None. Contains [..spatial_filter_strides..] if - 'extract_patches_fn' is compatible with tf.nn.convolution(), else - [1, ..spatial_filter_strides, 1]. - dilation_rate: List of ints or None. Rate for dilation along each spatial - dimension if 'extract_patches_fn' is compatible with - tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1]. - data_format: str or None. Format of input data. - extract_patches_fn: str or None. Name of function that extracts image - patches. One of "extract_convolution_patches", "extract_image_patches", - "extract_pointwise_conv2d_patches". - has_bias: bool. If True, append 1 to in_channel. - sub_sample_inputs: `bool`. If True, then subsample the inputs from which - the image patches are extracted. (Default: None) - sub_sample_patches: `bool`, If `True` then subsample the extracted - patches.(Default: None) - """ - self._inputs = inputs - self._filter_shape = filter_shape - self._strides = strides - self._padding = padding - self._dilation_rate = dilation_rate - self._data_format = data_format - self._extract_patches_fn = extract_patches_fn - self._has_bias = has_bias - if sub_sample_inputs is None: - self._sub_sample_inputs = _SUB_SAMPLE_INPUTS - else: - self._sub_sample_inputs = sub_sample_inputs - - if sub_sample_patches is None: - self._sub_sample_patches = _SUB_SAMPLE_OUTER_PRODUCTS - else: - self._sub_sample_patches = sub_sample_patches - super(ConvInputKroneckerFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_convinkron_" + scope_string_from_params( - tuple(self._inputs) + - tuple((self._filter_shape, self._strides, self._padding, - self._dilation_rate, self._data_format, self._has_bias))) - - @property - def _cov_shape(self): - spatial_filter_shape = self._filter_shape[0:-2] - in_channels = self._filter_shape[-2] - size = np.prod(spatial_filter_shape) * in_channels + self._has_bias - return [size, size] - - @property - def _num_sources(self): - return 1 - - @property - def _num_towers(self): - return len(self._inputs) - - @property - def _dtype(self): - return self._inputs[0].dtype - - def _compute_new_cov(self, source, tower): - assert source == 0 - - inputs = self._inputs[tower] - if self._sub_sample_inputs: - batch_size = inputs.shape.as_list()[0] - max_size = int(batch_size * _INPUTS_TO_EXTRACT_PATCHES_FACTOR) - inputs = _random_tensor_gather(inputs, max_size) - - # TODO(b/64144716): there is potential here for a big savings in terms of - # memory use. - if self._extract_patches_fn in [None, "extract_convolution_patches"]: - patches = utils.extract_convolution_patches( - inputs, - self._filter_shape, - padding=self._padding, - strides=self._strides, - dilation_rate=self._dilation_rate, - data_format=self._data_format) - - elif self._extract_patches_fn == "extract_image_patches": - assert inputs.shape.ndims == 4 - assert len(self._filter_shape) == 4 - assert len(self._strides) == 4, self._strides - if self._dilation_rate is None: - rates = [1, 1, 1, 1] - else: - rates = self._dilation_rate - assert len(rates) == 4 - assert rates[0] == rates[-1] == 1 - patches = array_ops.extract_image_patches( - inputs, - ksizes=[1] + list(self._filter_shape[0:-2]) + [1], - strides=self._strides, - rates=rates, - padding=self._padding) - - elif self._extract_patches_fn == "extract_pointwise_conv2d_patches": - assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)] - assert self._filter_shape[0] == self._filter_shape[1] == 1 - patches = utils.extract_pointwise_conv2d_patches( - inputs, self._filter_shape, data_format=None) - - else: - raise NotImplementedError(self._extract_patches_fn) - - flatten_size = np.prod(self._filter_shape[0:-1]) - # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde - # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14), - # where M = minibatch size, |T| = number of spatial locations, - # |Delta| = number of spatial offsets, and J = number of input maps - # for convolutional layer l. - patches_flat = array_ops.reshape(patches, [-1, flatten_size]) - - # We append a homogenous coordinate to patches_flat if the layer has - # bias parameters. This gives us [[A_l]]_H from the paper. - if self._sub_sample_patches: - patches_flat = _subsample_for_cov_computation(patches_flat) - - if self._has_bias: - patches_flat = append_homog(patches_flat) - # We call compute_cov without passing in a normalizer. compute_cov uses - # the first dimension of patches_flat i.e. M|T| as the normalizer by - # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with - # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from - # the paper but has a different scale here for consistency with - # ConvOutputKroneckerFactor. - # (Tilde omitted over A for clarity.) - return compute_cov(patches_flat) - - def _get_data_device(self, tower): - return self._inputs[tower].device - - -class ConvOutputKroneckerFactor(DenseSquareMatrixFactor): - r"""Kronecker factor for the output side of a convolutional layer. - - Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer - given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over - all examples and locations. - - Equivalent to Gamma in https://arxiv.org/abs/1602.01407 for details. See - Section 3.1 Estimating the factors. - """ - - def __init__(self, outputs_grads, data_format=None): - """Initializes ConvOutputKroneckerFactor. - - Args: - outputs_grads: List of list of Tensors. Each Tensor is of shape - [batch_size, ..spatial_input_size.., out_channels]. First list index - is source, the second is tower. - data_format: None or str. Format of outputs_grads. - - Raises: - ValueError: If channels are not final dimension. - """ - if not utils.is_data_format_channel_last(data_format): - raise ValueError("Channel must be last.") - self._out_channels = outputs_grads[0][0].shape.as_list()[-1] - self._outputs_grads = outputs_grads - super(ConvOutputKroneckerFactor, self).__init__() - - @property - def _var_scope(self): - return "ff_convoutkron_" + scope_string_from_params( - nest.flatten(self._outputs_grads)) - - @property - def _cov_shape(self): - size = self._out_channels - return [size, size] - - @property - def _num_sources(self): - return len(self._outputs_grads) - - @property - def _num_towers(self): - return len(self._outputs_grads[0]) - - @property - def _dtype(self): - return self._outputs_grads[0][0].dtype - - def _compute_new_cov(self, source, tower): - outputs_grad = self._outputs_grads[source][tower] - - # reshaped_tensor below is the matrix DS_l defined in the KFC paper - # (tilde omitted over S for clarity). It has shape M|T| x I, where - # M = minibatch size, |T| = number of spatial locations, and - # I = number of output maps for convolutional layer l. - reshaped_tensor = array_ops.reshape(outputs_grad, [-1, self._out_channels]) - # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov, - # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l - # as defined in the paper, with shape I x I. - # (Tilde omitted over S for clarity.) - return compute_cov(reshaped_tensor) - - def _get_data_device(self, tower): - return self._outputs_grads[0][tower].device - - -class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): - """Kronecker factor for a fully connected layer used multiple times.""" - - def __init__(self, - tensors, - num_uses=None, - has_bias=False): - """Constructs a new `FullyConnectedMultiKF`. - - Args: - tensors: List of list of Tensors of shape, each of shape - [num_uses * batch_size, n], and is a reshape version of a Tensor of - shape [num_uses, batch_size, n]. Each of these tensors is usually a - layer's inputs or its output's gradients. The first list index is - sources, the second is towers. - num_uses: int. The number of time-steps / uses. - has_bias: bool. If True, '1' is appended to each row. - """ - - self._num_uses = num_uses - - self._cov_dt1 = None - self._make_cov_dt1 = False - self._option1quants_by_damping = {} - self._option2quants_by_damping = {} - self._option1quants_registrations = set() - self._option2quants_registrations = set() - - super(FullyConnectedMultiKF, self).__init__(tensors=tensors, - has_bias=has_bias) - - @property - def _num_timesteps(self): - return self._num_uses - - @property - def _var_scope(self): - return "ff_fc_multi_" + scope_string_from_params( - tuple(nest.flatten(self._tensors)) - + (self._num_timesteps, self._has_bias,)) - - def make_covariance_update_op(self, ema_decay): - - op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay) - - if self._cov_dt1 is not None: - new_cov_dt1_contribs = [] - for source in range(self._num_sources): - for tower in range(self._num_towers): - with place_on_device(self._get_data_device(tower)): - new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source, - tower)) - - new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs) - / float(self._num_towers)) - - # See comments in FisherFactor.make_covariance_update_op() for details. - if utils.on_tpu(): - new_cov_dt1 = utils.cross_replica_mean(new_cov_dt1) - - op2 = moving_averages.assign_moving_average( - self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) - - # TODO(b/69112164): - # It's important that _cov and _cov_dt1 remain consistent with each - # other while the inverse ops are happening. How can we ensure this? - # We will need to add explicit synchronization for this to - # work with asynchronous training. - op = control_flow_ops.group(op, op2) - - return op - - def _compute_new_cov_dt1(self, source, tower): # pylint: disable=missing-docstring - tensor = self._tensors[source][tower] - if self._has_bias: - # This appending is technically done twice (the other time is for - # _compute_new_cov()) - tensor = append_homog(tensor) - - total_len = array_ops.shape(tensor)[0] - batch_size = total_len // self._num_timesteps - - tensor_present = tensor[:-batch_size, :] - tensor_future = tensor[batch_size:, :] - - # We specify a normalizer for this computation to ensure a PSD Fisher - # block estimate. This is equivalent to padding with zeros, as was done - # in Section B.2 of the appendix. - return compute_cov( - tensor_future, tensor_right=tensor_present, normalizer=total_len) - - def _get_data_device(self, tower): - return self._tensors[0][tower].device - - @property - def _vec_shape(self): - size = self._tensors[0][0].shape[1] + self._has_bias - return [size] - - def get_option1quants(self, damping_func): - damping_id = graph_func_to_id(damping_func) - return self._option1quants_by_damping[damping_id] - - def get_option2quants(self, damping_func): - damping_id = graph_func_to_id(damping_func) - return self._option2quants_by_damping[damping_id] - - def get_cov_dt1(self): - assert self._cov_dt1 is not None - return self._cov_dt1 - - def register_cov_dt1(self): - self._make_cov_dt1 = True - - def instantiate_cov_variables(self): - super(FullyConnectedMultiKF, self).instantiate_cov_variables() - assert self._cov_dt1 is None - if self._make_cov_dt1: - with variable_scope.variable_scope(self._var_scope): - self._cov_dt1 = variable_scope.get_variable( - "cov_dt1", - initializer=init_ops.zeros_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - - def register_option1quants(self, damping_func): - damping_id = self._register_damping(damping_func) - if damping_id not in self._option1quants_registrations: - self._option1quants_registrations.add(damping_id) - - def register_option2quants(self, damping_func): - damping_id = self._register_damping(damping_func) - if damping_id not in self._option2quants_registrations: - self._option2quants_registrations.add(damping_id) - - def instantiate_inv_variables(self): - super(FullyConnectedMultiKF, self).instantiate_inv_variables() - - for damping_id in self._option1quants_registrations: - damping_func = self._damping_funcs_by_id[damping_id] - damping_string = graph_func_to_string(damping_func) - # It's questionable as to whether we should initialize with stuff like - # this at all. Ideally these values should never be used until they are - # updated at least once. - with variable_scope.variable_scope(self._var_scope): - Lmat = variable_scope.get_variable( # pylint: disable=invalid-name - "Lmat_damp{}".format(damping_string), - initializer=inverse_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - psi = variable_scope.get_variable( - "psi_damp{}".format(damping_string), - initializer=init_ops.ones_initializer, - shape=self._vec_shape, - trainable=False, - dtype=self._dtype) - - assert damping_id not in self._option1quants_by_damping - self._option1quants_by_damping[damping_id] = (Lmat, psi) - - for damping_id in self._option2quants_registrations: - damping_func = self._damping_funcs_by_id[damping_id] - damping_string = graph_func_to_string(damping_func) - # It's questionable as to whether we should initialize with stuff like - # this at all. Ideally these values should never be used until they are - # updated at least once. - with variable_scope.variable_scope(self._var_scope): - Pmat = variable_scope.get_variable( # pylint: disable=invalid-name - "Lmat_damp{}".format(damping_string), - initializer=inverse_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - Kmat = variable_scope.get_variable( # pylint: disable=invalid-name - "Kmat_damp{}".format(damping_string), - initializer=inverse_initializer, - shape=self._cov_shape, - trainable=False, - dtype=self._dtype) - mu = variable_scope.get_variable( - "mu_damp{}".format(damping_string), - initializer=init_ops.ones_initializer, - shape=self._vec_shape, - trainable=False, - dtype=self._dtype) - - assert damping_id not in self._option2quants_by_damping - self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu) - - def make_inverse_update_ops(self): - """Create and return update ops corresponding to registered computations.""" - # TODO(b/69918258): Add correctness tests for this method. - # pylint: disable=invalid-name - - ops = [] - - if (len(self._option1quants_by_damping) + - len(self._option2quants_by_damping)): - - # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from - # the pseudo-code in the original paper. Because the computations for - # the A and G case are essentially the same they can both be performed by - # the same class (this one). - - C1 = self.get_cov_dt1() - - # Get the eigendecomposition of C0 (= self.get_cov()) - eigen_e, eigen_V = self.get_eigendecomp() - - # TODO(b/69678661): Note, there is an implicit assumption here that C1 - # and C0 (as represented here by its eigen-decomp) are consistent. This - # could fail to be the case if self._cov and self._cov_dt1 are not updated - # consistently, or are somehow read between or during the cov updates. - # Can this possibly happen? Is there a way to prevent it? - - for damping_id, (Lmat_var, - psi_var) in self._option1quants_by_damping.items(): - - damping = self._damping_funcs_by_id[damping_id]() - damping = math_ops.cast(damping, self._dtype) - - invsqrtC0 = math_ops.matmul( - eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) - - # Might need to enforce symmetry lost due to numerical issues. - invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 - - # The following line imposes the symmetry assumed by "Option 1" on C1. - # Strangely the code can work okay with this line commented out, - # depending on how psd_eig is defined. I'm not sure why. - C1 = (C1 + array_ops.transpose(C1)) / 2.0 - - # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) - hPsi = math_ops.matmul(math_ops.matmul(invsqrtC0, C1), invsqrtC0) - - # Compute the decomposition U*diag(psi)*U^T = hPsi - psi, U = utils.posdef_eig(hPsi) - - # L = C0^(-1/2) * U - Lmat = math_ops.matmul(invsqrtC0, U) - - ops.append(Lmat_var.assign(Lmat)) - ops.append(psi_var.assign(psi)) - - for damping_id, (Pmat_var, Kmat_var, - mu_var) in self._option2quants_by_damping.items(): - - damping = self._damping_funcs_by_id[damping_id]() - damping = math_ops.cast(damping, self._dtype) - - # compute C0^(-1/2) - invsqrtC0 = math_ops.matmul( - eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) - - # Might need to enforce symmetry lost due to numerical issues. - invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 - - # Compute the product C0^(-1/2) * C1 - invsqrtC0C1 = math_ops.matmul(invsqrtC0, C1) - - # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) - hPsi = math_ops.matmul(invsqrtC0C1, invsqrtC0) - - # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi - # Note that we using the notation mu instead of "m" for the eigenvalues. - # Instead of computing the product hPsi^T * hPsi and then doing an - # eigen-decomposition of this we just compute the SVD of hPsi and then - # square the singular values to get the eigenvalues. For a justification - # of this approach, see: - # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition - sqrtmu, _, E = linalg_ops.svd(hPsi) - mu = math_ops.square(sqrtmu) - - # Mathematically, the eigenvalues should not should not exceed 1.0, but - # due to numerical issues, or possible issues with inconsistent - # values of C1 and (the eigen-decomposition of) C0 they might. So - # we enforce this condition. - mu = math_ops.minimum(mu, 1.0) - - # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1) - Pmat = math_ops.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True) - - # K = C_0^(-1/2) * E - Kmat = math_ops.matmul(invsqrtC0, E) - - ops.append(Pmat_var.assign(Pmat)) - ops.append(Kmat_var.assign(Kmat)) - ops.append(mu_var.assign(mu)) - - ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops() - return [control_flow_ops.group(*ops)] - - # pylint: enable=invalid-name diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py deleted file mode 100644 index 2d8e378a932c16d48360bc4b15ff4f3239c0ed1f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""FisherFactor definitions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.fisher_factors import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - "inverse_initializer", "covariance_initializer", - "diagonal_covariance_initializer", "scope_string_from_params", - "scope_string_from_name", "scalar_or_tensor_to_string", "FisherFactor", - "InverseProvidingFactor", "FullFactor", "DiagonalFactor", - "NaiveDiagonalFactor", "EmbeddingInputKroneckerFactor", - "FullyConnectedDiagonalFactor", "FullyConnectedKroneckerFactor", - "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor", - "ConvDiagonalFactor", "set_global_constants", "maybe_colocate_with", - "compute_cov", "append_homog" -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py deleted file mode 100644 index 43aa713edcbc4f55ba76385c962c7ceb77fd83c8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ /dev/null @@ -1,1269 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Registry for layers and their parameters/variables. - -This represents the collection of all layers in the approximate Fisher -information matrix to which a particular FisherBlock may belong. That is, we -might have several layer collections for one TF graph (if we have multiple K-FAC -optimizers being used, for example.) -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from collections import defaultdict -from collections import OrderedDict -from contextlib import contextmanager -from functools import partial -import warnings - -import math -import six - -from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb -from tensorflow.contrib.kfac.python.ops import loss_functions as lf -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.util import nest - -# Names for various approximations that can be requested for Fisher blocks. -APPROX_KRONECKER_NAME = "kron" -APPROX_DIAGONAL_NAME = "diagonal" -APPROX_FULL_NAME = "full" - -_GENERIC_APPROX_TO_BLOCK_TYPES = { - APPROX_FULL_NAME: fb.FullFB, - APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB, -} - -_FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES = { - APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, - APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, -} - -_CONV2D_APPROX_TO_BLOCK_TYPES = { - APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, - APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, -} - -_EMBEDDING_APPROX_TO_BLOCK_TYPES = { - APPROX_KRONECKER_NAME: fb.EmbeddingKFACFB -} - -APPROX_KRONECKER_INDEP_NAME = "kron_indep" -APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1" -APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2" - -_FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { - APPROX_KRONECKER_INDEP_NAME: fb.FullyConnectedMultiIndepFB, - APPROX_KRONECKER_SERIES_1_NAME: partial(fb.FullyConnectedSeriesFB, - option=1), - APPROX_KRONECKER_SERIES_2_NAME: partial(fb.FullyConnectedSeriesFB, - option=2) -} - -_CONV2D_MULTI_APPROX_TO_BLOCK_TYPES = { - APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB -} - -_EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = { - APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB -} - -# Possible value for `reuse` keyword argument. Sets `reuse` to -# tf.get_variable_scope().reuse. -VARIABLE_SCOPE = "VARIABLE_SCOPE" - -_DEFAULT_LAYER_COLLECTION = None - - -def get_default_layer_collection(): - """Get default LayerCollection.""" - if _DEFAULT_LAYER_COLLECTION is None: - raise ValueError( - "Attempted to retrieve default LayerCollection when none is set. Use " - "LayerCollection.as_default().") - - return _DEFAULT_LAYER_COLLECTION - - -def set_default_layer_collection(layer_collection): - global _DEFAULT_LAYER_COLLECTION - - if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None: - raise ValueError("Default LayerCollection is already set.") - - _DEFAULT_LAYER_COLLECTION = layer_collection - - -class LayerParametersDict(OrderedDict): - """An OrderedDict where keys are Tensors or tuples of Tensors. - - Ensures that no Tensor is associated with two different keys. - """ - - def __init__(self, *args, **kwargs): - self._tensors = set() - super(LayerParametersDict, self).__init__(*args, **kwargs) - - def __setitem__(self, key, value): - key = self._canonicalize_key(key) - tensors = key if isinstance(key, (tuple, list)) else (key,) - key_collisions = self._tensors.intersection(tensors) - if key_collisions: - raise ValueError("Key(s) already present: {}".format(key_collisions)) - self._tensors.update(tensors) - super(LayerParametersDict, self).__setitem__(key, value) - - def __delitem__(self, key): - key = self._canonicalize_key(key) - self._tensors.remove(key) - super(LayerParametersDict, self).__delitem__(key) - - def __getitem__(self, key): - key = self._canonicalize_key(key) - return super(LayerParametersDict, self).__getitem__(key) - - def __contains__(self, key): - key = self._canonicalize_key(key) - return super(LayerParametersDict, self).__contains__(key) - - def _canonicalize_key(self, key): - if isinstance(key, (list, tuple)): - return tuple(key) - return key - - -# TODO(b/68034464): add capability for LayerCollection to be "finalized" -# and do this when it gets used by FisherEstimator / KfacOptimizer. - - -class LayerCollection(object): - """Registry of information about layers and losses. - - Note that you need to create a new one of these for each MatrixEstimator or - KfacOptimizer. - - Attributes: - fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer - parameters (Tensors or tuples of Tensors) to FisherBlock instances. - fisher_factors: an OrderedDict mapping tuples to FisherFactor instances. - losses: a list of LossFunction objects. The loss to be optimized is their - sum. - loss_colocation_ops: ops to colocate loss function evaluations with. These - will typically be the inputs to the losses. - """ - - def __init__(self, - graph=None, - name="LayerCollection"): - warnings.warn( - "tf.contrib.kfac is deprecated and will be removed by 2018-11-01. " - "Use https://pypi.python.org/pypi/kfac instead.") - self.fisher_blocks = LayerParametersDict() - self.fisher_factors = OrderedDict() - self._linked_parameters = dict( - ) # dict mapping sets of variables to optionally specified approximations. - self._graph = graph or ops.get_default_graph() - self._loss_dict = {} # {str: LossFunction} - self._subgraph = None - self._default_generic_approximation = APPROX_DIAGONAL_NAME - self._default_embedding_approximation = APPROX_KRONECKER_NAME - self._default_fully_connected_approximation = APPROX_KRONECKER_NAME - self._default_conv2d_approximation = APPROX_KRONECKER_NAME - self._default_fully_connected_multi_approximation = ( - APPROX_KRONECKER_INDEP_NAME) - self._default_conv2d_multi_approximation = ( - APPROX_KRONECKER_INDEP_NAME) - self._default_embedding_multi_approximation = APPROX_KRONECKER_INDEP_NAME - self.loss_colocation_ops = {} - self._vars_to_uses = defaultdict(lambda: 0) - - with variable_scope.variable_scope(None, default_name=name) as scope: - self._var_scope = scope.name - - @property - def losses(self): - """Tuple of LossFunction objects registered with this LayerCollection.""" - return nest.flatten(self.towers_by_loss) - - @property - def towers_by_loss(self): - """Tuple across losses of LossFunction objects registered to each tower.""" - return tuple(tuple(lst) for lst in self._loss_dict.values()) - - @property - def registered_variables(self): - """A tuple of all of the variables currently registered.""" - tuple_of_tuples = (utils.ensure_sequence(key) for key, block - in six.iteritems(self.fisher_blocks)) - flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_) - return flat_tuple - - @property - def linked_parameters(self): - """Groups of parameters with an optionally specified approximation. - - Linked parameters can be added using `define_linked_parameters`. - If an approximation is specified, then this approximation will be used - when registering a layer with exactly these parameters, unless an - approximation is specified when calling the registration function. - - Returns: - A `dict` mapping tuples of parameters to an optional string. - """ - return self._linked_parameters - - @property - def default_embedding_approximation(self): - return self._default_embedding_approximation - - def set_default_embedding_approximation(self, value): - if value != APPROX_KRONECKER_NAME: - raise ValueError( - "{} is not a valid approximation for embedding variables.".format( - value)) - self._default_embedding_approximation = value - - @property - def default_generic_approximation(self): - return self._default_generic_approximation - - def set_default_generic_approximation(self, value): - if value not in _GENERIC_APPROX_TO_BLOCK_TYPES: - raise ValueError( - "{} is not a valid approximation for generic variables.".format( - value)) - self._default_generic_approximation = value - - @property - def default_fully_connected_approximation(self): - return self._default_fully_connected_approximation - - def set_default_fully_connected_approximation(self, value): - if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: - raise ValueError( - "{} is not a valid approximation for fully connected layers.".format( - value)) - self._default_fully_connected_approximation = value - - @property - def default_conv2d_approximation(self): - return self._default_conv2d_approximation - - def set_default_conv2d_approximation(self, value): - if value not in _CONV2D_APPROX_TO_BLOCK_TYPES: - raise ValueError( - "{} is not a valid approximation for 2d convolutional layers.".format( - value)) - self._default_conv2d_approximation = value - - @property - def default_fully_connected_multi_approximation(self): - return self._default_fully_connected_multi_approximation - - def set_default_fully_connected_multi_approximation(self, value): - if value not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: - raise ValueError("{} is not a valid approximation for a fully-connected " - "multi layer.".format(value)) - self._default_fully_connected_multi_approximation = value - - @property - def default_conv2d_multi_approximation(self): - return self._default_conv2d_multi_approximation - - @property - def default_embedding_multi_approximation(self): - return self._default_embedding_multi_approximation - - def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): - """Validates and registers the layer_key associated with the fisher_block. - - Args: - layer_key: A variable or tuple of variables. The key to check for in - existing registrations and to register if valid. - fisher_block: The associated `FisherBlock`. - reuse: Method to use for inserting new `FisherBlock's. One of True, False, - or `VARIABLE_SCOPE`. - - Raises: - ValueError: If `layer_key` was already registered and reuse is `False`, - if `layer_key` was registered with a different block type, or if - `layer_key` shares any variables with but is not equal to a previously - registered key. - KeyError: If `reuse` is `True` but `layer_key` was not previously - registered. - - Returns: - The `FisherBlock` registered under `layer_key`. If `layer_key` was already - registered, this will be the previously registered `FisherBlock`. - """ - if reuse is VARIABLE_SCOPE: - reuse = variable_scope.get_variable_scope().reuse - - if reuse is True or (reuse is variable_scope.AUTO_REUSE and - layer_key in self.fisher_blocks): - result = self.fisher_blocks[layer_key] - if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck - raise ValueError( - "Attempted to register FisherBlock of type %s when existing " - "FisherBlock has type %s." % (type(fisher_block), type(result))) - return result - if reuse is False and layer_key in self.fisher_blocks: - raise ValueError("FisherBlock for %s is already in LayerCollection." % - (layer_key,)) - - # Insert fisher_block into self.fisher_blocks. - if layer_key in self.fisher_blocks: - raise ValueError("Duplicate registration: {}".format(layer_key)) - # Raise an error if any variable in layer_key has been registered in any - # other blocks. - variable_to_block = { - var: (params, block) - for (params, block) in self.fisher_blocks.items() - for var in utils.ensure_sequence(params) - } - for variable in utils.ensure_sequence(layer_key): - if variable in variable_to_block: - prev_key, prev_block = variable_to_block[variable] - raise ValueError( - "Attempted to register layer_key {} with block {}, but variable {}" - " was already registered in key {} with block {}.".format( - layer_key, fisher_block, variable, prev_key, prev_block)) - self.fisher_blocks[layer_key] = fisher_block - return fisher_block - - def register_loss_function(self, - loss, - colocation_op, - base_name, - name=None, - reuse=VARIABLE_SCOPE): - """Registers a LossFunction object. - - Args: - loss: The LossFunction object. - colocation_op: The op to colocate the loss function's computations with. - base_name: The name to derive a new unique name from is the name argument - is None. - name: (OPTIONAL) str or None. Unique name for this loss function. If None, - a new name is generated. (Default: None) - reuse: (OPTIONAL) bool or str. If True, adds `loss` as an additional - tower for the existing loss function. - - Raises: - ValueError: If reuse == True and name == None. - ValueError: If reuse == True and seed != None. - KeyError: If reuse == True and no existing LossFunction with `name` found. - KeyError: If reuse == False and existing LossFunction with `name` found. - """ - - name = name or self._graph.unique_name(base_name) - - if reuse == VARIABLE_SCOPE: - reuse = variable_scope.get_variable_scope().reuse - - if reuse: - if name is None: - raise ValueError( - "If reuse is enabled, loss function's name must be set.") - - loss_list = self._loss_dict.get(name, None) - - if loss_list is None: - raise KeyError( - "Unable to find loss function named {}. Register a new loss " - "function with reuse=False.".format(name)) - else: - if name in self._loss_dict: - raise KeyError( - "Loss function named {} already exists. Set reuse=True to append " - "another tower.".format(name)) - - loss_list = [] - self._loss_dict[name] = loss_list - - loss_list.append(loss) - self.loss_colocation_ops[loss] = colocation_op - - def _get_use_count_map(self): - """Returns a dict mapping variables to their number of registrations.""" - return self._vars_to_uses - - def _add_uses(self, params, uses): - """Register additional uses by params in the graph. - - Args: - params: Variable or tuple of Variables. Parameters for a layer. - uses: int or float. Number of additional uses for these parameters. - """ - params = params if isinstance(params, (tuple, list)) else (params,) - for var in params: - self._vars_to_uses[var] += uses - - def check_registration(self, variables): - """Checks that all variable uses have been registered properly. - - Args: - variables: List of variables. - - Raises: - ValueError: If any registered variables are not included in the list. - ValueError: If any variable in the list is not registered. - ValueError: If any variable in the list is registered with the wrong - number of "uses" in the subgraph recorded (vs the number of times that - variable is actually used in the subgraph). - """ - # Note that overlapping parameters (i.e. those that share variables) will - # be caught by layer_collection.LayerParametersDict during registration. - - reg_use_map = self._get_use_count_map() - - error_messages = [] - - for var in variables: - total_uses = self.subgraph.variable_uses(var) - reg_uses = reg_use_map[var] - - if reg_uses == 0: - error_messages.append("Variable {} not registered.".format(var)) - elif (not math.isinf(reg_uses)) and reg_uses != total_uses: - error_messages.append( - "Variable {} registered with wrong number of uses ({} " - "registrations vs {} uses).".format(var, reg_uses, total_uses)) - - num_get_vars = len(reg_use_map) - - if num_get_vars > len(variables): - error_messages.append("{} registered variables were not included in list." - .format(num_get_vars - len(variables))) - - if error_messages: - error_messages = [ - "Found the following errors with variable registration:" - ] + error_messages - raise ValueError("\n\t".join(error_messages)) - - def get_blocks(self): - return self.fisher_blocks.values() - - def get_factors(self): - return self.fisher_factors.values() - - @property - def graph(self): - return self._graph - - @property - def subgraph(self): - return self._subgraph - - def define_linked_parameters(self, params, approximation=None): - """Identify a set of parameters that should be grouped together. - - During automatic graph scanning, any matches containing variables that have - been identified as part of a linked group will be filtered out unless - the match parameters are exactly equal to the ones specified in the linked - group. - - Args: - params: A variable, or a tuple or list of variables. The variables - to be linked. - approximation: Optional string specifying the type of approximation to use - for these variables. If unspecified, this layer collection's default - approximation for the layer type will be used. - - Raises: - ValueError: If the parameters were already registered in a layer or - identified as part of an incompatible group. - """ - params = frozenset(utils.ensure_sequence(params)) - - # Check if any of the variables in `params` is already in - # 'self.fisher_blocks.keys()`. - for registered_params, fisher_block in self.fisher_blocks.items(): - registered_params_set = set(utils.ensure_sequence(registered_params)) - for variable in params: - if (variable in registered_params_set and - params != registered_params_set): - raise ValueError( - "Can`t link parameters {}, variable {} was already registered in " - "group {} with layer {}".format(params, variable, - registered_params, fisher_block)) - - # Check if any of the variables in `params` is already in - # 'self.linked_parameters`. - for variable in params: - for other_linked_params in self.linked_parameters: - if variable in other_linked_params: - raise ValueError("Can`t link parameters {}, variable {} was already " - "linked in group {}.".format(params, variable, - other_linked_params)) - self._linked_parameters[params] = approximation - - def create_subgraph(self): - if not self.losses: - raise ValueError("Must have at least one registered loss.") - inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses)) - self._subgraph = utils.SubGraph(inputs_to_losses) - - def eval_losses(self): - """Return evaluated losses (colocated with inputs to losses).""" - evals = [] - for loss in self.losses: - with ops.colocate_with(self.loss_colocation_ops[loss]): - evals.append(loss.evaluate()) - return evals - - def eval_losses_on_samples(self): - """Return losses evaluated on samples (colocated with inputs to losses).""" - evals = [] - for loss in self.losses: - with ops.colocate_with(self.loss_colocation_ops[loss]): - evals.append(loss.evaluate_on_sample()) - return evals - - def total_loss(self): - return math_ops.add_n(self.eval_losses()) - - def total_sampled_loss(self): - return math_ops.add_n(self.eval_losses_on_samples()) - - def _get_linked_approx(self, params): - """If params were linked, return their specified approximation.""" - params_set = frozenset(utils.ensure_sequence(params)) - if params_set in self.linked_parameters: - return self.linked_parameters[params_set] - else: - return None - - def _get_block_type(self, params, approx, default, approx_to_type): - if approx is None: - approx = self._get_linked_approx(params) - if approx is None: - approx = default - - if approx not in approx_to_type: - raise ValueError("Bad value {} for approx.".format(approx)) - - return approx_to_type[approx], approx - - def register_embedding(self, - params, - inputs, - outputs, - approx=None, - reuse=VARIABLE_SCOPE): - """Registers an embedding layer. - - Args: - params: Embedding matrix of shape [vocab_size, embedding_size]. - inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices - into embedding matrix. - outputs: Tensor of shape [batch_size, embedding_size]. Outputs - produced by layer. - approx: str or None. If not None must be "kron". The Fisher - approximation to use. If None the default value is used. (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - block_type, approx = self._get_block_type( - params, approx, self.default_embedding_approximation, - _EMBEDDING_APPROX_TO_BLOCK_TYPES) - - if isinstance(params, (tuple, list)): - raise ValueError("Bias not supported.") - vocab_size = int(params.shape[0]) - block = self.register_block( - params, block_type(self, vocab_size), reuse=reuse) - block.register_additional_tower(inputs, outputs) - - self._add_uses(params, 1) - - def register_fully_connected(self, - params, - inputs, - outputs, - approx=None, - reuse=VARIABLE_SCOPE): - """Registers a fully connected layer. - - Args: - params: Tensor or 2-tuple of Tensors corresponding to weight and bias of - this layer. Weight matrix should have shape [input_size, output_size]. - Bias should have shape [output_size]. - inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. - outputs: Tensor of shape [batch_size, output_size]. Outputs - produced by layer. - approx: str or None. If not None must be one of "kron" or "diagonal". - The Fisher approximation to use. If None the default value is used. - (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - - block_type, approx = self._get_block_type( - params, approx, self.default_fully_connected_approximation, - _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES) - - has_bias = isinstance(params, (tuple, list)) - block = self.register_block(params, block_type(self, has_bias=has_bias), - reuse=reuse) - block.register_additional_tower(inputs, outputs) - - self._add_uses(params, 1) - - def register_conv2d(self, - params, - strides, - padding, - inputs, - outputs, - data_format=None, - dilations=None, - approx=None, - reuse=VARIABLE_SCOPE): - """Registers a call to tf.nn.conv2d(). - - Args: - params: Tensor or 2-tuple of Tensors corresponding to weight and bias of - this layer. Weight matrix should have shape [kernel_height, - kernel_width, in_channels, out_channels]. Bias should have shape - [out_channels]. - strides: List of 4 ints. Strides for convolution kernel. - padding: string. see tf.nn.conv2d for valid values. - inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs - to layer. - outputs: Tensor of shape [batch_size, height, width, out_channels]. - Output produced by layer. - data_format: str or None. Format of data. - dilations: List of 4 ints. Dilations along each dimension. - approx: str or None. If not None must be one of "kron" or "diagonal". - The Fisher approximation to use. If None the default value is used. - (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - - block_type, approx = self._get_block_type( - params, approx, self.default_conv2d_approximation, - _CONV2D_APPROX_TO_BLOCK_TYPES) - - # It feels bad to pass in configuration that has to do with the internal - # implementation. And then we can`t use the same constructor for both - # anymore and are thus forced to use this ugly if-statement. - # TODO(b/74793309): Clean this up? - if approx == APPROX_KRONECKER_NAME: - block = self.register_block( - params, - block_type( - layer_collection=self, - params=params, - padding=padding, - strides=strides, - data_format=data_format, - dilation_rate=dilations, - extract_patches_fn="extract_image_patches"), - reuse=reuse) - elif approx == APPROX_DIAGONAL_NAME: - assert strides[0] == strides[-1] == 1 - block = self.register_block( - params, - block_type( - layer_collection=self, - params=params, - padding=padding, - strides=strides, - dilations=dilations, - data_format=data_format), - reuse=reuse) - else: - raise NotImplementedError(approx) - - block.register_additional_tower(inputs, outputs) - - self._add_uses(params, 1) - - def register_convolution(self, - params, - inputs, - outputs, - padding, - strides=None, - dilation_rate=None, - data_format=None, - approx=None, - reuse=VARIABLE_SCOPE): - """Register a call to tf.nn.convolution(). - - Args: - params: Tensor or 2-tuple of Tensors corresponding to weight and bias of - this layer. Weight matrix should have shape [..filter_spatial_size.., - in_channels, out_channels]. Bias should have shape [out_channels]. - inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels]. - Inputs to layer. - outputs: Tensor of shape [batch_size, ..output_spatial_size.., - out_channels]. Output produced by layer. - padding: string. see tf.nn.conv2d for valid values. - strides: List of ints of length len(..input_spatial_size..). Strides for - convolution kernel in spatial dimensions. - dilation_rate: List of ints of length len(..input_spatial_size..). - Dilations along spatial dimension. - data_format: str or None. Format of data. - approx: str or None. If not None must be one of "kron" or "diagonal". - The Fisher approximation to use. If None the default value is used. - (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - # TODO(b/74793309): Have this use _get_block_type like the other - # registration functions? - assert approx is None or approx == APPROX_KRONECKER_NAME - - block = self.register_block( - params, - fb.ConvKFCBasicFB( - layer_collection=self, - params=params, - padding=padding, - strides=strides, - dilation_rate=dilation_rate, - data_format=data_format), - reuse=reuse) - block.register_additional_tower(inputs, outputs) - - self._add_uses(params, 1) - - def register_depthwise_conv2d(self, - params, - inputs, - outputs, - strides, - padding, - rate=None, - data_format=None, - approx=None, - reuse=VARIABLE_SCOPE): - """Register a call to tf.nn.depthwise_conv2d(). - - Args: - params: 4-D Tensor of shape [filter_height, filter_width, - in_channels, channel_multiplier]. Convolutional filter. - inputs: Tensor of shape [batch_size, input_height, input_width, - in_channels]. Inputs to layer. - outputs: Tensor of shape [batch_size, output_height, output_width, - in_channels * channel_multiplier]. Output produced by depthwise conv2d. - strides: List of ints of length 4. Strides along all dimensions. - padding: string. see tf.nn.conv2d for valid values. - rate: None or List of ints of length 2. Dilation rates in spatial - dimensions. - data_format: str or None. Format of data. - approx: str or None. If not None must "diagonal". The Fisher - approximation to use. If None the default value is used. (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - # TODO(b/74793309): Have this use _get_block_type like the other - # registration functions? - assert approx is None or approx == APPROX_DIAGONAL_NAME - assert data_format in [None, "NHWC"] - - block = self.register_block( - params, - fb.DepthwiseConvDiagonalFB( - layer_collection=self, - params=params, - strides=strides, - padding=padding, - rate=rate, - data_format=data_format), - reuse=reuse) - block.register_additional_tower(inputs, outputs) - - self._add_uses(params, 1) - - def register_separable_conv2d(self, - depthwise_params, - pointwise_params, - inputs, - depthwise_outputs, - pointwise_outputs, - strides, - padding, - rate=None, - data_format=None, - approx=None, - reuse=VARIABLE_SCOPE): - """Register a call to tf.nn.separable_conv2d(). - - Note: This requires access to intermediate outputs between depthwise and - pointwise convolutions. - - Args: - depthwise_params: 4-D Tensor of shape [filter_height, filter_width, - in_channels, channel_multiplier]. Filter for depthwise conv2d. - pointwise_params: 4-D Tensor of shape [1, 1, in_channels * - channel_multiplier, out_channels]. Filter for pointwise conv2d. - inputs: Tensor of shape [batch_size, input_height, input_width, - in_channels]. Inputs to layer. - depthwise_outputs: Tensor of shape [batch_size, output_height, - output_width, in_channels * channel_multiplier]. Output produced by - depthwise conv2d. - pointwise_outputs: Tensor of shape [batch_size, output_height, - output_width, out_channels]. Output produced by pointwise conv2d. - strides: List of ints of length 4. Strides for depthwise conv2d kernel in - all dimensions. - padding: string. see tf.nn.conv2d for valid values. - rate: None or List of ints of length 2. Dilation rate of depthwise conv2d - kernel in spatial dimensions. - data_format: str or None. Format of data. - approx: str or None. If not None must be one of "kron" or "diagonal". - The Fisher approximation to use. If None the default value is used. - (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - self.register_depthwise_conv2d( - params=depthwise_params, - inputs=inputs, - outputs=depthwise_outputs, - strides=strides, - padding=padding, - rate=rate, - data_format=data_format, - approx=APPROX_DIAGONAL_NAME, - reuse=reuse) - - self.register_conv2d( - params=pointwise_params, - inputs=depthwise_outputs, - outputs=pointwise_outputs, - strides=[1, 1, 1, 1], - padding="VALID", - data_format=data_format, - approx=approx, - reuse=reuse) - - def register_generic(self, - params, - batch_size, - approx=None, - reuse=VARIABLE_SCOPE): - """Registers a generic layer. - - Args: - params: Tensor or tuple of Tensors corresponding to the parameters. - batch_size: 0-D Tensor. Size of the minibatch (for this tower). - approx: str or None. It not None, must be one of "full" or "diagonal". - The Fisher approximation to use. If None the default value is used. - (Default: None) - reuse: bool or str. If True, this adds `batch_size` to the total - mini-batch size use when estimating the Fisher block for this layer - (which must have already been registered). If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - block_type, approx = self._get_block_type( - params, approx, self.default_generic_approximation, - _GENERIC_APPROX_TO_BLOCK_TYPES) - - block = self.register_block(params, block_type(self, params), reuse=reuse) - block.register_additional_tower(batch_size) - - self._add_uses(params, float("inf")) - - def register_fully_connected_multi(self, params, inputs, outputs, - num_uses=None, approx=None, - reuse=VARIABLE_SCOPE): - """Register fully connected layers with shared parameters. - - This can handle general fully-connected layers with shared parameters, but - has specialized approximations to deal with the case where there is a - meaningful linear order to the share instances (such as in an RNN). - - Args: - params: Tensor or 2-tuple of Tensors corresponding to weight and bias of - this layer. Weight matrix should have shape [input_size, output_size]. - Bias should have shape [output_size]. - inputs: A list of Tensors, each of shape [batch_size, input_size]. Inputs - to layer. The list indexes each use in the graph (which might - correspond to a "time-step" in an RNN). OR, can be single Tensor, of - shape [num_uses * batch_size , input_size], which is a reshaped version - of a Tensor of shape [num_uses, batch_size, input_size]. - outputs: A list of Tensors, the same length as `inputs`, each of shape - [batch_size, output_size]. Outputs produced by layer. The list indexes - each use in the graph (which might correspond to a "time-step" in an - RNN). Needs to correspond with the order used in `inputs`. OR, can be - a single Tensor of shape [num_uses * batch_size, output_size], which is - a reshaped version of a Tensor of shape [num_uses, batch_size, - output_size]. - num_uses: int or None. The number uses/time-steps in the graph where the - layer appears. Only needed if both inputs and outputs are given in the - single Tensor format. (Default: None) - approx: str or None. If not None, must be of "kron_indep", "kron_series_1" - or "kron_series_2". The Fisher approximation to use. If None the default - value is used. (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the - word `use` here has a completely different meaning to "use in the graph" - as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.) - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - """ - block_type, approx = self._get_block_type( - params, approx, self.default_fully_connected_multi_approximation, - _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES) - - # TODO(b/70283649): something along the lines of find_canonical_output - # should be added back in here (and for the other block types, arguably). - - has_bias = isinstance(params, (tuple, list)) - block = self.register_block(params, block_type(self, has_bias=has_bias, - num_uses=num_uses), - reuse=reuse) - block.register_additional_tower(inputs, outputs) - if isinstance(inputs, (tuple, list)): - assert len(inputs) == len(outputs) - self._add_uses(params, len(inputs)) - else: - self._add_uses(params, 1) - - def register_conv2d_multi(self, - params, - strides, - padding, - inputs, - outputs, - num_uses=None, - data_format=None, - dilations=None, - approx=None, - reuse=VARIABLE_SCOPE): - """Registers convolutional layers with shared parameters. - - Args: - params: Tensor or 2-tuple of Tensors corresponding to weight and bias of - this layer. Weight matrix should have shape [kernel_height, - kernel_width, in_channels, out_channels]. Bias should have shape - [out_channels]. - strides: 1-D Tensor of length 4. Strides for convolution kernel. - padding: string. see tf.nn.conv2d for valid values. - inputs: A list of Tensors, each of shape [batch_size, height, width, - in_channels]. Inputs to layer. The list indexes each use in the graph - (which might correspond to a "time-step" in an RNN). OR, can be single - Tensor, of shape [num_uses * batch_size, height, width, in_channels], - which is a reshaped version of a Tensor of shape [num_uses, batch_size, - height, width, in_channels]. - outputs: A list of Tensors, each of shape [batch_size, height, width, - out_channels]. Output produced by layer. The list indexes each use - in the graph (which might correspond to a "time-step" in an RNN). - Needs to correspond with the order used in `inputs`. OR, can be a - single Tensor, of shape [num_uses * batch_size, height, width, - out_channels], which is a reshaped version of a Tensor of shape - [num_uses, batch_size, height, width, out_channels]. - num_uses: int or None. The number uses/time-steps in the graph where the - layer appears. Only needed if both inputs and outputs are given in the - single Tensor format. (Default: None) - data_format: str or None. Format of data. - dilations: List of 4 ints. Dilations along each dimension. - approx: str or None. If not None must by "kron_indep". The Fisher - approximation to use. If None the default value is used. - (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the - word `use` here has a completely different meaning to "use in the graph" - as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.) - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - block_type, approx = self._get_block_type( - params, approx, self.default_conv2d_multi_approximation, - _CONV2D_MULTI_APPROX_TO_BLOCK_TYPES) - - block = self.register_block( - params, - block_type( - layer_collection=self, - params=params, - padding=padding, - strides=strides, - data_format=data_format, - dilation_rate=dilations, - extract_patches_fn="extract_image_patches", - num_uses=num_uses), - reuse=reuse) - - block.register_additional_tower(inputs, outputs) - if isinstance(inputs, (tuple, list)): - assert len(inputs) == len(outputs) - self._add_uses(params, len(inputs)) - else: - self._add_uses(params, 1) - - # TODO(b/74108452): change the loss registration functions names to refer - # to "loss functions" instead of distributions. Following naming convention - # of the loss function classes themselves. - - def register_embedding_multi(self, - params, - inputs, - outputs, - num_uses=None, - approx=None, - reuse=VARIABLE_SCOPE): - """Registers embedding layers with shared parameters. - - Args: - params: Embedding matrix of shape [vocab_size, embedding_size]. - inputs: A list of Tensors, each of shape [batch_size, input_size] and - dtype int32. Indices into embedding matrix. The list indexes each use - in the graph (which might correspond to a "time-step" in an RNN). - OR, can be single Tensor, of shape [num_uses*batch_size, input_size], - which is a reshaped version of a Tensor of shape [num_uses, batch_size, - input_size]. - outputs: A list of Tensors, each of shape [batch_size, embedding_size]. - Outputs produced by layer. The list indexes each use in the graph - (which might correspond to a "time-step" in an RNN). Needs to - correspond with the order used in `inputs`. OR, can be a - single Tensor, of shape [num_uses * batch_size, embedding_size], which - is a reshaped version of a Tensor of shape [num_uses, batch_size, - embedding_size]. - num_uses: int or None. The number uses/time-steps in the graph where the - layer appears. Only needed if both inputs and outputs are given in the - single Tensor format. (Default: None) - approx: str or None. If not None must by "kron_indep". The Fisher - approximation to use. If None the default value is used. - (Default: None) - reuse: bool or str. If True, this adds `inputs` and `outputs` as an - additional mini-batch/tower of data to use when estimating the Fisher - block for this layer (which must have already been registered). If - "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the - word `use` here has a completely different meaning to "use in the graph" - as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.) - (Default: "VARIABLE_SCOPE") - - Raises: - ValueError: For improper value to `approx`. - KeyError: If reuse == True but no FisherBlock found for `params`. - ValueError: If reuse == True and FisherBlock found but of the wrong type. - """ - block_type, approx = self._get_block_type( - params, approx, self.default_embedding_multi_approximation, - _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES) - - if isinstance(params, (tuple, list)): - raise ValueError("Bias not supported.") - vocab_size = int(params.shape[0]) - - block = self.register_block( - params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse) - block.register_additional_tower(inputs, outputs) - - if isinstance(inputs, (tuple, list)): - self._add_uses(params, len(inputs)) - else: - self._add_uses(params, 1) - - def register_categorical_predictive_distribution(self, - logits, - seed=None, - targets=None, - name=None, - reuse=VARIABLE_SCOPE): - """Registers a categorical predictive distribution. - - Args: - logits: The logits of the distribution (i.e. its parameters). - seed: The seed for the RNG (for debugging) (Default: None) - targets: (OPTIONAL) The targets for the loss function. Only required if - one wants to call total_loss() instead of total_sampled_loss(). - total_loss() is required, for example, to estimate the - "empirical Fisher" (instead of the true Fisher). - (Default: None) - name: (OPTIONAL) str or None. Unique name for this loss function. If None, - a new name is generated. (Default: None) - reuse: bool or str. If True, this adds `logits` as an additional - mini-batch/tower of inputs to the loss-function/predictive distribution - (which must have already been registered). If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") - """ - loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets, - seed=seed) - self.register_loss_function(loss, logits, - "categorical_predictive_distribution", - name=name, reuse=reuse) - - def register_normal_predictive_distribution(self, - mean, - var=0.5, - seed=None, - targets=None, - name=None, - reuse=VARIABLE_SCOPE): - """Registers a normal predictive distribution. - - Args: - mean: The mean vector defining the distribution. - var: The variance (must be a scalar). Note that the default value of - 0.5 corresponds to a standard squared error loss (target - - prediction)**2. If your squared error loss is of the form - 0.5*(target - prediction)**2 you should use var=1.0. (Default: 0.5) - seed: The seed for the RNG (for debugging) (Default: None) - targets: (OPTIONAL) The targets for the loss function. Only required if - one wants to call total_loss() instead of total_sampled_loss(). - total_loss() is required, for example, to estimate the - "empirical Fisher" (instead of the true Fisher). - (Default: None) - name: (OPTIONAL) str or None. Unique name for this loss function. If None, - a new name is generated. (Default: None) - reuse: bool or str. If True, this adds `mean` and `var` as an additional - mini-batch/tower of inputs to the loss-function/predictive distribution - (which must have already been registered). If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") - """ - loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets, - seed=seed) - self.register_loss_function(loss, mean, - "normal_predictive_distribution", - name=name, reuse=reuse) - - def register_multi_bernoulli_predictive_distribution(self, - logits, - seed=None, - targets=None, - name=None, - reuse=VARIABLE_SCOPE): - """Registers a multi-Bernoulli predictive distribution. - - Args: - logits: The logits of the distribution (i.e. its parameters). - seed: The seed for the RNG (for debugging) (Default: None) - targets: (OPTIONAL) The targets for the loss function. Only required if - one wants to call total_loss() instead of total_sampled_loss(). - total_loss() is required, for example, to estimate the - "empirical Fisher" (instead of the true Fisher). - (Default: None) - name: (OPTIONAL) str or None. Unique name for this loss function. If None, - a new name is generated. (Default: None) - reuse: bool or str. If True, this adds `logits` as an additional - mini-batch/tower of inputs to the loss-function/predictive distribution - (which must have already been registered). If "VARIABLE_SCOPE", use - tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE") - """ - loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets, - seed=seed) - self.register_loss_function(loss, logits, - "multi_bernoulli_predictive_distribution", - name=name, reuse=reuse) - - def make_or_get_factor(self, cls, args): - """Insert `cls(args)` into 'self.fisher_factors` if not already present. - - Wraps constructor in `tf.variable_scope()` to ensure variables constructed - in `cls.__init__` are placed under this LayerCollection's scope. - - Args: - cls: Class that implements FisherFactor. - args: Tuple of arguments to pass into `cls's constructor. Must be - hashable. - - Returns: - Instance of `cls` found in self.fisher_factors. - """ - try: - hash(args) - except TypeError: - raise TypeError( - ("Unable to use (cls, args) = ({}, {}) as a key in " - "LayerCollection.fisher_factors. The pair cannot be hashed.").format( - cls, args)) - - key = cls, args - if key not in self.fisher_factors: - with variable_scope.variable_scope(self._var_scope): - self.fisher_factors[key] = cls(*args) - return self.fisher_factors[key] - - @contextmanager - def as_default(self): - """Sets this LayerCollection as the default.""" - set_default_layer_collection(self) - yield - set_default_layer_collection(None) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py deleted file mode 100644 index 9f4685380705bd409dbcd7e85d0e3bb4189a6adc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Registry for layers and their parameters/variables. - -This represents the collection of all layers in the approximate Fisher -information matrix to which a particular FisherBlock may belong. That is, we -might have several layer collections for one TF graph (if we have multiple K-FAC -optimizers being used, for example.) -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.layer_collection import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - "get_default_layer_collection", - "set_default_layer_collection", - "LayerParametersDict", - "LayerCollection", - "APPROX_KRONECKER_NAME", - "APPROX_DIAGONAL_NAME", - "APPROX_FULL_NAME", - "VARIABLE_SCOPE", - "APPROX_KRONECKER_INDEP_NAME", - "APPROX_KRONECKER_SERIES_1_NAME", - "APPROX_KRONECKER_SERIES_2_NAME" -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/linear_operator.py b/tensorflow/contrib/kfac/python/ops/linear_operator.py deleted file mode 100644 index 61cb955ae85df9e56cbe165acba98ece750cba90..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/linear_operator.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SmartMatrices definitions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.kfac.python.ops import utils -from tensorflow.python.framework import ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.linalg import linalg -from tensorflow.python.ops.linalg import linalg_impl -from tensorflow.python.ops.linalg import linear_operator_util as lou - - -class LinearOperatorExtras(object): # pylint: disable=missing-docstring - - def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): - - with self._name_scope(name, values=[x]): - if isinstance(x, ops.IndexedSlices): - return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg) - - x = ops.convert_to_tensor(x, name="x") - self._check_input_dtype(x) - - self_dim = -2 if adjoint else -1 - arg_dim = -1 if adjoint_arg else -2 - self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) - - return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) - - def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"): - - with self._name_scope(name, values=[x]): - - if isinstance(x, ops.IndexedSlices): - return self._matmul_right_sparse( - x, adjoint=adjoint, adjoint_arg=adjoint_arg) - - x = ops.convert_to_tensor(x, name="x") - self._check_input_dtype(x) - - self_dim = -1 if adjoint else -2 - arg_dim = -2 if adjoint_arg else -1 - self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) - - return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg) - - -class LinearOperatorFullMatrix(LinearOperatorExtras, - linalg.LinearOperatorFullMatrix): - - # TODO(b/78117889) Remove this definition once core LinearOperator - # has _matmul_right. - def _matmul_right(self, x, adjoint=False, adjoint_arg=False): - return lou.matmul_with_broadcast( - x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint) - - def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): - raise NotImplementedError - - def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): - assert not adjoint and not adjoint_arg - return utils.matmul_sparse_dense(x, self._matrix) - - -class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring - linalg.LinearOperatorDiag): - - def _matmul_right(self, x, adjoint=False, adjoint_arg=False): - diag_mat = math_ops.conj(self._diag) if adjoint else self._diag - x = linalg_impl.adjoint(x) if adjoint_arg else x - return diag_mat * x - - def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): - diag_mat = math_ops.conj(self._diag) if adjoint else self._diag - assert not adjoint_arg - return utils.matmul_diag_sparse(diag_mat, x) - - def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): - raise NotImplementedError diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py deleted file mode 100644 index c8cebc42cb329965410df808bc8eeef60985a603..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ /dev/null @@ -1,754 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Loss functions to be used by LayerCollection.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import abc - -import six - -from tensorflow.contrib.distributions.python.ops import onehot_categorical -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bernoulli -from tensorflow.python.ops.distributions import categorical -from tensorflow.python.ops.distributions import normal - - -@six.add_metaclass(abc.ABCMeta) -class LossFunction(object): - """Abstract base class for loss functions. - - Note that unlike typical loss functions used in neural networks these are - summed and not averaged across cases in the batch, since this is what the - users of this class (FisherEstimator and MatrixVectorProductComputer) will - be expecting. The implication of this is that you will may want to - normalize things like Fisher-vector products by the batch size when you - use this class. It depends on the use case. - """ - - @abc.abstractproperty - def targets(self): - """The targets being predicted by the model. - - Returns: - None or Tensor of appropriate shape for calling self._evaluate() on. - """ - pass - - @abc.abstractproperty - def inputs(self): - """The inputs to the loss function (excluding the targets).""" - pass - - def evaluate(self): - """Evaluate the loss function on the targets.""" - if self.targets is not None: - # We treat the targets as "constant". It's only the inputs that get - # "back-propped" through. - return self._evaluate(array_ops.stop_gradient(self.targets)) - else: - raise Exception("Cannot evaluate losses with unspecified targets.") - - @abc.abstractmethod - def _evaluate(self, targets): - """Evaluates the negative log probability of the targets. - - Args: - targets: Tensor that distribution can calculate log_prob() of. - - Returns: - negative log probability of each target, summed across all targets. - """ - pass - - @abc.abstractmethod - def multiply_hessian(self, vector): - """Right-multiply a vector by the Hessian. - - Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) - of the loss function with respect to its inputs. - - Args: - vector: The vector to multiply. Must be the same shape(s) as the - 'inputs' property. - - Returns: - The vector right-multiplied by the Hessian. Will be of the same shape(s) - as the 'inputs' property. - """ - pass - - @abc.abstractmethod - def multiply_hessian_factor(self, vector): - """Right-multiply a vector by a factor B of the Hessian. - - Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) - of the loss function with respect to its inputs. Typically this will be - block-diagonal across different cases in the batch, since the loss function - is typically summed across cases. - - Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, - but will agree with the one used in the other methods of this class. - - Args: - vector: The vector to multiply. Must be of the shape given by the - 'hessian_factor_inner_shape' property. - - Returns: - The vector right-multiplied by B. Will be of the same shape(s) as the - 'inputs' property. - """ - pass - - @abc.abstractmethod - def multiply_hessian_factor_transpose(self, vector): - """Right-multiply a vector by the transpose of a factor B of the Hessian. - - Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) - of the loss function with respect to its inputs. Typically this will be - block-diagonal across different cases in the batch, since the loss function - is typically summed across cases. - - Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, - but will agree with the one used in the other methods of this class. - - Args: - vector: The vector to multiply. Must be the same shape(s) as the - 'inputs' property. - - Returns: - The vector right-multiplied by B^T. Will be of the shape given by the - 'hessian_factor_inner_shape' property. - """ - pass - - @abc.abstractmethod - def multiply_hessian_factor_replicated_one_hot(self, index): - """Right-multiply a replicated-one-hot vector by a factor B of the Hessian. - - Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) - of the loss function with respect to its inputs. Typically this will be - block-diagonal across different cases in the batch, since the loss function - is typically summed across cases. - - A 'replicated-one-hot' vector means a tensor which, for each slice along the - batch dimension (assumed to be dimension 0), is 1.0 in the entry - corresponding to the given index and 0 elsewhere. - - Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, - but will agree with the one used in the other methods of this class. - - Args: - index: A tuple representing in the index of the entry in each slice that - is 1.0. Note that len(index) must be equal to the number of elements - of the 'hessian_factor_inner_shape' tensor minus one. - - Returns: - The vector right-multiplied by B^T. Will be of the same shape(s) as the - 'inputs' property. - """ - pass - - @abc.abstractproperty - def hessian_factor_inner_shape(self): - """The shape of the tensor returned by multiply_hessian_factor.""" - pass - - @abc.abstractproperty - def hessian_factor_inner_static_shape(self): - """Static version of hessian_factor_inner_shape.""" - pass - - -@six.add_metaclass(abc.ABCMeta) -class NegativeLogProbLoss(LossFunction): - """Abstract base class for loss functions that are negative log probs.""" - - def __init__(self, seed=None): - self._default_seed = seed - super(NegativeLogProbLoss, self).__init__() - - @property - def inputs(self): - return self.params - - @abc.abstractproperty - def params(self): - """Parameters to the underlying distribution.""" - pass - - @abc.abstractmethod - def multiply_fisher(self, vector): - """Right-multiply a vector by the Fisher. - - Args: - vector: The vector to multiply. Must be the same shape(s) as the - 'inputs' property. - - Returns: - The vector right-multiplied by the Fisher. Will be of the same shape(s) - as the 'inputs' property. - """ - pass - - @abc.abstractmethod - def multiply_fisher_factor(self, vector): - """Right-multiply a vector by a factor B of the Fisher. - - Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- - product of gradients) with respect to the parameters of the underlying - probability distribution (whose log-prob defines the loss). Typically this - will be block-diagonal across different cases in the batch, since the - distribution is usually (but not always) conditionally iid across different - cases. - - Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, - but will agree with the one used in the other methods of this class. - - Args: - vector: The vector to multiply. Must be of the shape given by the - 'fisher_factor_inner_shape' property. - - Returns: - The vector right-multiplied by B. Will be of the same shape(s) as the - 'inputs' property. - """ - pass - - @abc.abstractmethod - def multiply_fisher_factor_transpose(self, vector): - """Right-multiply a vector by the transpose of a factor B of the Fisher. - - Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- - product of gradients) with respect to the parameters of the underlying - probability distribution (whose log-prob defines the loss). Typically this - will be block-diagonal across different cases in the batch, since the - distribution is usually (but not always) conditionally iid across different - cases. - - Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, - but will agree with the one used in the other methods of this class. - - Args: - vector: The vector to multiply. Must be the same shape(s) as the - 'inputs' property. - - Returns: - The vector right-multiplied by B^T. Will be of the shape given by the - 'fisher_factor_inner_shape' property. - """ - pass - - @abc.abstractmethod - def multiply_fisher_factor_replicated_one_hot(self, index): - """Right-multiply a replicated-one-hot vector by a factor B of the Fisher. - - Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- - product of gradients) with respect to the parameters of the underlying - probability distribution (whose log-prob defines the loss). Typically this - will be block-diagonal across different cases in the batch, since the - distribution is usually (but not always) conditionally iid across different - cases. - - A 'replicated-one-hot' vector means a tensor which, for each slice along the - batch dimension (assumed to be dimension 0), is 1.0 in the entry - corresponding to the given index and 0 elsewhere. - - Note that B can be any matrix satisfying B * B^T = H where H is the Fisher, - but will agree with the one used in the other methods of this class. - - Args: - index: A tuple representing in the index of the entry in each slice that - is 1.0. Note that len(index) must be equal to the number of elements - of the 'fisher_factor_inner_shape' tensor minus one. - - Returns: - The vector right-multiplied by B. Will be of the same shape(s) as the - 'inputs' property. - """ - pass - - @abc.abstractproperty - def fisher_factor_inner_shape(self): - """The shape of the tensor returned by multiply_fisher_factor.""" - pass - - @abc.abstractproperty - def fisher_factor_inner_static_shape(self): - """Static version of fisher_factor_inner_shape.""" - pass - - @abc.abstractmethod - def sample(self, seed): - """Sample 'targets' from the underlying distribution.""" - pass - - def evaluate_on_sample(self, seed=None): - """Evaluates the log probability on a random sample. - - Args: - seed: int or None. Random seed for this draw from the distribution. - - Returns: - Log probability of sampled targets, summed across examples. - """ - if seed is None: - seed = self._default_seed - # We treat the targets as "constant". It's only the inputs that get - # "back-propped" through. - return self._evaluate(array_ops.stop_gradient(self.sample(seed))) - - -# TODO(jamesmartens): should this just inherit from object to avoid "diamond" -# inheritance, or is there a better way? -class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss): - """Base class for neg log prob losses whose inputs are 'natural' parameters. - - Note that the Hessian and Fisher for natural parameters of exponential- - family models are the same, hence the purpose of this class. - See here: https://arxiv.org/abs/1412.1193 - - 'Natural parameters' are defined for exponential-family models. See for - example: https://en.wikipedia.org/wiki/Exponential_family - """ - - def multiply_hessian(self, vector): - return self.multiply_fisher(vector) - - def multiply_hessian_factor(self, vector): - return self.multiply_fisher_factor(vector) - - def multiply_hessian_factor_transpose(self, vector): - return self.multiply_fisher_factor_transpose(vector) - - def multiply_hessian_factor_replicated_one_hot(self, index): - return self.multiply_fisher_factor_replicated_one_hot(index) - - @property - def hessian_factor_inner_shape(self): - return self.fisher_factor_inner_shape - - @property - def hessian_factor_inner_static_shape(self): - return self.fisher_factor_inner_shape - - -class DistributionNegativeLogProbLoss(NegativeLogProbLoss): - """Base class for neg log prob losses that use the TF Distribution classes.""" - - def __init__(self, seed=None): - super(DistributionNegativeLogProbLoss, self).__init__(seed=seed) - - @abc.abstractproperty - def dist(self): - """The underlying tf.distributions.Distribution.""" - pass - - def _evaluate(self, targets): - return -math_ops.reduce_sum(self.dist.log_prob(targets)) - - def sample(self, seed): - return self.dist.sample(seed=seed) - - -class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, - NaturalParamsNegativeLogProbLoss): - """Neg log prob loss for a normal distribution parameterized by a mean vector. - - - Note that the covariance is treated as a constant 'var' times the identity. - Also note that the Fisher for such a normal distribution with respect the mean - parameter is given by: - - F = (1/var) * I - - See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf. - """ - - def __init__(self, mean, var=0.5, targets=None, seed=None): - self._mean = mean - self._var = var - self._targets = targets - super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed) - - @property - def targets(self): - return self._targets - - @property - def dist(self): - return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var)) - - @property - def params(self): - return self._mean - - def multiply_fisher(self, vector): - return (1. / self._var) * vector - - def multiply_fisher_factor(self, vector): - return self._var**-0.5 * vector - - def multiply_fisher_factor_transpose(self, vector): - return self.multiply_fisher_factor(vector) # it's symmetric in this case - - def multiply_fisher_factor_replicated_one_hot(self, index): - assert len(index) == 1, "Length of index was {}".format(len(index)) - ones_slice = array_ops.expand_dims( - array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype), - axis=-1) - output_slice = self._var**-0.5 * ones_slice - return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]), - index[0]) - - @property - def fisher_factor_inner_shape(self): - return array_ops.shape(self._mean) - - @property - def fisher_factor_inner_static_shape(self): - return self._mean.shape - - -class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): - """Negative log prob loss for a normal distribution with mean and variance. - - This class parameterizes a multivariate normal distribution with n independent - dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not - assume the variance is held constant. The Fisher Information for n = 1 - is given by, - - F = [[1 / variance, 0], - [ 0, 0.5 / variance^2]] - - where the parameters of the distribution are concatenated into a single - vector as [mean, variance]. For n > 1, the mean parameter vector is - concatenated with the variance parameter vector. - - See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation. - """ - - def __init__(self, mean, variance, targets=None, seed=None): - assert len(mean.shape) == 2, "Expect 2D mean tensor." - assert len(variance.shape) == 2, "Expect 2D variance tensor." - self._mean = mean - self._variance = variance - self._targets = targets - super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed) - - @property - def targets(self): - return self._targets - - @property - def dist(self): - return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance)) - - @property - def params(self): - return self._mean, self._variance - - def _concat(self, mean, variance): - return array_ops.concat([mean, variance], axis=-1) - - def _split(self, params): - return array_ops.split(params, 2, axis=-1) - - @property - def _fisher_mean(self): - return 1. / self._variance - - @property - def _fisher_mean_factor(self): - return 1. / math_ops.sqrt(self._variance) - - @property - def _fisher_var(self): - return 1. / (2 * math_ops.square(self._variance)) - - @property - def _fisher_var_factor(self): - return 1. / (math_ops.sqrt(2.) * self._variance) - - def multiply_fisher(self, vecs): - mean_vec, var_vec = vecs - return (self._fisher_mean * mean_vec, self._fisher_var * var_vec) - - def multiply_fisher_factor(self, vecs): - mean_vec, var_vec = self._split(vecs) - return (self._fisher_mean_factor * mean_vec, - self._fisher_var_factor * var_vec) - - def multiply_fisher_factor_transpose(self, vecs): - mean_vec, var_vec = vecs - return self._concat(self._fisher_mean_factor * mean_vec, - self._fisher_var_factor * var_vec) - - def multiply_fisher_factor_replicated_one_hot(self, index): - assert len(index) == 1, "Length of index was {}".format(len(index)) - index = index[0] - - if index < int(self._mean.shape[-1]): - # Index corresponds to mean parameter. - mean_slice = self._fisher_mean_factor[:, index] - mean_slice = array_ops.expand_dims(mean_slice, axis=-1) - mean_output = insert_slice_in_zeros(mean_slice, 1, int( - self._mean.shape[1]), index) - var_output = array_ops.zeros_like(mean_output) - else: - index -= int(self._mean.shape[-1]) - # Index corresponds to variance parameter. - var_slice = self._fisher_var_factor[:, index] - var_slice = array_ops.expand_dims(var_slice, axis=-1) - var_output = insert_slice_in_zeros(var_slice, 1, - int(self._variance.shape[1]), index) - mean_output = array_ops.zeros_like(var_output) - - return mean_output, var_output - - @property - def fisher_factor_inner_shape(self): - return array_ops.concat( - [ - array_ops.shape(self._mean)[:-1], - 2 * array_ops.shape(self._mean)[-1:] - ], - axis=0) - - @property - def fisher_factor_inner_static_shape(self): - shape = self._mean.shape.as_list() - return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]]) - - def multiply_hessian(self, vector): - raise NotImplementedError() - - def multiply_hessian_factor(self, vector): - raise NotImplementedError() - - def multiply_hessian_factor_transpose(self, vector): - raise NotImplementedError() - - def multiply_hessian_factor_replicated_one_hot(self, index): - raise NotImplementedError() - - @property - def hessian_factor_inner_shape(self): - raise NotImplementedError() - - @property - def hessian_factor_inner_static_shape(self): - raise NotImplementedError() - - -class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, - NaturalParamsNegativeLogProbLoss): - """Neg log prob loss for a categorical distribution parameterized by logits. - - - Note that the Fisher (for a single case) of a categorical distribution, with - respect to the natural parameters (i.e. the logits), is given by: - - F = diag(p) - p*p^T - - where p = softmax(logits). F can be factorized as F = B * B^T where - - B = diag(q) - p*q^T - - where q is the entry-wise square root of p. This is easy to verify using the - fact that q^T*q = 1. - """ - - def __init__(self, logits, targets=None, seed=None): - """Instantiates a CategoricalLogitsNegativeLogProbLoss. - - Args: - logits: Tensor of shape [batch_size, output_size]. Parameters for - underlying distribution. - targets: None or Tensor of shape [output_size]. Each elements contains an - index in [0, output_size). - seed: int or None. Default random seed when sampling. - """ - self._logits = logits - self._targets = targets - super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed) - - @property - def targets(self): - return self._targets - - @property - def dist(self): - return categorical.Categorical(logits=self._logits) - - @property - def _probs(self): - return self.dist.probs - - @property - def _sqrt_probs(self): - return math_ops.sqrt(self._probs) - - @property - def params(self): - return self._logits - - def multiply_fisher(self, vector): - probs = self._probs - return vector * probs - probs * math_ops.reduce_sum( - vector * probs, axis=-1, keepdims=True) - - def multiply_fisher_factor(self, vector): - probs = self._probs - sqrt_probs = self._sqrt_probs - return sqrt_probs * vector - probs * math_ops.reduce_sum( - sqrt_probs * vector, axis=-1, keepdims=True) - - def multiply_fisher_factor_transpose(self, vector): - probs = self._probs - sqrt_probs = self._sqrt_probs - return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum( - probs * vector, axis=-1, keepdims=True) - - def multiply_fisher_factor_replicated_one_hot(self, index): - assert len(index) == 1, "Length of index was {}".format(len(index)) - probs = self._probs - sqrt_probs = self._sqrt_probs - sqrt_probs_slice = array_ops.expand_dims(sqrt_probs[:, index[0]], -1) - padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1, - int(sqrt_probs.shape[1]), index[0]) - return padded_slice - probs * sqrt_probs_slice - - @property - def fisher_factor_inner_shape(self): - return array_ops.shape(self._logits) - - @property - def fisher_factor_inner_static_shape(self): - return self._logits.shape - - -class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss, - NaturalParamsNegativeLogProbLoss): - """Neg log prob loss for multiple Bernoulli distributions param'd by logits. - - Represents N independent Bernoulli distributions where N = len(logits). Its - Fisher Information matrix is given by, - - F = diag(p * (1-p)) - p = sigmoid(logits) - - As F is diagonal with positive entries, its factor B is, - - B = diag(sqrt(p * (1-p))) - """ - - def __init__(self, logits, targets=None, seed=None): - self._logits = logits - self._targets = targets - super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed) - - @property - def targets(self): - return self._targets - - @property - def dist(self): - return bernoulli.Bernoulli(logits=self._logits) - - @property - def _probs(self): - return self.dist.probs - - @property - def params(self): - return self._logits - - def multiply_fisher(self, vector): - return self._probs * (1 - self._probs) * vector - - def multiply_fisher_factor(self, vector): - return math_ops.sqrt(self._probs * (1 - self._probs)) * vector - - def multiply_fisher_factor_transpose(self, vector): - return self.multiply_fisher_factor(vector) # it's symmetric in this case - - def multiply_fisher_factor_replicated_one_hot(self, index): - assert len(index) == 1, "Length of index was {}".format(len(index)) - probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1) - output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice)) - return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]), - index[0]) - - @property - def fisher_factor_inner_shape(self): - return array_ops.shape(self._logits) - - @property - def fisher_factor_inner_static_shape(self): - return self._logits.shape - - -def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position): - """Inserts slice into a larger tensor of zeros. - - Forms a new tensor which is the same shape as slice_to_insert, except that - the dimension given by 'dim' is expanded to the size given by 'dim_size'. - 'position' determines the position (index) at which to insert the slice within - that dimension. - - Assumes slice_to_insert.shape[dim] = 1. - - Args: - slice_to_insert: The slice to insert. - dim: The dimension which to expand with zeros. - dim_size: The new size of the 'dim' dimension. - position: The position of 'slice_to_insert' in the new tensor. - - Returns: - The new tensor. - - Raises: - ValueError: If the slice's shape at the given dim is not 1. - """ - slice_shape = slice_to_insert.shape - if slice_shape[dim] != 1: - raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but " - "was {}".format(dim, slice_to_insert.shape[dim])) - - before = [0] * int(len(slice_shape)) - after = before[:] - before[dim] = position - after[dim] = dim_size - position - 1 - - return array_ops.pad(slice_to_insert, list(zip(before, after))) - - -class OnehotCategoricalLogitsNegativeLogProbLoss( - CategoricalLogitsNegativeLogProbLoss): - """Neg log prob loss for a categorical distribution with onehot targets. - - Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying - distribution is OneHotCategorical as opposed to Categorical. - """ - - @property - def dist(self): - return onehot_categorical.OneHotCategorical(logits=self._logits) diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py deleted file mode 100644 index 4279cb2792854249e3e076d200e2656bc615779d..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Loss functions to be used by LayerCollection.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.loss_functions import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - "LossFunction", - "NegativeLogProbLoss", - "NaturalParamsNegativeLogProbLoss", - "DistributionNegativeLogProbLoss", - "NormalMeanNegativeLogProbLoss", - "NormalMeanVarianceNegativeLogProbLoss", - "CategoricalLogitsNegativeLogProbLoss", - "OnehotCategoricalLogitsNegativeLogProbLoss", - "MultiBernoulliNegativeLogProbLoss", - "insert_slice_in_zeros", -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/op_queue.py b/tensorflow/contrib/kfac/python/ops/op_queue.py deleted file mode 100644 index b6d9d37a31a949b154b79e6f3677289a0d167373..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/op_queue.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Helper for choosing which op to run next in a distributed setting.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import ops as tf_ops - - -class OpQueue(object): - """Class for choosing which Op to run next. - - Constructs an infinitely repeating sequence of Ops in shuffled order. - - In K-FAC, this can be used to distribute inverse update operations among - workers. - """ - - def __init__(self, ops, seed=None): - """Initializes an OpQueue. - - Args: - ops: list of TensorFlow Ops. Ops to be selected from. All workers must - initialize with the same set of ops. - seed: int or None. Random seed used when shuffling order of ops. - """ - self._ops_by_name = {op.name: op for op in ops} - - # Construct a (shuffled) Dataset with Op names. - op_names = tf_ops.convert_to_tensor(list(sorted(op.name for op in ops))) - op_names_dataset = (dataset_ops.Dataset.from_tensor_slices(op_names) - .shuffle(len(ops), seed=seed).repeat()) - self._next_op_name = op_names_dataset.make_one_shot_iterator().get_next() - - @property - def ops(self): - """Ops this OpQueue can return in next_op().""" - return self._ops_by_name.values() - - def next_op(self, sess): - """Chooses which op to run next. - - Note: This call will make a call to sess.run(). - - Args: - sess: tf.Session. - - Returns: - Next Op chosen from 'ops'. - """ - # In Python 3, type(next_op_name) == bytes. Calling bytes.decode('ascii') - # returns a str. - next_op_name = sess.run(self._next_op_name).decode('ascii') - return self._ops_by_name[next_op_name] diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py deleted file mode 100644 index 38605259b5f8566f4230f0f441f83d1b7b820c93..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ /dev/null @@ -1,727 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""The KFAC optimizer.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import warnings - -# pylint disable=long-line -from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp -from tensorflow.contrib.kfac.python.ops import estimator as est -# pylint enable=long-line - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as tf_variables -from tensorflow.python.training import gradient_descent - - -class KfacOptimizer(gradient_descent.GradientDescentOptimizer): - """The KFAC Optimizer (https://arxiv.org/abs/1503.05671).""" - - def __init__(self, - learning_rate, - cov_ema_decay, - damping, - layer_collection, - var_list=None, - momentum=0.9, - momentum_type="regular", - norm_constraint=None, - name="KFAC", - estimation_mode="gradients", - colocate_gradients_with_ops=True, - batch_size=None, - placement_strategy=None, - **kwargs): - """Initializes the KFAC optimizer with the given settings. - - Args: - learning_rate: The base learning rate for the optimizer. Should probably - be set to 1.0 when using momentum_type = 'qmodel', but can still be - set lowered if desired (effectively lowering the trust in the - quadratic model.) - cov_ema_decay: The decay factor used when calculating the covariance - estimate moving averages. - damping: The damping factor used to stabilize training due to errors in - the local approximation with the Fisher information matrix, and to - regularize the update direction by making it closer to the gradient. - If damping is adapted during training then this value is used for - initializing damping variable. - (Higher damping means the update looks more like a standard gradient - update - see Tikhonov regularization.) - layer_collection: The layer collection object, which holds the fisher - blocks, Kronecker factors, and losses associated with the - graph. The layer_collection cannot be modified after KfacOptimizer's - initialization. - var_list: Optional list or tuple of variables to train. Defaults to the - list of variables collected in the graph under the key - `GraphKeys.TRAINABLE_VARIABLES`. - momentum: The momentum decay constant to use. Only applies when - momentum_type is 'regular' or 'adam'. (Default: 0.9) - momentum_type: The type of momentum to use in this optimizer, one of - 'regular', 'adam', or 'qmodel'. (Default: 'regular') - norm_constraint: float or Tensor. If specified, the update is scaled down - so that its approximate squared Fisher norm v^T F v is at most the - specified value. May only be used with momentum type 'regular'. - (Default: None) - name: The name for this optimizer. (Default: 'KFAC') - estimation_mode: The type of estimator to use for the Fishers. Can be - 'gradients', 'empirical', 'curvature_propagation', or 'exact'. - (Default: 'gradients'). See the doc-string for FisherEstimator for - more a more detailed description of these options. - colocate_gradients_with_ops: Whether we should request gradients we - compute in the estimator be colocated with their respective ops. - (Default: True) - batch_size: The size of the mini-batch. Only needed when momentum_type - == 'qmodel' or when automatic adjustment is used. (Default: None) - placement_strategy: string, Device placement strategy used when creating - covariance variables, covariance ops, and inverse ops. - (Default: `None`) - **kwargs: Arguments to be passed to specific placement - strategy mixin. Check `placement.RoundRobinPlacementMixin` for example. - - Raises: - ValueError: If the momentum type is unsupported. - ValueError: If clipping is used with momentum type other than 'regular'. - ValueError: If no losses have been registered with layer_collection. - ValueError: If momentum is non-zero and momentum_type is not 'regular' - or 'adam'. - """ - warnings.warn( - "third_party.tensorflow.contrib.kfac is deprecated." - "This will be removed on 15-07-2018. Check README for further details.", - DeprecationWarning) - # Parameters to be passed to the Fisher estimator: - self._variables = var_list or tf_variables.trainable_variables - self._cov_ema_decay = cov_ema_decay - self._layers = layer_collection - self._estimation_mode = estimation_mode - self._colocate_gradients_with_ops = colocate_gradients_with_ops - - # The below parameters are required only if damping needs to be adapted. - # These parameters can be set by calling - # set_damping_adaptation_params() explicitly. - self._damping_adaptation_decay = 0.95 - self._damping_adaptation_interval = 5 - # Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval) - self._omega = ( - self._damping_adaptation_decay**self._damping_adaptation_interval) - self._adapt_damping = False - self._min_damping = 1e-5 - self._prev_train_batch = None - self._is_chief = False - self._loss_fn = None - self._damping_constant = damping - self._damping = None - self._rho = None - self._prev_loss = None - self._q_model_change = None - self._update_damping_op = None - - momentum_type = momentum_type.lower() - legal_momentum_types = ["regular", "adam", "qmodel"] - - if momentum_type not in legal_momentum_types: - raise ValueError("Unsupported momentum type {}. Must be one of {}." - .format(momentum_type, legal_momentum_types)) - if momentum_type != "regular" and norm_constraint is not None: - raise ValueError("Update clipping is only supported with momentum " - "type 'regular'.") - if momentum_type not in ["regular", "adam"] and momentum != 0: - raise ValueError("Momentum must be unspecified if using a momentum_type " - "other than 'regular' or 'adam'.") - - # Extra parameters of the optimizer - self._momentum = momentum - self._momentum_type = momentum_type - self._norm_constraint = norm_constraint - self._batch_size = batch_size - self._placement_strategy = placement_strategy - - with variable_scope.variable_scope(name): - self._fisher_est = est.make_fisher_estimator( - placement_strategy=placement_strategy, - variables=self._variables, - cov_ema_decay=self._cov_ema_decay, - damping=self.damping, - layer_collection=self._layers, - exps=(-1,), - estimation_mode=self._estimation_mode, - colocate_gradients_with_ops=self._colocate_gradients_with_ops, - **kwargs) - - super(KfacOptimizer, self).__init__(learning_rate, name=name) - - def set_damping_adaptation_params(self, - is_chief, - prev_train_batch, - loss_fn, - min_damping=1e-5, - damping_adaptation_decay=0.99, - damping_adaptation_interval=5): - """Sets parameters required to adapt damping during training. - - When called, enables damping adaptation according to the Levenberg-Marquardt - style rule described in Section 6.5 of "Optimizing Neural Networks with - Kronecker-factored Approximate Curvature". - - Note that this function creates Tensorflow variables which store a few - scalars and are accessed by the ops which update the damping (as part - of the training op returned by the minimize() method). - - Args: - is_chief: `Boolean`, `True` if the worker is chief. - prev_train_batch: Training data used to minimize loss in the previous - step. This will be used to evaluate loss by calling - `loss_fn(prev_train_batch)`. - loss_fn: `function` that takes as input training data tensor and returns - a scalar loss. - min_damping: `float`(Optional), Minimum value the damping parameter - can take. Default value 1e-5. - damping_adaptation_decay: `float`(Optional), The `damping` parameter is - multiplied by the `damping_adaptation_decay` every - `damping_adaptation_interval` number of iterations. Default value 0.99. - damping_adaptation_interval: `int`(Optional), Number of steps in between - updating the `damping` parameter. Default value 5. - - Raises: - ValueError: If `set_damping_adaptation_params` is already called and the - the `adapt_damping` is `True`. - """ - if self._adapt_damping: - raise ValueError("Damping adaptation parameters already set.") - - with variable_scope.variable_scope(self.get_name()): - self._adapt_damping = True - self._is_chief = is_chief - self._prev_train_batch = prev_train_batch - self._loss_fn = loss_fn - self._damping_adaptation_decay = damping_adaptation_decay - self._damping_adaptation_interval = damping_adaptation_interval - self._omega = ( - self._damping_adaptation_decay**self._damping_adaptation_interval) - self._min_damping = min_damping - - self._rho = variable_scope.get_variable( - "rho", shape=(), dtype=dtypes.float32, trainable=False) # LM ratio. - self._prev_loss = variable_scope.get_variable( - "prev_loss", shape=(), dtype=dtypes.float32, trainable=False) - self._q_model_change = variable_scope.get_variable( - "q_model_change", shape=(), dtype=dtypes.float32, trainable=False) - self._damping = variable_scope.get_variable( - "damping", initializer=self._damping_constant, trainable=False) - - @property - def variables(self): - return self._fisher_est.variables - - @property - def damping(self): - if self._damping: - return self._damping - else: - return self._damping_constant - - @property - def damping_adaptation_interval(self): - return self._damping_adaptation_interval - - def make_vars_and_create_op_thunks(self): - """Make vars and create op thunks. - - Returns: - cov_update_thunks: List of cov update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - inv_update_thunks: List of inv update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - """ - scope = self.get_name() + "/" + self._fisher_est.name - return self._fisher_est.make_vars_and_create_op_thunks(scope=scope) - - def create_ops_and_vars_thunks(self): - """Create thunks that make the ops and vars on demand. - - This function returns 4 lists of thunks: cov_variable_thunks, - cov_update_thunks, inv_variable_thunks, and inv_update_thunks. - - The length of each list is the number of factors and the i-th element of - each list corresponds to the i-th factor (given by the "factors" property). - - Note that the execution of these thunks must happen in a certain - partial order. The i-th element of cov_variable_thunks must execute - before the i-th element of cov_update_thunks (and also the i-th element - of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks - must execute before the i-th element of inv_update_thunks. - - TL;DR (oversimplified): Execute the thunks according to the order that - they are returned. - - Returns: - cov_variable_thunks: A list of thunks that make the cov variables. - cov_update_thunks: A list of thunks that make the cov update ops. - inv_variable_thunks: A list of thunks that make the inv variables. - inv_update_thunks: A list of thunks that make the inv update ops. - """ - scope = self.get_name() + "/" + self._fisher_est.name - return self._fisher_est.create_ops_and_vars_thunks(scope=scope) - - def minimize(self, *args, **kwargs): - # Should this variable scope encompass everything below? Or will the super- - # class make another copy of the same name scope? - with variable_scope.variable_scope(self.get_name()): - kwargs["var_list"] = kwargs.get("var_list") or self.variables - if set(kwargs["var_list"]) != set(self.variables): - raise ValueError("var_list doesn't match with set of Fisher-estimating " - "variables.") - if self._adapt_damping and self._is_chief: - global_step = kwargs.get("global_step", None) - if not global_step: - raise KeyError("global_step needs to be passed to optimizer.minimize " - "if damping parameter is adapted.") - update_damping_op = self._update_damping(self._prev_train_batch, - global_step) - with ops.control_dependencies([update_damping_op]): - loss = args[0] - loss_assign_op = state_ops.assign(self._prev_loss, loss) - train_op = super(KfacOptimizer, self).minimize(*args, **kwargs) - return control_flow_ops.group(loss_assign_op, train_op) - else: - return super(KfacOptimizer, self).minimize(*args, **kwargs) - - def compute_gradients(self, *args, **kwargs): - # args[1] could be our var_list - if len(args) > 1: - var_list = args[1] - else: - kwargs["var_list"] = kwargs.get("var_list") or self.variables - var_list = kwargs["var_list"] - - if set(var_list) != set(self.variables): - raise ValueError("var_list doesn't match with set of Fisher-estimating " - "variables.") - return super(KfacOptimizer, self).compute_gradients(*args, **kwargs) - - def apply_gradients(self, grads_and_vars, *args, **kwargs): - """Applies gradients to variables. - - Args: - grads_and_vars: List of (gradient, variable) pairs. - *args: Additional arguments for super.apply_gradients. - **kwargs: Additional keyword arguments for super.apply_gradients. - - Returns: - An `Operation` that applies the specified gradients. - """ - # In Python 3, grads_and_vars can be a zip() object which can only be - # iterated over once. By converting it to a list, we ensure that it can be - # iterated over more than once. - grads_and_vars = list(grads_and_vars) - - # Compute step. - steps_and_vars = self._compute_update_steps(grads_and_vars) - - # Update trainable variables with this step. - return super(KfacOptimizer, self).apply_gradients(steps_and_vars, *args, - **kwargs) - - def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars): - """Computes the squared (approximate) Fisher norm of the updates. - - This is defined as v^T F v, where F is the approximate Fisher matrix - as computed by the estimator, and v = F^{-1} g, where g is the gradient. - This is computed efficiently as v^T g. - - Args: - grads_and_vars: List of (gradient, variable) pairs. - precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. - Must be the result of calling `self._fisher_est.multiply_inverse` - on `grads_and_vars`. - - Returns: - Scalar representing the squared norm. - - Raises: - ValueError: if the two list arguments do not contain the same variables, - in the same order. - """ - for (_, gvar), (_, pgvar) in zip(grads_and_vars, precon_grads_and_vars): - if gvar is not pgvar: - raise ValueError("The variables referenced by the two arguments " - "must match.") - terms = [ - math_ops.reduce_sum(grad * pgrad) - for (grad, _), (pgrad, _) in zip(grads_and_vars, precon_grads_and_vars) - ] - return math_ops.reduce_sum(terms) - - def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars): - """Computes the scale factor for the update to satisfy the norm constraint. - - Defined as min(1, sqrt(c / r^T F r)), where c is the norm constraint, - F is the approximate Fisher matrix, and r is the update vector, i.e. - -alpha * v, where alpha is the learning rate, and v is the preconditioned - gradient. - - This is based on Section 5 of Ba et al., Distributed Second-Order - Optimization using Kronecker-Factored Approximations. Note that they - absorb the learning rate alpha (which they denote eta_max) into the formula - for the coefficient, while in our implementation, the rescaling is done - before multiplying by alpha. Hence, our formula differs from theirs by a - factor of alpha. - - Args: - grads_and_vars: List of (gradient, variable) pairs. - precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. - Must be the result of calling `self._fisher_est.multiply_inverse` - on `grads_and_vars`. - - Returns: - Scalar representing the coefficient which should be applied to the - preconditioned gradients to satisfy the norm constraint. - """ - sq_norm_grad = self._squared_fisher_norm(grads_and_vars, - precon_grads_and_vars) - sq_norm_up = sq_norm_grad * self._learning_rate**2 - return math_ops.minimum(1., - math_ops.sqrt(self._norm_constraint / sq_norm_up)) - - def _clip_updates(self, grads_and_vars, precon_grads_and_vars): - """Rescales the preconditioned gradients to satisfy the norm constraint. - - Rescales the preconditioned gradients such that the resulting update r - (after multiplying by the learning rate) will satisfy the norm constraint. - This constraint is that r^T F r <= C, where F is the approximate Fisher - matrix, and C is the norm_constraint attribute. See Section 5 of - Ba et al., Distributed Second-Order Optimization using Kronecker-Factored - Approximations. - - Args: - grads_and_vars: List of (gradient, variable) pairs. - precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. - Must be the result of calling `self._fisher_est.multiply_inverse` - on `grads_and_vars`. - - Returns: - List of (rescaled preconditioned gradient, variable) pairs. - """ - coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars) - return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars] - - def _compute_prev_updates(self, variables): - """Computes previous updates as negative velocities scaled by learning rate. - - Args: - variables: List of variables in the graph that the update will be - applied to. - - Returns: - List of previous updates applied to the `variables`. - """ - return list( - -1 * self._learning_rate * self._zeros_slot(var, "velocity", self._name) - for var in variables) - - def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads, - variables): - """Compute optimal update hyperparameters from the quadratic model. - - More specifically, if L is the loss we minimize a quadratic approximation - of L(theta + d) which we denote by qmodel(d) with - d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where - - qmodel(d) = (1/2) * d^T * B * d + grad^T*d + L(theta) . - - Unlike in the KL clipping approach we use the non-approximated quadratic - model where the curvature matrix C is the true Fisher on the current - mini-batch (computed without any approximations beyond mini-batch sampling), - with the usual Tikhonov damping/regularization applied, - - C = F + damping * I - - See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of - the formula. See Appendix C for a discussion of the trick of using - a factorized Fisher matrix to more efficiently compute the required - vector-matrix-vector products. - - Note that the elements of all 4 lists passed to this function must - be in correspondence with each other. - - Args: - precon_grads: List of preconditioned gradients. - prev_updates: List of updates computed at the previous iteration. - grads: List of gradients. - variables: List of variables in the graph that the update will be - applied to. (Note that this function doesn't actually apply the - update.) - - Returns: - (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the - quadratic model, and - qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0) - = qmodel(alpha*precon_grad + mu*prev_update) - L(theta). - """ - - cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses, - variables) - - # compute the matrix-vector products with the transposed Fisher factor - fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads) - fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates) - batch_size = math_ops.cast( - self._batch_size, dtype=fft_precon_grads[0].dtype) - - # compute the entries of the 2x2 matrix - m_11 = ( - _inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size + - self.damping * _inner_product_list(precon_grads, precon_grads)) - - m_21 = ( - _inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size + - self.damping * _inner_product_list(prev_updates, precon_grads)) - - m_22 = ( - _inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size + - self.damping * _inner_product_list(prev_updates, prev_updates)) - - def non_zero_prevupd_case(): - r"""Computes optimal (alpha, mu) given non-zero previous update. - - We solve the full 2x2 linear system. See Martens & Grosse (2015), - Section 7, definition of $\alpha^*$ and $\mu^*$. - - Returns: - (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize - the quadratic model, and - qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0). - """ - m = ops.convert_to_tensor([[m_11, m_21], [m_21, m_22]]) - - c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)], - [_inner_product_list(grads, prev_updates)]]) - - sol = -1. * _two_by_two_solve(m, c) - alpha = sol[0] - mu = sol[1] - qmodel_change = 0.5 * math_ops.reduce_sum(sol * c) - - return alpha, mu, qmodel_change - - def zero_prevupd_case(): - r"""Computes optimal (alpha, mu) given all-zero previous update. - - The linear system reduces to 1x1. See Martens & Grosse (2015), - Section 6.4, definition of $\alpha^*$. - - Returns: - (alpha, 0.0, qmodel_change), where alpha is chosen to optimize the - quadratic model, and - qmodel_change = qmodel(alpha*precon_grad) - qmodel(0) - """ - m = m_11 - c = _inner_product_list(grads, precon_grads) - - alpha = -c / m - mu = 0.0 - qmodel_change = 0.5 * alpha * c - - return alpha, mu, qmodel_change - - return control_flow_ops.cond( - math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case) - - def _assign_q_model_change(self, q_model_change): - """Assigns `q_model_change` to `self._q_model_change` if damping is adapted. - - Note only the chief worker does the assignment. - - Args: - q_model_change: Scalar tensor of type `float32`. - - Returns: - If `adapt_damping` is `True` then returns an assign op, Otherwise returns - a no_op(). - """ - if self._adapt_damping and self._is_chief: - q_model_assign_op = state_ops.assign(self._q_model_change, q_model_change) - else: - q_model_assign_op = control_flow_ops.no_op() - return q_model_assign_op - - def _compute_qmodel_hyperparams_wrapper(self, grads_and_vars, - precon_grads_and_vars): - """Wrapper function for `self._compute_qmodel_hyperparams`. - - Constructs a list of preconditioned gradients and variables. Also creates a - op to assign the computed q model change to `self._q_model_change`. - - Args: - grads_and_vars: List of (gradient, variable) pairs. - precon_grads_and_vars: List of (preconditioned gradients, variable) - pairs. - - Returns: - (alpha, mu, q_model_assign_op), where alpha and mu are chosen to optimize - the quadratic model, `q_model_assign_op` assigns the computed q model - change to `self._q_model_change`. - """ - precon_grads = list( - precon_grad for (precon_grad, _) in precon_grads_and_vars) - grads = list(grad for (grad, _) in grads_and_vars) - variables = list(var for (_, var) in grads_and_vars) - prev_updates = self._compute_prev_updates(variables) - # Compute optimal velocity update parameters according to quadratic model - alpha, mu, q_model_change = self._compute_qmodel_hyperparams( - precon_grads, prev_updates, grads, variables) - - return alpha, mu, self._assign_q_model_change(q_model_change) - - def _compute_update_steps(self, grads_and_vars): - """Computes the update steps for the variables given the gradients. - - Args: - grads_and_vars: List of (gradient, variable) pairs. - - Returns: - A list of tuple (assign_op ,var) where `assign_op` assigns the update - steps to `var`. - """ - - if self._momentum_type == "regular": - # Compute "preconditioned" gradient. - precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) - - # Apply "KL clipping" if asked for. - if self._norm_constraint is not None: - precon_grads_and_vars = self._clip_updates(grads_and_vars, - precon_grads_and_vars) - - # Update the velocity with this and return it as the step. - if self._adapt_damping and self._is_chief: - _, _, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper( - grads_and_vars, precon_grads_and_vars) - with ops.control_dependencies([q_model_assign_op]): - return self._update_velocities(precon_grads_and_vars, self._momentum) - else: - return self._update_velocities(precon_grads_and_vars, self._momentum) - elif self._momentum_type == "adam": - # Update velocity. - velocities_and_vars = self._update_velocities(grads_and_vars, - self._momentum) - # Return "preconditioned" velocity vector as the step. - return self._fisher_est.multiply_inverse(velocities_and_vars) - - elif self._momentum_type == "qmodel": - # Compute "preconditioned" gradient. - precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) - - # Compute optimal velocity update parameters according to quadratic model - alpha, mu, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper( - grads_and_vars, precon_grads_and_vars) - - with ops.control_dependencies([q_model_assign_op]): - return self._update_velocities( - precon_grads_and_vars, mu, vec_coeff=-alpha) - - def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0): - """Updates the velocities of the variables with the given vectors. - - Args: - vecs_and_vars: List of (vector, variable) pairs. - decay: How much to decay the old velocity by. This is often referred to - as the 'momentum constant'. - vec_coeff: Coefficient to apply to the vectors before adding them to the - velocity. - - Returns: - A list of (velocity, var) indicating the new velocity for each var. - """ - - def _update_velocity(vec, var): - velocity = self._zeros_slot(var, "velocity", self._name) - with ops.colocate_with(velocity): - # NOTE(mattjj): read/modify/write race condition not suitable for async. - - # Compute the new velocity for this variable. - new_velocity = decay * velocity + vec_coeff * vec - - # Save the updated velocity. - return (array_ops.identity(velocity.assign(new_velocity)), var) - - # Go through variable and update its associated part of the velocity vector. - return [_update_velocity(vec, var) for vec, var in vecs_and_vars] - - def _update_damping(self, prev_batch, global_step): - """Adapts damping parameter. Check KFAC (Section 6.5) for the details. - - The damping parameter is updated according to the Levenberg-Marquardt rule - every `self._damping_adaptation_interval` iterations. - - Args: - prev_batch: Tensor or tuple of tensors which can be passed to - `self._loss_fn` to evaluate loss. - global_step: `Variable` which keeps track of number of times the training - variables have been updated. - Returns: - A `tf.cond` op which updates the damping parameter. - """ - def compute_damping(): - """"Adapts damping parameter based on "reduction ratio". - - Reduction ratio captures how closely the quadratic approximation to the - loss function approximates the actual loss within a trust region. The - damping update tries to make the damping as small as possible while - maintaining the property that the quadratic model remains a good local - approximation to the loss function. - - Returns: - An Op to assign newly computed damping value to `self._damping`. - """ - prev_batch_loss = self._loss_fn(prev_batch) - with ops.control_dependencies([prev_batch_loss]): - rho_assign = self._rho.assign( - (prev_batch_loss - self._prev_loss) / self._q_model_change) - with ops.control_dependencies([rho_assign]): - new_damping = control_flow_ops.case( - [(self._rho < 0.25, lambda: self.damping / self._omega), - (self._rho > 0.75, lambda: self.damping * self._omega)], - lambda: self.damping) - with ops.control_dependencies([new_damping]): - new_damping_min = math_ops.maximum(new_damping, self._min_damping) - return control_flow_ops.group(self._damping.assign(new_damping_min)) - - return control_flow_ops.cond( - math_ops.equal( - math_ops.mod(global_step + 1, self._damping_adaptation_interval), - 0), compute_damping, control_flow_ops.no_op) - - -def _inner_product_list(list1, list2): - return math_ops.add_n( - [math_ops.reduce_sum(elt1 * elt2) for elt1, elt2 in zip(list1, list2)]) - - -def _two_by_two_solve(m, c): - # it might be better just to crank out the exact formula for 2x2 inverses - return math_ops.matmul(linalg_ops.matrix_inverse(m), c) diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py deleted file mode 100644 index c4454325aebe131058282ff15c2734bf10d1cc49..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/placement.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Implements placement strategies for cov and inv ops, cov variables.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import itertools - -from tensorflow.python.framework import ops as tf_ops - - -def _make_thunk_on_device(func, device): - def thunk(): - with tf_ops.device(device): - return func() - return thunk - - -class RoundRobinPlacementMixin(object): - """Implements round robin placement strategy for ops and variables.""" - - def __init__(self, cov_devices=None, inv_devices=None, **kwargs): - """Initializes the RoundRobinPlacementMixin class. - - Args: - cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion - computations will be placed on these devices in a round-robin fashion. - Can be None, which means that no devices are specified. - **kwargs: Need something here? - - """ - super(RoundRobinPlacementMixin, self).__init__(**kwargs) - self._cov_devices = cov_devices - self._inv_devices = inv_devices - - def make_vars_and_create_op_thunks(self, scope=None): - """Make vars and create op thunks w/ a round-robin device placement start. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. A new device is chosen - for each factor by cycling through list of devices in the - `self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no - explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the `self._inv_devices` attribute. - - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All variables will be created, - and all thunks will execute, inside of a variable scope of the given - name. (Default: None) - - Returns: - cov_update_thunks: List of cov update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - inv_update_thunks: List of inv update thunks. Corresponds one-to-one with - the list of factors given by the "factors" property. - """ - # Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`. - (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw, - inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope) - - if self._cov_devices: - cov_update_thunks = [] - for cov_variable_thunk, cov_update_thunk, device in zip( - cov_variable_thunks_raw, cov_update_thunks_raw, - itertools.cycle(self._cov_devices)): - with tf_ops.device(device): - cov_variable_thunk() - cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk, - device)) - else: - for cov_variable_thunk in cov_variable_thunks_raw: - cov_variable_thunk() - cov_update_thunks = cov_update_thunks_raw - - for inv_variable_thunk in inv_variable_thunks_raw: - inv_variable_thunk() - - if self._inv_devices: - inv_update_thunks = [] - for inv_update_thunk, device in zip(inv_update_thunks_raw, - itertools.cycle(self._inv_devices)): - inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk, - device)) - else: - inv_update_thunks = inv_update_thunks_raw - - return cov_update_thunks, inv_update_thunks diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py deleted file mode 100644 index 144295f4c7e36f61b4bae4178a6f57f6657204c5..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ /dev/null @@ -1,709 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility functions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.tpu.python.ops import tpu_ops -from tensorflow.contrib.tpu.python.tpu import tpu_function -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variables - -# Method used for inverting matrices. -POSDEF_INV_METHOD = "cholesky" -POSDEF_EIG_METHOD = "self_adjoint" - - -def set_global_constants(posdef_inv_method=None): - """Sets various global constants used by the classes in this module.""" - global POSDEF_INV_METHOD - - if posdef_inv_method is not None: - POSDEF_INV_METHOD = posdef_inv_method - - -class SequenceDict(object): - """A dict convenience wrapper that allows getting/setting with sequences.""" - - def __init__(self, iterable=None): - self._dict = dict(iterable or []) - - def __getitem__(self, key_or_keys): - if isinstance(key_or_keys, (tuple, list)): - return list(map(self.__getitem__, key_or_keys)) - else: - return self._dict[key_or_keys] - - def __setitem__(self, key_or_keys, val_or_vals): - if isinstance(key_or_keys, (tuple, list)): - for key, value in zip(key_or_keys, val_or_vals): - self[key] = value - else: - self._dict[key_or_keys] = val_or_vals - - def items(self): - return list(self._dict.items()) - - -def tensors_to_column(tensors): - """Converts a tensor or list of tensors to a column vector. - - Args: - tensors: A tensor or list of tensors. - - Returns: - The tensors reshaped into vectors and stacked on top of each other. - """ - if isinstance(tensors, (tuple, list)): - return array_ops.concat( - tuple(array_ops.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0) - else: - return array_ops.reshape(tensors, [-1, 1]) - - -def column_to_tensors(tensors_template, colvec): - """Converts a column vector back to the shape of the given template. - - Args: - tensors_template: A tensor or list of tensors. - colvec: A 2d column vector with the same shape as the value of - tensors_to_column(tensors_template). - - Returns: - X, where X is tensor or list of tensors with the properties: - 1) tensors_to_column(X) = colvec - 2) X (or its elements) have the same shape as tensors_template (or its - elements) - """ - if isinstance(tensors_template, (tuple, list)): - offset = 0 - tensors = [] - for tensor_template in tensors_template: - sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32) - tensor = array_ops.reshape(colvec[offset:(offset + sz)], - tensor_template.shape) - tensors.append(tensor) - offset += sz - - tensors = tuple(tensors) - else: - tensors = array_ops.reshape(colvec, tensors_template.shape) - - return tensors - - -def kronecker_product(mat1, mat2): - """Computes the Kronecker product two matrices.""" - m1, n1 = mat1.get_shape().as_list() - mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1]) - m2, n2 = mat2.get_shape().as_list() - mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2]) - return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2]) - - -def layer_params_to_mat2d(vector): - """Converts a vector shaped like layer parameters to a 2D matrix. - - In particular, we reshape the weights/filter component of the vector to be - 2D, flattening all leading (input) dimensions. If there is a bias component, - we concatenate it to the reshaped weights/filter component. - - Args: - vector: A Tensor or pair of Tensors shaped like layer parameters. - - Returns: - A 2D Tensor with the same coefficients and the same output dimension. - """ - if isinstance(vector, (tuple, list)): - w_part, b_part = vector - w_part_reshaped = array_ops.reshape(w_part, - [-1, w_part.shape.as_list()[-1]]) - return array_ops.concat( - (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0) - elif isinstance(vector, ops.IndexedSlices): - return vector - else: # Tensor or Tensor-like. - return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]]) - - -def mat2d_to_layer_params(vector_template, mat2d): - """Converts a canonical 2D matrix representation back to a vector. - - Args: - vector_template: A Tensor or pair of Tensors shaped like layer parameters. - mat2d: A 2D Tensor with the same shape as the value of - layer_params_to_mat2d(vector_template). - - Returns: - A Tensor or pair of Tensors with the same coefficients as mat2d and the same - shape as vector_template. - """ - if isinstance(vector_template, (tuple, list)): - w_part, b_part = mat2d[:-1], mat2d[-1] - return array_ops.reshape(w_part, vector_template[0].shape), b_part - elif isinstance(vector_template, ops.IndexedSlices): - if not isinstance(mat2d, ops.IndexedSlices): - raise TypeError( - "If vector_template is an IndexedSlices, so should mat2d.") - return mat2d - else: - return array_ops.reshape(mat2d, vector_template.shape) - - -def posdef_inv(tensor, damping): - """Computes the inverse of tensor + damping * identity.""" - identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) - damping = math_ops.cast(damping, dtype=tensor.dtype) - return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping) - - -def posdef_inv_matrix_inverse(tensor, identity, damping): - """Computes inverse(tensor + damping * identity) directly.""" - return linalg_ops.matrix_inverse(tensor + damping * identity) - - -def posdef_inv_cholesky(tensor, identity, damping): - """Computes inverse(tensor + damping * identity) with Cholesky.""" - chol = linalg_ops.cholesky(tensor + damping * identity) - return linalg_ops.cholesky_solve(chol, identity) - - -def posdef_inv_eig(tensor, identity, damping): - """Computes inverse(tensor + damping * identity) with eigendecomposition.""" - eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig( - tensor + damping * identity) - return math_ops.matmul( - eigenvectors / eigenvalues, eigenvectors, transpose_b=True) - - -posdef_inv_functions = { - "matrix_inverse": posdef_inv_matrix_inverse, - "cholesky": posdef_inv_cholesky, - "eig": posdef_inv_eig, -} - - -def posdef_eig(mat): - """Computes the eigendecomposition of a positive semidefinite matrix.""" - return posdef_eig_functions[POSDEF_EIG_METHOD](mat) - - -def posdef_eig_svd(mat): - """Computes the singular values and left singular vectors of a matrix.""" - evals, evecs, _ = linalg_ops.svd(mat) - - return evals, evecs - - -def posdef_eig_self_adjoint(mat): - """Computes eigendecomposition using self_adjoint_eig.""" - evals, evecs = linalg_ops.self_adjoint_eig(mat) - evals = math_ops.abs(evals) # Should be equivalent to svd approach. - - return evals, evecs - - -posdef_eig_functions = { - "self_adjoint": posdef_eig_self_adjoint, - "svd": posdef_eig_svd, -} - - -def cholesky(tensor, damping): - """Computes the inverse of tensor + damping * identity.""" - identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) - damping = math_ops.cast(damping, dtype=tensor.dtype) - return linalg_ops.cholesky(tensor + damping * identity) - - -class SubGraph(object): - """Defines a subgraph given by all the dependencies of a given set of outputs. - """ - - def __init__(self, outputs): - # Set of all ancestor Tensors, Ops to 'outputs'. - self._members = set() - - self._iter_add(outputs) - - def _iter_add(self, root): - """Iteratively adds all of nodes' ancestors using depth first search.""" - stack = [root] - while stack: - nodes = stack.pop() - for node in nodes: - if node in self._members: - continue - self._members.add(node) - - if isinstance(node, ops.Tensor): - stack.append((node.op,)) - elif isinstance(node, ops.Operation): - stack.append(node.inputs) - - def is_member(self, node): - """Check if 'node' is in this subgraph.""" - return node in self._members - - def variable_uses(self, var): - """Computes number of times a variable is used. - - Args: - var: Variable or ResourceVariable instance. - - Returns: - Number of times a variable is used within this subgraph. - - Raises: - ValueError: If 'var' is not a variable type. - """ - if isinstance(var, resource_variable_ops.ResourceVariable): - var = var.handle - elif isinstance(var, variables.Variable): - var = var.value() - else: - raise ValueError("%s does not appear to be a variable." % str(var)) - - return len(self._members.intersection(set(var.consumers()))) - - def filter_list(self, node_list): - """Filters 'node_list' to nodes in this subgraph.""" - filtered_list = [] - for node in node_list: - if self.is_member(node): - filtered_list.append(node) - return filtered_list - - -def generate_random_signs(shape, dtype=dtypes.float32): - """Generate a random tensor with {-1, +1} entries.""" - ints = random_ops.random_uniform(shape, maxval=2, dtype=dtypes.int32) - return 2 * math_ops.cast(ints, dtype=dtype) - 1 - - -def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): - """Compute forward-mode gradients.""" - # See b/37888268. - - # This version of forward-mode autodiff is based on code by Tim Cooijmans - # and handles list arguments and certain special cases such as when the - # ys doesn't depend on one or more of the xs, and when ops.IndexedSlices are - # generated by the first gradients_impl.gradients call. - - us = [array_ops.zeros_like(y) + float("nan") for y in ys] - dydxs = gradients_impl.gradients( - ys, xs, grad_ys=us, stop_gradients=stop_gradients) - - # Deal with strange types that gradients_impl.gradients returns but can't - # deal with. - dydxs = [ - ops.convert_to_tensor(dydx) - if isinstance(dydx, ops.IndexedSlices) else dydx for dydx in dydxs - ] - dydxs = [ - array_ops.zeros_like(x) if dydx is None else dydx - for x, dydx in zip(xs, dydxs) - ] - - dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs) - - return dysdx - - -def on_tpu(): - """Returns True when building a TPU computation.""" - return tpu_function.get_tpu_context().number_of_shards is not None - - -def cross_replica_mean(tensor, name=None): - """Takes mean value of a Tensor across all TPU cores. - - Args: - tensor: Tensor to be synchronized. - name: None or string. Name of Op. - - Returns: - Average of Tensor across all TPU cores. - - Raises: - ValueError: If called outside of TPU context. - """ - with ops.name_scope(name, "cross_replica_mean", [tensor]): - num_shards = tpu_function.get_tpu_context().number_of_shards - if num_shards is None: - raise ValueError( - "Cannot take cross_replica_mean() outside of TPU Context.") - if num_shards == 1: - return tensor - return tpu_ops.cross_replica_sum(tensor / num_shards) - - -def ensure_sequence(obj): - """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" - if isinstance(obj, (tuple, list)): - return obj - else: - return (obj,) - - -def batch_execute(global_step, thunks, batch_size, name=None): - """Executes a subset of ops per global step. - - Given a list of thunks, each of which produces a single stateful op, - ensures that exactly 'batch_size' ops are run per global step. Ops are - scheduled in a round-robin fashion. For example, with 3 ops - - global_step | op0 | op1 | op2 - ------------+-----+-----+----- - 0 | x | x | - ------------+-----+-----+----- - 1 | x | | x - ------------+-----+-----+----- - 2 | | x | x - ------------+-----+-----+----- - 3 | x | x | - ------------+-----+-----+----- - 4 | x | | x - - Does not guarantee order of op execution within a single global step. - - Args: - global_step: Tensor indicating time. Determines which ops run. - thunks: List of thunks. Each thunk encapsulates one op. Return values are - ignored. - batch_size: int. Number of ops to execute per global_step. - name: string or None. Name scope for newly added ops. - - Returns: - List of ops. Exactly 'batch_size' ops are guaranteed to have an effect - every global step. - """ - - def true_fn(thunk): - """Ensures thunk is executed and returns an Op (not a Tensor).""" - - def result(): - with ops.control_dependencies([thunk()]): - return control_flow_ops.no_op() - - return result - - def false_fn(_): - """Executes a no-op.""" - - def result(): - return control_flow_ops.no_op() - - return result - - with ops.name_scope(name, "batch_execute"): - true_fns = [true_fn(thunk) for thunk in thunks] - false_fns = [false_fn(thunk) for thunk in thunks] - num_thunks = len(thunks) - conditions = [ - math_ops.less( - math_ops.mod(batch_size - 1 + global_step * batch_size - j, - num_thunks), batch_size) for j in range(num_thunks) - ] - result = [ - control_flow_ops.cond(condition, true_fn, false_fn) - for (condition, true_fn, - false_fn) in zip(conditions, true_fns, false_fns) - ] - return result - - -def extract_convolution_patches(inputs, - filter_shape, - padding, - strides=None, - dilation_rate=None, - name=None, - data_format=None): - """Extracts inputs to each output coordinate in tf.nn.convolution. - - This is a generalization of tf.extract_image_patches() to tf.nn.convolution(), - where the number of spatial dimensions may be something other than 2. - - Assumes, - - First dimension of inputs is batch_size - - Convolution filter is applied to all input channels. - - Args: - inputs: Tensor of shape [batch_size, ..spatial_image_shape.., - ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution(). - filter_shape: List of ints. Shape of filter passed to tf.nn.convolution(). - padding: string. Padding method. One of "VALID", "SAME". - strides: None or list of ints. Strides along spatial dimensions. - dilation_rate: None or list of ints. Dilation along spatial dimensions. - name: None or str. Name of Op. - data_format: None or str. Format of data. - - Returns: - Tensor of shape [batch_size, ..spatial_image_shape.., - ..spatial_filter_shape.., in_channels] - - Raises: - ValueError: If data_format does not put channel last. - ValueError: If inputs and filter disagree on in_channels. - """ - if not is_data_format_channel_last(data_format): - raise ValueError("Channel must be last dimension.") - with ops.name_scope(name, "extract_convolution_patches", - [inputs, filter_shape, padding, strides, dilation_rate]): - batch_size = inputs.shape.as_list()[0] - in_channels = inputs.shape.as_list()[-1] - - # filter_shape = spatial_filter_shape + [in_channels, out_channels] - spatial_filter_shape = filter_shape[:-2] - if in_channels != filter_shape[-2]: - raise ValueError("inputs and filter_shape must agree on in_channels.") - - # Map each input feature to a location in the output. - out_channels = np.prod(spatial_filter_shape) * in_channels - filters = linalg_ops.eye(out_channels) - filters = array_ops.reshape( - filters, - list(spatial_filter_shape) + [in_channels, out_channels]) - - result = nn_ops.convolution( - inputs, - filters, - padding=padding, - strides=strides, - dilation_rate=dilation_rate) - spatial_output_shape = result.shape.as_list()[1:-1] - result = array_ops.reshape(result, - [batch_size or -1] + spatial_output_shape + - list(spatial_filter_shape) + [in_channels]) - - return result - - -def extract_pointwise_conv2d_patches(inputs, - filter_shape, - name=None, - data_format=None): - """Extract patches for a 1x1 conv2d. - - Args: - inputs: 4-D Tensor of shape [batch_size, height, width, in_channels]. - filter_shape: List of 4 ints. Shape of filter to apply with conv2d() - name: None or str. Name for Op. - data_format: None or str. Format for data. See 'data_format' in - tf.nn.conv2d() for details. - - Returns: - Tensor of shape [batch_size, ..spatial_input_shape.., - ..spatial_filter_shape.., in_channels] - - Raises: - ValueError: if inputs is not 4-D. - ValueError: if filter_shape is not [1, 1, ?, ?] - ValueError: if data_format is not channels-last. - """ - if inputs.shape.ndims != 4: - raise ValueError("inputs must have 4 dims.") - if len(filter_shape) != 4: - raise ValueError("filter_shape must have 4 dims.") - if filter_shape[0] != 1 or filter_shape[1] != 1: - raise ValueError("filter_shape must have shape 1 along spatial dimensions.") - if not is_data_format_channel_last(data_format): - raise ValueError("data_format must be channels last.") - with ops.name_scope(name, "extract_pointwise_conv2d_patches", - [inputs, filter_shape]): - ksizes = [1, 1, 1, 1] # Spatial shape is 1x1. - strides = [1, 1, 1, 1] # Operate on all pixels. - rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1. - padding = "VALID" # Doesn't matter. - result = array_ops.extract_image_patches(inputs, ksizes, strides, rates, - padding) - - batch_size, input_height, input_width, in_channels = inputs.shape.as_list() - filter_height, filter_width, in_channels, _ = filter_shape - return array_ops.reshape(result, [ - batch_size, input_height, input_width, filter_height, filter_width, - in_channels - ]) - - -def is_data_format_channel_last(data_format): - """True if data_format puts channel last.""" - if data_format is None: - return True - return data_format.endswith("C") - - -def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name - """Computes matmul(A, B) where A is sparse, B is dense. - - Args: - A: tf.IndexedSlices with dense shape [m, n]. - B: tf.Tensor with shape [n, k]. - name: str. Name of op. - transpose_a: Bool. If true we transpose A before multiplying it by B. - (Default: False) - transpose_b: Bool. If true we transpose B before multiplying it by A. - (Default: False) - - Returns: - tf.IndexedSlices resulting from matmul(A, B). - - Raises: - ValueError: If A doesn't represent a matrix. - ValueError: If B is not rank-2. - """ - with ops.name_scope(name, "matmul_sparse_dense", [A, B]): - if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2: - raise ValueError("A must represent a matrix. Found: %s." % A) - if B.shape.ndims != 2: - raise ValueError("B must be a matrix.") - new_values = math_ops.matmul( - A.values, B, transpose_a=transpose_a, transpose_b=transpose_b) - return ops.IndexedSlices( - new_values, - A.indices, - dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]])) - - -def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name - """Computes matmul(A, B) where A is a diagonal matrix, B is sparse. - - Args: - A_diag: diagonal entries of matrix A of shape [m, m]. - B: tf.IndexedSlices. Represents matrix of shape [m, n]. - name: str. Name of op. - - Returns: - tf.IndexedSlices resulting from matmul(A, B). - - Raises: - ValueError: If A_diag is not rank-1. - ValueError: If B doesn't represent a matrix. - """ - with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]): - A_diag = ops.convert_to_tensor(A_diag) - if A_diag.shape.ndims != 1: - raise ValueError("A_diag must be a rank-1 Tensor.") - if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2: - raise ValueError("B must represent a matrix. Found: %s." % B) - a = array_ops.gather(A_diag, B.indices) - a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1)) - return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape) - - -class PartitionedTensor(object): - """A Tensor partitioned across its 0-th dimension.""" - - def __init__(self, tensors): - """Initializes PartitionedTensor. - - Args: - tensors: List of Tensors. All Tensors must agree on shape (excepting - batch dimension) and dtype. - - Raises: - ValueError: If 'tensors' has length zero. - ValueError: if contents of 'tensors' don't agree on shape or dtype. - """ - if not tensors: - raise ValueError("tensors must be a list of 1+ Tensors.") - - dtype = tensors[0].dtype - if not all(tensor.dtype == dtype for tensor in tensors): - raise ValueError("all tensors must have dtype = %s." % dtype) - - shape = tensors[0].shape[1:] - if not all(tensor.shape[1:] == shape for tensor in tensors): - raise ValueError("All tensors must have shape = %s (excluding batch " - "dimension)." % shape) - - self.tensors = tensors - self._concats = {} # {device: Tensor} - - @property - def shape(self): - feature_shape = self.tensors[0].shape[1:] - batch_size = sum([tensor.shape[0] for tensor in self.tensors], - tensor_shape.Dimension(0)) - return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape) - - def get_shape(self): - return self.shape - - @property - def dtype(self): - return self.tensors[0].dtype - - def __str__(self): - return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % ( - self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list())) - - def __hash__(self): - return hash(tuple(self.tensors)) - - def __eq__(self, other): - if not isinstance(other, PartitionedTensor): - return False - return self.tensors == other.tensors - - def __ne__(self, other): - return not self == other # pylint: disable=g-comparison-negation - - def __getitem__(self, key): - return self.as_tensor()[key] - - def as_tensor(self, dtype=None, name=None, as_ref=False): - with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors): - assert not as_ref - assert dtype in [None, self.dtype] - result = array_ops.concat(self.tensors, axis=0) - - # Cache 'result' if we haven't already cached a value for this device. - if result.device not in self._concats: - self._concats[result.device] = result - return self._concats[result.device] - - @property - def device(self): - # PartitionedTensors in general do not live on a single device. If the - # device cannot be determined unambiguously this property will return None. - device = self.tensors[0].device - if all(tensor.device == device for tensor in self.tensors): - return device - return None - - -ops.register_tensor_conversion_function( - PartitionedTensor, - lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref)) - - -# TODO(b/69623235): Add a function for finding tensors that share gradients -# to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py deleted file mode 100644 index 330d222dbf70fcfa02ffd47261c0513d9dd6e0e9..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility functions.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,line-too-long,wildcard-import -from tensorflow.contrib.kfac.python.ops.utils import * -from tensorflow.python.util.all_util import remove_undocumented -# pylint: enable=unused-import,line-too-long,wildcard-import - -_allowed_symbols = [ - "set_global_constants", - "SequenceDict", - "tensors_to_column", - "column_to_tensors", - "kronecker_product", - "layer_params_to_mat2d", - "mat2d_to_layer_params", - "posdef_inv", - "posdef_inv_matrix_inverse", - "posdef_inv_cholesky", - "posdef_inv_funcs", - "SubGraph", - "generate_random_signs", - "fwd_gradients", - "ensure_sequence", - "batch_execute", - "extract_convolution_patches", - "extract_pointwise_conv2d_patches", - "is_data_format_channel_last", - "matmul_sparse_dense", - "matmul_diag_sparse", -] - -remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 45a0ded7eb6e7376dfc516e5cb147d18c5448004..fc199f0a0e835c6ab3c03b1e06956bbbaafdb02a 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -235,6 +235,7 @@ def generated_test_models(): "exp", "expand_dims", "floor", + "floor_div", "fully_connected", "fused_batch_norm", "gather", @@ -266,6 +267,7 @@ def generated_test_models(): "padv2", "prelu", "pow", + "reduce_any", "reduce_max", "reduce_min", "reduce_prod", @@ -293,6 +295,7 @@ def generated_test_models(): "topk", "transpose", #"transpose_conv", # disabled due to b/111213074 + "unpack", "where", ] diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 70178b2faabe85f8a53a94c2b5d2e3ea40c8ba05..e81f9e4f514b43233d153d386f9c647c70e6d5da 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -286,6 +286,11 @@ typedef struct { int axis; } TfLiteOneHotParams; +typedef struct { + int num; + int axis; +} TfLiteUnpackParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 706f64a84acdb552d8ed1f3e29c6360ab43c9c77..9cf4bea73edd2a03c63ae735057a8bb28cd81c93 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -115,6 +115,8 @@ typedef enum { kTfLiteBuiltinLogicalNot = 87, kTfLiteBuiltinUnpack = 88, kTfLiteBuiltinReduceMin = 89, + kTfLiteBuiltinFloorDiv = 90, + kTfLiteBuiltinReduceAny = 91, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD index 8abc8285783920f0c30735fd0259e5122b0240bb..88c70fbb8a6e9d4b00c3e21de2dc0f44c4cd4387 100644 --- a/tensorflow/contrib/lite/delegates/eager/BUILD +++ b/tensorflow/contrib/lite/delegates/eager/BUILD @@ -132,6 +132,7 @@ cc_library( ], "//conditions:default": [ "//tensorflow/core:protos_all_cc", + "//tensorflow/core:framework", ], }), ) diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc index 1082b78725986ba3e6f31607f526ea2df2f1fdfb..f8467c7cb2c1ef07fc6f3d1e3e4897a362ddcb92 100644 --- a/tensorflow/contrib/lite/delegates/eager/kernel.cc +++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/delegates/eager/kernel.h" -#include "flatbuffers/flexbuffers.h" +#include "flatbuffers/flexbuffers.h" // flatbuffers #include "tensorflow/contrib/lite/builtin_ops.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/context_util.h" @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" // Note: this is part of TF Lite's Eager delegation code which is to be // completed soon. @@ -189,6 +190,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { } } + // Fill NodeDef with defaults if it's a valid op. + const tensorflow::OpRegistrationData* op_reg_data; + auto tf_status = tensorflow::OpRegistry::Global()->LookUp( + node_data.nodedef.op(), &op_reg_data); + if (tf_status.ok()) { + AddDefaultsToNodeDef(op_reg_data->op_def, &node_data.nodedef); + } + for (auto input_index : TfLiteIntArrayView(node->inputs)) { node_data.inputs.push_back(input_index); } diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/eager/test_util.cc index 26d96acc82064ba1046555940e1b1132874ef23e..b8c9e2652a8c8b33ba1be9323269db56df82757f 100644 --- a/tensorflow/contrib/lite/delegates/eager/test_util.cc +++ b/tensorflow/contrib/lite/delegates/eager/test_util.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/contrib/lite/delegates/eager/test_util.h" #include "absl/memory/memory.h" -#include "flatbuffers/flexbuffers.h" +#include "flatbuffers/flexbuffers.h" // flatbuffers #include "tensorflow/contrib/lite/string.h" namespace tflite { diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc index e6cc3dd99c2e18bf297f8fac244e5d809954a01a..980a1cb4a09c0e2bd892db2842112fcaf84dd70e 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -238,7 +238,7 @@ class NNAPIOpBuilder { tensor->params.zero_point}; CHECK_NN(context_, ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); - augmented_inputs_.push_back(ann_index); + augmented_outputs_.push_back(ann_index); *ann_tensor_index_out = ann_index; return kTfLiteOk; @@ -370,8 +370,8 @@ struct NNAPIOpMappingArgs { TfLiteContext* context; NNAPIOpBuilder* builder; TfLiteNode* node; - std::vector* model_state_inputs; - std::vector* model_state_tfl_outputs; + std::vector* model_state_outputs; + std::vector* model_state_tfl_inputs; }; // The kernel that represents the subgraph of TF Lite being run on NN API. @@ -781,8 +781,7 @@ class NNAPIDelegateKernel { break; case kTfLiteBuiltinRnn: // NNAPI only support float32 weights. - // TODO(miaowang): check the number of inputs before accessing it. - if (version == 1 && + if (version == 1 && node->inputs->size == 5 && context->tensors[node->inputs->data[/*kWeightsTensor*/ 1]].type == kTfLiteFloat32) { return [](const NNAPIOpMappingArgs& mapping_args) @@ -790,11 +789,11 @@ class NNAPIDelegateKernel { // NNAPI need both state_in and state_out. int ann_index; mapping_args.builder->AddStateFloat32Tensor( - mapping_args.node->outputs->data[/*kHiddenStateTensor*/ 0], + mapping_args.node->inputs->data[/*kHiddenStateTensor*/ 4], &ann_index); - mapping_args.model_state_inputs->push_back(ann_index); - mapping_args.model_state_tfl_outputs->push_back( - mapping_args.node->outputs->data[/*kHiddenStateTensor*/ 0]); + mapping_args.model_state_outputs->push_back(ann_index); + mapping_args.model_state_tfl_inputs->push_back( + mapping_args.node->inputs->data[/*kHiddenStateTensor*/ 4]); auto builtin = reinterpret_cast( mapping_args.node->builtin_data); mapping_args.builder->AddScalarInt32Operand(builtin->activation); @@ -806,7 +805,7 @@ class NNAPIDelegateKernel { break; case kTfLiteBuiltinSvdf: // NNAPI only support float32 weights. - if (version == 1 && + if (version == 1 && node->inputs->size == 5 && context->tensors[node->inputs->data[/*kWeightsFeatureTensor*/ 1]] .type == kTfLiteFloat32) { return [](const NNAPIOpMappingArgs& mapping_args) @@ -814,11 +813,13 @@ class NNAPIDelegateKernel { // NNAPI need both state_in and state_out. int ann_index; mapping_args.builder->AddStateFloat32Tensor( - mapping_args.node->outputs->data[/*kStateTensor*/ 0], + mapping_args.node->inputs + ->data[/*kInputActivationStateTensor*/ 4], &ann_index); - mapping_args.model_state_inputs->push_back(ann_index); - mapping_args.model_state_tfl_outputs->push_back( - mapping_args.node->outputs->data[/*kStateTensor*/ 0]); + mapping_args.model_state_outputs->push_back(ann_index); + mapping_args.model_state_tfl_inputs->push_back( + mapping_args.node->inputs + ->data[/*kInputActivationStateTensor*/ 4]); auto builtin = reinterpret_cast( mapping_args.node->builtin_data); @@ -833,28 +834,12 @@ class NNAPIDelegateKernel { case kTfLiteBuiltinLstm: // NNAPI only support float32 weights. // TODO(miaowang): add loggings to indicate why the op is rejected. - if (version == 1 && node->inputs->size == 18 && + if (version == 1 && node->inputs->size == 20 && context->tensors[node->inputs ->data[/*kInputToOutputWeightsTensor*/ 4]] .type == kTfLiteFloat32) { return [](const NNAPIOpMappingArgs& mapping_args) -> ANeuralNetworksOperationType { - // NNAPI need both state_in and state_out for cell_state and - // output_state. - int ann_index; - mapping_args.builder->AddStateFloat32Tensor( - mapping_args.node->outputs->data[/*kOutputStateTensor*/ 0], - &ann_index); - mapping_args.model_state_inputs->push_back(ann_index); - mapping_args.model_state_tfl_outputs->push_back( - mapping_args.node->outputs->data[/*kOutputStateTensor*/ 0]); - mapping_args.builder->AddStateFloat32Tensor( - mapping_args.node->outputs->data[/*kCellStateTensor*/ 1], - &ann_index); - mapping_args.model_state_inputs->push_back(ann_index); - mapping_args.model_state_tfl_outputs->push_back( - mapping_args.node->outputs->data[/*kCellStateTensor*/ 1]); - auto builtin = reinterpret_cast( mapping_args.node->builtin_data); mapping_args.builder->AddScalarInt32Operand(builtin->activation); @@ -864,6 +849,25 @@ class NNAPIDelegateKernel { // Current NNAPI implementation requires the sratch_buffer as // output. mapping_args.builder->AddAdditionalFloat32OutputTensor(2); + + // NNAPI need both state_in and state_out for cell_state and + // output_state. + int ann_index; + mapping_args.builder->AddStateFloat32Tensor( + mapping_args.node->inputs + ->data[/*kInputActivationStateTensor*/ 18], + &ann_index); + mapping_args.model_state_outputs->push_back(ann_index); + mapping_args.model_state_tfl_inputs->push_back( + mapping_args.node->inputs + ->data[/*kInputActivationStateTensor*/ 18]); + mapping_args.builder->AddStateFloat32Tensor( + mapping_args.node->inputs->data[/*kInputCellStateTensor*/ 19], + &ann_index); + mapping_args.model_state_outputs->push_back(ann_index); + mapping_args.model_state_tfl_inputs->push_back( + mapping_args.node->inputs->data[/*kInputCellStateTensor*/ 19]); + return ANEURALNETWORKS_LSTM; }; } else { @@ -950,12 +954,10 @@ class NNAPIDelegateKernel { // Set the input tensor buffers. Note: we access tflite tensors using // absolute indices but NN api indices inputs by relative indices. int relative_input_index = 0; - int num_optional_tensors = 0; size_t input_offset = 0; for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) { if (absolute_input_index == kOptionalTensor) { - num_optional_tensors++; continue; } TfLiteTensor* tensor = &context->tensors[absolute_input_index]; @@ -989,16 +991,16 @@ class NNAPIDelegateKernel { // The state_out of previous invocation need to be mapped to state_in of // current invocation. - for (size_t i = 0; i < model_state_tfl_outputs_.size(); i++) { - int state_tensor_idx = model_state_tfl_outputs_[i]; + for (size_t i = 0; i < model_state_tfl_inputs_.size(); i++) { + int state_tensor_idx = model_state_tfl_inputs_[i]; TfLiteTensor* tensor = &context->tensors[state_tensor_idx]; // Here we are using a deep copy for state_in tensors so that we are not // reading and writing into the same buffer during a invocation. // TODO(110369471): using double shared buffer to minimize the copies. - CHECK_NN(context, - ANeuralNetworksExecution_setInput( - execution, i + node->inputs->size - num_optional_tensors, - nullptr, tensor->data.raw, tensor->bytes)); + CHECK_NN(context, ANeuralNetworksExecution_setOutput( + execution, relative_output_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_output_index++; } // Invoke ANN in blocking fashion. ANeuralNetworksEvent* event = nullptr; @@ -1030,8 +1032,8 @@ class NNAPIDelegateKernel { // Track indices we use OperandMapping operand_mapping_; - std::vector model_state_inputs_; - std::vector model_state_tfl_outputs_; + std::vector model_state_outputs_; + std::vector model_state_tfl_inputs_; std::unique_ptr nn_input_memory_; std::unique_ptr nn_output_memory_; @@ -1063,9 +1065,9 @@ class NNAPIDelegateKernel { } } // Get op type and operands - int nn_op_type = Map(context, reg->builtin_code, reg->version, - node)({context, &builder, node, &model_state_inputs_, - &model_state_tfl_outputs_}); + int nn_op_type = Map(context, reg->builtin_code, reg->version, node)( + {context, &builder, node, &model_state_outputs_, + &model_state_tfl_inputs_}); // Map outputs to NN API tensor indices. for (auto output_index : TfLiteIntArrayView(node->outputs)) { TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index)); @@ -1098,17 +1100,17 @@ class NNAPIDelegateKernel { } } - // Add state input tensors as model inputs - for (int i : model_state_inputs_) { - inputs.push_back(i); - } - size_t total_output_byte_size = 0; for (int i : TfLiteIntArrayView(output_tensors)) { outputs.push_back(operand_mapping_.lite_index_to_ann(i)); total_output_byte_size += context->tensors[i].bytes; } + // Add state output tensors as model inputs + for (int i : model_state_outputs_) { + outputs.push_back(i); + } + // Tell ANN to declare inputs/outputs CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs( nn_model_.get(), inputs.size(), inputs.data(), diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc index 3224b23a0c3bc8456bd75f2923d16f0eed7d53ff..4b01aefd6a3103e9cad2d279666511175213ad26 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -1773,15 +1773,16 @@ class RNNOpModel : public SingleOpModelWithNNAPI { weights_ = AddInput(weights); recurrent_weights_ = AddInput(recurrent_weights); bias_ = AddInput(TensorType_FLOAT32); - hidden_state_ = AddOutput(TensorType_FLOAT32); + hidden_state_ = AddInput(TensorType_FLOAT32, true); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp( BuiltinOperator_RNN, BuiltinOptions_RNNOptions, CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); - BuildInterpreter({{batches_, input_size_}, - {units_, input_size_}, - {units_, units_}, - {units_}}); + BuildInterpreter({{batches_, input_size_}, // input tensor + {units_, input_size_}, // weights tensor + {units_, units_}, // recurrent weights tensor + {units_}, // bias tensor + {batches_, units_}}); // hidden state tensor } void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } @@ -1802,14 +1803,6 @@ class RNNOpModel : public SingleOpModelWithNNAPI { PopulateTensor(input_, offset, begin, end); } - void ResetHiddenState() { - const int zero_buffer_size = units_ * batches_; - std::unique_ptr zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(hidden_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - std::vector GetOutput() { return ExtractVector(output_); } int input_size() { return input_size_; } @@ -1835,7 +1828,6 @@ TEST(NNAPIDelegate, RnnBlackBoxTest) { rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); - rnn.ResetHiddenState(); const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / (rnn.input_size() * rnn.num_batches()); @@ -1968,16 +1960,20 @@ class BaseSVDFOpModel : public SingleOpModelWithNNAPI { weights_feature_ = AddInput(weights_feature_type); weights_time_ = AddInput(weights_time_type); bias_ = AddNullInput(); - state_ = AddOutput(TensorType_FLOAT32); + const int num_filters = units * rank; + activation_state_ = AddInput( + TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}}, + /*is_variable=*/true); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp( BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union()); BuildInterpreter({ - {batches_, input_size_}, // Input tensor - {units_ * rank, input_size_}, // weights_feature tensor - {units_ * rank, memory_size_}, // weights_time tensor - {units_} // bias tensor + {batches_, input_size_}, // input tensor + {units_ * rank, input_size_}, // weights_feature tensor + {units_ * rank, memory_size_}, // weights_time tensor + {units_}, // bias tensor + {batches, memory_size * num_filters} // activation_state tensor }); } @@ -1996,15 +1992,6 @@ class BaseSVDFOpModel : public SingleOpModelWithNNAPI { PopulateTensor(input_, offset, begin, end); } - // Resets the state of SVDF op by filling it with 0's. - void ResetState() { - const int zero_buffer_size = rank_ * units_ * batches_ * memory_size_; - std::unique_ptr zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - // Extracts the output tensor from the SVDF op. std::vector GetOutput() { return ExtractVector(output_); } @@ -2017,7 +2004,7 @@ class BaseSVDFOpModel : public SingleOpModelWithNNAPI { int weights_feature_; int weights_time_; int bias_; - int state_; + int activation_state_; int output_; int batches_; @@ -2081,7 +2068,6 @@ TEST(NNAPIDelegate, SVDFBlackBoxTestRank1) { -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); - svdf.ResetState(); svdf.VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input)); } @@ -2120,7 +2106,6 @@ TEST(NNAPIDelegate, SVDFBlackBoxTestRank2) { 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); - svdf.ResetState(); svdf.VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input)); } @@ -2192,8 +2177,12 @@ class LSTMOpModel : public SingleOpModelWithNNAPI { projection_bias_ = AddNullInput(); } - output_state_ = AddOutput(TensorType_FLOAT32); - cell_state_ = AddOutput(TensorType_FLOAT32); + // Adding the 2 input state tensors. + input_activation_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true); + input_cell_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_cell_}}, true); + output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, @@ -2271,22 +2260,6 @@ class LSTMOpModel : public SingleOpModelWithNNAPI { PopulateTensor(projection_bias_, f); } - void ResetOutputState() { - const int zero_buffer_size = n_cell_ * n_batch_; - std::unique_ptr zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(output_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - - void ResetCellState() { - const int zero_buffer_size = n_cell_ * n_batch_; - std::unique_ptr zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(cell_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - void SetInput(int offset, const float* begin, const float* end) { PopulateTensor(input_, offset, const_cast(begin), const_cast(end)); @@ -2495,10 +2468,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } @@ -2602,10 +2571,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { lstm.SetCellToForgetWeights(cell_to_forget_weights_); lstm.SetCellToOutputWeights(cell_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } @@ -3266,10 +3231,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { lstm.SetProjectionWeights(projection_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc index 834d1ebd666db2be46394166edadf2a166d958aa..121997dcb2756df75f85b1405bb05cbb5fdd7aa3 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include -#include "flatbuffers/flexbuffers.h" +#include "flatbuffers/flexbuffers.h" // flatbuffers #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc index 9d1e6a562f00905d1db7f7e055ac1c6b1cc34f9e..32458305c4ff3d4a5871519b3c412692a66788d6 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "flatbuffers/flexbuffers.h" +#include "flatbuffers/flexbuffers.h" // flatbuffers #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md index 776803da8c7126c6198e3740448888119df030b9..f255017ad9d938359b2378745dc93a86e4317920 100644 --- a/tensorflow/contrib/lite/g3doc/apis.md +++ b/tensorflow/contrib/lite/g3doc/apis.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # TensorFlow Lite APIs diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md index d979353bb3550fe53d86b2e6c76702a3970b01fe..ee6150b60e8e8511dc5552bbbf0c71c71d80d1fe 100644 --- a/tensorflow/contrib/lite/g3doc/custom_operators.md +++ b/tensorflow/contrib/lite/g3doc/custom_operators.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # How to use custom operators diff --git a/tensorflow/contrib/lite/g3doc/demo_android.md b/tensorflow/contrib/lite/g3doc/demo_android.md index d79a2696b4e9cc10480aa67c7eaec5a356eff596..c38b928684848b858e3f6cc9df6f05e31f778b05 100644 --- a/tensorflow/contrib/lite/g3doc/demo_android.md +++ b/tensorflow/contrib/lite/g3doc/demo_android.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Android Demo App diff --git a/tensorflow/contrib/lite/g3doc/demo_ios.md b/tensorflow/contrib/lite/g3doc/demo_ios.md index a554898899e67a6bc2bc52733f5301767bc1c06a..7579ad84a049ec592aafb16ce95a4b703ac78c5a 100644 --- a/tensorflow/contrib/lite/g3doc/demo_ios.md +++ b/tensorflow/contrib/lite/g3doc/demo_ios.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # iOS Demo App diff --git a/tensorflow/contrib/lite/g3doc/devguide.md b/tensorflow/contrib/lite/g3doc/devguide.md index dc9cc98c0821edff57cb9428a50637a15211cfda..90e7915c52cecc7fff108cbe829aaa97b0fc4ce3 100644 --- a/tensorflow/contrib/lite/g3doc/devguide.md +++ b/tensorflow/contrib/lite/g3doc/devguide.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Developer Guide diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md index d78d373ccfea074872773693c562253b202a646b..5ff041220955bd0cdff70bcd431bdcb9e8fda6f5 100644 --- a/tensorflow/contrib/lite/g3doc/ios.md +++ b/tensorflow/contrib/lite/g3doc/ios.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # TensorFlow Lite for iOS diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md index 4ceb9a53dc0967ab6320a1bfdb1ddb859482c5dd..b984671e8998659b7ad3f6f5560feff0043756cf 100644 --- a/tensorflow/contrib/lite/g3doc/models.md +++ b/tensorflow/contrib/lite/g3doc/models.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # List of Hosted Models diff --git a/tensorflow/contrib/lite/g3doc/ops_versioning.md b/tensorflow/contrib/lite/g3doc/ops_versioning.md index b06f4fd3b893e5e5977f92de26109a6dd264531f..0d571ce54779547a5e3457b089b791abca858930 100644 --- a/tensorflow/contrib/lite/g3doc/ops_versioning.md +++ b/tensorflow/contrib/lite/g3doc/ops_versioning.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # TensorFlow Lite Ops Versioning diff --git a/tensorflow/contrib/lite/g3doc/overview.md b/tensorflow/contrib/lite/g3doc/overview.md index be60d7941ade824ee201bfd05400fb3e4e9fae7e..8cf43496dfef351cb094db9c9355b280d112e2fa 100644 --- a/tensorflow/contrib/lite/g3doc/overview.md +++ b/tensorflow/contrib/lite/g3doc/overview.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Introduction to TensorFlow Lite diff --git a/tensorflow/contrib/lite/g3doc/performance.md b/tensorflow/contrib/lite/g3doc/performance.md index 5cd0aab44f10de1b76e1acb302fc1ee2711c8d74..28cb6aba6ec61d12d86e078e47665833df8afec7 100644 --- a/tensorflow/contrib/lite/g3doc/performance.md +++ b/tensorflow/contrib/lite/g3doc/performance.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Performance diff --git a/tensorflow/contrib/lite/g3doc/rpi.md b/tensorflow/contrib/lite/g3doc/rpi.md index 9fcf79ba004d85566b64ce35b3693e01c4b0e2cf..8ed8640582307a64827a6b83a511c0057e727d92 100644 --- a/tensorflow/contrib/lite/g3doc/rpi.md +++ b/tensorflow/contrib/lite/g3doc/rpi.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # TensorFlow Lite for Raspberry Pi diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index aa65ec99887a61df658dd7add7b5cc3b91d81846..8660d29855899c110df9dd1746d0e6f1075f21e5 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # TensorFlow Lite & TensorFlow Compatibility Guide @@ -843,6 +841,31 @@ Outputs { } ``` +**UNPACK** + +``` +Inputs { + 0: a tensor. + 1: an integer. + 2: an integer. +} +Outputs { + 0-N: tensors of unpacked tensor. +} +``` + +**FLOOR_DIV** + +``` +Inputs { + 0: a list of tensors. + 1: a list of tensors. +} +Outputs { + 0: A tensor of floor_div output tensors. +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md index 76e16fc9db27782fe0f9454ba463722f4bf6eb4b..c7cdee07de375c165e01626154d92a81ad880eca 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Building TensorFlow on Android diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md index bd047bfceceddfd0b5a9fd0c83cb47a339299abf..d003bb2f3855141b51c6d4afc7fc5a46dc08d665 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/index.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Overview diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md index 6223707892ce7b288ecabf932b33cd39860446a6..be8b4100c89f4b02e651b1585faf438881c9119d 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Building TensorFlow on iOS diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md index 4c2071ed053125cfa643ed785fe302198f734ead..4d4bb3bc081d613714271f8b0bf7461cb1e0f4d5 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Integrating TensorFlow libraries diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md index a0192c3541483437b817e22eb92193bd7bcb4c28..7436594fd8580151ba66562eccd408cc7e6c4201 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Optimizing for mobile diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md index 6b4e4a92bd9262139be3cf650b7d16714ee3a277..d1c67d4c61608bcbc9b0bcee5b60f46a73b44692 100644 --- a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md +++ b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md @@ -1,5 +1,3 @@ -book_path: /mobile/_book.yaml -project_path: /mobile/_project.yaml # Preparing models for mobile deployment diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 362e5887257f1a06263aadbdaef011b3893a577f..5ab53f4c1dadacc8901df5e0dcf543804deedea1 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -476,6 +476,10 @@ TfLiteStatus Interpreter::ResetVariableTensorsToZero() { return kTfLiteOk; } +void Interpreter::ReserveNodes(int count) { + nodes_and_registration_.reserve(count); +} + TfLiteStatus Interpreter::AddNodeWithParameters( const std::vector& inputs, const std::vector& outputs, const char* init_data, size_t init_data_size, void* builtin_data, diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 7d69aa2ad3894c42ff5b2b6df1604ab5701f4aa0..2b1f1819b9acdc22b8a56cfec5a4d5b5b5c5d16f 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -136,6 +136,11 @@ class Interpreter { // interpreter. TfLiteStatus SetVariables(std::vector variables); + // Ensure the internal node storage memory allocates at least `count` + // spots for node. NOTE, this doesn't actually add operators. This is an + // efficiency optimization that is subject to change. + void ReserveNodes(int count); + // Adds a node with the given parameters and returns the index of the new // node in `node_index` (optionally). Interpreter will take ownership of // `builtin_data` and destroy it with `free`. Ownership of 'init_data' diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 1f528fdab9f264a338bdf8826340b404f87041ed..8287115f5cb1fe0302c4dc865c0c6a777b2c910a 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -172,6 +172,7 @@ cc_library( "expand_dims.cc", "fake_quant.cc", "floor.cc", + "floor_div.cc", "fully_connected.cc", "gather.cc", "hashtable_lookup.cc", @@ -211,6 +212,7 @@ cc_library( "transpose_conv.cc", "unidirectional_sequence_lstm.cc", "unidirectional_sequence_rnn.cc", + "unpack.cc", ], hdrs = [ "padding.h", @@ -1201,6 +1203,34 @@ tf_cc_test( ], ) +tf_cc_test( + name = "unpack_test", + size = "small", + srcs = ["unpack_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "floor_div_test", + size = "small", + srcs = ["floor_div_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc index 91d8dd3fa71b4f2ac70c64c4923c5240b61a2b25..1170d84553a69209e2e53b0df1e5c2426d543e12 100644 --- a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc +++ b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" -#include "flatbuffers/flexbuffers.h" +#include "flatbuffers/flexbuffers.h" // flatbuffers namespace tflite { namespace ops { diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc index 8d460fdfc610ef9a867acd492ca0558fb6eab8c3..7346b9fd80d6645b6a40884c0d1ae34677a714fc 100644 --- a/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc +++ b/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "flatbuffers/flexbuffers.h" +#include "flatbuffers/flexbuffers.h" // flatbuffers #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc index c09b15b3d263d6cd639234590c99a50a9a48f4a7..c5a5c0182ffe28c6724240bbac1e14ef6e2a259e 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -31,8 +31,10 @@ constexpr int kInputTensor = 0; constexpr int kWeightsTensor = 1; constexpr int kRecurrentWeightsTensor = 2; constexpr int kBiasTensor = 3; -constexpr int kHiddenStateTensor = 0; -constexpr int kOutputTensor = 1; +constexpr int kHiddenStateTensor = 4; + +// Output tensor. +constexpr int kOutputTensor = 0; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; @@ -46,14 +48,16 @@ void Free(TfLiteContext* context, void* buffer) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check we have all the inputs and outputs we need. - TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor); const TfLiteTensor* recurrent_weights = GetInput(context, node, kRecurrentWeightsTensor); const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); + const TfLiteTensor* hidden_state = + GetInput(context, node, kHiddenStateTensor); // Check all the parameters of tensor match within themselves and match the // input configuration. @@ -65,20 +69,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type); + TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2); + TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size); + TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units); - TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - // Resize state. - TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2); - hidden_state_size_array->data[0] = batch_size; - hidden_state_size_array->data[1] = num_units; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state, - hidden_state_size_array)); - - // Mark hidden state as a persistent tensor. - hidden_state->allocation_type = kTfLiteArenaRwPersistent; - // Resize output. TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); output_size_array->data[0] = batch_size; @@ -205,7 +201,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* recurrent_weights = GetInput(context, node, kRecurrentWeightsTensor); const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); - TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor); + TfLiteTensor* hidden_state = + &context->tensors[node->inputs->data[kHiddenStateTensor]]; TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // We already checked that weight types are consistent, so branch on one. diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc index 96465fcaf0a78527237faa7b82ddbc32ec56d114..d1797354044c2f2086f1af0cffb7f1edff65f24c 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc @@ -181,15 +181,16 @@ class RNNOpModel : public SingleOpModel { weights_ = AddInput(weights); recurrent_weights_ = AddInput(recurrent_weights); bias_ = AddInput(TensorType_FLOAT32); - hidden_state_ = AddOutput(TensorType_FLOAT32); + hidden_state_ = AddInput(TensorType_FLOAT32, true); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp( BuiltinOperator_RNN, BuiltinOptions_RNNOptions, CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); - BuildInterpreter({{batches_, input_size_}, - {units_, input_size_}, - {units_, units_}, - {units_}}); + BuildInterpreter({{batches_, input_size_}, // input tensor + {units_, input_size_}, // weights tensor + {units_, units_}, // recurrent weights tensor + {units_}, // bias tensor + {batches_, units_}}); // hidden state tensor } void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } @@ -210,14 +211,6 @@ class RNNOpModel : public SingleOpModel { PopulateTensor(input_, offset, begin, end); } - void ResetHiddenState() { - const int zero_buffer_size = units_ * batches_; - std::unique_ptr zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(hidden_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - std::vector GetOutput() { return ExtractVector(output_); } int input_size() { return input_size_; } @@ -258,7 +251,6 @@ TEST(RnnOpTest, BlackBoxTest) { rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); - rnn.ResetHiddenState(); const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / (rnn.input_size() * rnn.num_batches()); @@ -286,7 +278,6 @@ TEST(HybridRnnOpTest, BlackBoxTest) { rnn.SetBias(rnn_bias); rnn.SetRecurrentWeights(rnn_recurrent_weights); - rnn.ResetHiddenState(); const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / (rnn.input_size() * rnn.num_batches()); diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc index 517309a226bcfb717186be8c1d02d68e3b337f8e..4162d9bb889fa5703116b44e568b4c36ed45cf14 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -44,25 +45,37 @@ constexpr int kFwOutputTensor = 1; constexpr int kBwHiddenStateTensor = 2; constexpr int kBwOutputTensor = 3; +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* scratch_tensor_index = new int; + context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); + return scratch_tensor_index; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 7); TF_LITE_ENSURE_EQ(context, node->outputs->size, 4); - TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; - TfLiteTensor* fw_input_weights = - &context->tensors[node->inputs->data[kFwWeightsTensor]]; - TfLiteTensor* fw_recurrent_weights = - &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]]; - TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]]; - TfLiteTensor* bw_input_weights = - &context->tensors[node->inputs->data[kBwWeightsTensor]]; - TfLiteTensor* bw_recurrent_weights = - &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]]; - TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]]; + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* fw_input_weights = + GetInput(context, node, kFwWeightsTensor); + const TfLiteTensor* fw_recurrent_weights = + GetInput(context, node, kFwRecurrentWeightsTensor); + const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor); + const TfLiteTensor* bw_input_weights = + GetInput(context, node, kBwWeightsTensor); + const TfLiteTensor* bw_recurrent_weights = + GetInput(context, node, kBwRecurrentWeightsTensor); + const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor); // Check all the parameters of tensor match within themselves and match the // input configuration. + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + const int batch_size = input->dims->data[0]; const int max_time = input->dims->data[1]; const int fw_num_units = fw_input_weights->dims->data[0]; @@ -76,17 +89,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ASSERT_EQ(bw_recurrent_weights->dims->data[1], bw_bias->dims->data[0]); - TfLiteTensor* fw_output = - &context->tensors[node->outputs->data[kFwOutputTensor]]; - TfLiteTensor* bw_output = - &context->tensors[node->outputs->data[kBwOutputTensor]]; + TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); + TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); // Resize hidden states. TfLiteIntArray* fw_hidden_state_size_array = TfLiteIntArrayCreate(2); fw_hidden_state_size_array->data[0] = batch_size; fw_hidden_state_size_array->data[1] = fw_num_units; TfLiteTensor* fw_hidden_state = - &context->tensors[node->outputs->data[kFwHiddenStateTensor]]; + GetOutput(context, node, kFwHiddenStateTensor); TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_hidden_state, fw_hidden_state_size_array)); @@ -94,7 +105,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { bw_hidden_state_size_array->data[0] = batch_size; bw_hidden_state_size_array->data[1] = fw_num_units; TfLiteTensor* bw_hidden_state = - &context->tensors[node->outputs->data[kBwHiddenStateTensor]]; + GetOutput(context, node, kBwHiddenStateTensor); TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_hidden_state, bw_hidden_state_size_array)); @@ -102,6 +113,50 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { fw_hidden_state->allocation_type = kTfLiteArenaRwPersistent; bw_hidden_state->allocation_type = kTfLiteArenaRwPersistent; + const bool is_hybrid_op = + (fw_input_weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32); + + if (is_hybrid_op) { + int* scratch_tensor_index = reinterpret_cast(node->user_data); + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(2); + node->temporaries->data[0] = *scratch_tensor_index; + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + node->temporaries->data[1] = *scratch_tensor_index + 1; + TfLiteTensor* fw_hidden_state_quantized = + GetTemporary(context, node, /*index=*/1); + fw_hidden_state_quantized->type = kTfLiteUInt8; + fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims, + fw_hidden_state->dims)) { + TfLiteIntArray* fw_hidden_state_quantized_size = + TfLiteIntArrayCopy(fw_hidden_state->dims); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, fw_hidden_state_quantized, + fw_hidden_state_quantized_size)); + } + node->temporaries->data[2] = *scratch_tensor_index + 2; + TfLiteTensor* bw_hidden_state_quantized = + GetTemporary(context, node, /*index=*/2); + bw_hidden_state_quantized->type = kTfLiteUInt8; + bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims, + bw_hidden_state->dims)) { + TfLiteIntArray* bw_hidden_state_quantized_size = + TfLiteIntArrayCopy(bw_hidden_state->dims); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, bw_hidden_state_quantized, + bw_hidden_state_quantized_size)); + } + } + // Resize outputs. TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3); fw_output_size_array->data[0] = batch_size; @@ -119,30 +174,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - - TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; - TfLiteTensor* fw_input_weights = - &context->tensors[node->inputs->data[kFwWeightsTensor]]; - TfLiteTensor* fw_recurrent_weights = - &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]]; - TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]]; - TfLiteTensor* fw_hidden_state = - &context->tensors[node->outputs->data[kFwHiddenStateTensor]]; - TfLiteTensor* fw_output = - &context->tensors[node->outputs->data[kFwOutputTensor]]; - - TfLiteTensor* bw_input_weights = - &context->tensors[node->inputs->data[kBwWeightsTensor]]; - TfLiteTensor* bw_recurrent_weights = - &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]]; - TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]]; - TfLiteTensor* bw_hidden_state = - &context->tensors[node->outputs->data[kBwHiddenStateTensor]]; - TfLiteTensor* bw_output = - &context->tensors[node->outputs->data[kBwOutputTensor]]; - +TfLiteStatus EvalFloat(const TfLiteTensor* input, + const TfLiteTensor* fw_input_weights, + const TfLiteTensor* fw_recurrent_weights, + const TfLiteTensor* fw_bias, + const TfLiteTensor* bw_input_weights, + const TfLiteTensor* bw_recurrent_weights, + const TfLiteTensor* bw_bias, + const TfLiteSequenceRNNParams* params, + TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, + TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) { const int batch_size = input->dims->data[0]; const int max_time = input->dims->data[1]; const int input_size = input->dims->data[2]; @@ -190,12 +231,139 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* fw_input_weights, + const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias, + const TfLiteTensor* bw_input_weights, + const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias, + const TfLiteSequenceRNNParams* params, TfLiteTensor* input_quantized, + TfLiteTensor* fw_hidden_state_quantized, TfLiteTensor* fw_scaling_factors, + TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, + TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_scaling_factors, + TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) { + const int batch_size = input->dims->data[0]; + const int max_time = input->dims->data[1]; + const int input_size = input->dims->data[2]; + + const int fw_num_units = fw_input_weights->dims->data[0]; + const float* fw_bias_ptr = fw_bias->data.f; + const int8_t* fw_input_weights_ptr = + reinterpret_cast(fw_input_weights->data.uint8); + float fw_input_weights_scale = fw_input_weights->params.scale; + const int8_t* fw_recurrent_weights_ptr = + reinterpret_cast(fw_recurrent_weights->data.uint8); + float fw_recurrent_weights_scale = fw_recurrent_weights->params.scale; + + const int bw_num_units = bw_input_weights->dims->data[0]; + const float* bw_bias_ptr = bw_bias->data.f; + const int8_t* bw_input_weights_ptr = + reinterpret_cast(bw_input_weights->data.uint8); + float bw_input_weights_scale = bw_input_weights->params.scale; + const int8_t* bw_recurrent_weights_ptr = + reinterpret_cast(bw_recurrent_weights->data.uint8); + float bw_recurrent_weights_scale = bw_recurrent_weights->params.scale; + + // Initialize temporary storage for quantized values. + int8_t* quantized_input_ptr = + reinterpret_cast(input_quantized->data.uint8); + int8_t* fw_quantized_hidden_state_ptr = + reinterpret_cast(fw_hidden_state_quantized->data.uint8); + int8_t* bw_quantized_hidden_state_ptr = + reinterpret_cast(bw_hidden_state_quantized->data.uint8); + float* fw_scaling_factors_ptr = fw_scaling_factors->data.f; + float* bw_scaling_factors_ptr = bw_scaling_factors->data.f; + + for (int b = 0; b < batch_size; b++) { + // Forward cell. + float* fw_hidden_state_ptr_batch = + fw_hidden_state->data.f + b * fw_num_units; + for (int s = 0; s < max_time; s++) { + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + float* output_ptr_batch = + fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units; + + kernel_utils::RnnBatchStep( + input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale, + fw_recurrent_weights_ptr, fw_recurrent_weights_scale, fw_bias_ptr, + input_size, fw_num_units, /*batch_size=*/1, params->activation, + quantized_input_ptr, fw_quantized_hidden_state_ptr, + fw_scaling_factors_ptr, fw_hidden_state_ptr_batch, output_ptr_batch); + } + // Backward cell. + float* bw_hidden_state_ptr_batch = + bw_hidden_state->data.f + b * bw_num_units; + for (int s = max_time - 1; s >= 0; s--) { + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + float* output_ptr_batch = + bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units; + + kernel_utils::RnnBatchStep( + input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale, + bw_recurrent_weights_ptr, bw_recurrent_weights_scale, bw_bias_ptr, + input_size, bw_num_units, /*batch_size=*/1, params->activation, + quantized_input_ptr, bw_quantized_hidden_state_ptr, + bw_scaling_factors_ptr, bw_hidden_state_ptr_batch, output_ptr_batch); + } + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* params = + reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* fw_input_weights = + GetInput(context, node, kFwWeightsTensor); + const TfLiteTensor* fw_recurrent_weights = + GetInput(context, node, kFwRecurrentWeightsTensor); + const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor); + const TfLiteTensor* bw_input_weights = + GetInput(context, node, kBwWeightsTensor); + const TfLiteTensor* bw_recurrent_weights = + GetInput(context, node, kBwRecurrentWeightsTensor); + const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor); + + TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); + TfLiteTensor* fw_hidden_state = + GetOutput(context, node, kFwHiddenStateTensor); + TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); + TfLiteTensor* bw_hidden_state = + GetOutput(context, node, kBwHiddenStateTensor); + + switch (fw_input_weights->type) { + case kTfLiteFloat32: + return EvalFloat(input, fw_input_weights, fw_recurrent_weights, fw_bias, + bw_input_weights, bw_recurrent_weights, bw_bias, params, + fw_hidden_state, fw_output, bw_hidden_state, bw_output); + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = GetTemporary(context, node, 0); + TfLiteTensor* fw_hidden_state_quantized = GetTemporary(context, node, 1); + TfLiteTensor* bw_hidden_state_quantized = GetTemporary(context, node, 2); + TfLiteTensor* fw_scaling_factors = GetTemporary(context, node, 3); + TfLiteTensor* bw_scaling_factors = GetTemporary(context, node, 4); + return EvalHybrid(input, fw_input_weights, fw_recurrent_weights, fw_bias, + bw_input_weights, bw_recurrent_weights, bw_bias, params, + input_quantized, fw_hidden_state_quantized, + fw_scaling_factors, fw_hidden_state, fw_output, + bw_hidden_state_quantized, bw_scaling_factors, + bw_hidden_state, bw_output); + } + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + } // namespace bidirectional_sequence_rnn TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() { - static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, - bidirectional_sequence_rnn::Prepare, - bidirectional_sequence_rnn::Eval}; + static TfLiteRegistration r = { + bidirectional_sequence_rnn::Init, bidirectional_sequence_rnn::Free, + bidirectional_sequence_rnn::Prepare, bidirectional_sequence_rnn::Eval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 50fe5c2e042fc94d665b05632cd029c9c05f550b..51989f541fbe3b0e726b6f90363405934db16201 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" #include "tensorflow/contrib/lite/kernels/padding.h" @@ -60,6 +61,8 @@ struct OpData { // memory buffers. int im2col_id = kTensorNotAllocated; int hwcn_weights_id = kTensorNotAllocated; + int input_quantized_id = kTensorNotAllocated; + int scaling_factors_id = kTensorNotAllocated; TfLitePaddingValues padding; // The scaling factor from input to output (aka the 'real multiplier') can @@ -74,6 +77,8 @@ struct OpData { // of the allocated temporaries. int32_t im2col_index; int32_t hwcn_weights_index; + int32_t input_quantized_index; + int32_t scaling_factors_index; bool need_hwcn_weights; bool have_weights_been_transposed; bool need_im2col; @@ -125,6 +130,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; TfLiteTensor* filter = &context->tensors[node->inputs->data[1]]; + const bool is_hybrid = + (input->type == kTfLiteFloat32 && filter->type == kTfLiteUInt8); + int filter_width = filter->dims->data[2]; int filter_height = filter->dims->data[1]; @@ -145,8 +153,8 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, // buffer to store the results. // This path is only used for float processing, so only create the buffer if // we're running with that data type. - data->need_hwcn_weights = - (input->type == kTfLiteFloat32 && data->run_multithreaded_kernel); + data->need_hwcn_weights = (input->type == kTfLiteFloat32 && + data->run_multithreaded_kernel && !is_hybrid); int temporaries_count = 0; if (data->need_im2col) { @@ -164,6 +172,25 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, ++temporaries_count; } + if (is_hybrid) { + // Allocate tensor to store the on-the-fly quantized inputs. + data->input_quantized_index = temporaries_count; + if (data->input_quantized_id == kTensorNotAllocated) { + TF_LITE_ENSURE_OK( + context, context->AddTensors(context, 1, &data->input_quantized_id)); + } + ++temporaries_count; + + // Allocate tensor to store the quantization params computed during + // on-the-fly input quantization. + data->scaling_factors_index = temporaries_count; + if (data->scaling_factors_id == kTensorNotAllocated) { + TF_LITE_ENSURE_OK( + context, context->AddTensors(context, 1, &data->scaling_factors_id)); + } + ++temporaries_count; + } + TfLiteIntArrayFree(node->temporaries); node->temporaries = TfLiteIntArrayCreate(temporaries_count); @@ -174,10 +201,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); - data->run_multithreaded_kernel = context->recommended_num_threads != 1; - - TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node)); - bool has_bias = node->inputs->size == 3; // Check number of inputs/outputs TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); @@ -193,11 +216,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, input->dims->data[3], filter->dims->data[3]); // Check types. (We assume that UINT8 refers to quantized tensors) - TfLiteType data_type = input->type; + TfLiteType input_type = input->type; TF_LITE_ENSURE(context, - data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8); - TF_LITE_ENSURE_EQ(context, output->type, data_type); - TF_LITE_ENSURE_EQ(context, filter->type, data_type); + input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + TF_LITE_ENSURE_EQ(context, output->type, input_type); TfLiteTensor* bias = nullptr; @@ -207,15 +229,26 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { if (has_bias) { bias = &context->tensors[node->inputs->data[2]]; - if (data_type == kTfLiteUInt8) { + if (input_type == kTfLiteUInt8) { TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); } else { - TF_LITE_ENSURE_EQ(context, bias->type, data_type); + TF_LITE_ENSURE_EQ(context, bias->type, input_type); } TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0)); } + const bool is_hybrid = + (input->type == kTfLiteFloat32 && filter->type == kTfLiteUInt8); + + data->run_multithreaded_kernel = context->recommended_num_threads != 1; + // Hybrid kernels don't support multithreading yet. + if (is_hybrid) { + data->run_multithreaded_kernel = false; + } + + TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node)); + int channels_out = filter->dims->data[0]; int width = input->dims->data[2]; int height = input->dims->data[1]; @@ -250,9 +283,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, has_bias); - // Note that quantized inference requires that all tensors have their + // Note that full fixed-point inference requires that all tensors have their // parameters set. This is usually done during quantized training. - if (data_type != kTfLiteFloat32) { + if (input_type != kTfLiteFloat32) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); @@ -287,7 +320,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* im2col = &context->tensors[node->temporaries->data[data->im2col_index]]; - im2col->type = data_type; + im2col->type = input->type; + if (is_hybrid) { + im2col->type = kTfLiteUInt8; + } im2col->allocation_type = kTfLiteArenaRw; auto im2col_status = context->ResizeTensor(context, im2col, im2col_size); if (im2col_status != kTfLiteOk) return im2col_status; @@ -307,7 +343,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* hwcn_weights = &context->tensors[node->temporaries->data[data->hwcn_weights_index]]; - hwcn_weights->type = data_type; + hwcn_weights->type = input_type; hwcn_weights->allocation_type = kTfLiteArenaRwPersistent; auto hwcn_weights_status = @@ -319,6 +355,35 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { data->have_weights_been_transposed = false; } + if (is_hybrid) { + node->temporaries->data[data->input_quantized_index] = + data->input_quantized_id; + TfLiteTensor* input_quantized = + GetTemporary(context, node, data->input_quantized_index); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + + node->temporaries->data[data->scaling_factors_index] = + data->scaling_factors_id; + TfLiteTensor* scaling_factors = + GetTemporary(context, node, data->scaling_factors_index); + scaling_factors->type = kTfLiteInt32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + // Only one scale factor per batch is typically necessary. See optimized + // implementation for why we need to allocate for height elements here. + scaling_factors_size->data[0] = height; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + } + return kTfLiteOk; } @@ -455,6 +520,57 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, } } +template +void EvalHybrid(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col, + TfLiteTensor* hwcn_weights, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRange(params->activation, &output_activation_min, + &output_activation_max); + + const int input_size = NumElements(input) / SizeOfDimension(input, 0); + const int batch_size = SizeOfDimension(input, 0); + + const TfLiteTensor* input_quantized = + GetTemporary(context, node, data->input_quantized_index); + int8_t* quantized_input_ptr_batch = + reinterpret_cast(input_quantized->data.uint8); + float* scaling_factors_ptr = + GetTemporary(context, node, data->scaling_factors_index)->data.f; + + // Per-batch input quantization for higher accuracy. + for (int b = 0; b < batch_size; ++b) { + float unused_min, unused_max; + const int offset = b * input_size; + tensor_utils::SymmetricQuantizeFloats( + input->data.f + offset, input_size, quantized_input_ptr_batch + offset, + &unused_min, &unused_max, &scaling_factors_ptr[b]); + scaling_factors_ptr[b] *= filter->params.scale; + } + + int8_t* im2col_ptr = reinterpret_cast(im2col->data.uint8); + int8_t* filter_ptr = reinterpret_cast(filter->data.uint8); + + switch (kernel_type) { + case kReference: + case kGenericOptimized: + case kMultithreadOptimized: + case kCblasOptimized: + // There is only one implementation for hybrid kernel. Note + // this does not make use of gemmlowp nor supports multithreading. + optimized_ops::HybridConv( + quantized_input_ptr_batch, GetTensorDims(input), filter_ptr, + GetTensorDims(filter), GetTensorData(bias), + GetTensorDims(bias), params->stride_width, params->stride_height, + data->padding.width, data->padding.height, scaling_factors_ptr, + output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), im2col_ptr, + GetTensorDims(im2col)); + break; + } +} + template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); @@ -484,7 +600,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // separate ops to avoid dispatch overhead here. switch (input->type) { // Already know in/outtypes are same. case kTfLiteFloat32: - if (data->run_multithreaded_kernel) { + if (filter->type == kTfLiteUInt8) { + EvalHybrid(context, node, params, data, input, filter, + bias, im2col, hwcn_weights, output); + } else if (data->run_multithreaded_kernel) { EvalFloat(context, node, params, data, input, filter, bias, im2col, hwcn_weights, output); } else { diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc index 98152043c99f772eea2e75c7a90bbc8332cd8100..a4b9fb1a0bf4fad18718ca3045744cc1b4962c74 100644 --- a/tensorflow/contrib/lite/kernels/conv_test.cc +++ b/tensorflow/contrib/lite/kernels/conv_test.cc @@ -142,6 +142,41 @@ TEST_P(ConvolutionOpTest, SimpleTestFloat32) { })); } +// This test's output is equivalent to the SimpleTestFloat32 +// because we break each input into two channels, each with half of the value, +// while keeping the filters for each channel equivalent. +// +// 2 * (A/2) * B = A * B, where the left side is this new test. +TEST_P(ConvolutionOpTest, SimpleTestFloat32WithChannels) { + ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}}, + {TensorType_FLOAT32, {3, 2, 2, 2}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + // First batch + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1 + 1, 1, 1, 1, 1, 1, 1, 1, // row = 2 + // Second batch + 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1 + 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2 + }); + m.SetFilter({ + 1, 1, 2, 2, 3, 3, 4, 4, // first 2x2 filter + -1, -1, 1, 1, -1, -1, 1, 1, // second 2x2 filter + -1, -1, -1, -1, 1, 1, 1, 1 // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + })); +} + TEST_P(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) { ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 6, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, @@ -624,6 +659,116 @@ TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithDilation) { ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5})); } +class HybridConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetFilter(std::initializer_list f) { + SymmetricQuantizeAndPopulate(filter_, f); + } + + void SetBias(std::initializer_list data) { + PopulateTensor(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST_P(ConvolutionOpTest, SimpleTestHybrid) { + HybridConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_UINT8, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}}); + + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + // Example: we get 17.1577 instead of 17. + // + // Second batch: + // 1 2 3 4 -> 32 64 95 127 with scale factor 127/4. + // 1 2 3 4 32 64 95 127 + // + // First filter: + // 1 2 -> 32 64 with scale factor of 127/4. + // 3 4 95 127 + // + // The left half of the input gives us 16288. Multiply by (4/127)^2 for + // dequantization and adding 1 for the bias gives us the result. and adding + // the bias gives us the result. + // + // The optimized kernel converts the input into this matrix via Im2Col + // + // 1 1 2 2 + // 1 1 2 2 + // 1 2 1 2 + // 3 4 3 4 + // + // and multiplies it with the filter directly. + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + { + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + }, + 0.16))); +} + +// This test's output is equivalent to the SimpleTestHybrid +// because we break each input into two channels, each with half of the value, +// while keeping the filters for each channel equivalent. +// +// 2 * (A/2) * B = A * B, where the left side is this new test. +TEST_P(ConvolutionOpTest, SimpleTestHybridWithChannels) { + HybridConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}}, + {TensorType_UINT8, {3, 2, 2, 2}}, {TensorType_FLOAT32, {}}); + + m.SetInput({ + // First batch + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1 + 1, 1, 1, 1, 1, 1, 1, 1, // row = 2 + // Second batch + 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1 + 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2 + }); + m.SetFilter({ + 1, 1, 2, 2, 3, 3, 4, 4, // first 2x2 filter + -1, -1, 1, 1, -1, -1, 1, 1, // second 2x2 filter + -1, -1, -1, -1, 1, 1, 1, 1 // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + { + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + }, + 0.16))); +} + INSTANTIATE_TEST_CASE_P( ConvolutionOpTest, ConvolutionOpTest, ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap))); diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc index d7bde0ff79bd23fa4c277dd04ec4343663e0ad00..136697f945bceb9c3bda871aacff76f19db70fc6 100644 --- a/tensorflow/contrib/lite/kernels/detection_postprocess.cc +++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc @@ -15,7 +15,7 @@ limitations under the License. #include #include #include -#include "flatbuffers/flexbuffers.h" +#include "flatbuffers/flexbuffers.h" // flatbuffers #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc index 4e0f8484a328d7d1668afd096ad3d08204fbb4a1..94c91a6bd6030eee91e045d1dd0453e4ffa72b17 100644 --- a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc +++ b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include -#include "flatbuffers/flexbuffers.h" +#include "flatbuffers/flexbuffers.h" // flatbuffers #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" diff --git a/tensorflow/contrib/lite/kernels/floor_div.cc b/tensorflow/contrib/lite/kernels/floor_div.cc new file mode 100644 index 0000000000000000000000000000000000000000..3c177ea330f2725476b956a003b84a0ed1dd0084 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/floor_div.cc @@ -0,0 +1,146 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace floor_div { +namespace { + +// Input/output tensor index. +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +// Op data for floor_div op. +struct OpData { + bool requires_broadcast; +}; + +template +T FloorDiv(T input1, T input2) { + return std::floor(std::divides()(static_cast(input1), + static_cast(input2))); +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Reinterprete the opaque data provided by user. + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + + const TfLiteType type = input1->type; + if (type != kTfLiteInt32) { + context->ReportError(context, "Currently floor_div only supports int32."); + return kTfLiteError; + } + output->type = type; + + data->requires_broadcast = !HaveSameShapes(input1, input2); + + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast, + const TfLiteTensor* input1, const TfLiteTensor* input2, + TfLiteTensor* output) { + const T* denominator_data = GetTensorData(input2); + + // Validate the denominator. + for (int i = 0; i < NumElements(input2); ++i) { + if (std::equal_to()(denominator_data[i], 0)) { + context->ReportError(context, "Division by 0"); + return kTfLiteError; + } + } + if (requires_broadcast) { + reference_ops::BroadcastBinaryFunction( + GetTensorData(input1), GetTensorDims(input1), denominator_data, + GetTensorDims(input2), GetTensorData(output), GetTensorDims(output), + FloorDiv); + } else { + reference_ops::BinaryFunction( + GetTensorData(input1), GetTensorDims(input1), + GetTensorData(input2), GetTensorDims(input2), + GetTensorData(output), GetTensorDims(output), FloorDiv); + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input1->type) { + case kTfLiteInt32: { + return EvalImpl(context, data->requires_broadcast, input1, + input2, output); + } + default: { + context->ReportError(context, "Currently floor_div only supports int32."); + return kTfLiteError; + } + } +} + +} // namespace +} // namespace floor_div + +TfLiteRegistration* Register_FLOOR_DIV() { + // Init, Free, Prepare, Eval are satisfying the Interface required by + // TfLiteRegistration. + static TfLiteRegistration r = {floor_div::Init, floor_div::Free, + floor_div::Prepare, floor_div::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/floor_div_test.cc b/tensorflow/contrib/lite/kernels/floor_div_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..eea69b61ac161ea66d62e06e6d778666f289f510 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/floor_div_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +template +class FloorDivModel : public SingleOpModel { + public: + FloorDivModel(const TensorData& input1, const TensorData& input2, + const TensorData& output) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_FLOOR_DIV, BuiltinOptions_FloorDivOptions, + CreateFloorDivOptions(builder_).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int output_; +}; + +TEST(PowOpModel, Simple) { + FloorDivModel model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {}}); + model.PopulateTensor(model.input1(), {10, 9, 11, 3}); + model.PopulateTensor(model.input2(), {2, 2, 3, 4}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAre(5, 4, 3, 0)); +} + +TEST(PowOpModel, NegativeValue) { + FloorDivModel model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {}}); + model.PopulateTensor(model.input1(), {10, -9, -11, 7}); + model.PopulateTensor(model.input2(), {2, 2, -3, -4}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAre(5, -5, 3, -2)); +} + +TEST(PowOpModel, BroadcastFloorDiv) { + FloorDivModel model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1}}, {TensorType_INT32, {}}); + model.PopulateTensor(model.input1(), {10, -9, -11, 7}); + model.PopulateTensor(model.input2(), {-3}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAre(-4, 3, 3, -3)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index 96798c900e53b06873548a40ff5e57cb49e59cbb..464163bd78da8114aba7a65d1ea2b76ed7833600 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -160,6 +160,7 @@ cc_library( ":types", ":reference_base", ":round", + ":tensor_utils", "//third_party/eigen3", "@gemmlowp", "//tensorflow/contrib/lite:builtin_op_data", @@ -191,6 +192,7 @@ cc_library( deps = [ ":quantization_util", ":strided_slice_logic", + ":tensor_utils", ":types", ":legacy_reference_base", ":round", diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index 200f2f151582c2361dd2403164d0bbe119cbed72..88a0622286bef5f8b19169abc289cc98a77edd5e 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -127,6 +127,47 @@ void LstmStep( float* cell_state_ptr, float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* output_ptr_batch) { + LstmStepWithAuxInput( + input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr, + input_to_cell_weights_ptr, input_to_output_weights_ptr, + /*aux_input_ptr_batch=*/nullptr, + /*aux_input_to_input_weights_ptr=*/nullptr, + /*aux_input_to_forget_weights_ptr=*/nullptr, + /*aux_input_to_cell_weights_ptr=*/nullptr, + /*aux_input_to_output_weights_ptr=*/nullptr, + recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr, + recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr, + cell_to_input_weights_ptr, cell_to_forget_weights_ptr, + cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, + cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, + projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, + output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, output_ptr_batch); +} + +void LstmStepWithAuxInput( + const float* input_ptr_batch, const float* input_to_input_weights_ptr, + const float* input_to_forget_weights_ptr, + const float* input_to_cell_weights_ptr, + const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch, + const float* aux_input_to_input_weights_ptr, + const float* aux_input_to_forget_weights_ptr, + const float* aux_input_to_cell_weights_ptr, + const float* aux_input_to_output_weights_ptr, + const float* recurrent_to_input_weights_ptr, + const float* recurrent_to_forget_weights_ptr, + const float* recurrent_to_cell_weights_ptr, + const float* recurrent_to_output_weights_ptr, + const float* cell_to_input_weights_ptr, + const float* cell_to_forget_weights_ptr, + const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const float* projection_weights_ptr, + const float* projection_bias_ptr, const TfLiteLSTMParams* params, + int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr, + float* cell_state_ptr, float* input_gate_scratch, + float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, + float* output_ptr_batch) { // Since we have already checked that weights are all there or none, we can // check the existense of only one to the get the condition. const bool use_cifg = (input_to_input_weights_ptr == nullptr); @@ -160,6 +201,25 @@ void LstmStep( input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1); + // If auxiliary input is available then compute aux_input_weight * aux_input + if (aux_input_ptr_batch != nullptr) { + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_input_weights_ptr, n_cell, n_input, aux_input_ptr_batch, + n_batch, input_gate_scratch, /*result_stride=*/1); + } + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_forget_weights_ptr, n_cell, n_input, aux_input_ptr_batch, + n_batch, forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_cell_weights_ptr, n_cell, n_input, aux_input_ptr_batch, + n_batch, cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_output_weights_ptr, n_cell, n_input, aux_input_ptr_batch, + n_batch, output_gate_scratch, /*result_stride=*/1); + } + // For each batch and cell: compute recurrent_weight * output_state. if (!use_cifg) { tensor_utils::MatrixBatchVectorMultiplyAccumulate( @@ -286,227 +346,362 @@ void LstmStep( int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr, float* output_state_ptr, float* cell_state_ptr, float* output_ptr_batch) { - // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. - const bool use_cifg = (input_to_input_weights_ptr == nullptr); - const bool use_peephole = (cell_to_output_weights_ptr != nullptr); - // Initialize scratch buffers with bias. - if (!use_cifg) { - tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, - input_gate_scratch); - } - tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, - forget_gate_scratch); - tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, - cell_scratch); - tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, - output_gate_scratch); - - if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { - // Save quantization and matmul computation for all zero input. - float unused_min, unused_max; - for (int b = 0; b < n_batch; ++b) { - const int offset = b * n_input; - tensor_utils::SymmetricQuantizeFloats( - input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset, - &unused_min, &unused_max, &scaling_factors[b]); + LstmStepWithAuxInput( + input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + /*aux_input_ptr_batch=*/nullptr, + /*aux_input_to_input_weights_ptr=*/nullptr, + /*aux_input_to_input_weights_scale=*/0.0f, + /*aux_input_to_forget_weights_ptr=*/nullptr, + /*aux_input_to_forget_weights_scale=*/0.0f, + /*aux_input_to_cell_weights_ptr=*/nullptr, + /*aux_input_to_cell_weights_scale=*/0.0f, + /*aux_input_to_output_weights_ptr=*/nullptr, + /*aux_input_to_output_weights_scale=*/0.0f, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, + cell_to_input_weights_ptr, cell_to_input_weights_scale, + cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + cell_to_output_weights_ptr, cell_to_output_weights_scale, + input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, + output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, + projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, + input_gate_scratch, forget_gate_scratch, cell_scratch, + output_gate_scratch, scaling_factors, product_scaling_factors, + recovered_cell_weights, quantized_input_ptr_batch, + /*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr, + quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, + output_ptr_batch); } - // For each batch and cell: compute input_weight * input. - if (!use_cifg) { - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_input_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_input_weights_ptr, n_cell, n_input, - quantized_input_ptr_batch, product_scaling_factors, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_forget_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, - product_scaling_factors, n_batch, forget_gate_scratch, - /*result_stride=*/1); - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_cell_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, - product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * input_to_output_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, - product_scaling_factors, n_batch, output_gate_scratch, - /*result_stride=*/1); - } - - if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { - // Save quantization and matmul computation for all zero input. - float unused_min, unused_max; - for (int b = 0; b < n_batch; ++b) { - const int offset = b * n_output; - tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output, - quantized_output_state_ptr + offset, - &unused_min, &unused_max, - &scaling_factors[b]); - } - // For each batch and cell: compute recurrent_weight * output_state. - if (!use_cifg) { - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_input_weights_scale; + void LstmStepWithAuxInput( + const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, + float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + float input_to_output_weights_scale, const float* aux_input_ptr_batch, + const int8_t* aux_input_to_input_weights_ptr, + float aux_input_to_input_weights_scale, + const int8_t* aux_input_to_forget_weights_ptr, + float aux_input_to_forget_weights_scale, + const int8_t* aux_input_to_cell_weights_ptr, + float aux_input_to_cell_weights_scale, + const int8_t* aux_input_to_output_weights_ptr, + float aux_input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, + float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_output, float* input_gate_scratch, float* forget_gate_scratch, + float* cell_scratch, float* output_gate_scratch, float* scaling_factors, + float* product_scaling_factors, float* recovered_cell_weights, + int8_t* quantized_input_ptr_batch, + int8_t* quantized_aux_input_ptr_batch, + int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr, + float* output_state_ptr, float* cell_state_ptr, + float* output_ptr_batch) { + // Since we have already checked that weights are all there or none, we + // can check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights_ptr == nullptr); + const bool use_peephole = (cell_to_output_weights_ptr != nullptr); + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, + n_batch, input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, + n_batch, forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, + n_batch, output_gate_scratch); + + if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_input; + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, n_input, + quantized_input_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights_ptr, n_cell, n_input, + quantized_input_ptr_batch, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights_ptr, n_cell, n_input, + quantized_input_ptr_batch, product_scaling_factors, n_batch, + forget_gate_scratch, + /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights_ptr, n_cell, n_input, + quantized_input_ptr_batch, product_scaling_factors, n_batch, + cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights_ptr, n_cell, n_input, + quantized_input_ptr_batch, product_scaling_factors, n_batch, + output_gate_scratch, + /*result_stride=*/1); } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_input_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - input_gate_scratch, /*result_stride=*/1); - } - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_forget_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_forget_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - forget_gate_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_cell_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_cell_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - cell_scratch, /*result_stride=*/1); - - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * recurrent_to_output_weights_scale; - } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - recurrent_to_output_weights_ptr, n_cell, n_output, - quantized_output_state_ptr, product_scaling_factors, n_batch, - output_gate_scratch, /*result_stride=*/1); - } - - // Save quantization and matmul computation for all zero input. - bool is_cell_state_all_zeros = - tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); - // For each batch and cell: update input gate. - if (!use_cifg) { - if (use_peephole && !is_cell_state_all_zeros) { - tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell, - cell_to_input_weights_scale, - recovered_cell_weights); - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - recovered_cell_weights, n_cell, cell_state_ptr, n_batch, - input_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, - input_gate_scratch); - } + if (aux_input_ptr_batch != nullptr && + !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_input; + tensor_utils::SymmetricQuantizeFloats( + aux_input_ptr_batch + offset, n_input, + quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * aux_input_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_input_weights_ptr, n_cell, n_input, + quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * aux_input_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_forget_weights_ptr, n_cell, n_input, + quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, + forget_gate_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * aux_input_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_cell_weights_ptr, n_cell, n_input, + quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, + cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * aux_input_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_output_weights_ptr, n_cell, n_input, + quantized_aux_input_ptr_batch, product_scaling_factors, n_batch, + output_gate_scratch, /*result_stride=*/1); + } - // For each batch and cell: update forget gate. - if (use_peephole && !is_cell_state_all_zeros) { - tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell, - cell_to_forget_weights_scale, - recovered_cell_weights); - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - recovered_cell_weights, n_cell, cell_state_ptr, n_batch, - forget_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, - forget_gate_scratch); + if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_output; + tensor_utils::SymmetricQuantizeFloats( + output_state_ptr + offset, n_output, + quantized_output_state_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + forget_gate_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + output_gate_scratch, /*result_stride=*/1); + } - // For each batch and cell: update the cell. - tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, - n_batch * n_cell, cell_state_ptr); - tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, - params->activation, cell_scratch); - if (use_cifg) { - tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, - forget_gate_scratch); - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); - } else { - tensor_utils::VectorVectorCwiseProductAccumulate( - cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); - } - if (params->cell_clip > 0.0) { - tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, - params->cell_clip, cell_state_ptr); - } + // Save quantization and matmul computation for all zero input. + bool is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole && !is_cell_state_all_zeros) { + tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell, + cell_to_input_weights_scale, + recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } - is_cell_state_all_zeros = - tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); - // For each batch and cell: update the output gate. - if (use_peephole && !is_cell_state_all_zeros) { - tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell, - cell_to_output_weights_scale, - recovered_cell_weights); - tensor_utils::VectorBatchVectorCwiseProductAccumulate( - recovered_cell_weights, n_cell, cell_state_ptr, n_batch, - output_gate_scratch); - } - tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, - output_gate_scratch); - tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, - params->activation, cell_scratch); - tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, - n_batch * n_cell, output_gate_scratch); + // For each batch and cell: update forget gate. + if (use_peephole && !is_cell_state_all_zeros) { + tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell, + cell_to_forget_weights_scale, + recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, + cell_state_ptr, n_batch * n_cell, + cell_state_ptr); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, + cell_state_ptr); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, + params->cell_clip, cell_state_ptr); + } - // For each batch: update the projection and output_state. - const bool use_projection_weight = (projection_weights_ptr != nullptr); - const bool use_projection_bias = (projection_bias_ptr != nullptr); - if (use_projection_weight) { - if (use_projection_bias) { - tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, - n_batch, output_ptr_batch); - } else { - tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); - } - if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) { - // Save quantization and matmul computation for all zero input. - float unused_min, unused_max; - for (int b = 0; b < n_batch; ++b) { - const int offset = b * n_cell; - tensor_utils::SymmetricQuantizeFloats( - output_gate_scratch + offset, n_cell, - quantized_cell_state_ptr + offset, &unused_min, &unused_max, - &scaling_factors[b]); + is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + // For each batch and cell: update the output gate. + if (use_peephole && !is_cell_state_all_zeros) { + tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell, + cell_to_output_weights_scale, + recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + output_gate_scratch); } - for (int b = 0; b < n_batch; ++b) { - product_scaling_factors[b] = - scaling_factors[b] * projection_weights_scale; + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, + output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights_ptr != nullptr); + const bool use_projection_bias = (projection_bias_ptr != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, + n_batch, output_ptr_batch); + } else { + tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); + } + if (!tensor_utils::IsZeroVector(output_gate_scratch, + n_batch * n_cell)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_cell; + tensor_utils::SymmetricQuantizeFloats( + output_gate_scratch + offset, n_cell, + quantized_cell_state_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * projection_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights_ptr, n_output, n_cell, + quantized_cell_state_ptr, product_scaling_factors, n_batch, + output_ptr_batch, + /*result_stride=*/1); + } + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, + params->proj_clip, output_ptr_batch); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_batch); } - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr, - product_scaling_factors, n_batch, output_ptr_batch, - /*result_stride=*/1); - } - if (params->proj_clip > 0.0) { - tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, - params->proj_clip, output_ptr_batch); + tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, + output_state_ptr); } - } else { - tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, - output_ptr_batch); - } - tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, - output_state_ptr); -} } // namespace kernel_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h index 2a11b37a6069367e8232350c2fc68d4c385e14ba..599850db607b0e52d9067ec18a34976df7b7407e 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -66,8 +66,7 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, // - n_input: the input size, // - n_output: the output size. // -// The pointers to the cell and output state and the output are updated. Unless -// projection is specified output and output state contain the same data. +// The pointers to the cell and output state and the output are updated. // // The pointers with the suffix "_batch" point to data aligned in batch_major // order, and each step processes batch_size many inputs from input_ptr_batch, @@ -92,6 +91,31 @@ void LstmStep( float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* output_ptr_batch); +// Same as above but includes an auxiliary input with the corresponding weights. +void LstmStepWithAuxInput( + const float* input_ptr_batch, const float* input_to_input_weights_ptr, + const float* input_to_forget_weights_ptr, + const float* input_to_cell_weights_ptr, + const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch, + const float* aux_input_to_input_weights_ptr, + const float* aux_input_to_forget_weights_ptr, + const float* aux_input_to_cell_weights_ptr, + const float* aux_input_to_output_weights_ptr, + const float* recurrent_to_input_weights_ptr, + const float* recurrent_to_forget_weights_ptr, + const float* recurrent_to_cell_weights_ptr, + const float* recurrent_to_output_weights_ptr, + const float* cell_to_input_weights_ptr, + const float* cell_to_forget_weights_ptr, + const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const float* projection_weights_ptr, + const float* projection_bias_ptr, const TfLiteLSTMParams* params, + int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr, + float* cell_state_ptr, float* input_gate_scratch, + float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, + float* output_ptr_batch); + // Same as above but with quantized weight matrices. In detail: // Input of size 'n_batch * n_input': // input_ptr_batch @@ -175,6 +199,46 @@ void LstmStep( int8_t* quantized_cell_state_ptr, float* output_state_ptr, float* cell_state_ptr, float* output_ptr_batch); +void LstmStepWithAuxInput( + const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + float input_to_output_weights_scale, const float* aux_input_ptr_batch, + const int8_t* aux_input_to_input_weights_ptr, + float aux_input_to_input_weights_scale, + const int8_t* aux_input_to_forget_weights_ptr, + float aux_input_to_forget_weights_scale, + const int8_t* aux_input_to_cell_weights_ptr, + float aux_input_to_cell_weights_scale, + const int8_t* aux_input_to_output_weights_ptr, + float aux_input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_output, float* input_gate_scratch, float* forget_gate_scratch, + float* cell_scratch, float* output_gate_scratch, float* scaling_factors, + float* product_scaling_factors, float* recovered_cell_weights, + int8_t* quantized_input_ptr_batch, int8_t* quantized_aux_input_ptr_batch, + int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr, + float* output_state_ptr, float* cell_state_ptr, float* output_ptr_batch); + } // namespace kernel_utils } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 51a9aa5a420f5d923e927ab15a9d9f365bec1866..e4bb4e0534b892fd271ccdcd58bc91ecf25807e4 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/contrib/lite/kernels/internal/round.h" #include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/internal/types.h" namespace tflite { @@ -319,6 +320,7 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data, #endif } +// Note: This to be converted to RuntimeShapes along with Conv. // legacy, for compatibility with old checked-in code template void AddBiasAndEvalActivationFunction(const float* bias_data, @@ -1934,6 +1936,85 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, output_activation_max); } +inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims, + const int8_t* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float* scaling_factors_ptr, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims, + int8_t* im2col_data, const Dims<4>& im2col_dims) { + const int batch_size = input_dims.sizes[3]; + const int filter_width = ArraySize(filter_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + + const int8* gemm_input_data = nullptr; + int num_input; + const bool need_im2col = stride_width != 1 || stride_height != 1 || + filter_width != 1 || filter_height != 1; + + if (need_im2col) { + TFLITE_DCHECK(im2col_data); + // symmetric quantization assumes zero point of 0. + const int input_zero_point = 0; + Im2col(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_height, filter_width, input_zero_point, + im2col_data, im2col_dims); + gemm_input_data = im2col_data; + num_input = im2col_dims.sizes[0] * im2col_dims.sizes[1] * + im2col_dims.sizes[2] * im2col_dims.sizes[3]; + } else { + TFLITE_DCHECK(!im2col_data); + gemm_input_data = input_data; + num_input = input_dims.sizes[0] * input_dims.sizes[1] * + input_dims.sizes[2] * input_dims.sizes[3]; + } + + // Flatten 4D matrices into 2D matrices for matrix multiplication. + + // Flatten so that each filter has its own row. + const int filter_rows = filter_dims.sizes[3]; + const int filter_cols = + filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2]; + + // In MatrixBatchVectorMultiplyAccumulate, each output value is the + // dot product of one row of the first matrix with one row of the second + // matrix. Therefore, the number of cols in each matrix are equivalent. + // + // After Im2Col, each input patch becomes a row. + const int gemm_input_cols = filter_cols; + const int gemm_input_rows = num_input / gemm_input_cols; + + const int output_cols = output_dims.sizes[0]; + const int output_rows = + output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3]; + TFLITE_DCHECK_EQ(output_cols, filter_rows); + TFLITE_DCHECK_EQ(output_rows, gemm_input_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_cols); + TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); + + // MatrixBatchVectorMultiplyAccumulate assumes that each row of the second + // input matrix has its own scale factor. This code duplicates the scale + // factors for each row in the same batch. + const int rows_per_batch = gemm_input_rows / batch_size; + for (int i = gemm_input_rows - 1; i >= 0; --i) { + scaling_factors_ptr[i] = scaling_factors_ptr[i / rows_per_batch]; + } + + tensor_utils::ZeroVector(output_data, output_rows * output_cols); + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + filter_data, filter_rows, filter_cols, gemm_input_data, + scaling_factors_ptr, /*n_batch=*/gemm_input_rows, output_data, + /*result_stride=*/1); + + AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data, + output_dims, output_activation_min, + output_activation_max); +} + template void Conv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, @@ -2142,38 +2223,6 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims, im2col_data, im2col_dims, gemm_context); } -template -inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, - int block_size, T* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("DepthToSpace"); - - const int input_depth = ArraySize(input_dims, 0); - const int input_width = ArraySize(input_dims, 1); - const int input_height = ArraySize(input_dims, 2); - - const int output_depth = ArraySize(output_dims, 0); - const int batch_size = ArraySize(output_dims, 3); - - // Number of continuous values that we can copy in one interation. - const int stride = block_size * output_depth; - - for (int batch = 0; batch < batch_size; ++batch) { - for (int in_h = 0; in_h < input_height; ++in_h) { - const T* input_ptr = input_data + Offset(input_dims, 0, 0, in_h, batch); - for (int offset_h = 0; offset_h < block_size; ++offset_h) { - const T* src = input_ptr; - for (int in_w = 0; in_w < input_width; ++in_w) { - memcpy(output_data, src, stride * sizeof(T)); - output_data += stride; - src += input_depth; - } - input_ptr += stride; - } - } - } -} - // legacy, for compatibility with old checked-in code template void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, @@ -2249,25 +2298,87 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims, } template -inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, +inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + gemmlowp::ScopedProfilingLabel label("DepthToSpace"); + + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int input_depth = input_shape.Dims(3); + const int input_width = input_shape.Dims(2); + const int input_height = input_shape.Dims(1); + + const int output_depth = output_shape.Dims(3); + const int batch_size = output_shape.Dims(0); + + // Number of continuous values that we can copy in one interation. + const int stride = op_params.block_size * output_depth; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int in_h = 0; in_h < input_height; ++in_h) { + const T* input_ptr = input_data + Offset(input_shape, batch, in_h, 0, 0); + for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) { + const T* src = input_ptr; + for (int in_w = 0; in_w < input_width; ++in_w) { + memcpy(output_data, src, stride * sizeof(T)); + output_data += stride; + src += input_depth; + } + input_ptr += stride; + } + } + } +} + +// Legacy Dims<4>. +template +inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, int block_size, T* output_data, const Dims<4>& output_dims) { + tflite::DepthToSpaceParams op_params; + op_params.block_size = block_size; + + DepthToSpace(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + +template +inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { gemmlowp::ScopedProfilingLabel label("SpaceToDepth"); - const int output_depth = ArraySize(output_dims, 0); - const int output_width = ArraySize(output_dims, 1); - const int output_height = ArraySize(output_dims, 2); + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); - const int input_depth = ArraySize(input_dims, 0); - const int batch_size = ArraySize(input_dims, 3); + const int output_depth = output_shape.Dims(3); + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + + const int input_depth = input_shape.Dims(3); + const int batch_size = input_shape.Dims(0); // Number of continuous values that we can copy in one interation. - const int stride = block_size * input_depth; + const int stride = op_params.block_size * input_depth; for (int batch = 0; batch < batch_size; ++batch) { for (int out_h = 0; out_h < output_height; ++out_h) { - T* output_ptr = output_data + Offset(output_dims, 0, 0, out_h, batch); - for (int offset_h = 0; offset_h < block_size; ++offset_h) { + T* output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0); + for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) { T* dst = output_ptr; for (int out_w = 0; out_w < output_width; ++out_w) { memcpy(dst, input_data, stride * sizeof(T)); @@ -2280,51 +2391,16 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, } } -template -void NonGlobalBatchNormalization( - const float* input_data, const Dims<4>& input_dims, const float* mean_data, - const Dims<4>& mean_dims, const float* multiplier_data, - const Dims<4>& multiplier_dims, const float* offset_data, - const Dims<4>& offset_dims, float* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("NonGlobalBatchNormalization"); - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int inner_size = MatchingFlatSizeSkipDim( - input_dims, 3, mean_dims, multiplier_dims, offset_dims, output_dims); - - for (int b = 0; b < batches; ++b) { - for (int i = 0; i < inner_size; ++i) { - *output_data = ActivationFunction( - (*input_data - mean_data[i]) * multiplier_data[i] + offset_data[i]); - ++output_data; - ++input_data; - } - } -} - -template -void GlobalBatchNormalization(const float* input_data, - const Dims<4>& input_dims, const float* mean_data, - const Dims<4>& mean_dims, - const float* multiplier_data, - const Dims<4>& multiplier_dims, - const float* offset_data, - const Dims<4>& offset_dims, float* output_data, - const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("GlobalBatchNormalization"); - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = - MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, - offset_dims, 0, output_dims, 0); +// Legacy Dims<4>. +template +inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + tflite::SpaceToDepthParams op_params; + op_params.block_size = block_size; - for (int i = 0; i < outer_size; ++i) { - for (int c = 0; c < depth; ++c) { - *output_data = ActivationFunction( - (*input_data - mean_data[c]) * multiplier_data[c] + offset_data[c]); - ++output_data; - ++input_data; - } - } + SpaceToDepth(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); } inline void Relu(const RuntimeShape& input_shape, const float* input_data, @@ -2336,11 +2412,12 @@ inline void Relu(const RuntimeShape& input_shape, const float* input_data, output = input.cwiseMax(0.0f); } -template -void L2Normalization(const float* input_data, const RuntimeShape& input_shape, - float* output_data, const RuntimeShape& output_shape) { +inline void L2Normalization(const tflite::L2NormalizationParams& op_params, + const RuntimeShape& input_shape, + const float* input_data, + const RuntimeShape& output_shape, + float* output_data) { gemmlowp::ScopedProfilingLabel label("L2Normalization"); - static_assert(Ac == FusedActivationFunctionType::kNone, ""); const int trailing_dim = input_shape.DimensionsCount() - 1; const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); @@ -2361,6 +2438,18 @@ void L2Normalization(const float* input_data, const RuntimeShape& input_shape, } } +// Legacy. +template +void L2Normalization(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + static_assert(Ac == FusedActivationFunctionType::kNone, ""); + tflite::L2NormalizationParams op_params; + // No params need to be set for float. + + L2Normalization(op_params, input_shape, input_data, output_shape, + output_data); +} + inline void GetInvSqrtQuantizedMultiplierExp(int32 input, int32* output_inv_sqrt, int* output_shift) { @@ -2409,16 +2498,18 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input, *output_shift *= kReverseShift; } -inline void L2Normalization(const uint8* input_data, +inline void L2Normalization(const tflite::L2NormalizationParams& op_params, const RuntimeShape& input_shape, - int32 input_zero_point, uint8* output_data, - const RuntimeShape& output_shape) { + const uint8* input_data, + const RuntimeShape& output_shape, + uint8* output_data) { gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit"); const int trailing_dim = input_shape.DimensionsCount() - 1; const int depth = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int32 input_zero_point = op_params.input_zero_point; for (int i = 0; i < outer_size; ++i) { int32 square_l2_norm = 0; for (int c = 0; c < depth; c++) { @@ -2444,6 +2535,18 @@ inline void L2Normalization(const uint8* input_data, } } +// Legacy. +inline void L2Normalization(const uint8* input_data, + const RuntimeShape& input_shape, + int32 input_zero_point, uint8* output_data, + const RuntimeShape& output_shape) { + tflite::L2NormalizationParams op_params; + op_params.input_zero_point = input_zero_point; + + L2Normalization(op_params, input_shape, input_data, output_shape, + output_data); +} + inline void Add(const ArithmeticParams& params, const RuntimeShape& input1_shape, const float* input1_data, const RuntimeShape& input2_shape, const float* input2_data, @@ -2725,17 +2828,16 @@ inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params, } } -inline void Mul(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float output_activation_min, float output_activation_max, - float* output_data, const Dims<4>& output_dims) { +inline void Mul(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const float* input1_data, + const RuntimeShape& input2_shape, const float* input2_data, + const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("Mul"); - TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); - TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + const float output_activation_min = params.float_activation_min; + const float output_activation_max = params.float_activation_max; int i = 0; - const int size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape); #ifdef USE_NEON const auto activation_min = vdupq_n_f32(output_activation_min); const auto activation_max = vdupq_n_f32(output_activation_max); @@ -2786,6 +2888,20 @@ inline void Mul(const float* input1_data, const Dims<4>& input1_dims, } } +// Legacy Dims<4>. +inline void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + tflite::ArithmeticParams op_params; + op_params.float_activation_min = output_activation_min; + op_params.float_activation_max = output_activation_max; + + Mul(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + // legacy, for compatibility with old checked-in code template void Mul(const float* input1_data, const Dims<4>& input1_dims, @@ -2798,13 +2914,16 @@ void Mul(const float* input1_data, const Dims<4>& input1_dims, output_activation_max, output_data, output_dims); } -inline void Mul(const int32* input1_data, const Dims<4>& input1_dims, - const int32* input2_data, const Dims<4>& input2_dims, - int32 output_activation_min, int32 output_activation_max, - int32* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Mul/int32"); +inline void Mul(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const int32* input1_data, + const RuntimeShape& input2_shape, const int32* input2_data, + const RuntimeShape& output_shape, int32* output_data) { + gemmlowp::ScopedProfilingLabel label("Mul/int32/activation"); - const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( input1_data[i] * input2_data[i], output_activation_min, @@ -2812,22 +2931,38 @@ inline void Mul(const int32* input1_data, const Dims<4>& input1_dims, } } -template -void Mul(const int32* input1_data, const Dims<4>& input1_dims, - const int32* input2_data, const Dims<4>& input2_dims, - int32* output_data, const Dims<4>& output_dims) { +// Legacy Dims<4>. +inline void Mul(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32 output_activation_min, int32 output_activation_max, + int32* output_data, const Dims<4>& output_dims) { + tflite::ArithmeticParams op_params; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + + Mul(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +inline void MulNoActivation(const ArithmeticParams& params, + const RuntimeShape& input1_shape, + const int32* input1_data, + const RuntimeShape& input2_shape, + const int32* input2_data, + const RuntimeShape& output_shape, + int32* output_data) { gemmlowp::ScopedProfilingLabel label("Mul/int32"); - TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); - auto input1_map = MapAsVector(input1_data, input1_dims); - auto input2_map = MapAsVector(input2_data, input2_dims); - auto output_map = MapAsVector(output_data, output_dims); - if (AreSameDims(input1_dims, input2_dims)) { + auto input1_map = MapAsVector(input1_data, input1_shape); + auto input2_map = MapAsVector(input2_data, input2_shape); + auto output_map = MapAsVector(output_data, output_shape); + if (input1_shape == input2_shape) { output_map.array() = input1_map.array() * input2_map.array(); - } else if (FlatSize(input2_dims) == 1) { + } else if (input2_shape.FlatSize() == 1) { auto scalar = input2_data[0]; output_map.array() = input1_map.array() * scalar; - } else if (FlatSize(input1_dims) == 1) { + } else if (input1_shape.FlatSize() == 1) { auto scalar = input1_data[0]; output_map.array() = scalar * input2_map.array(); } else { @@ -2836,14 +2971,30 @@ void Mul(const int32* input1_data, const Dims<4>& input1_dims, } } -inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, - const int16* input2_data, const Dims<4>& input2_dims, - int16* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("Mul/Int16"); +// Legacy Dims<4>. +template +void Mul(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32* output_data, const Dims<4>& output_dims) { + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + tflite::ArithmeticParams op_params; + // No parameters needed. + + MulNoActivation(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + +inline void Mul(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const int16* input1_data, + const RuntimeShape& input2_shape, const int16* input2_data, + const RuntimeShape& output_shape, int16* output_data) { + gemmlowp::ScopedProfilingLabel label("Mul/Int16/NoActivation"); // This is a copy of the reference implementation. We do not currently have a // properly optimized version. - const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -2855,17 +3006,32 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, } } +// Legacy Dims<4>. inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, const int16* input2_data, const Dims<4>& input2_dims, - int32 output_offset, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { + int16* output_data, const Dims<4>& output_dims) { + tflite::ArithmeticParams op_params; + // No parameters needed. + + Mul(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +inline void Mul(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const int16* input1_data, + const RuntimeShape& input2_shape, const int16* input2_data, + const RuntimeShape& output_shape, uint8* output_data) { gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8"); // This is a copy of the reference implementation. We do not currently have a // properly optimized version. + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + const int32 output_offset = params.output_offset; TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -2883,62 +3049,51 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, } } -// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary -// dimensionality if the runtime code does a single loop over one dimension -// that handles broadcasting as the base case. The code generator would then -// generate max(D1, D2) nested for loops. -// TODO(benoitjacob): BroadcastMul is intentionally duplicated from -// reference_ops.h. Once an optimized version is implemented and NdArrayDesc -// is no longer referenced in this file, move NdArrayDesc from types.h to -// reference_ops.h. +// Legacy Dims<4>. +inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, + const int16* input2_data, const Dims<4>& input2_dims, + int32 output_offset, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + tflite::ArithmeticParams op_params; + op_params.output_offset = output_offset; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + + Mul(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +// Legacy Dims<4>. template void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, const Dims<4>& input2_dims, T output_activation_min, T output_activation_max, T* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastMul"); + tflite::ArithmeticParams op_params; + SetActivationParams(output_activation_min, output_activation_max, &op_params); - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest stride, - // typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for the - // best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] * - input2_data[SubscriptToIndex(desc2, c, x, y, b)], - output_activation_min, output_activation_max); - } - } - } - } + BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); } +// Legacy Dims<4>. // legacy, for compatibility with old checked-in code -template -void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T* output_data, const Dims<4>& output_dims) { - T output_activation_min, output_activation_max; - GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); +template +inline void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + tflite::ArithmeticParams op_params; + float float_activation_min; + float float_activation_max; + GetActivationMinMax(Ac, &float_activation_min, &float_activation_max); + SetActivationParams(float_activation_min, float_activation_max, &op_params); - BroadcastMul(input1_data, input1_dims, input2_data, input2_dims, - output_activation_min, output_activation_max, output_data, - output_dims); + BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); } // Element-wise mul that can often be used for inner loop of broadcast Mul as @@ -4034,29 +4189,28 @@ inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape, } } -inline void LocalResponseNormalization(const float* input_data, - const Dims<4>& input_dims, int range, - float bias, float alpha, float beta, - float* output_data, - const Dims<4>& output_dims) { +inline void LocalResponseNormalization( + const tflite::LocalResponseNormalizationParams& op_params, + const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("LocalResponseNormalization"); - MatchingFlatSize(input_dims, output_dims); + MatchingFlatSize(input_shape, output_shape); - const auto data_in = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); - auto data_out = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + const auto data_in = MapAsMatrixWithLastDimAsRows(input_data, input_shape); + auto data_out = MapAsMatrixWithLastDimAsRows(output_data, output_shape); // Carry out local response normalization, vector by vector. // Since the data are stored column major, making row-wise operation // probably not memory efficient anyway, we do an explicit for loop over // the columns. - const int double_range = range * 2; + const int double_range = op_params.range * 2; Eigen::VectorXf padded_square(data_in.rows() + double_range); padded_square.setZero(); for (int r = 0; r < data_in.cols(); ++r) { // Do local response normalization for data_in(:, r) // first, compute the square and store them in buffer for repeated use - padded_square.block(range, 0, data_in.rows(), 1) = - data_in.col(r).cwiseProduct(data_in.col(r)) * alpha; + padded_square.block(op_params.range, 0, data_in.rows(), 1) = + data_in.col(r).cwiseProduct(data_in.col(r)) * op_params.alpha; // Then, compute the scale and writes them to data_out float accumulated_scale = 0; for (int i = 0; i < double_range; ++i) { @@ -4064,21 +4218,37 @@ inline void LocalResponseNormalization(const float* input_data, } for (int i = 0; i < data_in.rows(); ++i) { accumulated_scale += padded_square(i + double_range); - data_out(i, r) = bias + accumulated_scale; + data_out(i, r) = op_params.bias + accumulated_scale; accumulated_scale -= padded_square(i); } } // In a few cases, the pow computation could benefit from speedups. - if (beta == 1) { + if (op_params.beta == 1) { data_out.array() = data_in.array() * data_out.array().inverse(); - } else if (beta == 0.5) { + } else if (op_params.beta == 0.5) { data_out.array() = data_in.array() * data_out.array().sqrt().inverse(); } else { - data_out.array() = data_in.array() * data_out.array().pow(-beta); + data_out.array() = data_in.array() * data_out.array().pow(-op_params.beta); } } +// Legacy Dims<4>. +inline void LocalResponseNormalization(const float* input_data, + const Dims<4>& input_dims, int range, + float bias, float alpha, float beta, + float* output_data, + const Dims<4>& output_dims) { + tflite::LocalResponseNormalizationParams op_params; + op_params.range = range; + op_params.bias = bias; + op_params.alpha = alpha; + op_params.beta = beta; + + LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + inline void Softmax(const float* input_data, const RuntimeShape& input_shape, float beta, float* output_data, const RuntimeShape& output_shape) { @@ -5012,14 +5182,22 @@ inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, } template -inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, - DstT* output_data, const Dims<4>& output_dims) { +inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data, + const RuntimeShape& output_shape, DstT* output_data) { gemmlowp::ScopedProfilingLabel label("Cast"); - auto input_map = MapAsVector(input_data, input_dims); - auto output_map = MapAsVector(output_data, output_dims); + auto input_map = MapAsVector(input_data, input_shape); + auto output_map = MapAsVector(output_data, output_shape); output_map.array() = input_map.array().template cast(); } +// Legacy Dims<4> version. +template +void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data, + const Dims<4>& output_dims) { + Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); +} + inline void Floor(const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("Floor"); @@ -5134,12 +5312,14 @@ inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, int32 x, int32 y, int32 depth, int32 batch, + const RuntimeShape& input_shape, const float* input_data, - const Dims<4>& input_dims, - float* output_data, - const Dims<4>& output_dims) { - const int32 input_width = ArraySize(input_dims, 1); - const int32 output_width = ArraySize(output_dims, 1); + const RuntimeShape& output_shape, + float* output_data) { + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + const int32 input_width = input_shape.Dims(2); + const int32 output_width = output_shape.Dims(2); const int32 input_x_offset = (x1 - x0) * depth; const int32 input_y_offset = (y1 - y0) * depth * input_width; @@ -5147,7 +5327,6 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, const int32 output_y_offset = depth * output_width; #ifdef USE_NEON - TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); TFLITE_DCHECK(x1 >= x0); TFLITE_DCHECK(y1 >= y0); @@ -5157,7 +5336,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, const float* input_ptr = nullptr; float32x4x2_t x0y0; - input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)]; + input_ptr = &input_data[Offset(input_shape, batch, y0, x0, ic)]; x0y0.val[0] = vld1q_f32(input_ptr); x0y0.val[1] = vld1q_f32(input_ptr + 4); @@ -5177,7 +5356,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, x1y1.val[1] = vld1q_f32(input_ptr + 4); // Top left corner. - float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)]; + float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)]; vst1q_f32(output_ptr, x0y0.val[0]); vst1q_f32(output_ptr + 4, x0y0.val[1]); @@ -5216,14 +5395,15 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, } // Handle 4 input channels at a time. for (; ic <= depth - 4; ic += 4) { - const float* input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)]; + const float* input_ptr = + &input_data[Offset(input_shape, batch, y0, x0, ic)]; float32x4_t x0y0 = vld1q_f32(input_ptr); float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset); float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset); float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset); // Top left corner. - float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)]; + float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)]; vst1q_f32(output_ptr, x0y0); // Top right corner. @@ -5247,7 +5427,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, } // Handle one input channel at a time. for (; ic < depth; ic++) { - const int32 input_offset = Offset(input_dims, ic, x0, y0, batch); + const int32 input_offset = Offset(input_shape, batch, y0, x0, ic); float x0y0 = input_data[input_offset]; float x1y0 = input_data[input_offset + input_x_offset]; @@ -5255,7 +5435,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, float x1y1 = input_data[input_offset + input_x_offset + input_y_offset]; // Top left corner. - const int32 output_offset = Offset(output_dims, ic, x, y, batch); + const int32 output_offset = Offset(output_shape, batch, y, x, ic); output_data[output_offset] = x0y0; // Top right corner. @@ -5271,7 +5451,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, } #else for (int ch = 0; ch < depth; ch++) { - const int32 input_offset = Offset(input_dims, ch, x0, y0, batch); + const int32 input_offset = Offset(input_shape, batch, y0, x0, ch); float x0y0 = input_data[input_offset]; float x1y0 = input_data[input_offset + input_x_offset]; @@ -5279,7 +5459,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, float x1y1 = input_data[input_offset + input_x_offset + input_y_offset]; // Top left corner. - const int32 output_offset = Offset(output_dims, ch, x, y, batch); + const int32 output_offset = Offset(output_shape, batch, y, x, ch); output_data[output_offset] = x0y0; // Top right corner. @@ -5296,31 +5476,30 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, #endif } -inline void ResizeBilinear2x2(const float* input_data, - const Dims<4>& input_dims, float* output_data, - const Dims<4>& output_dims, int32 batches, - int32 input_height, int32 input_width, - int32 depth, int32 output_height, - int32 output_width) { +inline void ResizeBilinear2x2(int32 batches, int32 input_height, + int32 input_width, int32 depth, + int32 output_height, int32 output_width, + const RuntimeShape& input_shape, + const float* input_data, + const RuntimeShape& output_shape, + float* output_data) { for (int b = 0; b < batches; b++) { for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) { for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) { int32 x1 = std::min(x0 + 1, input_width - 1); int32 y1 = std::min(y0 + 1, input_height - 1); - ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_data, - input_dims, output_data, output_dims); + ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_shape, + input_data, output_shape, output_data); } } } } -inline void ResizeBilinearGeneric(const float* input_data, - const Dims<4>& input_dims, float* output_data, - const Dims<4>& output_dims, int32 batches, - int32 input_height, int32 input_width, - int32 depth, int32 output_height, - int32 output_width, float height_scale, - float width_scale) { +inline void ResizeBilinearGeneric( + int32 batches, int32 input_height, int32 input_width, int32 depth, + int32 output_height, int32 output_width, float height_scale, + float width_scale, const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { memset(output_data, 0, batches * output_height * output_width * depth * sizeof(float)); @@ -5337,22 +5516,22 @@ inline void ResizeBilinearGeneric(const float* input_data, float* output_ptr = &output_data[output_offset]; // Run kernel on the 4 corners of the bilinear resize algorithm. - int32 input_offset = Offset(input_dims, 0, x0, y0, b); + int32 input_offset = Offset(input_shape, b, y0, x0, 0); float scale = (1 - (input_y - y0)) * (1 - (input_x - x0)); const float* input_ptr = &input_data[input_offset]; ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); - input_offset = Offset(input_dims, 0, x1, y0, b); + input_offset = Offset(input_shape, b, y0, x1, 0); scale = (1 - (input_y - y0)) * (input_x - x0); input_ptr = &input_data[input_offset]; ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); - input_offset = Offset(input_dims, 0, x0, y1, b); + input_offset = Offset(input_shape, b, y1, x0, 0); scale = (input_y - y0) * (1 - (input_x - x0)); input_ptr = &input_data[input_offset]; ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); - input_offset = Offset(input_dims, 0, x1, y1, b); + input_offset = Offset(input_shape, b, y1, x1, 0); scale = (input_y - y0) * (input_x - x0); input_ptr = &input_data[input_offset]; ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); @@ -5365,10 +5544,10 @@ inline void ResizeBilinearGeneric(const float* input_data, template inline void ResizeBilinearGenericSmallChannel( - const T* input_data, const Dims<4>& input_dims, T* output_data, - const Dims<4>& output_dims, int32 batches, int32 input_height, - int32 input_width, int32 depth, int32 output_height, int32 output_width, - float height_scale, float width_scale) { + int32 batches, int32 input_height, int32 input_width, int32 depth, + int32 output_height, int32 output_width, float height_scale, + float width_scale, const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& output_shape, T* output_data) { memset(output_data, 0, batches * output_height * output_width * depth * sizeof(T)); @@ -5383,9 +5562,10 @@ inline void ResizeBilinearGenericSmallChannel( int32 x0 = static_cast(input_x); int32 x1 = std::min(x0 + 1, input_width - 1); - int32 input_offset[4] = { - Offset(input_dims, 0, x0, y0, b), Offset(input_dims, 0, x1, y0, b), - Offset(input_dims, 0, x0, y1, b), Offset(input_dims, 0, x1, y1, b)}; + int32 input_offset[4] = {Offset(input_shape, b, y0, x0, 0), + Offset(input_shape, b, y0, x1, 0), + Offset(input_shape, b, y1, x0, 0), + Offset(input_shape, b, y1, x1, 0)}; float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)), (1 - (input_y - y0)) * (input_x - x0), (input_y - y0) * (1 - (input_x - x0)), @@ -5403,79 +5583,123 @@ inline void ResizeBilinearGenericSmallChannel( } } -inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, +inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params, + const RuntimeShape& unextended_input_shape, + const float* input_data, + const RuntimeShape& unextended_output_size_shape, const int32* output_size_data, - const Dims<4>& output_size_dims, float* output_data, - const Dims<4>& output_dims, bool align_corners) { + const RuntimeShape& unextended_output_shape, + float* output_data) { gemmlowp::ScopedProfilingLabel label("ResizeBilinear"); - int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); - int32 input_height = ArraySize(input_dims, 2); - int32 input_width = ArraySize(input_dims, 1); - int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); - - TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); - TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); - TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); - TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); - int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; - int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_size_shape = + RuntimeShape::ExtendedShape(4, unextended_output_size_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + int32 batches = MatchingDim(input_shape, 0, output_shape, 0); + int32 input_height = input_shape.Dims(1); + int32 input_width = input_shape.Dims(2); + int32 depth = MatchingDim(input_shape, 3, output_shape, 3); + + TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1); + TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1); + TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1); + TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2); + int32 output_height = output_size_data[Offset(output_size_shape, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_shape, 0, 0, 0, 1)]; // Specialize for 2x2 upsample. - if (!align_corners && output_height == 2 * input_height && + if (!op_params.align_corners && output_height == 2 * input_height && output_width == 2 * input_width) { - ResizeBilinear2x2(input_data, input_dims, output_data, output_dims, batches, - input_height, input_width, depth, output_height, - output_width); + ResizeBilinear2x2(batches, input_height, input_width, depth, output_height, + output_width, input_shape, input_data, output_shape, + output_data); } else { float height_scale = static_cast(input_height) / output_height; float width_scale = static_cast(input_width) / output_width; - if (align_corners && output_height > 1) { + if (op_params.align_corners && output_height > 1) { height_scale = static_cast(input_height - 1) / (output_height - 1); } - if (align_corners && output_width > 1) { + if (op_params.align_corners && output_width > 1) { width_scale = static_cast(input_width - 1) / (output_width - 1); } - ResizeBilinearGeneric(input_data, input_dims, output_data, output_dims, - batches, input_height, input_width, depth, + ResizeBilinearGeneric(batches, input_height, input_width, depth, output_height, output_width, height_scale, - width_scale); + width_scale, input_shape, input_data, output_shape, + output_data); } } +// Legacy Dims<4> +inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_dims, bool align_corners) { + tflite::ResizeBilinearParams op_params; + op_params.align_corners = align_corners; + ResizeBilinear(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_size_dims), output_size_data, + DimsToShape(output_dims), output_data); +} + // TODO(prabhumk): This is not a real quantized bilinear. It does not use int8 // or int16 arithmetic. -inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, +inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params, + const RuntimeShape& input_shape, + const uint8* input_data, + const RuntimeShape& output_size_shape, const int32* output_size_data, - const Dims<4>& output_size_dims, uint8* output_data, - const Dims<4>& output_dims, bool align_corners) { + const RuntimeShape& output_shape, + uint8* output_data) { gemmlowp::ScopedProfilingLabel label("ResizeBilinear"); - int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); - int32 input_height = ArraySize(input_dims, 2); - int32 input_width = ArraySize(input_dims, 1); - int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); - - TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); - TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); - TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); - TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); - int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; - int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_size_shape.DimensionsCount(), 4); + TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); + + int32 batches = MatchingDim(input_shape, 0, output_shape, 0); + int32 input_height = input_shape.Dims(1); + int32 input_width = input_shape.Dims(2); + int32 depth = MatchingDim(input_shape, 3, output_shape, 3); + + TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1); + TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1); + TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1); + TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2); + int32 output_height = output_size_data[Offset(output_size_shape, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_shape, 0, 0, 0, 1)]; float height_scale = - (align_corners && output_height > 1) + (op_params.align_corners && output_height > 1) ? (static_cast(input_height - 1) / (output_height - 1)) : (static_cast(input_height) / output_height); float width_scale = - (align_corners && output_width > 1) + (op_params.align_corners && output_width > 1) ? (static_cast(input_width - 1) / (output_width - 1)) : (static_cast(input_width) / output_width); ResizeBilinearGenericSmallChannel( - input_data, input_dims, output_data, output_dims, batches, input_height, - input_width, depth, output_height, output_width, height_scale, - width_scale); + batches, input_height, input_width, depth, output_height, output_width, + height_scale, width_scale, input_shape, input_data, output_shape, + output_data); +} + +// Legacy Dims<4> +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, + const Dims<4>& output_dims, bool align_corners) { + tflite::ResizeBilinearParams op_params; + op_params.align_corners = align_corners; + ResizeBilinear(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_size_dims), output_size_data, + DimsToShape(output_dims), output_data); } // legacy, for compatibility with old checked-in code @@ -5518,20 +5742,29 @@ inline void GetIndexRange(int spatial_index_dim, int block_shape_dim, } template -inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, - const int32* block_shape_data, - const Dims<4>& block_shape_dims, - const int32* crops_data, const Dims<4>& crops_dims, - T* output_data, const Dims<4>& output_dims) { +inline void BatchToSpaceND( + const RuntimeShape& unextended_input1_shape, const T* input1_data, + const RuntimeShape& unextended_input2_shape, const int32* block_shape_data, + const RuntimeShape& unextended_input3_shape, const int32* crops_data, + const RuntimeShape& unextended_output_shape, T* output_data) { gemmlowp::ScopedProfilingLabel label("BatchToSpaceND"); - const int output_batch_size = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int input_batch_size = ArraySize(input_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int depth = ArraySize(input_dims, 0); + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input1_shape = + RuntimeShape::ExtendedShape(4, unextended_input1_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_batch_size = output_shape.Dims(0); + + const int depth = input1_shape.Dims(3); + const int input_width = input1_shape.Dims(2); + const int input_height = input1_shape.Dims(1); + const int input_batch_size = input1_shape.Dims(0); + const int block_shape_width = block_shape_data[1]; const int block_shape_height = block_shape_data[0]; const int crops_top = crops_data[0]; @@ -5566,14 +5799,28 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, spatial_offset % block_shape_width - crops_left; TFLITE_DCHECK_GE(out_w, 0); TFLITE_DCHECK_LT(out_w, output_width); - T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch); - const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch); + T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0); + const T* in = + input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0); memcpy(out, in, depth * sizeof(T)); } } } } +// Legacy Dims<4>. +template +inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* crops_data, const Dims<4>& crops_dims, + T* output_data, const Dims<4>& output_dims) { + BatchToSpaceND(DimsToShape(input_dims), input_data, + DimsToShape(block_shape_dims), block_shape_data, + DimsToShape(crops_dims), crops_data, DimsToShape(output_dims), + output_data); +} + template void TypedMemset(void* ptr, T value, size_t num) { // Optimization for common cases where memset() will suffice. diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index b241ecbcf5acdb0bb8bc15c21a452c1607ee2298..3875b73e05c35677d65f6578b0509d7bbb95b999 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -407,18 +407,29 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims, } template -inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, - int block_size, T* output_data, - const Dims<4>& output_dims) { - const int input_depth = ArraySize(input_dims, 0); - const int input_width = ArraySize(input_dims, 1); - const int input_height = ArraySize(input_dims, 2); - const int input_batch = ArraySize(input_dims, 3); +inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int input_depth = input_shape.Dims(3); + const int input_width = input_shape.Dims(2); + const int input_height = input_shape.Dims(1); + const int input_batch = input_shape.Dims(0); - const int output_depth = ArraySize(output_dims, 0); - const int output_width = ArraySize(output_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_batch = ArraySize(output_dims, 3); + const int output_depth = output_shape.Dims(3); + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_batch = output_shape.Dims(0); + + const int32 block_size = op_params.block_size; TFLITE_DCHECK_EQ(input_width * block_size, output_width); TFLITE_DCHECK_EQ(input_height * block_size, output_height); @@ -437,9 +448,9 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, const int in_h = out_h / block_size; const int in_b = out_b; + const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d); const int output_index = - Offset(output_dims, out_d, out_w, out_h, out_b); - const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b); + Offset(output_shape, out_b, out_h, out_w, out_d); output_data[output_index] = input_data[input_index]; } @@ -448,19 +459,42 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, } } +// Legacy Dims<4>. template -inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, +inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, int block_size, T* output_data, const Dims<4>& output_dims) { - const int input_depth = ArraySize(input_dims, 0); - const int input_width = ArraySize(input_dims, 1); - const int input_height = ArraySize(input_dims, 2); - const int input_batch = ArraySize(input_dims, 3); + tflite::DepthToSpaceParams op_params; + op_params.block_size = block_size; - const int output_depth = ArraySize(output_dims, 0); - const int output_width = ArraySize(output_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_batch = ArraySize(output_dims, 3); + DepthToSpace(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + +template +inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int input_depth = input_shape.Dims(3); + const int input_width = input_shape.Dims(2); + const int input_height = input_shape.Dims(1); + const int input_batch = input_shape.Dims(0); + + const int output_depth = output_shape.Dims(3); + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_batch = output_shape.Dims(0); + + const int32 block_size = op_params.block_size; TFLITE_DCHECK_EQ(input_width, output_width * block_size); TFLITE_DCHECK_EQ(input_height, output_height * block_size); @@ -478,9 +512,9 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, const int out_h = in_h / block_size; const int out_b = in_b; + const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d); const int output_index = - Offset(output_dims, out_d, out_w, out_h, out_b); - const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b); + Offset(output_shape, out_b, out_h, out_w, out_d); output_data[output_index] = input_data[input_index]; } @@ -489,6 +523,18 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, } } +// Legacy Dims<4>. +template +inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + tflite::SpaceToDepthParams op_params; + op_params.block_size = block_size; + + SpaceToDepth(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, const float* weights_data, const Dims<4>& weights_dims, const float* bias_data, @@ -803,49 +849,6 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, output_activation_max, output_data, output_dims, gemm_context); } -template -void NonGlobalBatchNormalization( - const float* input_data, const Dims<4>& input_dims, const float* mean_data, - const Dims<4>& mean_dims, const float* multiplier_data, - const Dims<4>& multiplier_dims, const float* offset_data, - const Dims<4>& offset_dims, float* output_data, - const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int inner_size = MatchingFlatSizeSkipDim( - input_dims, 3, mean_dims, multiplier_dims, offset_dims, output_dims); - - for (int b = 0; b < batches; ++b) { - for (int i = 0; i < inner_size; ++i) { - output_data[b * inner_size + i] = ActivationFunction( - (input_data[b * inner_size + i] - mean_data[i]) * multiplier_data[i] + - offset_data[i]); - } - } -} - -template -void GlobalBatchNormalization(const float* input_data, - const Dims<4>& input_dims, const float* mean_data, - const Dims<4>& mean_dims, - const float* multiplier_data, - const Dims<4>& multiplier_dims, - const float* offset_data, - const Dims<4>& offset_dims, float* output_data, - const Dims<4>& output_dims) { - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = - MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, - offset_dims, 0, output_dims, 0); - - for (int i = 0; i < outer_size; ++i) { - for (int c = 0; c < depth; ++c) { - output_data[depth * i + c] = ActivationFunction( - (input_data[depth * i + c] - mean_data[c]) * multiplier_data[c] + - offset_data[c]); - } - } -} - inline void Relu(const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { const int flat_size = MatchingFlatSize(input_shape, output_shape); @@ -883,11 +886,14 @@ inline void Relu6(const RuntimeShape& input_shape, const float* input_data, } } -inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data, - const RuntimeShape& input_shape, uint8* output_data, - const RuntimeShape& output_shape) { +inline void ReluX(const tflite::ActivationParams& params, + const RuntimeShape& input_shape, const uint8* input_data, + + const RuntimeShape& output_shape, uint8* output_data) { gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)"); const int flat_size = MatchingFlatSize(input_shape, output_shape); + const uint8 max_value = params.quantized_activation_max; + const uint8 min_value = params.quantized_activation_min; for (int i = 0; i < flat_size; ++i) { const uint8 val = input_data[i]; const uint8 clamped = @@ -896,10 +902,21 @@ inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data, } } -template -void L2Normalization(const float* input_data, const RuntimeShape& input_shape, - float* output_data, const RuntimeShape& output_shape) { - static_assert(Ac == FusedActivationFunctionType::kNone, ""); +// Legacy. +inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data, + const RuntimeShape& input_shape, uint8* output_data, + const RuntimeShape& output_shape) { + tflite::ActivationParams params; + params.quantized_activation_max = max_value; + params.quantized_activation_min = min_value; + ReluX(params, input_shape, input_data, output_shape, output_data); +} + +inline void L2Normalization(const tflite::L2NormalizationParams& op_params, + const RuntimeShape& input_shape, + const float* input_data, + const RuntimeShape& output_shape, + float* output_data) { const int trailing_dim = input_shape.DimensionsCount() - 1; const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); @@ -918,6 +935,18 @@ void L2Normalization(const float* input_data, const RuntimeShape& input_shape, } } +// Legacy . +template +void L2Normalization(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + static_assert(Ac == FusedActivationFunctionType::kNone, ""); + tflite::L2NormalizationParams op_params; + // No params need to be set for float. + + L2Normalization(op_params, input_shape, input_data, output_shape, + output_data); +} + inline void GetInvSqrtQuantizedMultiplierExp(int32 input, int32* output_inv_sqrt, int* output_shift) { @@ -966,15 +995,17 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input, *output_shift *= kReverseShift; } -inline void L2Normalization(const uint8* input_data, +inline void L2Normalization(const tflite::L2NormalizationParams& op_params, const RuntimeShape& input_shape, - int32 input_zero_point, uint8* output_data, - const RuntimeShape& output_shape) { + const uint8* input_data, + const RuntimeShape& output_shape, + uint8* output_data) { const int trailing_dim = input_shape.DimensionsCount() - 1; const int depth = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int32 input_zero_point = op_params.input_zero_point; for (int i = 0; i < outer_size; ++i) { int32 square_l2_norm = 0; for (int c = 0; c < depth; c++) { @@ -997,6 +1028,18 @@ inline void L2Normalization(const uint8* input_data, } } +// Legacy. +inline void L2Normalization(const uint8* input_data, + const RuntimeShape& input_shape, + int32 input_zero_point, uint8* output_data, + const RuntimeShape& output_shape) { + tflite::L2NormalizationParams op_params; + op_params.input_zero_point = input_zero_point; + + L2Normalization(op_params, input_shape, input_data, output_shape, + output_data); +} + template inline void Add(const ArithmeticParams& params, const RuntimeShape& input1_shape, const T* input1_data, @@ -1320,11 +1363,16 @@ inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params, } template -inline void Mul(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T output_activation_min, T output_activation_max, - T* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); +inline void Mul(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const T* input1_data, + const RuntimeShape& input2_shape, const T* input2_data, + const RuntimeShape& output_shape, T* output_data) { + T output_activation_min; + T output_activation_max; + GetActivationParams(params, &output_activation_min, &output_activation_max); + + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; ++i) { output_data[i] = ActivationFunctionWithMinMax( input1_data[i] * input2_data[i], output_activation_min, @@ -1332,6 +1380,20 @@ inline void Mul(const T* input1_data, const Dims<4>& input1_dims, } } +// Legacy Dims<4>. +template +inline void Mul(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { + tflite::ArithmeticParams op_params; + SetActivationParams(output_activation_min, output_activation_max, &op_params); + + Mul(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + // legacy, for compatibility with old checked-in code template void Mul(const float* input1_data, const Dims<4>& input1_dims, @@ -1340,44 +1402,65 @@ void Mul(const float* input1_data, const Dims<4>& input1_dims, float output_activation_min, output_activation_max; GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, - output_activation_max, output_data, output_dims); + tflite::ArithmeticParams op_params; + SetActivationParams(output_activation_min, output_activation_max, &op_params); + + Mul(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); } // TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then // generate max(D1, D2) nested for loops. +// TODO(benoitjacob): BroadcastMul is intentionally duplicated from +// reference_ops.h. Once an optimized version is implemented and NdArrayDesc +// is no longer referenced in this file, move NdArrayDesc from types.h to +// reference_ops.h. template -void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, - const T* input2_data, const Dims<4>& input2_dims, - T output_activation_min, T output_activation_max, - T* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastMul"); +void BroadcastMul4DSlow(const ArithmeticParams& params, + const RuntimeShape& unextended_input1_shape, + const T* input1_data, + const RuntimeShape& unextended_input2_shape, + const T* input2_data, + const RuntimeShape& unextended_output_shape, + T* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul4DSlow"); + T output_activation_min; + T output_activation_max; + GetActivationParams(params, &output_activation_min, &output_activation_max); + + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, + unextended_input2_shape, &desc1, &desc2); // In Tensorflow, the dimensions are canonically named (batch_number, row, // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest - // stride, typically 1 element). + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). // // In generated C code, we store arrays with the dimensions reversed. The // first dimension has smallest stride. // // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for - // the best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < output_shape.Dims(0); ++b) { + for (int y = 0; y < output_shape.Dims(1); ++y) { + for (int x = 0; x < output_shape.Dims(2); ++x) { + for (int c = 0; c < output_shape.Dims(3); ++c) { + output_data[Offset(output_shape, b, y, x, c)] = ActivationFunctionWithMinMax( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] * - input2_data[SubscriptToIndex(desc2, c, x, y, b)], + input1_data[SubscriptToIndex(desc1, b, y, x, c)] * + input2_data[SubscriptToIndex(desc2, b, y, x, c)], output_activation_min, output_activation_max); } } @@ -1385,6 +1468,20 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, } } +// Legacy. +template +void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { + tflite::ArithmeticParams op_params; + SetActivationParams(output_activation_min, output_activation_max, &op_params); + + BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + // legacy, for compatibility with old checked-in code template void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, @@ -1393,9 +1490,12 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, T output_activation_min, output_activation_max; GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); - BroadcastMul(input1_data, input1_dims, input2_data, input2_dims, - output_activation_min, output_activation_max, output_data, - output_dims); + tflite::ArithmeticParams op_params; + SetActivationParams(output_activation_min, output_activation_max, &op_params); + + BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); } // Element-wise mul that can often be used for inner loop of broadcast Mul as @@ -1526,6 +1626,7 @@ inline void BroadcastMul4DSlow(const ArithmeticParams& params, } } +// Legacy. // Transitional version that will be moved shortly to legacy_reference_ops, as // part of RuntimeShape revisions. inline void BroadcastMul4DSlow(const uint8* input1_data, @@ -1536,52 +1637,27 @@ inline void BroadcastMul4DSlow(const uint8* input1_data, int output_shift, int32 output_activation_min, int32 output_activation_max, uint8* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit"); - - NdArrayDesc<4> desc1; - NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - - // In Tensorflow, the dimensions are canonically named (batch_number, row, - // col, channel), with extents (batches, height, width, depth), with the - // trailing dimension changing most rapidly (channels has the smallest - // stride, typically 1 element). - // - // In generated C code, we store arrays with the dimensions reversed. The - // first dimension has smallest stride. - // - // We name our variables by their Tensorflow convention, but generate C code - // nesting loops such that the innermost loop has the smallest stride for - // the best cache behavior. - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - const int32 input1_val = - input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; - const int32 input2_val = - input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; - const int32 unclamped_result = - output_offset + - MultiplyByQuantizedMultiplierSmallerThanOneExp( - input1_val * input2_val, output_multiplier, output_shift); - const int32 clamped_output = - std::min(output_activation_max, - std::max(output_activation_min, unclamped_result)); - output_data[Offset(output_dims, c, x, y, b)] = - static_cast(clamped_output); - } - } - } - } + tflite::ArithmeticParams op_params; + SetActivationParams(output_activation_min, output_activation_max, &op_params); + op_params.input1_offset = input1_offset; + op_params.input2_offset = input2_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = output_multiplier; + op_params.output_shift = output_shift; + + BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); } -inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, - const int16* input2_data, const Dims<4>& input2_dims, - int16* output_data, const Dims<4>& output_dims) { +inline void Mul(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const int16* input1_data, + const RuntimeShape& input2_shape, const int16* input2_data, + const RuntimeShape& output_shape, int16* output_data) { gemmlowp::ScopedProfilingLabel label("Mul/Int16"); - const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -1593,15 +1669,30 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, } } +// Legacy Dims<4>. inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, const int16* input2_data, const Dims<4>& input2_dims, - int32 output_offset, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { + int16* output_data, const Dims<4>& output_dims) { + tflite::ArithmeticParams op_params; + // No params in this version. + + Mul(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +inline void Mul(const ArithmeticParams& params, + const RuntimeShape& input1_shape, const int16* input1_data, + const RuntimeShape& input2_shape, const int16* input2_data, + const RuntimeShape& output_shape, uint8* output_data) { gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8"); + int32 output_offset = params.output_offset; + int32 output_activation_min = params.quantized_activation_min; + int32 output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_LE(output_activation_min, output_activation_max); - const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims); + const int flat_size = + MatchingFlatSize(input1_shape, input2_shape, output_shape); for (int i = 0; i < flat_size; i++) { // F0 uses 0 integer bits, range [-1, 1]. @@ -1619,6 +1710,22 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, } } +// Legacy Dims<4>. +inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, + const int16* input2_data, const Dims<4>& input2_dims, + int32 output_offset, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + tflite::ArithmeticParams op_params; + op_params.quantized_activation_min = output_activation_min; + op_params.quantized_activation_max = output_activation_max; + op_params.output_offset = output_offset; + + Mul(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + // TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then @@ -2021,6 +2128,25 @@ void Pack(int dim, const Scalar* const* input_data, } } +template +void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims, + int dimensions, int outputs_count, Scalar* const* output_datas, + const Dims<4>& output_dims) { + int outer_size = 1; + for (int i = dimensions - axis; i < 4; i++) { + outer_size *= input_dims.sizes[i]; + } + + const int copy_size = FlatSize(input_dims) / outer_size / outputs_count; + for (int k = 0; k < outer_size; k++) { + for (int i = 0; i < outputs_count; ++i) { + Scalar* output_ptr = output_datas[i] + copy_size * k; + int loc = k * outputs_count * copy_size + i * copy_size; + memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar)); + } + } +} + // TODO(prabhumk): This is the same as the optimized implementation. // TODO(prabhumk): The quantized implementation of concatentation isn't fully // quantized as it takes scale as a floating point value. This should be fixed @@ -2758,29 +2884,48 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, } } -inline void LocalResponseNormalization(const float* input_data, - const Dims<4>& input_dims, int range, - float bias, float alpha, float beta, - float* output_data, - const Dims<4>& output_dims) { - const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); +inline void LocalResponseNormalization( + const tflite::LocalResponseNormalizationParams& op_params, + const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + const int trailing_dim = input_shape.DimensionsCount() - 1; + const int outer_size = + MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); + const int depth = + MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int i = 0; i < outer_size; ++i) { for (int c = 0; c < depth; ++c) { - const int begin_input_c = std::max(0, c - range); - const int end_input_c = std::min(depth, c + range); + const int begin_input_c = std::max(0, c - op_params.range); + const int end_input_c = std::min(depth, c + op_params.range); float accum = 0.f; for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) { const float input_val = input_data[i * depth + input_c]; accum += input_val * input_val; } - const float multiplier = std::pow(bias + alpha * accum, -beta); + const float multiplier = + std::pow(op_params.bias + op_params.alpha * accum, -op_params.beta); output_data[i * depth + c] = input_data[i * depth + c] * multiplier; } } } +// Legacy Dims<4>. +inline void LocalResponseNormalization(const float* input_data, + const Dims<4>& input_dims, int range, + float bias, float alpha, float beta, + float* output_data, + const Dims<4>& output_dims) { + tflite::LocalResponseNormalizationParams op_params; + op_params.range = range; + op_params.bias = bias; + op_params.alpha = alpha; + op_params.beta = beta; + + LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + inline void Softmax(const float* input_data, const RuntimeShape& input_shape, float beta, float* output_data, const RuntimeShape& output_shape) { @@ -3310,9 +3455,9 @@ inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, } template -inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, - DstT* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data, + const RuntimeShape& output_shape, DstT* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { int offset = i; @@ -3320,9 +3465,17 @@ inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, } } -inline void Floor(const float* input_data, const Dims<4>& input_dims, - float* output_data, const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +// Legacy Dims<4> version. +template +void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data, + const Dims<4>& output_dims) { + Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); +} + +inline void Floor(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { int offset = i; @@ -3330,6 +3483,13 @@ inline void Floor(const float* input_data, const Dims<4>& input_dims, } } +// Legacy Dims<4> version. +inline void Floor(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); +} + template inline void Gather(const T* input_data, const Dims<4>& input_dims, int input_rank, const int32* coords_data, @@ -3349,27 +3509,41 @@ inline void Gather(const T* input_data, const Dims<4>& input_dims, } template -inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims, +inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_size_shape, const int32* output_size_data, - const Dims<4>& output_size_dims, T* output_data, - const Dims<4>& output_dims, bool align_corners) { - int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); - int32 input_height = ArraySize(input_dims, 2); - int32 input_width = ArraySize(input_dims, 1); - int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); - - TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); - TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); - TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); - TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); - int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; - int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + const RuntimeShape& unextended_output_shape, + T* output_data) { + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_size_shape = + RuntimeShape::ExtendedShape(4, unextended_output_size_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + int32 batches = MatchingDim(input_shape, 0, output_shape, 0); + int32 input_height = input_shape.Dims(1); + int32 input_width = input_shape.Dims(2); + int32 depth = MatchingDim(input_shape, 3, output_shape, 3); + + TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1); + TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1); + TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1); + TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2); + int32 output_height = output_size_data[Offset(output_size_shape, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_shape, 0, 0, 0, 1)]; + float height_scale = static_cast(input_height) / output_height; float width_scale = static_cast(input_width) / output_width; - if (align_corners && output_height > 1) { + if (op_params.align_corners && output_height > 1) { height_scale = static_cast(input_height - 1) / (output_height - 1); } - if (align_corners && output_width > 1) { + if (op_params.align_corners && output_width > 1) { width_scale = static_cast(input_width - 1) / (output_width - 1); } @@ -3384,21 +3558,34 @@ inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims, int32 x1 = std::min(x0 + 1, input_width - 1); for (int c = 0; c < depth; ++c) { T interpolation = - static_cast(input_data[Offset(input_dims, c, x0, y0, b)] * + static_cast(input_data[Offset(input_shape, b, y0, x0, c)] * (1 - (input_y - y0)) * (1 - (input_x - x0)) + - input_data[Offset(input_dims, c, x0, y1, b)] * + input_data[Offset(input_shape, b, y1, x0, c)] * (input_y - y0) * (1 - (input_x - x0)) + - input_data[Offset(input_dims, c, x1, y0, b)] * + input_data[Offset(input_shape, b, y0, x1, c)] * (1 - (input_y - y0)) * (input_x - x0) + - input_data[Offset(input_dims, c, x1, y1, b)] * + input_data[Offset(input_shape, b, y1, x1, c)] * (input_y - y0) * (input_x - x0)); - output_data[Offset(output_dims, c, x, y, b)] = interpolation; + output_data[Offset(output_shape, b, y, x, c)] = interpolation; } } } } } +// Legacy Dims<4>. +template +inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, T* output_data, + const Dims<4>& output_dims, bool align_corners) { + tflite::ResizeBilinearParams op_params; + op_params.align_corners = align_corners; + ResizeBilinear(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_size_dims), output_size_data, + DimsToShape(output_dims), output_data); +} + // legacy, for compatibility with old checked-in code inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, @@ -3409,6 +3596,7 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, /*align_corners=*/false); } +// Legacy. inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, uint8* output_data, @@ -3419,45 +3607,56 @@ inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, } template -inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, - const int32* block_shape_data, - const Dims<4>& block_shape_dims, - const int32* paddings_data, - const Dims<4>& paddings_dims, T* output_data, - const Dims<4>& output_dims, - const int32_t pad_value) { - const int output_batch_size = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int input_batch_size = ArraySize(input_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int depth = ArraySize(input_dims, 0); +inline void SpaceToBatchND( + const SpaceToBatchParams& params, + const RuntimeShape& unextended_input1_shape, const T* input1_data, + const RuntimeShape& unextended_input2_shape, const int32* block_shape_data, + const RuntimeShape& unextended_input3_shape, const int32* paddings_data, + const RuntimeShape& unextended_output_shape, T* output_data) { + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input1_shape = + RuntimeShape::ExtendedShape(4, unextended_input1_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int depth = input1_shape.Dims(3); + const int input_width = input1_shape.Dims(2); + const int input_height = input1_shape.Dims(1); + const int input_batch_size = input1_shape.Dims(0); + + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_batch_size = output_shape.Dims(0); + const int block_shape_height = block_shape_data[0]; const int block_shape_width = block_shape_data[1]; const int padding_top = paddings_data[0]; const int padding_left = paddings_data[2]; + // For uint8 quantized, the correct padding "zero value" is the output offset. + const int32_t pad_value = params.output_offset; + for (int out_b = 0; out_b < output_batch_size; ++out_b) { int input_batch = out_b % input_batch_size; int shift_w = (out_b / input_batch_size) % block_shape_width; int shift_h = (out_b / input_batch_size) / block_shape_width; for (int out_h = 0; out_h < output_height; ++out_h) { for (int out_w = 0; out_w < output_width; ++out_w) { - T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); + T* out = output_data + Offset(output_shape, out_b, out_h, out_w, 0); if (out_h * block_shape_height + shift_h < padding_top || out_h * block_shape_height + shift_h >= padding_top + input_height || out_w * block_shape_width + shift_w < padding_left || out_w * block_shape_width + shift_w >= padding_left + input_width) { + // This may not execute correctly when pad_value != 0 and T != uint8. memset(out, pad_value, depth * sizeof(T)); } else { const T* in = - input_data + - Offset(input_dims, 0, - (out_w * block_shape_width + shift_w) - padding_left, + input1_data + + Offset(input1_shape, input_batch, (out_h * block_shape_height + shift_h) - padding_top, - input_batch); + (out_w * block_shape_width + shift_w) - padding_left, 0); memcpy(out, in, depth * sizeof(T)); } } @@ -3465,30 +3664,63 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, } } +// Legacy Dims<4>. template inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, const Dims<4>& block_shape_dims, const int32* paddings_data, const Dims<4>& paddings_dims, T* output_data, - const Dims<4>& output_dims) { - SpaceToBatchND(input_data, input_dims, block_shape_data, block_shape_dims, - paddings_data, paddings_dims, output_data, output_dims, 0); + const Dims<4>& output_dims, + const int32_t pad_value) { + tflite::SpaceToBatchParams op_params; + op_params.output_offset = pad_value; + + SpaceToBatchND(op_params, DimsToShape(input_dims), input_data, + DimsToShape(block_shape_dims), block_shape_data, + DimsToShape(paddings_dims), paddings_data, + DimsToShape(output_dims), output_data); } +// Legacy if no good reason to have signature with pad_value=0. template -inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, +inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, const Dims<4>& block_shape_dims, - const int32* crops_data, const Dims<4>& crops_dims, - T* output_data, const Dims<4>& output_dims) { - const int output_batch_size = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int input_batch_size = ArraySize(input_dims, 3); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int depth = ArraySize(input_dims, 0); + const int32* paddings_data, + const Dims<4>& paddings_dims, T* output_data, + const Dims<4>& output_dims) { + tflite::SpaceToBatchParams op_params; + op_params.output_offset = 0; + + SpaceToBatchND(op_params, DimsToShape(input_dims), input_data, + DimsToShape(block_shape_dims), block_shape_data, + DimsToShape(paddings_dims), paddings_data, + DimsToShape(output_dims), output_data); +} + +template +inline void BatchToSpaceND( + const RuntimeShape& unextended_input1_shape, const T* input1_data, + const RuntimeShape& unextended_input2_shape, const int32* block_shape_data, + const RuntimeShape& unextended_input3_shape, const int32* crops_data, + const RuntimeShape& unextended_output_shape, T* output_data) { + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input1_shape = + RuntimeShape::ExtendedShape(4, unextended_input1_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int output_width = output_shape.Dims(2); + const int output_height = output_shape.Dims(1); + const int output_batch_size = output_shape.Dims(0); + + const int depth = input1_shape.Dims(3); + const int input_width = input1_shape.Dims(2); + const int input_height = input1_shape.Dims(1); + const int input_batch_size = input1_shape.Dims(0); + const int block_shape_width = block_shape_data[1]; const int block_shape_height = block_shape_data[0]; const int crops_top = crops_data[0]; @@ -3510,14 +3742,28 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, if (out_w < 0 || out_w >= output_width) { continue; } - T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch); - const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch); + T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0); + const T* in = + input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0); memcpy(out, in, depth * sizeof(T)); } } } } +// Legacy Dims<4>. +template +inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* crops_data, const Dims<4>& crops_dims, + T* output_data, const Dims<4>& output_dims) { + BatchToSpaceND(DimsToShape(input_dims), input_data, + DimsToShape(block_shape_dims), block_shape_data, + DimsToShape(crops_dims), crops_data, DimsToShape(output_dims), + output_data); +} + // There are two versions of pad: Pad and PadV2. In PadV2 there is a second // scalar input that provides the padding value. Therefore pad_value_ptr can be // equivalent to a simple input1_data. For Pad, it should point to a zero @@ -3962,6 +4208,23 @@ inline bool ReduceProd(const T* input_data, const int* input_dims, resolved_axis, init_value, reducer); } +// Computes the logical_or of elements across dimensions given in axis. +inline bool ReduceAny(const bool* input_data, const int* input_dims, + const int input_num_dims, bool* output_data, + const int* output_dims, const int output_num_dims, + const int* axis, const int64_t num_axis_dimensions, + bool keep_dims, int* temp_index, int* resolved_axis) { + bool init_value = false; + + auto reducer = [](const bool current, const bool in) -> bool { + return current || in; + }; + return ReduceGeneric(input_data, input_dims, input_num_dims, + output_data, output_dims, output_num_dims, axis, + num_axis_dimensions, keep_dims, temp_index, + resolved_axis, init_value, reducer); +} + // Computes the mean of elements across dimensions given in axis. // It does so in two stages, first calculates the sum of elements along the axis // then divides it by the number of element in axis. @@ -4053,6 +4316,70 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims, } } +// Computes the mean of elements across dimensions given in axis. +// It does so in two stages, first calculates the sum of elements along the axis +// then divides it by the number of element in axis for quantized values. +template +inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale, + const int* input_dims, const int input_num_dims, + T* output_data, int32 output_zero_point, float output_scale, + const int* output_dims, const int output_num_dims, + const int* axis, const int num_axis_dimensions, bool keep_dims, + int* temp_index, int* resolved_axis, U* temp_sum) { + // Reset output data. + size_t num_outputs = 1; + for (int idx = 0; idx < output_num_dims; ++idx) { + size_t current = static_cast(output_dims[idx]); + // Overflow prevention. + if (num_outputs > std::numeric_limits::max() / current) { + return false; + } + num_outputs *= current; + } + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = T(); + temp_sum[idx] = U(); + } + + // Resolve axis. + int num_resolved_axis = 0; + if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, + &num_resolved_axis)) { + return false; + } + + if (!ReduceSumImpl(input_data, input_dims, output_dims, input_num_dims, + output_num_dims, resolved_axis, num_resolved_axis, + temp_index, temp_sum)) { + return false; + } + + // Calculate mean by dividing output_data by num of aggregated element. + U num_elements_in_axis = 1; + for (int idx = 0; idx < num_resolved_axis; ++idx) { + size_t current = static_cast(input_dims[resolved_axis[idx]]); + // Overflow prevention. + if (current > (std::numeric_limits::max() / num_elements_in_axis)) { + return false; + } + num_elements_in_axis *= current; + } + + if (num_elements_in_axis > 0) { + const float scale = input_scale / output_scale; + const float bias = -input_zero_point * scale; + for (size_t idx = 0; idx < num_outputs; ++idx) { + float float_mean = static_cast(temp_sum[idx]) / + static_cast(num_elements_in_axis); + + // Convert to float value. + output_data[idx] = + static_cast(round(float_mean * scale + bias)) + output_zero_point; + } + } + return true; +} + template void Minimum(const RuntimeShape& input1_shape, const T* input1_data, const T* input2_data, const RuntimeShape& output_shape, @@ -4697,6 +5024,21 @@ inline void BroadcastBinaryFunction(const T1* input1_data, DimsToShape(output_dims), output_data, func); } +// Legacy Dims<4> version. +// +// R: Result type. T1: Input 1 type. T2: Input 2 type. +// TODO(renjieliu): Refactor other binary functions to use this one. +template +inline void BinaryFunction(const T1* input1_data, const Dims<4>& input1_dims, + const T2* input2_data, const Dims<4>& input2_dims, + R* output_data, const Dims<4>& output_dims, + R (*func)(T1, T2)) { + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = func(input1_data[i], input2_data[i]); + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 204df9ab19a1e69c054bc8bd36efb0d81f9cd754..8e17eaa964a8b76367786352717446142326243c 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -668,9 +668,9 @@ static_assert(sizeof(MinMax) == 8, ""); struct ActivationParams { FusedActivationFunctionType activation_type; - // Quantized inference params. - int32 activation_min; - int32 activation_max; + // uint8, etc, activation params. + int32 quantized_activation_min; + int32 quantized_activation_max; }; // For Add, Sub, Mul ops. @@ -745,7 +745,7 @@ struct ConvParams { }; struct DepthToSpaceParams { - int16 block_size; + int32 block_size; }; struct DepthwiseParams { @@ -871,8 +871,13 @@ struct SoftmaxParams { int diff_min; }; +struct SpaceToBatchParams { + // "Zero" padding for uint8 means padding with the output offset. + int32 output_offset; +}; + struct SpaceToDepthParams { - int16 block_size; + int32 block_size; }; struct SplitParams { @@ -908,23 +913,30 @@ struct TanhParams { int input_left_shift; }; -template -inline void SetActivationParams(T min, T max, ArithmeticParams* params); - -template <> -inline void SetActivationParams(float min, float max, - ArithmeticParams* params) { +template +inline void SetActivationParams(float min, float max, P* params) { params->float_activation_min = min; params->float_activation_max = max; } -template <> -inline void SetActivationParams(int32 min, int32 max, - ArithmeticParams* params) { +template +inline void SetActivationParams(int32 min, int32 max, P* params) { params->quantized_activation_min = min; params->quantized_activation_max = max; } +template +inline void GetActivationParams(const P& params, int32* min, int32* max) { + *min = params.quantized_activation_min; + *max = params.quantized_activation_max; +} + +template +inline void GetActivationParams(const P& params, float* min, float* max) { + *min = params.float_activation_min; + *max = params.float_activation_max; +} + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index ba251c451e549a09d265fc43fed7dc7eb6896d61..74dc3f25f96c8f302e85bb9cac5482fab1c5c4f6 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -37,7 +37,7 @@ namespace builtin { namespace lstm { struct OpData { - // Which kernel type to use. Full kernel (18 or 20 inputs) or basic kernel + // Which kernel type to use. Full kernel (20 inputs) or basic kernel // (5 inputs). TfLiteLSTMKernelType kernel_type; @@ -47,7 +47,7 @@ struct OpData { int scratch_tensor_index; }; -// For full inputs kernel (18 or 20 inputs). +// For full inputs kernel (20-inputs). namespace full { // Input Tensors of size {n_batch, n_input} @@ -81,19 +81,13 @@ constexpr int kProjectionWeightsTensor = 16; // Optional // Projection bias tensor of size {n_output} constexpr int kProjectionBiasTensor = 17; // Optional -// If the node has 20 inputs, the following 2 tensors are used as state tensors. -// These are defined as variable tensors, and will be modified by this op. +// These state tensors are defined as variable tensors, and will be modified by +// this op. constexpr int kInputActivationStateTensor = 18; constexpr int kInputCellStateTensor = 19; // Output tensors. -// * If the node has 18 inputs, these 2 tensors are used as state tensors. -// * If the node has 20 inputs, these 2 tensors are ignored. -// TODO(ycling): Make the 2 output state tensors optional, and propagate the -// state to output tensors when the 2 tensors present. -constexpr int kOutputStateTensor = 0; -constexpr int kCellStateTensor = 1; -constexpr int kOutputTensor = 2; +constexpr int kOutputTensor = 0; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* op_data = new OpData(); @@ -258,30 +252,12 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { OpData* op_data = reinterpret_cast(node->user_data); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 3); - - // True if the node is using input variable state tensors. It means: - // * The state tensors are defined as inputs. In this case it would be the - // 19th and 20th input tensors. - // * Otherwise, the output tensors are used to store states. - bool use_input_variable_states; - if (node->inputs->size == 20) { - use_input_variable_states = true; - op_data->activation_state_tensor_index = - node->inputs->data[kInputActivationStateTensor]; - op_data->cell_state_tensor_index = - node->inputs->data[kInputCellStateTensor]; - } else if (node->inputs->size == 18) { - use_input_variable_states = false; - op_data->activation_state_tensor_index = - node->outputs->data[kOutputStateTensor]; - op_data->cell_state_tensor_index = node->outputs->data[kCellStateTensor]; - } else { - context->ReportError( - context, "The LSTM Full kernel expects 18 or 20 inputs. Got %d inputs", - node->inputs->size); - return kTfLiteError; - } + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + TF_LITE_ENSURE_EQ(context, node->inputs->size, 20); + + op_data->activation_state_tensor_index = + node->inputs->data[kInputActivationStateTensor]; + op_data->cell_state_tensor_index = node->inputs->data[kInputCellStateTensor]; // Inferring batch size, number of outputs and number of cells from the // input tensors. @@ -316,31 +292,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* cell_state = &context->tensors[op_data->cell_state_tensor_index]; - if (use_input_variable_states) { - // Check the shape of input state tensors. - // These tensor may be 1D or 2D. It's fine as long as the total size is - // correct. - TF_LITE_ENSURE_EQ(context, NumElements(activation_state), - n_batch * n_output); - TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); - } else { - // If the state tensors are outputs, this function takes the - // responsibility to resize the state tensors. - TfLiteIntArray* activation_state_size = TfLiteIntArrayCreate(2); - activation_state_size->data[0] = n_batch; - activation_state_size->data[1] = n_output; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state, - activation_state_size)); - - TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); - cell_size->data[0] = n_batch; - cell_size->data[1] = n_cell; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, cell_state, cell_size)); - // Mark state tensors as persistent tensors. - activation_state->allocation_type = kTfLiteArenaRwPersistent; - cell_state->allocation_type = kTfLiteArenaRwPersistent; - } + // Check the shape of input state tensors. + // These tensor may be 1D or 2D. It's fine as long as the total size is + // correct. + TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output); + TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); // Resize the output tensors. TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index 0266f5fe57e6c60ea19ad5f8de05e879e7da9304..e7ddfceb4527c4c32cece224e9b155db4ff0ea4f 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -106,14 +106,13 @@ class LSTMOpModel : public SingleOpModel { input_cell_state_ = AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); - output_state_ = AddOutput(TensorType_FLOAT32); - cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, CreateLSTMOptions(builder_, ActivationFunctionType_TANH, cell_clip, proj_clip) .Union()); + BuildInterpreter(input_shapes); } @@ -185,22 +184,6 @@ class LSTMOpModel : public SingleOpModel { PopulateTensor(projection_bias_, f); } - void ResetOutputState() { - const int zero_buffer_size = n_cell_ * n_batch_; - std::unique_ptr zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(output_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - - void ResetCellState() { - const int zero_buffer_size = n_cell_ * n_batch_; - std::unique_ptr zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(cell_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - void SetInput(int offset, const float* begin, const float* end) { PopulateTensor(input_, offset, const_cast(begin), const_cast(end)); @@ -469,10 +452,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } @@ -529,10 +508,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.0157651); } @@ -637,10 +612,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { lstm.SetCellToForgetWeights(cell_to_forget_weights_); lstm.SetCellToOutputWeights(cell_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } @@ -698,14 +669,10 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { lstm.SetCellToForgetWeights(cell_to_forget_weights_); lstm.SetCellToOutputWeights(cell_to_output_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); } -class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { +class NoCifgPeepholeProjectionNoClippingLstmTest : public BaseLstmTest { void SetUp() override { input_to_input_weights_ = { 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, @@ -1304,7 +1271,7 @@ class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { } }; -TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { +TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -1362,14 +1329,10 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { lstm.SetProjectionWeights(projection_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); } -TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { +TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -1428,10 +1391,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { lstm.SetProjectionWeights(projection_weights_); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } diff --git a/tensorflow/contrib/lite/kernels/mfcc.cc b/tensorflow/contrib/lite/kernels/mfcc.cc index 3f5bc4d68a57daa8423953f591ac139dc55eacb9..306f67661987dfa7def1b7e8d3abdb993e47b220 100644 --- a/tensorflow/contrib/lite/kernels/mfcc.cc +++ b/tensorflow/contrib/lite/kernels/mfcc.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/internal/mfcc.h" -#include "flatbuffers/flexbuffers.h" +#include "flatbuffers/flexbuffers.h" // flatbuffers #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h" diff --git a/tensorflow/contrib/lite/kernels/mfcc_test.cc b/tensorflow/contrib/lite/kernels/mfcc_test.cc index 0291ca8c1c58ea6ab3bb7c22bc436ed3404cba74..c9124adcafac009f93aabdb61bcfee829178e418 100644 --- a/tensorflow/contrib/lite/kernels/mfcc_test.cc +++ b/tensorflow/contrib/lite/kernels/mfcc_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "flatbuffers/flexbuffers.h" +#include "flatbuffers/flexbuffers.h" // flatbuffers #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc index 1c728a473326564a85a5e7d3d72718265979e29a..90a915bb023b2b3db86e8334e93e2f1d41e0a9f2 100644 --- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -101,8 +101,6 @@ class LSTMOpModel : public SingleOpModel { input_cell_state_ = AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); - output_state_ = AddOutput(TensorType_FLOAT32); - cell_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, @@ -180,22 +178,6 @@ class LSTMOpModel : public SingleOpModel { PopulateTensor(projection_bias_, f); } - void ResetOutputState() { - const int zero_buffer_size = n_cell_ * n_batch_; - std::unique_ptr zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(output_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - - void ResetCellState() { - const int zero_buffer_size = n_cell_ * n_batch_; - std::unique_ptr zero_buffer(new float[zero_buffer_size]); - memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); - PopulateTensor(cell_state_, 0, zero_buffer.get(), - zero_buffer.get() + zero_buffer_size); - } - void SetInput(int offset, float* begin, float* end) { PopulateTensor(input_, offset, begin, end); } @@ -238,8 +220,6 @@ class LSTMOpModel : public SingleOpModel { int input_cell_state_; int output_; - int output_state_; - int cell_state_; int n_batch_; int n_input_; @@ -324,10 +304,6 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { lstm.SetCellToOutputWeights( {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); - // Resetting cell_state and output_state - lstm.ResetCellState(); - lstm.ResetOutputState(); - // Verify the model by unpacking it. lstm.Verify(); } diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc index 29374a0c27b2ce038dbefea126ed8a1cbf85c490..4001cf357f151ab486dba900b4003b2507ce99d1 100644 --- a/tensorflow/contrib/lite/kernels/reduce.cc +++ b/tensorflow/contrib/lite/kernels/reduce.cc @@ -177,6 +177,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, case kTfLiteUInt8: temp_sum->type = kTfLiteInt32; break; + case kTfLiteBool: + temp_sum->type = kTfLiteBool; + break; default: return kTfLiteError; } @@ -204,6 +207,13 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + const TfLiteTensor* input = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteBool); + return PrepareSimple(context, node); +} + TfLiteStatus PrepareMean(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, PrepareSimple(context, node)); @@ -256,11 +266,27 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, int64_t, int64_t)); break; case kTfLiteUInt8: - TF_LITE_ENSURE_EQ(context, op_context.input->params.scale, - op_context.output->params.scale); - TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, - op_context.output->params.zero_point); - TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, uint8_t, int)); + if (op_context.input->params.zero_point == + op_context.output->params.zero_point && + op_context.input->params.scale == op_context.output->params.scale) { + TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, uint8_t, int)); + } else { + TF_LITE_ENSURE( + context, + reference_ops::Mean<>( + GetTensorData(op_context.input), + op_context.input->params.zero_point, + op_context.input->params.scale, op_context.input->dims->data, + op_context.input->dims->size, + GetTensorData(op_context.output), + op_context.output->params.zero_point, + op_context.output->params.scale, + op_context.output->dims->data, op_context.output->dims->size, + GetTensorData(op_context.axis), num_axis, + op_context.params->keep_dims, GetTensorData(temp_index), + GetTensorData(resolved_axis), + GetTensorData(temp_sum))); + } break; default: return kTfLiteError; @@ -460,6 +486,31 @@ TfLiteStatus EvalMin(TfLiteContext* context, TfLiteNode* node) { #undef TF_LITE_MIN return kTfLiteOk; } + +template +TfLiteStatus EvalAny(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + int64_t num_axis = NumElements(op_context.axis); + TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); + TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + if (kernel_type == kReference) { + reference_ops::ReduceAny( + GetTensorData(op_context.input), op_context.input->dims->data, + op_context.input->dims->size, GetTensorData(op_context.output), + op_context.output->dims->data, op_context.output->dims->size, + GetTensorData(op_context.axis), num_axis, + op_context.params->keep_dims, GetTensorData(temp_index), + GetTensorData(resolved_axis)); + } + + return kTfLiteOk; +} } // namespace reduce TfLiteRegistration* Register_MEAN_REF() { @@ -497,6 +548,12 @@ TfLiteRegistration* Register_REDUCE_MIN_REF() { return &r; } +TfLiteRegistration* Register_REDUCE_ANY_REF() { + static TfLiteRegistration r = {reduce::Init, reduce::Free, reduce::PrepareAny, + reduce::EvalAny}; + return &r; +} + // TODO(kanlig): add optimized implementation of Mean. TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); } TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); } @@ -505,6 +562,7 @@ TfLiteRegistration* Register_REDUCE_PROD() { } TfLiteRegistration* Register_REDUCE_MAX() { return Register_REDUCE_MAX_REF(); } TfLiteRegistration* Register_REDUCE_MIN() { return Register_REDUCE_MIN_REF(); } +TfLiteRegistration* Register_REDUCE_ANY() { return Register_REDUCE_ANY_REF(); } } // namespace builtin } // namespace ops diff --git a/tensorflow/contrib/lite/kernels/reduce_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc index d9aca64356b85f9720da6827e7130ce5c29d496c..6d289b14d8964c1265daf3202b951a5aade54457 100644 --- a/tensorflow/contrib/lite/kernels/reduce_test.cc +++ b/tensorflow/contrib/lite/kernels/reduce_test.cc @@ -198,6 +198,35 @@ class MinOpDynamicModel : public BaseOpModel { } }; +// Model for the tests case where axis is a const tensor. +class AnyOpConstModel : public BaseOpModel { + public: + AnyOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_REDUCE_ANY, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Model for the tests case where axis is a dynamic tensor. +class AnyOpDynamicModel : public BaseOpModel { + public: + AnyOpDynamicModel(const TensorData& input, const TensorData& output, + const TensorData& axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddInput(axis); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_REDUCE_ANY, BuiltinOptions_ReducerOptions, + CreateReducerOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + // for quantized Add, the error shouldn't exceed step float GetTolerance(int min, int max) { return (max - min) / 255.0; } @@ -338,6 +367,33 @@ TEST(DynamicUint8MeanOpTest, KeepDims) { ElementsAreArray(ArrayFloatNear({9.2815, 0.3695}, kQuantizedTolerance))); } +TEST(DynamicUint8MeanOpTest, QuantizedScalar) { + float kQuantizedTolerance = GetTolerance(-10.0, 12.0); + std::vector data = {0.643}; + MeanOpDynamicModel m({TensorType_UINT8, {}, 0.0, 1.0}, + {TensorType_UINT8, {}, -10.0, 12.0}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({0.643}, kQuantizedTolerance))); +} + +TEST(ConstUint8MeanOpTest, QuantizedKeepDims) { + float kQuantizedTolerance = GetTolerance(-5.0, 5.0); + std::vector data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6}; + MeanOpConstModel m({TensorType_UINT8, {3, 2}, 0.0, 1.0}, + {TensorType_UINT8, {3}, -5.0, 5.0}, {1}, {1}, true); + m.QuantizeAndPopulate(m.Input(), data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1})); + EXPECT_THAT( + m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance))); +} + // Tests for reduce_sum TEST(ConstFloatSumOpTest, NotKeepDims) { @@ -751,7 +807,7 @@ TEST(DynamicFloatMinOpTest, KeepDims) { ElementsAreArray(ArrayFloatNear({1, 3, 5}))); } -TEST(DynamicFloatMinOpTest, Scale) { +TEST(DynamicFloatMinOpTest, Scalar) { std::vector data = {9.527}; MinOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, {TensorType_INT32, {1}}, true); @@ -835,6 +891,68 @@ TEST(DynamicUint8MinOpTest, Scalar) { ElementsAreArray(ArrayFloatNear({11.1294}, kQuantizedTolerance))); } +// Tests for reduce_any + +TEST(ConstAnyOpTest, NotKeepDims) { + std::vector data = {false, false, false, false, false, false, + false, true, false, false, false, true}; + AnyOpConstModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {2}}, {4}, + {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({false, true})); +} + +TEST(ConstAnyOpTest, KeepDims) { + std::vector data = {false, false, false, false, false, false, + false, true, false, false, false, true}; + AnyOpConstModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {3}}, {2}, + {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({true, false, true})); +} + +TEST(DynamicAnyOpTest, NotKeepDims) { + std::vector data = {false, false, false, false, false, false, + false, true, false, false, false, true}; + AnyOpDynamicModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {2}}, + {TensorType_INT32, {4}}, false); + std::vector axis = {1, 0, -3, -3}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({false, true})); +} + +TEST(DynamicAnyOpTest, KeepDims) { + std::vector data = {false, false, false, false, false, false, + false, true, false, false, false, true}; + AnyOpDynamicModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {3}}, + {TensorType_INT32, {2}}, true); + std::vector axis = {0, 2}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({true, false, true})); +} + +TEST(DynamicAnyOpTest, Scalar) { + std::vector data = {false}; + AnyOpDynamicModel m({TensorType_BOOL, {1}}, {TensorType_BOOL, {1}}, + {TensorType_INT32, {1}}, true); + std::vector axis = {0}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({false})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 10d1fcc5bc3e759816ccf8e2c306c650f1232afc..7b859dc3323b1ab52a0b556754f214e6cabc73d4 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -95,6 +95,7 @@ TfLiteRegistration* Register_SUM(); TfLiteRegistration* Register_REDUCE_PROD(); TfLiteRegistration* Register_REDUCE_MAX(); TfLiteRegistration* Register_REDUCE_MIN(); +TfLiteRegistration* Register_REDUCE_ANY(); TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SLICE(); TfLiteRegistration* Register_SIN(); @@ -113,6 +114,8 @@ TfLiteRegistration* Register_ONE_HOT(); TfLiteRegistration* Register_LOGICAL_OR(); TfLiteRegistration* Register_LOGICAL_AND(); TfLiteRegistration* Register_LOGICAL_NOT(); +TfLiteRegistration* Register_UNPACK(); +TfLiteRegistration* Register_FLOOR_DIV(); TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) { context->ReportError( @@ -221,6 +224,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_REDUCE_PROD, Register_REDUCE_PROD()); AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX()); AddBuiltin(BuiltinOperator_REDUCE_MIN, Register_REDUCE_MIN()); + AddBuiltin(BuiltinOperator_REDUCE_ANY, Register_REDUCE_ANY()); AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS()); AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE()); AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL()); @@ -235,6 +239,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR()); AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND()); AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT()); + AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK()); + AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc index 9e8ed3cbf32cb041cce725422187182adf258db2..6ba7959752ff7aa16b28c497b58876f5eb748cc4 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -105,16 +105,11 @@ constexpr int kInputTensor = 0; constexpr int kWeightsFeatureTensor = 1; constexpr int kWeightsTimeTensor = 2; constexpr int kBiasTensor = 3; - -// * If the node has 5 inputs the following tensor is used as state tensor. -// This is defined to be a variable tensor, and will be modified by this op. +// This is a variable tensor, and will be modified by this op. constexpr int kInputActivationStateTensor = 4; -// Output tensors. -// * If node has 4 inputs, kStateTensor will be used as state tensor. -// * If node has 5 inputs, kStateTensor is ignored. -constexpr int kStateTensor = 0; -constexpr int kOutputTensor = 1; +// Output tensor. +constexpr int kOutputTensor = 0; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* op_data = new OpData(); @@ -134,21 +129,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { int scratch_tensor_index = op_data->scratch_tensor_index; // Check we have all the inputs and outputs we need. - TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); - bool use_input_variable_states; - if (node->inputs->size == 5) { - use_input_variable_states = true; - op_data->activation_state_tensor_index = - node->inputs->data[kInputActivationStateTensor]; - } else if (node->inputs->size == 4) { - use_input_variable_states = false; - op_data->activation_state_tensor_index = node->outputs->data[kStateTensor]; - } else { - context->ReportError(context, - "The SVDF kernel expects 4 or 5 inputs. Got %d inputs", - node->inputs->size); - return kTfLiteError; - } + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); + op_data->activation_state_tensor_index = + node->inputs->data[kInputActivationStateTensor]; const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* weights_feature = @@ -178,28 +162,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { &context->tensors[op_data->activation_state_tensor_index]; TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - if (use_input_variable_states) { - // Check the shape of input state tensors. - TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2); - TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 0), - batch_size); - TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 1), - memory_size * num_filters); - } else { - // Resize activation_state. - // For each batch, the state is a 2-D tensor: memory_size * num_filters - // The left most column is used to save current cycle activation. - // The right most column is used to save temporary output which will be - // reduced to num_units outputs. - TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2); - state_size_array->data[0] = batch_size; - state_size_array->data[1] = memory_size * num_filters; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state, - state_size_array)); - - // Mark state as a persistent tensor. - activation_state->allocation_type = kTfLiteArenaRwPersistent; - } + // Check the shape of input state tensors. + TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 0), batch_size); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 1), + memory_size * num_filters); // Resize output. TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc index e485938343a073c4a290838b0bb36c72671a1f4b..6d60dc63f401144a5eda84d9f88992ce1f9ee47e 100644 --- a/tensorflow/contrib/lite/kernels/svdf_test.cc +++ b/tensorflow/contrib/lite/kernels/svdf_test.cc @@ -145,7 +145,6 @@ class BaseSVDFOpModel : public SingleOpModel { activation_state_ = AddInput( TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}}, /*is_variable=*/true); - state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp( BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, @@ -187,7 +186,6 @@ class BaseSVDFOpModel : public SingleOpModel { int weights_time_; int bias_; int activation_state_; - int state_; int output_; int batches_; diff --git a/tensorflow/contrib/lite/kernels/unpack.cc b/tensorflow/contrib/lite/kernels/unpack.cc new file mode 100644 index 0000000000000000000000000000000000000000..4998f88b41fd6b46f14d9342aca7c2ce2fb7fa68 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/unpack.cc @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace unpack { +namespace { + +constexpr int kInputTensor = 0; + +// Op data for unpack op. +struct OpData { + int num; + int axis; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->axis = 0; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const OpData* data = reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE(context, NumDimensions(input) <= 4); + TF_LITE_ENSURE(context, NumDimensions(input) > 1); + TF_LITE_ENSURE(context, NumDimensions(input) > data->axis); + // TODO(renjieliu): Support negative axis. + TF_LITE_ENSURE(context, data->axis >= 0); + if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32) { + context->ReportError(context, + "Currently pack only supports int32 and float32."); + return kTfLiteError; + } + + const TfLiteIntArray* input_shape = input->dims; + // Num should be equal to the shape[axis]. + // Resize outputs. rank will be R - 1. + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) - 1); + int o = 0; + for (int index = 0; index < NumDimensions(input); ++index) { + if (index != data->axis) { + output_shape->data[o++] = input_shape->data[index]; + } + } + + TF_LITE_ENSURE_EQ(context, data->num, input_shape->data[data->axis]); + for (int i = 0; i < data->num; ++i) { + TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape); + TfLiteTensor* output = GetOutput(context, node, i); + TF_LITE_ENSURE_EQ(context, output->type, input->type); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, output, copied_output_shape)); + } + + TfLiteIntArrayFree(output_shape); + return kTfLiteOk; +} + +template +void UnpackImpl(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* input, int output_count, int axis) { + VectorOfTensors all_outputs(*context, *node->outputs); + reference_ops::Unpack(axis, GetTensorData(input), GetTensorDims(input), + NumDimensions(input), output_count, + all_outputs.data(), **all_outputs.dims()); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const OpData* data = reinterpret_cast(node->builtin_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + switch (input->type) { + case kTfLiteFloat32: { + UnpackImpl(context, node, input, data->num, data->axis); + break; + } + case kTfLiteInt32: { + UnpackImpl(context, node, input, data->num, data->axis); + break; + } + default: { + context->ReportError(context, + "Currently pack only supports int32 and float32."); + return kTfLiteError; + } + } + + return kTfLiteOk; +} +} // namespace +} // namespace unpack + +TfLiteRegistration* Register_UNPACK() { + static TfLiteRegistration r = {unpack::Init, unpack::Free, unpack::Prepare, + unpack::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/unpack_test.cc b/tensorflow/contrib/lite/kernels/unpack_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4efc92a0fdd68082164c5788f99226f81717f91c --- /dev/null +++ b/tensorflow/contrib/lite/kernels/unpack_test.cc @@ -0,0 +1,225 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +template +class UnpackOpModel : public SingleOpModel { + public: + UnpackOpModel(const TensorData& input, int axis) { + CHECK_LE(axis, input.shape.size()); + const int num_outputs = input.shape[axis]; + input_ = AddInput(input); + for (int i = 0; i < num_outputs; ++i) { + outputs_.push_back(AddOutput(input.type)); + } + SetBuiltinOp(BuiltinOperator_UNPACK, BuiltinOptions_UnpackOptions, + CreatePackOptions(builder_, num_outputs, axis).Union()); + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector> GetOutputDatas() { + std::vector> output_datas; + for (const int output : outputs_) { + std::cerr << "the output is " << output << std::endl; + output_datas.push_back(ExtractVector(output)); + } + return output_datas; + } + + std::vector> GetOutputShapes() { + std::vector> output_shapes; + for (const int output : outputs_) { + output_shapes.push_back(GetTensorShape(output)); + } + return output_shapes; + } + + private: + int input_; + std::vector outputs_; +}; + +// float32 tests. +TEST(UnpackOpTest, FloatThreeOutputs) { + UnpackOpModel model({TensorType_FLOAT32, {3, 2}}, 0); + model.SetInput({1, 2, 3, 4, 5, 6}); + model.Invoke(); + + // Check outputs shapes. + const std::vector>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 3); + EXPECT_THAT(output_shapes[0], ElementsAre(2)); + EXPECT_THAT(output_shapes[1], ElementsAre(2)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + + // Check outputs values. + const std::vector>& output_datas = model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 3); + EXPECT_THAT(output_datas[0], ElementsAre(1, 2)); + EXPECT_THAT(output_datas[1], ElementsAre(3, 4)); + EXPECT_THAT(output_datas[2], ElementsAre(5, 6)); +} + +TEST(UnpackOpTest, FloatThreeOutputsAxisOne) { + UnpackOpModel model({TensorType_FLOAT32, {3, 2}}, 1); + model.SetInput({1, 2, 3, 4, 5, 6}); + model.Invoke(); + + // Check outputs shapes. + const std::vector>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 2); + EXPECT_THAT(output_shapes[0], ElementsAre(3)); + EXPECT_THAT(output_shapes[1], ElementsAre(3)); + + // Check outputs values. + const std::vector>& output_datas = model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 2); + EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5)); + EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6)); +} + +TEST(UnpackOpTest, FloatOneOutput) { + UnpackOpModel model({TensorType_FLOAT32, {1, 6}}, 0); + model.SetInput({1, 2, 3, 4, 5, 6}); + model.Invoke(); + + // Check outputs shapes. + const std::vector>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 1); + EXPECT_THAT(output_shapes[0], ElementsAre(6)); + + // Check outputs values. + const std::vector>& output_datas = model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 1); + EXPECT_THAT(output_datas[0], ElementsAre(1, 2, 3, 4, 5, 6)); +} + +TEST(UnpackOpTest, FloatThreeDimensionsOutputs) { + UnpackOpModel model({TensorType_FLOAT32, {2, 2, 2}}, 2); + model.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + model.Invoke(); + + // Check outputs shapes. + const std::vector>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 2); + EXPECT_THAT(output_shapes[0], ElementsAre(2, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(2, 2)); + + // Check outputs values. + const std::vector>& output_datas = model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 2); + EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5, 7)); + EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6, 8)); +} + +// int32 tests. +TEST(UnpackOpTest, IntThreeOutputs) { + UnpackOpModel model({TensorType_INT32, {3, 2}}, 0); + model.SetInput({1, 2, 3, 4, 5, 6}); + model.Invoke(); + + // Check outputs shapes. + const std::vector>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 3); + EXPECT_THAT(output_shapes[0], ElementsAre(2)); + EXPECT_THAT(output_shapes[1], ElementsAre(2)); + EXPECT_THAT(output_shapes[2], ElementsAre(2)); + + // Check outputs values. + const std::vector>& output_datas = + model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 3); + EXPECT_THAT(output_datas[0], ElementsAre(1, 2)); + EXPECT_THAT(output_datas[1], ElementsAre(3, 4)); + EXPECT_THAT(output_datas[2], ElementsAre(5, 6)); +} + +TEST(UnpackOpTest, IntThreeOutputsAxisOne) { + UnpackOpModel model({TensorType_INT32, {3, 2}}, 1); + model.SetInput({1, 2, 3, 4, 5, 6}); + model.Invoke(); + + // Check outputs shapes. + const std::vector>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 2); + EXPECT_THAT(output_shapes[0], ElementsAre(3)); + EXPECT_THAT(output_shapes[1], ElementsAre(3)); + + // Check outputs values. + const std::vector>& output_datas = + model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 2); + EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5)); + EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6)); +} + +TEST(UnpackOpTest, IntOneOutput) { + UnpackOpModel model({TensorType_INT32, {1, 6}}, 0); + model.SetInput({1, 2, 3, 4, 5, 6}); + model.Invoke(); + + // Check outputs shapes. + const std::vector>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 1); + EXPECT_THAT(output_shapes[0], ElementsAre(6)); + + // Check outputs values. + const std::vector>& output_datas = + model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 1); + EXPECT_THAT(output_datas[0], ElementsAre(1, 2, 3, 4, 5, 6)); +} + +TEST(UnpackOpTest, IntThreeDimensionsOutputs) { + UnpackOpModel model({TensorType_INT32, {2, 2, 2}}, 2); + model.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + model.Invoke(); + + // Check outputs shapes. + const std::vector>& output_shapes = model.GetOutputShapes(); + EXPECT_EQ(output_shapes.size(), 2); + EXPECT_THAT(output_shapes[0], ElementsAre(2, 2)); + EXPECT_THAT(output_shapes[1], ElementsAre(2, 2)); + + // Check outputs values. + const std::vector>& output_datas = + model.GetOutputDatas(); + EXPECT_EQ(output_datas.size(), 2); + EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5, 7)); + EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6, 8)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh index b58ae266017caf8781c28331f49a8f5bc1550767..6195426d6d441e858fbe225c132b409ac0a0be32 100755 --- a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh +++ b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh @@ -14,6 +14,7 @@ # limitations under the License. # ============================================================================== +# TODO(ycling): Refactoring - Move this script into `tools/make`. set -e echo "Starting" @@ -32,7 +33,7 @@ echo "Headers, populating: TensorFlow Lite" cd $TFLITE_DIR/../../.. find tensorflow/contrib/lite -name '*.h' \ - -not -path 'tensorflow/contrib/lite/downloads/*' \ + -not -path 'tensorflow/contrib/lite/tools/*' \ -not -path 'tensorflow/contrib/lite/examples/*' \ -not -path 'tensorflow/contrib/lite/gen/*' \ -not -path 'tensorflow/contrib/lite/toco/*' \ @@ -44,7 +45,7 @@ tar xf tmp.tar rm -f tmp.tar echo "Headers, populating: Flatbuffer" -cd $TFLITE_DIR/downloads/flatbuffers/include/ +cd $TFLITE_DIR/tools/make/downloads/flatbuffers/include/ find . -name '*.h' | tar -cf $FW_DIR_TFLITE_HDRS/tmp.tar -T - cd $FW_DIR_TFLITE_HDRS tar xf tmp.tar @@ -57,7 +58,7 @@ cp $TFLITE_DIR/../../../bazel-genfiles/tensorflow/tools/lib_package/include/tens $FW_DIR_TFLITE echo "Copying static libraries" -cp $TFLITE_DIR/gen/lib/libtensorflow-lite.a \ +cp $TFLITE_DIR/tools/make/gen/lib/libtensorflow-lite.a \ $FW_DIR_TFLITE/tensorflow_lite # This is required, otherwise they interfere with the documentation of the diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 5f8d5c318a38ca03aad79ea8c20fd8cf47081a82..aa410ab002c15596cc7535f55a177735a2a9bd99 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -624,7 +624,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_REDUCE_MAX: case BuiltinOperator_REDUCE_MIN: case BuiltinOperator_REDUCE_PROD: - case BuiltinOperator_SUM: { + case BuiltinOperator_SUM: + case BuiltinOperator_REDUCE_ANY: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ReducerOptions()) { params->keep_dims = schema_params->keep_dims(); @@ -745,6 +746,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = static_cast(params); break; } + case BuiltinOperator_UNPACK: { + TfLiteUnpackParams* params = MallocPOD(); + if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) { + params->num = unpack_params->num(); + params->axis = unpack_params->axis(); + } + *builtin_data = reinterpret_cast(params); + break; + } // Below are the ops with no builtin_data strcture. case BuiltinOperator_BATCH_TO_SPACE_ND: @@ -790,7 +800,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_LOGICAL_OR: case BuiltinOperator_LOGICAL_AND: case BuiltinOperator_LOGICAL_NOT: - case BuiltinOperator_UNPACK: + case BuiltinOperator_FLOOR_DIV: break; } return kTfLiteOk; @@ -802,6 +812,10 @@ TfLiteStatus InterpreterBuilder::ParseNodes( const flatbuffers::Vector>* operators, Interpreter* interpreter) { TfLiteStatus status = kTfLiteOk; + + // Reduce the number of redundant allocations + interpreter->ReserveNodes(operators->Length()); + for (int i = 0; i < operators->Length(); ++i) { const auto* op = operators->Get(i); int index = op->opcode_index(); diff --git a/tensorflow/contrib/lite/models/speech_test.cc b/tensorflow/contrib/lite/models/speech_test.cc index 206de1962d196400d2a58162c5ef692e2091e8d4..8ecf0b6154a622fa355c060ba7f2d61e6c670de2 100644 --- a/tensorflow/contrib/lite/models/speech_test.cc +++ b/tensorflow/contrib/lite/models/speech_test.cc @@ -102,7 +102,7 @@ class SpeechTest : public ::testing::TestWithParam { int GetMaxInvocations() { return GetParam(); } }; -TEST_P(SpeechTest, HotwordOkGoogleRank1Test) { +TEST_P(SpeechTest, DISABLED_HotwordOkGoogleRank1Test) { std::stringstream os; ASSERT_TRUE(ConvertCsvData( "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv", @@ -114,7 +114,7 @@ TEST_P(SpeechTest, HotwordOkGoogleRank1Test) { << test_driver.GetErrorMessage(); } -TEST_P(SpeechTest, HotwordOkGoogleRank2Test) { +TEST_P(SpeechTest, DISABLED_HotwordOkGoogleRank2Test) { std::stringstream os; ASSERT_TRUE(ConvertCsvData( "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv", @@ -126,7 +126,7 @@ TEST_P(SpeechTest, HotwordOkGoogleRank2Test) { << test_driver.GetErrorMessage(); } -TEST_P(SpeechTest, SpeakerIdOkGoogleTest) { +TEST_P(SpeechTest, DISABLED_SpeakerIdOkGoogleTest) { std::stringstream os; ASSERT_TRUE(ConvertCsvData( "speech_speakerid_model.tflite", "speech_speakerid_model_in.csv", @@ -139,7 +139,7 @@ TEST_P(SpeechTest, SpeakerIdOkGoogleTest) { << test_driver.GetErrorMessage(); } -TEST_P(SpeechTest, AsrAmTest) { +TEST_P(SpeechTest, DISABLED_AsrAmTest) { std::stringstream os; ASSERT_TRUE( ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv", @@ -156,7 +156,7 @@ TEST_P(SpeechTest, AsrAmTest) { // through the interpreter and stored the sum of all the output, which was them // compared for correctness. In this test we are comparing all the intermediate // results. -TEST_P(SpeechTest, AsrLmTest) { +TEST_P(SpeechTest, DISABLED_AsrLmTest) { std::ifstream in_file; testing::TfLiteDriver test_driver(/*use_nnapi=*/false); ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file)); @@ -165,7 +165,7 @@ TEST_P(SpeechTest, AsrLmTest) { << test_driver.GetErrorMessage(); } -TEST_P(SpeechTest, EndpointerTest) { +TEST_P(SpeechTest, DISABLED_EndpointerTest) { std::stringstream os; ASSERT_TRUE(ConvertCsvData( "speech_endpointer_model.tflite", "speech_endpointer_model_in.csv", @@ -178,7 +178,7 @@ TEST_P(SpeechTest, EndpointerTest) { << test_driver.GetErrorMessage(); } -TEST_P(SpeechTest, TtsTest) { +TEST_P(SpeechTest, DISABLED_TtsTest) { std::stringstream os; ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite", "speech_tts_model_in.csv", diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index d287aa635cc1320ee390f0e5666f504448b0a546..38f3e9881bc0e773765fc650fa92a9fef66cb862 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -649,6 +649,8 @@ TfLiteStatus AddOpsAndParams( case tflite::BuiltinOperator_LOGICAL_AND: case tflite::BuiltinOperator_LOGICAL_NOT: case tflite::BuiltinOperator_UNPACK: + case tflite::BuiltinOperator_FLOOR_DIV: + case tflite::BuiltinOperator_REDUCE_ANY: logError("Op code %d is currently not delegated to NNAPI", builtin); return kTfLiteError; break; diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index 7d7a4ba94a4d026e038bebc29cfa54b8e5d4aa1d..46bdb3e55336b3a0c55009da060cb173d38f4ce8 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -312,7 +312,7 @@ def run_main(_): "quantization via \"dummy quantization\". (default None)")) parser.add_argument( "--quantize_weights", - type=bool, + action="store_true", help=("Store float weights as quantized weights followed by dequantize " "operations. Inference is still done in FLOAT, but reduces model " "size (at the cost of accuracy and latency).")) diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD index b616e449e6ddae6467a6b86269cd108c7eec0c26..28a7e5000349b63844df472da3baafd3e6c71450 100644 --- a/tensorflow/contrib/lite/schema/BUILD +++ b/tensorflow/contrib/lite/schema/BUILD @@ -48,7 +48,7 @@ exports_files([ "schema_v3.fbs", ]) -load("//third_party/flatbuffers:build_defs.bzl", "flatbuffer_cc_library") +load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") # Generic schema for inference on device. flatbuffer_cc_library( diff --git a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc index cd46a06f7d173d87d04c2ff0910190ecd40a1954..11057203a816713a3d075baec5622ed7bb3f4717 100644 --- a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc +++ b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include #include -#include "flatbuffers/flatc.h" +#include "flatbuffers/flatc.h" // flatbuffers #include "tensorflow/core/platform/platform.h" #ifdef PLATFORM_GOOGLE diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index dcad77ccbbd1f447c75c90d5b6a90d55bb4cdd14..cf66403ec935ebfee2df2398f68276d740c520b1 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -171,6 +171,8 @@ enum BuiltinOperator : byte { LOGICAL_NOT = 87, UNPACK = 88, REDUCE_MIN = 89, + FLOOR_DIV = 90, + REDUCE_ANY = 91, } // Options for the builtin operators. @@ -239,6 +241,7 @@ union BuiltinOptions { LogicalAndOptions, LogicalNotOptions, UnpackOptions, + FloorDivOptions, } enum Padding : byte { SAME, VALID } @@ -573,6 +576,9 @@ table UnpackOptions { axis:int; } +table FloorDivOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index a2ea43f37012dc4718e06c12593312a919cc98fc..6d9630d75e53f4045debdce72acf29354c491720 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -223,6 +223,9 @@ struct LogicalNotOptionsT; struct UnpackOptions; struct UnpackOptionsT; +struct FloorDivOptions; +struct FloorDivOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -378,11 +381,13 @@ enum BuiltinOperator { BuiltinOperator_LOGICAL_NOT = 87, BuiltinOperator_UNPACK = 88, BuiltinOperator_REDUCE_MIN = 89, + BuiltinOperator_FLOOR_DIV = 90, + BuiltinOperator_REDUCE_ANY = 91, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_REDUCE_MIN + BuiltinOperator_MAX = BuiltinOperator_REDUCE_ANY }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[89] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[91] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -472,7 +477,9 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[89] { BuiltinOperator_LOGICAL_AND, BuiltinOperator_LOGICAL_NOT, BuiltinOperator_UNPACK, - BuiltinOperator_REDUCE_MIN + BuiltinOperator_REDUCE_MIN, + BuiltinOperator_FLOOR_DIV, + BuiltinOperator_REDUCE_ANY }; return values; } @@ -569,6 +576,8 @@ inline const char **EnumNamesBuiltinOperator() { "LOGICAL_NOT", "UNPACK", "REDUCE_MIN", + "FLOOR_DIV", + "REDUCE_ANY", nullptr }; return names; @@ -645,11 +654,12 @@ enum BuiltinOptions { BuiltinOptions_LogicalAndOptions = 62, BuiltinOptions_LogicalNotOptions = 63, BuiltinOptions_UnpackOptions = 64, + BuiltinOptions_FloorDivOptions = 65, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_UnpackOptions + BuiltinOptions_MAX = BuiltinOptions_FloorDivOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[65] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[66] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -715,7 +725,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[65] { BuiltinOptions_OneHotOptions, BuiltinOptions_LogicalAndOptions, BuiltinOptions_LogicalNotOptions, - BuiltinOptions_UnpackOptions + BuiltinOptions_UnpackOptions, + BuiltinOptions_FloorDivOptions }; return values; } @@ -787,6 +798,7 @@ inline const char **EnumNamesBuiltinOptions() { "LogicalAndOptions", "LogicalNotOptions", "UnpackOptions", + "FloorDivOptions", nullptr }; return names; @@ -1057,6 +1069,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_UnpackOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_FloorDivOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1600,6 +1616,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_UnpackOptions ? reinterpret_cast(value) : nullptr; } + FloorDivOptionsT *AsFloorDivOptions() { + return type == BuiltinOptions_FloorDivOptions ? + reinterpret_cast(value) : nullptr; + } + const FloorDivOptionsT *AsFloorDivOptions() const { + return type == BuiltinOptions_FloorDivOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -5739,6 +5763,46 @@ inline flatbuffers::Offset CreateUnpackOptions( flatbuffers::Offset CreateUnpackOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnpackOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct FloorDivOptionsT : public flatbuffers::NativeTable { + typedef FloorDivOptions TableType; + FloorDivOptionsT() { + } +}; + +struct FloorDivOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef FloorDivOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + FloorDivOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(FloorDivOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct FloorDivOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit FloorDivOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + FloorDivOptionsBuilder &operator=(const FloorDivOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateFloorDivOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + FloorDivOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateFloorDivOptions(flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -6064,6 +6128,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const UnpackOptions *builtin_options_as_UnpackOptions() const { return builtin_options_type() == BuiltinOptions_UnpackOptions ? static_cast(builtin_options()) : nullptr; } + const FloorDivOptions *builtin_options_as_FloorDivOptions() const { + return builtin_options_type() == BuiltinOptions_FloorDivOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -6351,6 +6418,10 @@ template<> inline const UnpackOptions *Operator::builtin_options_as inline const FloorDivOptions *Operator::builtin_options_as() const { + return builtin_options_as_FloorDivOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -8567,6 +8638,29 @@ inline flatbuffers::Offset CreateUnpackOptions(flatbuffers::FlatB _axis); } +inline FloorDivOptionsT *FloorDivOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new FloorDivOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void FloorDivOptions::UnPackTo(FloorDivOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset FloorDivOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateFloorDivOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateFloorDivOptions(flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FloorDivOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateFloorDivOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -9012,6 +9106,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_FloorDivOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -9286,6 +9384,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_FloorDivOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -9548,6 +9650,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateUnpackOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_FloorDivOptions: { + auto ptr = reinterpret_cast(value); + return CreateFloorDivOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -9810,6 +9916,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new UnpackOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_FloorDivOptions: { + value = new FloorDivOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -10137,6 +10247,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_FloorDivOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 599c82940e188ad37dc6386d1cc8246fab054331..57134ccd15787568e7863e9825ab94af5b8090f6 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -780,10 +780,15 @@ def make_binary_op_tests(zip_path, binary_operator): "input_shape_2": [[5]], "activation": [False, True] }, { - "dtype": [tf.float32], + "dtype": [tf.float32, tf.int32], "input_shape_1": [[1, 3, 4, 3]], "input_shape_2": [[3]], - "activation": [True] + "activation": [True, False] + }, { + "dtype": [tf.float32, tf.int32], + "input_shape_1": [[3]], + "input_shape_2": [[1, 3, 4, 3]], + "activation": [True, False] }, { "dtype": [tf.float32], "input_shape_1": [[]], @@ -821,13 +826,17 @@ def make_binary_op_tests(zip_path, binary_operator): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_reduce_tests(reduce_op, min_value=-10, max_value=10): +def make_reduce_tests(reduce_op, + min_value=-10, + max_value=10, + boolean_tensor_only=False): """Make a set of tests to do reduce operation. Args: reduce_op: TensorFlow reduce operation to test, i.e. `tf.reduce_mean`. min_value: min value for created tensor data. max_value: max value for created tensor data. + boolean_tensor_only: If true, will only generate tensor with boolean value. Returns: a function representing the true generator with `reduce_op_in` curried. @@ -867,10 +876,11 @@ def make_reduce_tests(reduce_op, min_value=-10, max_value=10): def build_graph(parameters): """Build the mean op testing graph.""" + dtype = parameters["input_dtype"] + if boolean_tensor_only: + dtype = tf.bool input_tensor = tf.placeholder( - dtype=parameters["input_dtype"], - name="input", - shape=parameters["input_shape"]) + dtype=dtype, name="input", shape=parameters["input_shape"]) # Get axis as either a placeholder or constants. if parameters["const_axis"]: @@ -889,9 +899,12 @@ def make_reduce_tests(reduce_op, min_value=-10, max_value=10): return input_tensors, [out] def build_inputs(parameters, sess, inputs, outputs): + dtype = parameters["input_dtype"] + if boolean_tensor_only: + dtype = tf.bool values = [ create_tensor_data( - parameters["input_dtype"], + dtype, parameters["input_shape"], min_value=min_value, max_value=max_value) @@ -931,6 +944,11 @@ def make_reduce_min_tests(zip_path): return make_reduce_tests(tf.reduce_min)(zip_path) +def make_reduce_any_tests(zip_path): + """Make a set of tests to do any.""" + return make_reduce_tests(tf.reduce_any, boolean_tensor_only=True)(zip_path) + + def make_exp_tests(zip_path): """Make a set of tests to do exp.""" @@ -1085,6 +1103,10 @@ def make_pow_tests(zip_path): make_binary_op_tests(zip_path, tf.pow) +def make_floor_div_tests(zip_path): + make_binary_op_tests(zip_path, tf.floor_div) + + def make_gather_tests(zip_path): """Make a set of tests to do gather.""" @@ -2378,7 +2400,7 @@ def make_lstm_tests(zip_path): "time_step_size": [1], "input_vec_size": [3], "num_cells": [4], - "split_tflite_lstm_inputs": [True, False], + "split_tflite_lstm_inputs": [False], }, ] @@ -3149,6 +3171,36 @@ def make_pack_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_unpack_tests(zip_path): + """Make a set of tests to do unstack.""" + + test_parameters = [{ + "base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]], + "axis": [0, 1, 2, 3], + }] + + def get_valid_axis(parameters): + """Return a tweaked version of 'axis'.""" + axis = parameters["axis"] + shape = parameters["base_shape"][:] + while axis > len(shape) - 1: + axis -= 1 + return axis + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name=("input"), shape=parameters["base_shape"]) + outs = tf.unstack(input_tensor, axis=get_valid_axis(parameters)) + return [input_tensor], outs + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(np.float32, shape=parameters["base_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def _make_logical_tests(op): """Make a set of tests to do logical operations.""" diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index e67fee2a1ca40790a171dc236dd2d85203690a62..37c7ae0e1cd31835d9df966b2b8ae692b09208e4 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -101,6 +101,15 @@ std::map kBrokenTests = { "77546240"}, {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.\])", "77546240"}, + + // No Support for float. + {R"(^\/floor_div.*dtype=tf\.float32)", "112859002"}, + + // Relu does not support int32. + // These test cases appends a Relu after the tested ops when + // activation=True. The tests are failing since Relu doesn't support int32. + {R"(^\/div.*activation=True.*dtype=tf\.int32)", "112968789"}, + {R"(^\/floor_div.*activation=True.*dtype=tf\.int32)", "112968789"}, }; // Allows test data to be unarchived into a temporary directory and makes diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index 4dacf9c84ba725ba04ce25a6cbd1f1a20c60891a..1836eb53b9af2743cd11ed8e8ff990c1eb2dcf30 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -302,28 +302,6 @@ bool TfLiteDriver::CheckResults() { void TfLiteDriver::ResetLSTMStateTensors() { interpreter_->ResetVariableTensorsToZero(); - - // Below is a workaround for initializing state tensors for LSTM. - // TODO(ycling): Remove the code below after nobody is using the 18-inputs - // definition. - for (auto node_index : interpreter_->execution_plan()) { - const auto& node_and_reg = interpreter_->node_and_registration(node_index); - const auto& node = node_and_reg->first; - const auto& registration = node_and_reg->second; - - if (registration.builtin_code == tflite::BuiltinOperator_LSTM) { - const auto* params = - reinterpret_cast(node.builtin_data); - if (params->kernel_type == kTfLiteLSTMFullKernel && - node.inputs->size == 18 && node.outputs->size >= 2) { - // The first 2 outputs of LSTM are state tensors. - for (int i = 0; i < 2; ++i) { - int node_index = node.outputs->data[i]; - ResetTensor(node_index); - } - } - } - } } } // namespace testing diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index f489c5ac653b4c4a765bdb8345d97123b2026ea3..6fdf47dedc0943e037fbfc75470d5acd72708819 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -1900,21 +1900,6 @@ void ConvertPowOperator(const Model& model, const PowOperator& src_op, (*pow_op->mutable_attr())["T"].set_type(data_type); } -void ConvertAnyOperator(const Model& model, const AnyOperator& src_op, - GraphDef* tensorflow_graph) { - tensorflow::NodeDef* any_op = tensorflow_graph->add_node(); - any_op->set_op("Any"); - any_op->set_name(src_op.outputs[0]); - CHECK_EQ(src_op.inputs.size(), 2); - for (int i = 0; i < 2; ++i) { - *any_op->add_input() = src_op.inputs[i]; - } - const tensorflow::DataType data_type = - GetTensorFlowDataType(model, src_op.inputs[1]); - (*any_op->mutable_attr())["Tidx"].set_type(data_type); - (*any_op->mutable_attr())["keep_dims"].set_b(src_op.keep_dims); -} - void ConvertLogicalAndOperator(const Model& model, const LogicalAndOperator& src_op, GraphDef* tensorflow_graph) { @@ -1967,6 +1952,20 @@ void ConvertCTCBeamSearchDecoderOperator( (*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated); } +void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + tensorflow::NodeDef* unpack_op = tensorflow_graph->add_node(); + unpack_op->set_op(op_name); + unpack_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *unpack_op->add_input() = src_op.inputs[0]; + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*unpack_op->mutable_attr())["T"].set_type(data_type); + (*unpack_op->mutable_attr())["num"].set_i(src_op.num); + (*unpack_op->mutable_attr())["axis"].set_i(src_op.axis); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -2207,8 +2206,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertPowOperator(model, static_cast(src_op), "Pow", tensorflow_graph); } else if (src_op.type == OperatorType::kAny) { - ConvertAnyOperator(model, static_cast(src_op), - tensorflow_graph); + ConvertReduceOperator(model, + static_cast(src_op), + tensorflow_graph, "Any"); } else if (src_op.type == OperatorType::kLogicalAnd) { ConvertLogicalAndOperator(model, static_cast(src_op), @@ -2228,6 +2228,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertCTCBeamSearchDecoderOperator( model, static_cast(src_op), "CTCBeamSearchDecoder", tensorflow_graph); + } else if (src_op.type == OperatorType::kUnpack) { + ConvertUnpackOperator(model, static_cast(src_op), + "Unpack", tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index c8310161cb33bcc7137e8b163ea6469698ed2fd7..323eefcd3a7665a8c01da1bc10d6f8d80da7a15d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -227,6 +227,15 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { ArrayDataType::kFloat; break; } + case OperatorType::kUnpack: { + CHECK_EQ(op->inputs.size(), 1); + const int output_size = op->outputs.size(); + for (int i = 0; i < output_size; ++i) { + model->GetArray(op->outputs[i]).data_type = + model->GetArray(op->inputs[0]).data_type; + } + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 91e290439ae4bfd491c8201b02b161fe2caf2f8d..28effc2a6730baa9ffba8dda934f02cd2a920cec 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -539,6 +539,8 @@ bool KeepDims(const Operator& op) { return static_cast(op).keep_dims; case OperatorType::kMean: return static_cast(op).keep_dims; + case OperatorType::kAny: + return static_cast(op).keep_dims; default: LOG(FATAL) << "Not a reduction operator!"; return false; @@ -1515,65 +1517,6 @@ void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) { } } -void ProcessAnyOperator(Model* model, AnyOperator* op) { - CHECK_EQ(op->inputs.size(), 2); - CHECK_EQ(op->outputs.size(), 1); - - auto& output_array = model->GetArray(op->outputs[0]); - if (output_array.has_shape()) { - // We have already run. - return; - } - - const auto& input_array = model->GetArray(op->inputs[0]); - if (!input_array.has_shape()) { - // Yield until input dims have been resolved. - return; - } - const auto& input_shape = input_array.shape(); - - auto& reduction_indices_array = model->GetArray(op->inputs[1]); - if (!reduction_indices_array.has_shape()) { - // Yield until reduction indices shape been resolved. - return; - } - if (!reduction_indices_array.buffer) { - // Yield until the reduction indices are constant. - return; - } - CHECK(reduction_indices_array.data_type == ArrayDataType::kInt32) - << "Any reduction input must be int32"; - - int input_rank = input_shape.dimensions_count(); - std::set true_indices; - const auto& reduction_indices = - reduction_indices_array.GetBuffer().data; - for (int i = 0; i < reduction_indices.size(); ++i) { - const int32 reduction_index = reduction_indices[i]; - if (reduction_index < -input_rank || reduction_index >= input_rank) { - CHECK(false) << "Invalid reduction dimension " << reduction_index - << " for input with " << input_rank << " dimensions"; - } - int32 wrapped_index = reduction_index; - if (wrapped_index < 0) { - wrapped_index += input_rank; - } - true_indices.insert(wrapped_index); - } - - auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); - mutable_dims->clear(); - for (int i = 0; i < input_rank; ++i) { - if (true_indices.count(i) > 0) { - if (op->keep_dims) { - mutable_dims->emplace_back(1); - } - } else { - mutable_dims->emplace_back(input_shape.dims(i)); - } - } -} - void ProcessOneHotOperator(Model* model, OneHotOperator* op) { CHECK_EQ(op->inputs.size(), 4); CHECK_EQ(op->outputs.size(), 1); @@ -1629,6 +1572,32 @@ void ProcessOneHotOperator(Model* model, OneHotOperator* op) { } } +void ProcessUnpackOperator(Model* model, UnpackOperator* op) { + CHECK_EQ(op->inputs.size(), 1); + const auto& input_array = model->GetArray(op->inputs[0]); + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + + const std::vector& input_dims = input_array.shape().dims(); + std::vector output_dims; + + output_dims.reserve(input_dims.size() - 1); + for (int i = 0; i < input_dims.size(); ++i) { + if (i != op->axis) { + output_dims.push_back(input_dims[i]); + } + } + for (const string& output_name : op->outputs) { + auto& output_array = model->GetArray(output_name); + if (output_array.has_shape()) { + return; + } + *output_array.mutable_shape()->mutable_dims() = output_dims; + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1743,6 +1712,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kSum: case OperatorType::kReduceProd: case OperatorType::kMean: + case OperatorType::kAny: ProcessTensorFlowReductionOperator(model, op); break; case OperatorType::kSelect: @@ -1874,12 +1844,13 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kTile: ProcessTileOperator(model, static_cast(op)); break; - case OperatorType::kAny: - ProcessAnyOperator(model, static_cast(op)); break; case OperatorType::kOneHot: ProcessOneHotOperator(model, static_cast(op)); break; + case OperatorType::kUnpack: + ProcessUnpackOperator(model, static_cast(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index b7fffbce2223a71ac1e16ec1ce18ba9f610cc2ac..cb6da21039540cc7a1588ba10c19f31893028b42 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1576,6 +1576,26 @@ tensorflow::Status ConvertPackOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertUnpackOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Unpack"); + auto op = absl::make_unique(); + const int num_inputs = GetInputsCount(node, tf_import_flags); + QCHECK_EQ(num_inputs, 1); + op->inputs.push_back(node.input(0)); + op->num = GetIntAttr(node, "num"); + op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0; + op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T")); + + op->outputs.push_back(node.name()); // Implicit :0. + for (int i = 1; i < op->num; ++i) { + op->outputs.push_back(node.name() + ":" + std::to_string(i)); + } + model->operators.emplace_back(std::move(op)); + return tensorflow::Status::OK(); +} + // Some TensorFlow ops only occur in graph cycles, representing // control flow. We do not currently support control flow, so we wouldn't // be able to fully support such graphs, including performing inference, @@ -1618,24 +1638,6 @@ tensorflow::Status ConvertShapeOperator( return tensorflow::Status::OK(); } -tensorflow::Status ConvertAnyOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Any"); - TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - const auto idx_type = - HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; - CHECK(idx_type == DT_INT32); - auto op = absl::make_unique(); - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - op->keep_dims = - HasAttr(node, "keep_dims") ? GetBoolAttr(node, "keep_dims") : false; - model->operators.push_back(std::move(op)); - return tensorflow::Status::OK(); -} - void StripCaretFromArrayNames(Model* model) { for (auto& op : model->operators) { for (auto& input : op->inputs) { @@ -1917,7 +1919,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Add", ConvertSimpleOperator}, {"AddN", ConvertSimpleOperator}, {"All", ConvertSimpleOperator}, - {"Any", ConvertAnyOperator}, + {"Any", ConvertReduceOperator}, {"ArgMax", ConvertArgMaxOperator}, {"ArgMin", ConvertArgMinOperator}, {"Assert", ConvertSimpleOperator}, @@ -2020,6 +2022,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"TopK", ConvertTopKV2Operator}, {"TopKV2", ConvertTopKV2Operator}, {"Transpose", ConvertSimpleOperator}, + {"Unpack", ConvertUnpackOperator}, }); } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 412e14c4ada3280dafcd2fcfa59e2908dd785f9f..fa1c459f0ecf7b2880727db1963775d702386cfe 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -149,6 +149,7 @@ enum class OperatorType : uint8 { kLogicalNot, kLogicalOr, kCTCBeamSearchDecoder, + kUnpack, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -1767,11 +1768,11 @@ struct PowOperator : Operator { // // Inputs: // Inputs[0]: required: A boolean input tensor. -// Inputs[1]: required: reduction_indices. // // TensorFlow equivalent: tf.reduce_any. -struct AnyOperator : Operator { - AnyOperator() : Operator(OperatorType::kAny) {} +struct TensorFlowAnyOperator : Operator { + TensorFlowAnyOperator() : Operator(OperatorType::kAny) {} + std::vector axis; bool keep_dims = false; }; @@ -1828,6 +1829,20 @@ struct LogicalOrOperator : Operator { LogicalOrOperator() : Operator(OperatorType::kLogicalOr) {} }; +// Unpack operator: +// +// Inputs: +// Inputs[0]: required: A boolean input tensor. +// Inputs[1]: required: reduction_indices. +// +// TensorFlow equivalent: tf.unstack. +struct UnpackOperator : Operator { + UnpackOperator() : Operator(OperatorType::kUnpack) {} + int num; + int axis; + ArrayDataType dtype = ArrayDataType::kNone; +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index dcb5fff39fe8a2bc52a5ebbd7b50eea297a0ad1d..a314c8d53ac430632cc1fbbbb4226a14eb7eb1bd 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -769,7 +769,7 @@ class Sum }; class ReduceMax - : public BuiltinOperator { public: using BuiltinOperator::BuiltinOperator; @@ -788,7 +788,7 @@ class ReduceMax }; class ReduceMin - : public BuiltinOperator { public: using BuiltinOperator::BuiltinOperator; @@ -807,7 +807,26 @@ class ReduceMin }; class ReduceProd - : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->keep_dims = options.keep_dims(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class ReduceAny + : public BuiltinOperator { public: using BuiltinOperator::BuiltinOperator; @@ -1110,6 +1129,24 @@ class CTCBeamSearchDecoder int GetVersion(const Operator& op) const override { return 1; } }; +class Unpack : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis); + } + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->num = options.num(); + op->axis = options.axis(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -1318,6 +1355,8 @@ std::vector> BuildOperatorList() { OperatorType::kReduceMax)); ops.push_back(MakeUnique(::tflite::BuiltinOperator_REDUCE_MIN, OperatorType::kReduceMin)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_REDUCE_ANY, + OperatorType::kAny)); ops.push_back( MakeUnique(::tflite::BuiltinOperator_RESIZE_BILINEAR, OperatorType::kResizeBilinear)); @@ -1353,6 +1392,8 @@ std::vector> BuildOperatorList() { MakeUnique(::tflite::BuiltinOperator_PACK, OperatorType::kPack)); ops.push_back(MakeUnique(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot)); + ops.push_back(MakeUnique(::tflite::BuiltinOperator_UNPACK, + OperatorType::kUnpack)); // Custom Operators. ops.push_back( @@ -1417,6 +1458,8 @@ std::vector> BuildOperatorList() { "LOGICAL_AND", OperatorType::kLogicalAnd)); ops.emplace_back(new SimpleOperator( "LOGICAL_NOT", OperatorType::kLogicalNot)); + ops.emplace_back(new SimpleOperator( + "FLOOR_DIV", OperatorType::kFloorDiv)); // Element-wise operator ops.push_back( MakeUnique>("SIN", OperatorType::kSin)); diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index fc854461b4e816e12e12590479501b6542258fef..519a3a4e015bed6822ce80487e8e44d61aa0ca58 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -97,6 +97,16 @@ class OperatorTest : public ::testing::Test { ASSERT_NE(nullptr, output_toco_op.get()); } + + template + void CheckReducerOperator(const string& name, OperatorType type) { + T op; + + op.keep_dims = false; + + auto output_toco_op = SerializeAndDeserialize(GetOperator(name, type), op); + EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims); + } }; TEST_F(OperatorTest, SimpleOperators) { @@ -133,6 +143,7 @@ TEST_F(OperatorTest, SimpleOperators) { OperatorType::kLogicalAnd); CheckSimpleOperator("LOGICAL_NOT", OperatorType::kLogicalNot); + CheckSimpleOperator("FLOOR_DIV", OperatorType::kFloorDiv); } TEST_F(OperatorTest, BuiltinAdd) { @@ -144,13 +155,16 @@ TEST_F(OperatorTest, BuiltinAdd) { output_toco_op->fused_activation_function); } -TEST_F(OperatorTest, BuiltinMean) { - MeanOperator op; - op.keep_dims = false; - - auto output_toco_op = - SerializeAndDeserialize(GetOperator("MEAN", OperatorType::kMean), op); - EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims); +TEST_F(OperatorTest, BuiltinReducerOps) { + CheckReducerOperator("MEAN", OperatorType::kMean); + CheckReducerOperator("SUM", OperatorType::kSum); + CheckReducerOperator("REDUCE_PROD", + OperatorType::kReduceProd); + CheckReducerOperator("REDUCE_MAX", + OperatorType::kReduceMax); + CheckReducerOperator("REDUCE_MIN", + OperatorType::kReduceMin); + CheckReducerOperator("REDUCE_ANY", OperatorType::kAny); } TEST_F(OperatorTest, BuiltinCast) { @@ -476,6 +490,16 @@ TEST_F(OperatorTest, BuiltinOneHot) { EXPECT_EQ(op.axis, output_toco_op->axis); } +TEST_F(OperatorTest, BuiltinUnpack) { + UnpackOperator op; + op.num = 5; + op.axis = 2; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("UNPACK", OperatorType::kUnpack), op); + EXPECT_EQ(op.num, output_toco_op->num); + EXPECT_EQ(op.axis, output_toco_op->axis); +} + TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) { CTCBeamSearchDecoderOperator op; op.beam_width = 3; diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 3a4542f52242ba73fd9a9208ca4b2e574b73b1a1..6ab93d931694d34583091dfbdf6c2a6b5b7049c6 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -405,6 +405,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(LogicalNot) HANDLE_OPERATORTYPENAME_CASE(LogicalOr) HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder) + HANDLE_OPERATORTYPENAME_CASE(Unpack) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE diff --git a/tensorflow/contrib/lite/tools/accuracy/BUILD b/tensorflow/contrib/lite/tools/accuracy/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..21941f5c8b928b5bb528016a27a0583988bb57d1 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/BUILD @@ -0,0 +1,314 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "tflite_linkopts") + +common_linkopts = tflite_linkopts() + select({ + "//conditions:default": [], + "//tensorflow:android": [ + "-pie", + "-llog", + ], +}) + +cc_library( + name = "utils", + srcs = ["utils.cc"], + hdrs = ["utils.h"], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + ], + }, + ), +) + +tf_cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + args = [ + "--test_model_file=$(location //tensorflow/contrib/lite:testdata/multi_add.bin)", + ], + data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], + linkopts = common_linkopts, + linkstatic = 1, + deps = [ + ":utils", + "@com_google_googletest//:gtest", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_tensorflow_test_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], + }, + ), +) + +cc_library( + name = "run_tflite_model_op", + srcs = ["run_tflite_model_op.cc"], + copts = tflite_copts(), + deps = [ + ":utils", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:tensorflow", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + ], + }, + ), + alwayslink = 1, +) + +cc_library( + name = "android_required_build_flags", + srcs = ["android_required_build_flags.cc"], + copts = tflite_copts(), +) + +tf_cc_test( + name = "run_tflite_model_op_test", + srcs = ["run_tflite_model_op_test.cc"], + args = [ + "--test_model_file=$(location //tensorflow/contrib/lite:testdata/multi_add.bin)", + ], + data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], + linkopts = common_linkopts, + linkstatic = 1, + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + ":run_tflite_model_op", + ":android_required_build_flags", + "@com_google_googletest//:gtest", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_tensorflow_test_lib", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ], + }, + ), +) + +cc_library( + name = "stage", + hdrs = ["stage.h"], + copts = tflite_copts(), + deps = [ + "//tensorflow/cc:scope", + ], +) + +cc_library( + name = "file_reader_stage", + srcs = ["file_reader_stage.cc"], + hdrs = ["file_reader_stage.h"], + deps = [ + ":stage", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + ], +) + +tf_cc_test( + name = "file_reader_stage_test", + srcs = ["file_reader_stage_test.cc"], + linkopts = common_linkopts, + linkstatic = 1, + deps = [ + ":file_reader_stage", + "@com_google_googletest//:gtest", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core/kernels:android_whole_file_read_ops", + "//tensorflow/core:android_tensorflow_test_lib", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:tensorflow", + ], + }, + ), +) + +cc_library( + name = "run_tflite_model_stage", + srcs = ["run_tflite_model_stage.cc"], + hdrs = ["run_tflite_model_stage.h"], + copts = tflite_copts(), + deps = [ + ":run_tflite_model_op", + ":stage", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + ], +) + +cc_library( + name = "accuracy_eval_stage", + hdrs = ["accuracy_eval_stage.h"], + copts = tflite_copts(), + deps = [ + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + ], + }, + ), +) + +cc_library( + name = "eval_pipeline", + srcs = ["eval_pipeline.cc"], + hdrs = ["eval_pipeline.h"], + copts = tflite_copts(), + deps = [ + ":accuracy_eval_stage", + ":stage", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:core_cpu", + ], + }, + ), +) + +tf_cc_test( + name = "eval_pipeline_test", + srcs = ["eval_pipeline_test.cc"], + linkopts = common_linkopts, + linkstatic = 1, + deps = [ + ":eval_pipeline", + "//tensorflow/cc:cc_ops", + "@com_google_googletest//:gtest", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_tensorflow_test_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:core_cpu", + "//tensorflow/core:ops", + "//tensorflow/core:tensorflow", + ], + }, + ), +) + +cc_library( + name = "eval_pipeline_builder", + srcs = ["eval_pipeline_builder.cc"], + hdrs = ["eval_pipeline_builder.h"], + copts = tflite_copts(), + deps = [ + ":eval_pipeline", + ":accuracy_eval_stage", + ":stage", + "@com_google_absl//absl/memory", + "//tensorflow/cc:cc_ops", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:core_cpu", + "//tensorflow/core:ops", + "//tensorflow/core:tensorflow", + ], + }, + ), +) + +tf_cc_test( + name = "eval_pipeline_builder_test", + srcs = ["eval_pipeline_builder_test.cc"], + linkopts = common_linkopts, + linkstatic = 1, + deps = [ + ":eval_pipeline_builder", + "//tensorflow/cc:cc_ops", + "@com_google_googletest//:gtest", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_tensorflow_test_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:core_cpu", + "//tensorflow/core:ops", + "//tensorflow/core:tensorflow", + ], + }, + ), +) + +cc_library( + name = "csv_writer", + hdrs = ["csv_writer.h"], + copts = tflite_copts(), + deps = select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:lib", + ], + }, + ), +) diff --git a/tensorflow/contrib/lite/tools/accuracy/README.md b/tensorflow/contrib/lite/tools/accuracy/README.md new file mode 100644 index 0000000000000000000000000000000000000000..769ef201d2379b117e859f63596e3b17beea93d5 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/README.md @@ -0,0 +1,40 @@ +## TFLite accuracy library. + +This library provides evaluation pipelines that can be used to evaluate +accuracy and other metrics of a model. The resulting binary can be run on +a desktop or on a mobile device. + +## Usage +The tool provides an evaluation pipeline with different stages. Each +stage outputs a Tensorflow graph. +A sample usage is shown below. + +```C++ +// First build the pipeline. +EvalPipelineBuilder builder; +std::unique_ptr eval_pipeline; +auto status = builder.WithInput("pipeline_input", DT_FLOAT) + .WithInputStage(&input_stage) + .WithRunModelStage(&run_model_stage) + .WithPreprocessingStage(&preprocess_stage) + .WithAccuracyEval(&eval) + .Build(scope, &eval_pipeline); +TF_CHECK_OK(status); + +// Now run the pipeline with inputs and outputs. +std::unique_ptr session(NewSession(SessionOptions())); +TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session))); +Tensor input = ... read input for the model ... +Tensor ground_truth = ... read ground truth for the model ... +TF_CHECK_OK(eval_pipeline.Run(input1, ground_truth1)); +``` +For further examples, check the usage in [imagenet accuracy evaluation binary] +(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc) + +## Measuring accuracy of published models. + +### ILSVRC (Imagenet Large Scale Visual Recognition Contest) classification task +For measuring accuracy for [ILSVRC 2012 image classification task] +(http://www.image-net.org/challenges/LSVRC/2012/), the binary can be built +using these +[instructions.](ilsvrc/) diff --git a/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h b/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h new file mode 100644 index 0000000000000000000000000000000000000000..9cb843729aa8c127814be23f1183b5a9edcb1702 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_ + +#include + +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace metrics { + +// Base class for evaluation stage that evaluates the accuracy of the model. +// This stage calculates the accuracy metrics given the model outputs and +// expected ground truth. +class AccuracyEval { + public: + AccuracyEval() = default; + AccuracyEval(const AccuracyEval&) = delete; + AccuracyEval& operator=(const AccuracyEval&) = delete; + + AccuracyEval(const AccuracyEval&&) = delete; + AccuracyEval& operator=(const AccuracyEval&&) = delete; + + virtual ~AccuracyEval() = default; + + // Evaluates the accuracy of the model for given `model_outputs` and the + // `ground truth`. + // Derived classes can do additional book keeping, calculate aggregrate + // statistics etc for the given model. + virtual Status ComputeEval(const std::vector& model_outputs, + const Tensor& ground_truth) = 0; +}; +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc b/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc new file mode 100644 index 0000000000000000000000000000000000000000..7fa8986716b8cbc2251c9a22274f7b5d1cf467b1 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tensorflow on Android requires selective registration to be enabled in order +// for certain types (e.g. DT_UINT8) to work. +// Checks below ensure that for Android build, the right flags are passed to +// the compiler. + +#if defined(__ANDROID__) && (!defined(__ANDROID_TYPES_FULL__) || \ + !defined(SUPPORT_SELECTIVE_REGISTRATION)) +#error \ + "Binary needs custom kernel support. For enabling custom kernels on " \ + "Android, please pass -D__ANDROID_TYPES_FULL__ && " \ + "-DSUPPORT_SELECTIVE_REGISTRATION for including the kernel in the binary." +#endif diff --git a/tensorflow/contrib/lite/tools/accuracy/csv_writer.h b/tensorflow/contrib/lite/tools/accuracy/csv_writer.h new file mode 100644 index 0000000000000000000000000000000000000000..806b0d9418e8b03b92c0f33b6d531ce248ae43a6 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/csv_writer.h @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_ + +#include +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace metrics { +// A simple CSV writer that writes values of same type for fixed number of +// columns. This supports a very limited set of CSV spec and doesn't do any +// escaping. +// Usage: +// std::ofstream * output_stream = ... +// CSVWriter writer({"column1", "column2"}, output_stream); +// writer.WriteRow({4, 5}); +// writer.Flush(); // flush results immediately. +class CSVWriter { + public: + CSVWriter(const std::vector& columns, std::ofstream* output_stream) + : num_columns_(columns.size()), output_stream_(output_stream) { + TF_CHECK_OK(WriteRow(columns, output_stream_)); + } + + template + Status WriteRow(const std::vector& values) { + if (values.size() != num_columns_) { + return errors::InvalidArgument("Invalid size for row:", values.size(), + " expected: ", num_columns_); + } + return WriteRow(values, output_stream_); + } + + void Flush() { output_stream_->flush(); } + + ~CSVWriter() { output_stream_->flush(); } + + private: + template + static Status WriteRow(const std::vector& values, + std::ofstream* output_stream) { + bool first = true; + for (const auto& v : values) { + if (!first) { + (*output_stream) << ", "; + } else { + first = false; + } + (*output_stream) << v; + } + (*output_stream) << "\n"; + if (!output_stream->good()) { + return errors::Internal("Writing to stream failed."); + } + return Status::OK(); + } + const size_t num_columns_; + std::ofstream* output_stream_; +}; +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc new file mode 100644 index 0000000000000000000000000000000000000000..a03aba6a2685db7a535829f98303174e9399b94d --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h" + +namespace tensorflow { +namespace metrics { + +Status EvalPipeline::AttachSession(std::unique_ptr session) { + session_ = std::move(session); + TF_RETURN_IF_ERROR(session_->Create(model_graph_)); + return Status::OK(); +} + +Status EvalPipeline::Run(const Tensor& input, const Tensor& ground_truth) { + if (session_ == nullptr) { + return errors::Internal("No session is associated with the graph."); + } + std::vector outputs; + TF_RETURN_IF_ERROR(session_->Run({{params_.model_input_node_name, input}}, + {params_.model_output_node_name}, {}, + &outputs)); + TF_RETURN_IF_ERROR(eval_->ComputeEval(outputs, ground_truth)); + return Status::OK(); +} +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h new file mode 100644 index 0000000000000000000000000000000000000000..c9cfc866139da86d7de2036a07315e66dfaf60f0 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h @@ -0,0 +1,87 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_ + +#include + +#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h" +#include "tensorflow/contrib/lite/tools/accuracy/stage.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace metrics { + +// Pipeline for evaluating a model. +// Runs the graph and passes the output of graph to +// the provided instance of AccuracyEval. +// Example usage: +// AccuracyEval *eval; +// GraphDef graph_def; +// ... populate graph_def... +// +// EvalPipeline eval_pipeline(&graph_def, +// {.model_input_node_name = "model_input", +// .model_output_node_name = "model_output"}, +// eval); +// std::unique_ptr session(NewSession(SessionOptions())); +// TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session))); +// Tensor input = ... read input for the model ... +// Tensor ground_truth = ... read ground truth for the model ... +// TF_CHECK_OK(eval_pipeline.Run(input, ground_truth)); +// +class EvalPipeline { + public: + struct Params { + string model_input_node_name; + string model_output_node_name; + }; + + // Creates a new `EvalPipeline` object. The ownership of the `accuracy_eval` + // is retained by the caller. Lifetime of `accuracy_eval` instance should + // be longer than the lifetime of this instance of pipeline. + EvalPipeline(const GraphDef& graph, const Params& params, + AccuracyEval* accuracy_eval) + : model_graph_(graph), + params_(params), + eval_(accuracy_eval), + session_(nullptr) {} + + EvalPipeline(const EvalPipeline&) = delete; + EvalPipeline& operator=(const EvalPipeline&) = delete; + + EvalPipeline(const EvalPipeline&&) = delete; + EvalPipeline& operator=(const EvalPipeline&&) = delete; + + // Attaches the given session to this instance of pipeline. + // The provided session object will be reused for subsequent calls to + // EvalPipeline::Run. + Status AttachSession(std::unique_ptr session); + + // Runs the model by feeding `input` and then passes the output of the model + // along with provided `ground_truth` to the AccuracyEval instance by calling + // AccuracyEval::ComputeEval. + Status Run(const Tensor& input, const Tensor& ground_truth); + + private: + GraphDef model_graph_; + Params params_; + AccuracyEval* eval_; + std::unique_ptr session_; +}; +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e16437e1588b400b915a488e402a52efa3b755c --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc @@ -0,0 +1,100 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h" + +#include "absl/memory/memory.h" +#include "tensorflow/cc/ops/standard_ops.h" + +namespace tensorflow { +namespace metrics { + +EvalPipelineBuilder& EvalPipelineBuilder::WithInputStage(Stage* input_stage) { + input_stage_ = input_stage; + return *this; +} + +EvalPipelineBuilder& EvalPipelineBuilder::WithPreprocessingStage( + Stage* preprocessing_stage) { + preprocessing_stage_ = preprocessing_stage; + return *this; +} + +EvalPipelineBuilder& EvalPipelineBuilder::WithRunModelStage( + Stage* run_model_stage) { + run_model_stage_ = run_model_stage; + return *this; +} + +EvalPipelineBuilder& EvalPipelineBuilder::WithAccuracyEval( + AccuracyEval* accuracy_eval) { + accuracy_eval_ = accuracy_eval; + return *this; +} + +EvalPipelineBuilder& EvalPipelineBuilder::WithInput(const string& input_name, + DataType input_type) { + input_name_ = input_name; + input_type_ = input_type; + return *this; +} + +Status EvalPipelineBuilder::Build( + const Scope& scope, std::unique_ptr* eval_pipeline) { + if (input_stage_ == nullptr) { + return errors::InvalidArgument("Input stage is null."); + } + if (preprocessing_stage_ == nullptr) { + return errors::InvalidArgument("Preprocessing stage is null."); + } + if (run_model_stage_ == nullptr) { + return errors::InvalidArgument("Run model stage is null."); + } + if (accuracy_eval_ == nullptr) { + return errors::InvalidArgument("accuracy_eval is null."); + } + if (input_name_.empty()) { + return errors::InvalidArgument("input name is not set."); + } + if (input_type_ == DT_INVALID) { + return errors::InvalidArgument("input type is not set."); + } + + auto input_placeholder = + ops::Placeholder(scope.WithOpName(input_name_), input_type_); + TF_RETURN_IF_ERROR(scope.status()); + + input_stage_->AddToGraph(scope, input_placeholder); + TF_RETURN_IF_ERROR(scope.status()); + + preprocessing_stage_->AddToGraph(scope, input_stage_->Output()); + TF_RETURN_IF_ERROR(scope.status()); + + run_model_stage_->AddToGraph(scope, preprocessing_stage_->Output()); + TF_RETURN_IF_ERROR(scope.status()); + + GraphDef graph_def; + TF_RETURN_IF_ERROR(scope.ToGraphDef(&graph_def)); + EvalPipeline::Params params; + params.model_input_node_name = input_name_; + params.model_output_node_name = run_model_stage_->output_name(); + *eval_pipeline = + absl::make_unique(graph_def, params, accuracy_eval_); + + return Status::OK(); +} + +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..692db022f8bc747979337dec7f08af9fcb6932fa --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_ + +#include +#include + +#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h" +#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h" +#include "tensorflow/contrib/lite/tools/accuracy/stage.h" + +namespace tensorflow { +namespace metrics { + +// A builder to simplify construction of an `EvalPipeline` instance. +// The `Build` method creates an |EvalPipeline| with the following structure: +// |input| -> |input_stage| +// |--> |preprocessing_stage| +// |--> |run_model_stage| -> |accuracy_eval_stage|. +// The stages are chained in the order shown above. Any missing stage results in +// an error. The ownership of the stage object is retained by the caller. Stage +// objects need to exist until the |Build| method is called. +// +// Currently only single inputs are supported. +// +// Example Usage: +// EvalPipelineBuilder builder; +// std::unique_ptr eval_pipeline; +// auto status = builder.WithInput("pipeline_input", DT_FLOAT) +// .WithInputStage(&input_stage) +// .WithRunModelStage(&run_model_stage) +// .WithPreprocessingStage(&preprocess_stage) +// .WithAccuracyEval(&eval) +// .Build(scope, &eval_pipeline); +// TF_CHECK_OK(status); +class EvalPipelineBuilder { + public: + EvalPipelineBuilder() = default; + EvalPipelineBuilder(const EvalPipelineBuilder&) = delete; + EvalPipeline& operator=(const EvalPipelineBuilder&) = delete; + + EvalPipelineBuilder(const EvalPipelineBuilder&&) = delete; + EvalPipeline& operator=(const EvalPipelineBuilder&&) = delete; + + // Sets the input stage for the pipeline. + // Input stage converts the input, say filename into appropriate format + // that can be consumed by the preprocessing stage. + EvalPipelineBuilder& WithInputStage(Stage* input_stage); + + // Sets the preprocessing stage for the pipeline. + // Preprocessing stage converts the input into a format that can be used to + // run the model. + EvalPipelineBuilder& WithPreprocessingStage(Stage* preprocessing_stage); + + // Sets the run model stage for the pipeline. + // This stage receives the preprocessing input and output of this stage is + // fed to the accuracy eval stage. + EvalPipelineBuilder& WithRunModelStage(Stage* run_model_stage); + + // Sets the accuracy eval for the pipeline. + // Results of evaluating the pipeline are fed to the `accuracy_eval` instance. + EvalPipelineBuilder& WithAccuracyEval(AccuracyEval* accuracy_eval); + + // Sets the name and type of input for the pipeline. + // TODO(shashishekhar): Support multiple inputs for the pipeline, use a vector + // here. + EvalPipelineBuilder& WithInput(const string& input_name, DataType input_type); + + // Builds the pipeline and assigns the pipeline to `eval_pipeline`. + // If the pipeline creation fails `eval_pipeline` is untouched. + Status Build(const Scope& scope, + std::unique_ptr* eval_pipeline); + + private: + Stage* input_stage_ = nullptr; + Stage* preprocessing_stage_ = nullptr; + Stage* run_model_stage_ = nullptr; + AccuracyEval* accuracy_eval_ = nullptr; + string input_name_; + DataType input_type_ = DT_INVALID; +}; + +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d41929b7920f403cb6b9858a7c54cb13273fb95 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc @@ -0,0 +1,229 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h" +#include +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace metrics { +namespace { + +class IdentityStage : public Stage { + public: + IdentityStage(const string& name, const string& output) + : name_(name), output_(output) {} + + void AddToGraph(const Scope& scope, const Input& input) override { + called_count_++; + inputs_.push_back(input.node()->name()); + stage_output_ = ops::Identity(scope.WithOpName(output_), input); + } + + string name() const override { return name_; } + string output_name() const override { return output_; } + + int times_called() const { return called_count_; } + + const std::vector input_params() { return inputs_; } + + private: + string name_; + string output_; + int called_count_ = 0; + std::vector inputs_; +}; + +class FailingStage : public Stage { + public: + FailingStage(const string& name, const string& output) + : name_(name), output_(output) {} + + void AddToGraph(const Scope& scope, const Input& input) override { + called_count_++; + scope.UpdateStatus(errors::Internal("Stage failed:", name_)); + } + + string name() const override { return name_; } + string output_name() const override { return output_; } + + int times_called() const { return called_count_; } + + private: + string name_; + string output_; + int called_count_ = 0; +}; + +class SimpleAccuracyEval : public AccuracyEval { + public: + SimpleAccuracyEval() {} + + Status ComputeEval(const std::vector& model_outputs, + const Tensor& ground_truth) override { + return Status::OK(); + } +}; + +TEST(EvalPipelineBuilder, MissingPipelineStages) { + IdentityStage input_stage("input_stage", "input_stage_out"); + IdentityStage run_model_stage("run_model", "run_model_out"); + IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out"); + const string pipeline_input = "pipeline_input"; + + SimpleAccuracyEval eval; + + Scope scope = Scope::NewRootScope(); + std::unique_ptr eval_pipeline; + EvalPipelineBuilder builder; + auto status = + builder.WithInputStage(&input_stage).Build(scope, &eval_pipeline); + EXPECT_FALSE(status.ok()); + EXPECT_FALSE(eval_pipeline); + + status = + builder.WithRunModelStage(&run_model_stage).Build(scope, &eval_pipeline); + EXPECT_FALSE(status.ok()); + EXPECT_FALSE(eval_pipeline); + + status = builder.WithPreprocessingStage(&preprocess_stage) + .Build(scope, &eval_pipeline); + EXPECT_FALSE(status.ok()); + EXPECT_FALSE(eval_pipeline); + + status = + builder.WithInput(pipeline_input, DT_FLOAT).Build(scope, &eval_pipeline); + EXPECT_FALSE(status.ok()); + EXPECT_FALSE(eval_pipeline); + + status = builder.WithAccuracyEval(&eval).Build(scope, &eval_pipeline); + TF_CHECK_OK(status); + EXPECT_TRUE(eval_pipeline); +} + +TEST(EvalPipeline, InputStageFailure) { + FailingStage input_stage("input_stage", "input_stage_out"); + IdentityStage run_model_stage("run_model", "run_model_out"); + IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out"); + const string pipeline_input = "pipeline_input"; + + SimpleAccuracyEval eval; + + Scope scope = Scope::NewRootScope(); + std::unique_ptr eval_pipeline; + EvalPipelineBuilder builder; + auto status = builder.WithInputStage(&input_stage) + .WithRunModelStage(&run_model_stage) + .WithPreprocessingStage(&preprocess_stage) + .WithInput(pipeline_input, DT_FLOAT) + .WithAccuracyEval(&eval) + .Build(scope, &eval_pipeline); + + EXPECT_FALSE(scope.status().ok()); + // None of the other stages would have been called. + EXPECT_EQ(1, input_stage.times_called()); + EXPECT_EQ(0, preprocess_stage.times_called()); + EXPECT_EQ(0, run_model_stage.times_called()); +} + +TEST(EvalPipeline, PreprocessingFailure) { + IdentityStage input_stage("input_stage", "input_stage_out"); + FailingStage preprocess_stage("preprocess_stage", "preprocess_stage_out"); + IdentityStage run_model_stage("run_model", "run_model_out"); + const string pipeline_input = "pipeline_input"; + + SimpleAccuracyEval eval; + + Scope scope = Scope::NewRootScope(); + std::unique_ptr eval_pipeline; + EvalPipelineBuilder builder; + auto status = builder.WithInputStage(&input_stage) + .WithRunModelStage(&run_model_stage) + .WithPreprocessingStage(&preprocess_stage) + .WithInput(pipeline_input, DT_FLOAT) + .WithAccuracyEval(&eval) + .Build(scope, &eval_pipeline); + + EXPECT_FALSE(status.ok()); + // None of the other stages would have been called. + EXPECT_EQ(1, input_stage.times_called()); + EXPECT_EQ(1, preprocess_stage.times_called()); + EXPECT_EQ(0, run_model_stage.times_called()); +} + +TEST(EvalPipeline, GraphEvalFailure) { + IdentityStage input_stage("input_stage", "input_stage_out"); + IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out"); + FailingStage run_model_stage("run_model", "run_model_out"); + const string pipeline_input = "pipeline_input"; + + SimpleAccuracyEval eval; + + Scope scope = Scope::NewRootScope(); + std::unique_ptr eval_pipeline; + EvalPipelineBuilder builder; + auto status = builder.WithInputStage(&input_stage) + .WithRunModelStage(&run_model_stage) + .WithPreprocessingStage(&preprocess_stage) + .WithInput(pipeline_input, DT_FLOAT) + .WithAccuracyEval(&eval) + .Build(scope, &eval_pipeline); + + EXPECT_FALSE(status.ok()); + // None of the other stages would have been called. + EXPECT_EQ(1, input_stage.times_called()); + EXPECT_EQ(1, preprocess_stage.times_called()); + EXPECT_EQ(1, run_model_stage.times_called()); +} + +TEST(EvalPipeline, PipelineHasCorrectSequence) { + IdentityStage input_stage("input_stage", "input_stage_out"); + IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out"); + IdentityStage run_model_stage("run_model", "run_model_out"); + const string pipeline_input = "pipeline_input"; + + SimpleAccuracyEval eval; + + Scope scope = Scope::NewRootScope(); + std::unique_ptr eval_pipeline; + EvalPipelineBuilder builder; + auto status = builder.WithInputStage(&input_stage) + .WithRunModelStage(&run_model_stage) + .WithPreprocessingStage(&preprocess_stage) + .WithInput(pipeline_input, DT_FLOAT) + .WithAccuracyEval(&eval) + .Build(scope, &eval_pipeline); + TF_CHECK_OK(status); + + ASSERT_EQ(1, input_stage.times_called()); + ASSERT_EQ(1, run_model_stage.times_called()); + ASSERT_EQ(1, preprocess_stage.times_called()); + + EXPECT_EQ(pipeline_input, input_stage.input_params()[0]); + EXPECT_EQ(input_stage.output_name(), preprocess_stage.input_params()[0]); + EXPECT_EQ(preprocess_stage.output_name(), run_model_stage.input_params()[0]); +} + +} // namespace + +} // namespace metrics +} // namespace tensorflow + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea0f6e19df46d8934dc9eabb1c57a01bb5e91a1f --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc @@ -0,0 +1,133 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h" +#include +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace metrics { +namespace { + +Tensor CreateFloatTensor(float value) { + Tensor tensor(DT_FLOAT, TensorShape({})); + tensor.scalar()() = value; + return tensor; +} + +class NoOpAccuracyEval : public AccuracyEval { + public: + explicit NoOpAccuracyEval(const Status& status_to_return) + : status_to_return_(status_to_return) {} + + Status ComputeEval(const std::vector& model_outputs, + const Tensor& ground_truth) override { + model_outputs_ = model_outputs; + ground_truth_ = ground_truth; + was_called_ = true; + return status_to_return_; + } + + bool WasCalled() { return was_called_; } + std::vector model_outputs() { return model_outputs_; } + Tensor ground_truth() { return ground_truth_; } + + private: + std::vector model_outputs_; + Tensor ground_truth_; + Status status_to_return_; + bool was_called_ = false; +}; + +TEST(EvalPipeline, AccuracyEvalIsCalled) { + Scope scope = Scope::NewRootScope(); + // A graph that adds 1 to input. + auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT); + auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f); + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + EvalPipeline::Params params; + params.model_input_node_name = "input"; + params.model_output_node_name = "output"; + NoOpAccuracyEval accuracy_eval(Status::OK()); + + EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval); + std::unique_ptr session(NewSession(SessionOptions())); + TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session))); + TF_CHECK_OK(eval_pipeline.Run(CreateFloatTensor(5), CreateFloatTensor(27))); + + EXPECT_TRUE(accuracy_eval.WasCalled()); + auto outputs = accuracy_eval.model_outputs(); + ASSERT_EQ(1, outputs.size()); + EXPECT_EQ(6.0f, outputs[0].scalar()()); + // Ground truth is unchanged. + EXPECT_EQ(27, accuracy_eval.ground_truth().scalar()()); +} + +TEST(EvalPipeline, EvalIsNotCalledOnGraphRunFailure) { + Scope scope = Scope::NewRootScope(); + // A graph that adds 1 to input. + auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT); + auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f); + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + EvalPipeline::Params params; + params.model_input_node_name = "input"; + params.model_output_node_name = "output"; + NoOpAccuracyEval accuracy_eval(Status::OK()); + + EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval); + std::unique_ptr session(NewSession(SessionOptions())); + TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session))); + + // Pass a string tensor instead of a float tensor. + Tensor string_tensor(DT_STRING, TensorShape{}); + auto status = eval_pipeline.Run(string_tensor, CreateFloatTensor(27)); + EXPECT_FALSE(accuracy_eval.WasCalled()); + EXPECT_FALSE(status.ok()); +} + +TEST(EvalPipeline, AccuracyEvalFailureResultsInFailure) { + Scope scope = Scope::NewRootScope(); + // A graph that adds 1 to input. + auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT); + auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f); + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + EvalPipeline::Params params; + params.model_input_node_name = "input"; + params.model_output_node_name = "output"; + NoOpAccuracyEval accuracy_eval(errors::Internal("accuracy_fail")); + + EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval); + std::unique_ptr session(NewSession(SessionOptions())); + TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session))); + auto status = eval_pipeline.Run(CreateFloatTensor(5), CreateFloatTensor(27)); + + EXPECT_TRUE(accuracy_eval.WasCalled()); + EXPECT_FALSE(status.ok()); +} + +} // namespace + +} // namespace metrics +} // namespace tensorflow + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc new file mode 100644 index 0000000000000000000000000000000000000000..61bed369f8b4f659ee12834efdc23f6315dd8d42 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h" + +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" + +namespace tensorflow { +namespace metrics { +void FileReaderStage::AddToGraph(const Scope& scope, const Input& input) { + if (!scope.ok()) return; + Scope s = scope.WithOpName(name()); + this->stage_output_ = ops::ReadFile(s.WithOpName(output_name()), input); +} +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h new file mode 100644 index 0000000000000000000000000000000000000000..18db5837c1717ca5be966d8a4d764ea88d2674d3 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_ + +#include + +#include "tensorflow/contrib/lite/tools/accuracy/stage.h" + +namespace tensorflow { +namespace metrics { +// A stage for reading a file into |string|. +// Inputs: a string tensor: |file_name|. +// Outputs: a string tensor: contents of |file_name|. +class FileReaderStage : public Stage { + public: + string name() const override { return "stage_filereader"; } + string output_name() const override { return "stage_filereader_output"; } + + void AddToGraph(const Scope& scope, const Input& input) override; +}; +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a75f99187d6ea0918398899ccef1511faa3ee0a6 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace metrics { +namespace { + +class TempFile { + public: + TempFile() { + string file_path; + if (Env::Default()->LocalTempFilename(&file_path)) { + file_path_ = file_path; + created_ = true; + } + } + + string filepath() { return file_path_; } + bool CreateFileWithContents(const std::string& contents) { + if (!created_) { + return false; + } + std::fstream file(file_path_, std::ios_base::out); + if (file) { + file << contents; + } + return file.good(); + } + + ~TempFile() { + if (created_) { + std::remove(file_path_.c_str()); + } + } + + private: + bool created_ = false; + string file_path_; +}; + +TEST(FileReaderStageTest, FileIsRead) { + TempFile file; + const string kFileContents = "Hello world."; + ASSERT_TRUE(file.CreateFileWithContents(kFileContents)); + Scope scope = Scope::NewRootScope(); + FileReaderStage reader_stage; + reader_stage.AddToGraph(scope, file.filepath()); + TF_CHECK_OK(scope.status()); + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + std::unique_ptr session(NewSession(SessionOptions())); + TF_CHECK_OK(session->Create(graph_def)); + std::vector outputs; + auto run_status = + session->Run({}, /*inputs*/ + {reader_stage.output_name()}, {}, /*target node names */ + &outputs); + TF_CHECK_OK(run_status); + EXPECT_EQ(1, outputs.size()); + string contents = outputs[0].scalar()(); + EXPECT_EQ(kFileContents, contents); +} + +TEST(FileReaderStageTest, InvalidFile) { + Scope scope = Scope::NewRootScope(); + FileReaderStage reader_stage; + reader_stage.AddToGraph(scope, string("non_existent_file")); + TF_CHECK_OK(scope.status()); + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + std::unique_ptr session(NewSession(SessionOptions())); + TF_CHECK_OK(session->Create(graph_def)); + std::vector outputs; + auto run_status = + session->Run({}, /*inputs*/ + {reader_stage.output_name()}, {}, /*target node names */ + &outputs); + EXPECT_FALSE(run_status.ok()); +} + +} // namespace + +} // namespace metrics +} // namespace tensorflow + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..db4b688a4537cbe6a6bad3c5694d9054e8e5d4d8 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD @@ -0,0 +1,171 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "tflite_linkopts") + +common_linkopts = tflite_linkopts() + select({ + "//conditions:default": [], + "//tensorflow:android": [ + "-pie", + "-llog", + ], +}) + +cc_library( + name = "inception_preprocessing", + srcs = ["inception_preprocessing.cc"], + hdrs = ["inception_preprocessing.h"], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags", + "//tensorflow/contrib/lite/tools/accuracy:stage", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core/kernels:android_tensorflow_image_op", + ], + "//conditions:default": [ + "//tensorflow/core:tensorflow", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + ], + }, + ), +) + +tf_cc_test( + name = "inception_preprocessing_test", + srcs = ["inception_preprocessing_test.cc"], + args = [ + "--test_image=$(location :testdata/grace_hopper.jpg)", + ], + data = [":testdata/grace_hopper.jpg"], + linkopts = common_linkopts, + linkstatic = 1, + deps = [ + ":inception_preprocessing", + "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags", + "@com_google_googletest//:gtest", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_tensorflow_test_lib", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], + }, + ), +) + +cc_library( + name = "imagenet_topk_eval", + srcs = ["imagenet_topk_eval.cc"], + hdrs = ["imagenet_topk_eval.h"], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite/tools/accuracy:accuracy_eval_stage", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + ], + }, + ), +) + +tf_cc_test( + name = "imagenet_topk_eval_test", + srcs = ["imagenet_topk_eval_test.cc"], + linkopts = common_linkopts, + linkstatic = 1, + deps = [ + ":imagenet_topk_eval", + "@com_google_googletest//:gtest", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core:android_tensorflow_test_lib", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + ], + }, + ), +) + +cc_library( + name = "imagenet_model_evaluator", + srcs = ["imagenet_model_evaluator.cc"], + hdrs = ["imagenet_model_evaluator.h"], + copts = tflite_copts(), + deps = [ + ":imagenet_topk_eval", + ":inception_preprocessing", + "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags", + "//tensorflow/contrib/lite/tools/accuracy:eval_pipeline", + "//tensorflow/contrib/lite/tools/accuracy:eval_pipeline_builder", + "//tensorflow/contrib/lite/tools/accuracy:file_reader_stage", + "//tensorflow/contrib/lite/tools/accuracy:run_tflite_model_stage", + "//tensorflow/contrib/lite/tools/accuracy:utils", + "@com_google_absl//absl/memory", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + "//tensorflow/core/kernels:android_whole_file_read_ops", + "//tensorflow/core/kernels:android_tensorflow_image_op", + ], + "//conditions:default": [ + "//tensorflow/core:tensorflow", + "//tensorflow/core:framework_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:core_cpu", + ], + }, + ), +) + +tf_cc_binary( + name = "imagenet_accuracy_eval", + srcs = ["imagenet_accuracy_eval.cc"], + copts = tflite_copts(), + linkopts = common_linkopts, + deps = [ + ":imagenet_model_evaluator", + ":imagenet_topk_eval", + "@com_google_absl//absl/memory", + "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags", + "//tensorflow/contrib/lite/tools/accuracy:csv_writer", + ] + select( + { + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:lib", + "//tensorflow/core:framework_internal", + ], + }, + ), +) diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9b3b99451dbeb6d72042aed001fe3f72f0216511 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md @@ -0,0 +1,138 @@ +## Accuracy evaluation for ILSVRC 2012 (Imagenet Large Scale Visual Recognition Challenge) image classification task + +This binary can evaluate the accuracy of TFLite models trained for the [ILSVRC 2012 image classification task] +(http://www.image-net.org/challenges/LSVRC/2012/). +The binary takes the path to validation images and labels as inputs. It outputs the accuracy after running the TFLite model on the validation sets. + +To run the binary download the ILSVRC 2012 devkit [see instructions](#downloading-ilsvrc) and run the [`generate_validation_ground_truth` script](#ground-truth-label-generation) to generate the ground truth labels. + +## Parameters +The binary takes the following parameters: + +* `model_file` : `string` \ + Path to the TFlite model file. + +* `ground_truth_images_path`: `string` \ + The path to the directory containing ground truth images. + +* `ground_truth_labels`: `string` \ + Path to ground truth labels file. This file should contain the same number of labels as the number images in the ground truth directory. The labels are assumed to be in the + same order as the sorted filename of images. See [ground truth label generation](#ground-truth-label-generation) + section for more information about how to generate labels for images. + +* `model_output_labels`: `string` \ + Path to the file containing labels, that is used to interpret the output of + the model. E.g. in case of mobilenets, this is the path to + `mobilenet_labels.txt` where each label is in the same order as the output + 1001 dimension tensor. + +* `output_path`: `string` \ + This is the path to the output file. The output is a CSV file that has top-10 accuracies in each row. Each line of output file is the cumulative accuracy after processing images in a sorted order. So first line is accuracy after processing the first image, second line is accuracy after procesing first two images. The last line of the file is accuracy after processing the entire validation set. + +and the following optional parameters: +* `num_images`: `int` (default=0) \ + The number of images to process, if 0, all images in the directory are processed otherwise only num_images will be processed. + +## Downloading ILSVRC +In order to use this tool to run evaluation on the full 50K ImageNet dataset, +download the data set from http://image-net.org/request. + +## Ground truth label generation +The ILSVRC 2012 devkit `validation_ground_truth.txt` contains IDs that correspond to synset of the image. +The accuracy binary however expects the ground truth labels to contain the actual name of +category instead of synset ids. A conversion script has been provided to convert the validation ground truth to +category labels. The `validation_ground_truth.txt` can be converted by the following steps: + +``` +ILSVRC_2012_DEVKIT_DIR=[set to path to ILSVRC 2012 devkit] +VALIDATION_LABELS=[set to path to output] + +python generate_validation_labels.py -- \ +--ilsvrc_devkit_dir=${ILSVRC_2012_DEVKIT_DIR} \ +--validation_labels_output=${VALIDATION_LABELS} +``` + +## Running the binary + +### On Android + +(0) Refer to https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android for configuring NDK and SDK. + +(1) Build using the following command: + +``` +bazel build -c opt \ + --config=android_arm \ + --config=monolithic \ + --cxxopt='--std=c++11' \ + --copt=-D__ANDROID_TYPES_FULL__ \ + --copt=-DSUPPORT_SELECTIVE_REGISTRATION \ + //tensorflow/contrib/lite/tools/accuracy/ilsvrc:imagenet_accuracy_eval +``` + +(2) Connect your phone. Push the binary to your phone with adb push + (make the directory if required): + +``` +adb push bazel-bin/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval /data/local/tmp +``` + +(3) Make the binary executable. + +``` +adb shell chmod +x /data/local/tmp/imagenet_accuracy_eval +``` + +(4) Push the TFLite model that you need to test. For example: + +``` +adb push mobilenet_quant_v1_224.tflite /data/local/tmp +``` + +(5) Push the imagenet images to device, make sure device has sufficient storage available before pushing the dataset: + +``` +adb shell mkdir /data/local/tmp/ilsvrc_images && \ +adb push ${IMAGENET_IMAGES_DIR} /data/local/tmp/ilsvrc_images +``` + +(6) Push the generated validation ground labels to device. + +``` +adb push ${VALIDATION_LABELS} /data/local/tmp/ilsvrc_validation_labels.txt +``` + +(7) Push the model labels text file to device. + +``` +adb push ${MODEL_LABELS_TXT} /data/local/tmp/model_output_labels.txt +``` + +(8) Run the binary. + +``` +adb shell /data/local/tmp/imagenet_accuracy_eval \ + --model_file=/data/local/tmp/mobilenet_quant_v1_224.tflite \ + --ground_truth_images_path=/data/local/tmp/ilsvrc_images \ + --ground_truth_labels=/data/local/tmp/ilsvrc_validation_labels.txt \ + --model_output_labels=/data/local/tmp/model_output_labels.txt \ + --output_file_path=/data/local/tmp/accuracy_output.txt \ + --num_images=0 # Run on all images. +``` + +### On Desktop + +(1) Build and run using the following command: + +``` +bazel run -c opt \ + --cxxopt='--std=c++11' \ + -- \ + //tensorflow/contrib/lite/tools/accuracy/ilsvrc:imagenet_accuracy_eval \ + --model_file=mobilenet_quant_v1_224.tflite \ + --ground_truth_images_path=${IMAGENET_IMAGES_DIR} \ + --ground_truth_labels=${VALIDATION_LABELS} \ + --model_output_labels=${MODEL_LABELS_TXT} \ + --output_file_path=/tmp/accuracy_output.txt \ + --num_images=0 # Run on all images. +``` diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py new file mode 100644 index 0000000000000000000000000000000000000000..c32a41e50d3a88536fc9b2d59d0a6c6842f3a531 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py @@ -0,0 +1,105 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tool to convert ILSVRC devkit validation ground truth to synset labels.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +from os import path +import sys +import scipy.io + +_SYNSET_ARRAYS_RELATIVE_PATH = 'data/meta.mat' +_VALIDATION_FILE_RELATIVE_PATH = 'data/ILSVRC2012_validation_ground_truth.txt' + + +def _synset_to_word(filepath): + """Returns synset to word dictionary by reading sysnset arrays.""" + mat = scipy.io.loadmat(filepath) + entries = mat['synsets'] + # These fields are listed in devkit readme.txt + fields = [ + 'synset_id', 'WNID', 'words', 'gloss', 'num_children', 'children', + 'wordnet_height', 'num_train_images' + ] + synset_index = fields.index('synset_id') + words_index = fields.index('words') + synset_to_word = {} + for entry in entries: + entry = entry[0] + synset_id = int(entry[synset_index][0]) + first_word = entry[words_index][0].split(',')[0] + synset_to_word[synset_id] = first_word + return synset_to_word + + +def _validation_file_path(ilsvrc_dir): + return path.join(ilsvrc_dir, _VALIDATION_FILE_RELATIVE_PATH) + + +def _synset_array_path(ilsvrc_dir): + return path.join(ilsvrc_dir, _SYNSET_ARRAYS_RELATIVE_PATH) + + +def _generate_validation_labels(ilsvrc_dir, output_file): + synset_to_word = _synset_to_word(_synset_array_path(ilsvrc_dir)) + with open(_validation_file_path(ilsvrc_dir), 'r') as synset_id_file, open( + output_file, 'w') as output: + for synset_id in synset_id_file: + synset_id = int(synset_id) + output.write('%s\n' % synset_to_word[synset_id]) + + +def _check_arguments(args): + if not args.validation_labels_output: + raise ValueError('Invalid path to output file.') + ilsvrc_dir = args.ilsvrc_devkit_dir + if not ilsvrc_dir or not path.isdir(ilsvrc_dir): + raise ValueError('Invalid path to ilsvrc_dir') + if not path.exists(_validation_file_path(ilsvrc_dir)): + raise ValueError('Invalid path to ilsvrc_dir, cannot find validation file.') + if not path.exists(_synset_array_path(ilsvrc_dir)): + raise ValueError( + 'Invalid path to ilsvrc_dir, cannot find synset arrays file.') + + +def main(): + parser = argparse.ArgumentParser( + description='Converts ILSVRC devkit validation_ground_truth.txt to synset' + ' labels file that can be used by the accuracy script.') + parser.add_argument( + '--validation_labels_output', + type=str, + help='Full path for outputting validation labels.') + parser.add_argument( + '--ilsvrc_devkit_dir', + type=str, + help='Full path to ILSVRC 2012 devikit directory.') + args = parser.parse_args() + try: + _check_arguments(args) + except ValueError as e: + parser.print_usage() + file_name = path.basename(sys.argv[0]) + sys.stderr.write('{0}: error: {1}\n'.format(file_name, str(e))) + sys.exit(1) + _generate_validation_labels(args.ilsvrc_devkit_dir, + args.validation_labels_output) + + +if __name__ == '__main__': + main() diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc new file mode 100644 index 0000000000000000000000000000000000000000..f361341f7c20021a2bf448ff2e15405660f4093a --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc @@ -0,0 +1,148 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/contrib/lite/tools/accuracy/csv_writer.h" +#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h" +#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace metrics { + +namespace { + +std::vector GetAccuracies( + const ImagenetTopKAccuracy::AccuracyStats& accuracy_stats) { + std::vector results; + results.reserve(accuracy_stats.number_of_images); + if (accuracy_stats.number_of_images > 0) { + for (int n : accuracy_stats.topk_counts) { + double accuracy = 0; + if (accuracy_stats.number_of_images > 0) { + accuracy = (n * 100.0) / accuracy_stats.number_of_images; + } + results.push_back(accuracy); + } + } + return results; +} + +} // namespace + +// Writes results to a CSV file. +class ResultsWriter : public ImagenetModelEvaluator::Observer { + public: + explicit ResultsWriter(std::unique_ptr writer) + : writer_(std::move(writer)) {} + + void OnEvaluationStart(int total_number_of_images) override {} + + void OnSingleImageEvaluationComplete( + const ImagenetTopKAccuracy::AccuracyStats& stats, + const string& image) override; + + private: + std::unique_ptr writer_; +}; + +void ResultsWriter::OnSingleImageEvaluationComplete( + const ImagenetTopKAccuracy::AccuracyStats& stats, const string& image) { + TF_CHECK_OK(writer_->WriteRow(GetAccuracies(stats))); + writer_->Flush(); +} + +// Logs results to standard output with `kLogDelayUs` microseconds. +class ResultsLogger : public ImagenetModelEvaluator::Observer { + public: + void OnEvaluationStart(int total_number_of_images) override; + + void OnSingleImageEvaluationComplete( + const ImagenetTopKAccuracy::AccuracyStats& stats, + const string& image) override; + + private: + int total_num_images_ = 0; + uint64 last_logged_time_us_ = 0; + static constexpr int kLogDelayUs = 500 * 1000; +}; + +void ResultsLogger::OnEvaluationStart(int total_number_of_images) { + total_num_images_ = total_number_of_images; + LOG(ERROR) << "Starting model evaluation: " << total_num_images_; +} + +void ResultsLogger::OnSingleImageEvaluationComplete( + const ImagenetTopKAccuracy::AccuracyStats& stats, const string& image) { + int num_evaluated = stats.number_of_images; + + double current_percent = num_evaluated * 100.0 / total_num_images_; + auto now_us = Env::Default()->NowMicros(); + + if ((now_us - last_logged_time_us_) >= kLogDelayUs) { + last_logged_time_us_ = now_us; + + LOG(ERROR) << "Evaluated " << num_evaluated << "/" << total_num_images_ + << " images, " << std::setprecision(2) << std::fixed + << current_percent << "%"; + } +} + +int Main(int argc, char* argv[]) { + // TODO(shashishekhar): Make this binary configurable and model + // agnostic. + string output_file_path; + std::vector flag_list = { + Flag("output_file_path", &output_file_path, "Path to output file."), + }; + Flags::Parse(&argc, argv, flag_list); + + std::unique_ptr evaluator; + CHECK(!output_file_path.empty()) << "Invalid output file path."; + + TF_CHECK_OK(ImagenetModelEvaluator::Create(argc, argv, &evaluator)); + + std::ofstream output_stream(output_file_path, std::ios::out); + CHECK(output_stream) << "Unable to open output file path: '" + << output_file_path << "'"; + + output_stream << std::setprecision(3) << std::fixed; + std::vector columns; + columns.reserve(evaluator->params().num_ranks); + for (int i = 0; i < evaluator->params().num_ranks; i++) { + string column_name = "Top "; + tensorflow::strings::StrAppend(&column_name, i + 1); + columns.push_back(column_name); + } + + ResultsWriter results_writer( + absl::make_unique(columns, &output_stream)); + ResultsLogger logger; + evaluator->AddObserver(&results_writer); + evaluator->AddObserver(&logger); + TF_CHECK_OK(evaluator->EvaluateModel()); + return 0; +} + +} // namespace metrics +} // namespace tensorflow + +int main(int argc, char* argv[]) { + return tensorflow::metrics::Main(argc, argv); +} diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc new file mode 100644 index 0000000000000000000000000000000000000000..a88a4a0fce7dd49e8ca412569af554c50b96ba85 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc @@ -0,0 +1,206 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h" +#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h" +#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h" +#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h" +#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h" +#include "tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h" +#include "tensorflow/contrib/lite/tools/accuracy/utils.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace { +using tensorflow::string; + +string StripTrailingSlashes(const string& path) { + int end = path.size(); + while (end > 0 && path[end - 1] == '/') { + end--; + } + return path.substr(0, end); +} + +tensorflow::Tensor CreateStringTensor(const string& value) { + tensorflow::Tensor tensor(tensorflow::DT_STRING, tensorflow::TensorShape({})); + tensor.scalar()() = value; + return tensor; +} + +template +std::vector GetFirstN(const std::vector& v, int n) { + if (n >= v.size()) return v; + std::vector result(v.begin(), v.begin() + n); + return result; +} + +// File pattern for imagenet files. +const char* const kImagenetFilePattern = "*.[jJ][pP][eE][gG]"; + +} // namespace + +namespace tensorflow { +namespace metrics { + +/*static*/ Status ImagenetModelEvaluator::Create( + int argc, char* argv[], + std::unique_ptr* model_evaluator) { + Params params; + const std::vector flag_list = { + Flag("model_output_labels", ¶ms.model_output_labels_path, + "Path to labels that correspond to output of model." + " E.g. in case of mobilenet, this is the path to label " + "file where each label is in the same order as the output" + " of the model."), + Flag("ground_truth_images_path", ¶ms.ground_truth_images_path, + "Path to ground truth images."), + Flag("ground_truth_labels", ¶ms.ground_truth_labels_path, + "Path to ground truth labels."), + Flag("num_images", ¶ms.number_of_images, + "Number of examples to evaluate, pass 0 for all " + "examples. Default: 100"), + tensorflow::Flag("model_file", ¶ms.model_file_path, + "Path to test tflite model file."), + }; + const bool parse_result = Flags::Parse(&argc, argv, flag_list); + if (!parse_result) + return errors::InvalidArgument("Invalid command line flags"); + ::tensorflow::port::InitMain(argv[0], &argc, &argv); + + TF_RETURN_WITH_CONTEXT_IF_ERROR( + Env::Default()->IsDirectory(params.ground_truth_images_path), + "Invalid ground truth data path."); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + Env::Default()->FileExists(params.ground_truth_labels_path), + "Invalid ground truth labels path."); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + Env::Default()->FileExists(params.model_output_labels_path), + "Invalid model output labels path."); + + if (params.number_of_images < 0) { + return errors::InvalidArgument("Invalid: num_examples"); + } + + utils::ModelInfo model_info; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + utils::GetTFliteModelInfo(params.model_file_path, &model_info), + "Invalid TFLite model."); + + *model_evaluator = + absl::make_unique(model_info, params); + return Status::OK(); +} + +Status ImagenetModelEvaluator::EvaluateModel() { + if (model_info_.input_shapes.size() != 1) { + return errors::InvalidArgument("Invalid input shape"); + } + + const TensorShape& input_shape = model_info_.input_shapes[0]; + // Input should be of the shape {1, height, width, 3} + if (input_shape.dims() != 4 || input_shape.dim_size(3) != 3) { + return errors::InvalidArgument("Invalid input shape for the model."); + } + + const int image_height = input_shape.dim_size(1); + const int image_width = input_shape.dim_size(2); + const bool is_quantized = (model_info_.input_types[0] == DT_UINT8); + + RunTFLiteModelStage::Params tfl_model_params; + tfl_model_params.model_file_path = params_.model_file_path; + if (is_quantized) { + tfl_model_params.input_type = {DT_UINT8}; + tfl_model_params.output_type = {DT_UINT8}; + } else { + tfl_model_params.input_type = {DT_FLOAT}; + tfl_model_params.output_type = {DT_FLOAT}; + } + + Scope root = Scope::NewRootScope(); + FileReaderStage reader; + InceptionPreprocessingStage inc(image_height, image_width, is_quantized); + RunTFLiteModelStage tfl_model_stage(tfl_model_params); + EvalPipelineBuilder builder; + std::vector model_labels; + TF_RETURN_IF_ERROR( + utils::ReadFileLines(params_.model_output_labels_path, &model_labels)); + if (model_labels.size() != 1001) { + return errors::InvalidArgument("Invalid number of labels: ", + model_labels.size()); + } + + ImagenetTopKAccuracy eval(model_labels, params_.num_ranks); + std::unique_ptr eval_pipeline; + + auto build_status = builder.WithInputStage(&reader) + .WithPreprocessingStage(&inc) + .WithRunModelStage(&tfl_model_stage) + .WithAccuracyEval(&eval) + .WithInput("input_file", DT_STRING) + .Build(root, &eval_pipeline); + TF_RETURN_WITH_CONTEXT_IF_ERROR(build_status, + "Failure while building eval pipeline."); + + std::unique_ptr session(NewSession(SessionOptions())); + + TF_RETURN_IF_ERROR(eval_pipeline->AttachSession(std::move(session))); + string data_path = + StripTrailingSlashes(params_.ground_truth_images_path) + "/"; + + const string imagenet_file_pattern = data_path + kImagenetFilePattern; + std::vector image_files; + TF_CHECK_OK( + Env::Default()->GetMatchingPaths(imagenet_file_pattern, &image_files)); + std::vector image_labels; + TF_CHECK_OK( + utils::ReadFileLines(params_.ground_truth_labels_path, &image_labels)); + CHECK_EQ(image_files.size(), image_labels.size()); + + // Process files in filename sorted order. + std::sort(image_files.begin(), image_files.end()); + if (params_.number_of_images > 0) { + image_files = GetFirstN(image_files, params_.number_of_images); + image_labels = GetFirstN(image_labels, params_.number_of_images); + } + + for (Observer* observer : observers_) { + observer->OnEvaluationStart(image_files.size()); + } + + for (int i = 0; i < image_files.size(); i++) { + TF_CHECK_OK(eval_pipeline->Run(CreateStringTensor(image_files[i]), + CreateStringTensor(image_labels[i]))); + auto stats = eval.GetTopKAccuracySoFar(); + + for (Observer* observer : observers_) { + observer->OnSingleImageEvaluationComplete(stats, image_files[i]); + } + } + return Status::OK(); +} + +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h new file mode 100644 index 0000000000000000000000000000000000000000..5f42b2a50ecb1d55647998f8ec0aab17234e2b88 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h @@ -0,0 +1,113 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_ +#include +#include + +#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h" +#include "tensorflow/contrib/lite/tools/accuracy/utils.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace metrics { + +// Evaluates models accuracy for ILSVRC dataset. +// +// Generates the top-1, top-k accuracy counts where k is +// controlled by |num_ranks|. +// Usage: +// ModelInfo model_info = .. +// ImagenetModelEvaluator::Params params; +// .. set params to image, label, output label and model file path.. +// SomeObserver observer; +// ImagenetModelEvaluator evaluator(model_info, params); +// evaluator.AddObserver(&observer); +// TF_CHECK_OK(evaluator.EvaluateModel()); +class ImagenetModelEvaluator { + public: + struct Params { + // Path to ground truth images. + string ground_truth_images_path; + + // Path to labels file for ground truth image. + // This file should be generated with the scripts. + string ground_truth_labels_path; + + // This is word labels generated by the model. The category + // indices of output probabilities generated by the model maybe different + // from the indices in the imagenet dataset. + string model_output_labels_path; + + // Path to the model file. + string model_file_path; + + // The maximum number of images to calculate accuracy. + // 0 means all images, a positive number means only the specified + // number of images. + int number_of_images = 0; + + // Number of ranks, top K. + int num_ranks = 10; + }; + + // An evaluation observer. + class Observer { + public: + Observer() = default; + Observer(const Observer&) = delete; + Observer& operator=(const Observer&) = delete; + + Observer(const Observer&&) = delete; + Observer& operator=(const Observer&&) = delete; + + // Called on start of evaluation. + virtual void OnEvaluationStart(int total_number_of_images) = 0; + + // Called when evaluation was complete for `image`. + virtual void OnSingleImageEvaluationComplete( + const ImagenetTopKAccuracy::AccuracyStats& stats, + const string& image) = 0; + + virtual ~Observer() = default; + }; + + ImagenetModelEvaluator(const utils::ModelInfo& model_info, + const Params& params) + : model_info_(model_info), params_(params) {} + + // Factory method to create the evaluator by parsing command line arguments. + static Status Create(int argc, char* argv[], + std::unique_ptr* evaluator); + + // Adds an observer that can observe evaluation events.. + void AddObserver(Observer* observer) { observers_.push_back(observer); } + + const Params& params() { return params_; } + + // Evaluates the provided model over the dataset. + Status EvaluateModel(); + + private: + std::vector observers_; + const utils::ModelInfo model_info_; + const Params params_; +}; + +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc new file mode 100644 index 0000000000000000000000000000000000000000..d46075d234313b7d23909abfd1e3f0062b6886e2 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc @@ -0,0 +1,107 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h" + +#include + +namespace { +constexpr int kNumCategories = 1001; +std::vector GetTopK(const std::vector& values, int k) { + CHECK_LE(k, values.size()); + std::vector indices(values.size()); + + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), + [&values](int a, int b) { return values[a] > values[b]; }); + + indices.resize(k); + return indices; +} +} // namespace + +namespace tensorflow { +namespace metrics { +ImagenetTopKAccuracy::ImagenetTopKAccuracy( + const std::vector& ground_truth_labels, int k) + : ground_truth_labels_(ground_truth_labels), + k_(k), + accuracy_counts_(k_, 0), + num_samples_(0) { + CHECK_EQ(kNumCategories, ground_truth_labels.size()); +} + +Status ImagenetTopKAccuracy::ComputeEval( + const std::vector& model_outputs, const Tensor& ground_truth) { + if (model_outputs.size() != 1) { + return errors::InvalidArgument("Invalid model output: ", + model_outputs.size()); + } + const Tensor& output = model_outputs[0]; + if (!output.shape().IsSameSize({1, kNumCategories})) { + return errors::InvalidArgument("Invalid shape of model output: ", + output.shape().DebugString()); + } + if (ground_truth.dtype() != DT_STRING && ground_truth.dims() != 0) { + return errors::InvalidArgument("Invalid ground truth type: ", + ground_truth.DebugString()); + } + string ground_truth_label = ground_truth.scalar()(); + + std::vector probabilities; + probabilities.reserve(kNumCategories); + if (output.dtype() == DT_FLOAT) { + auto probs = output.flat(); + for (size_t i = 0; i < probs.size(); i++) { + probabilities.push_back(probs(i)); + } + } else { + auto probs = output.flat(); + for (size_t i = 0; i < probs.size(); i++) { + probabilities.push_back(probs(i)); + } + } + + CHECK_EQ(kNumCategories, probabilities.size()); + std::vector topK = GetTopK(probabilities, k_); + int ground_truth_index = GroundTruthIndex(ground_truth_label); + for (size_t i = 0; i < topK.size(); ++i) { + if (ground_truth_index == topK[i]) { + for (size_t j = i; j < topK.size(); j++) { + accuracy_counts_[j] += 1; + } + break; + } + } + num_samples_++; + return Status::OK(); +} + +const ImagenetTopKAccuracy::AccuracyStats +ImagenetTopKAccuracy::GetTopKAccuracySoFar() const { + AccuracyStats stats; + stats.number_of_images = num_samples_; + stats.topk_counts = accuracy_counts_; + return stats; +} + +int ImagenetTopKAccuracy::GroundTruthIndex(const string& label) const { + auto index = std::find(ground_truth_labels_.cbegin(), + ground_truth_labels_.cend(), label); + CHECK(index != ground_truth_labels_.end()) << "Invalid label: " << label; + return std::distance(ground_truth_labels_.cbegin(), index); +} +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h new file mode 100644 index 0000000000000000000000000000000000000000..5a575ff244fc08977e9fbf0cca117c6638116453 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_ + +#include +#include + +#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace metrics { +// An |AccuracyEval| stage that calculates the top K error rate for model +// evaluations on imagenet like datasets. +// Inputs: A {1, 1001} shaped tensor that contains the probabilities for objects +// predicted by the model. +// Ground truth: A |string| label for the image. +// From the input object probabilities, the stage computes the predicted labels +// and finds the top K error rates by comparing the predictions with ground +// truths. +class ImagenetTopKAccuracy : public AccuracyEval { + public: + // Accuracy statistics. + struct AccuracyStats { + // Number of images evaluated. + int number_of_images; + // A vector of size |k| that contains the number of images + // that have correct labels in top K. + // E.g. topk_counts[0] contains number of images for which + // model returned the correct label as the first result. + // Similarly topk_counts[4] contains the number of images for which + // model returned the correct label in top 5 results. + // This can be used to compute the top K error-rate for the model. + std::vector topk_counts; + }; + + // Creates a new instance of |ImagenetTopKAccuracy| with the given + // |ground_truth_labels| and |k|. + // Args: + // |ground_truth_labels| : an ordered vector of labels for images. This is + // used to compute the index for the predicted labels and ground_truth label. + ImagenetTopKAccuracy(const std::vector& ground_truth_labels, int k); + + // Computes accuracy for a given image. The |model_outputs| should + // be a vector containing exactly one Tensor of shape: {1, 1001} where each + // item is a probability of the predicted object representing the image as + // output by the model. + // Uses |ground_truth_labels| to compute the index of |model_outputs| and + // |ground_truth| and computes the top K error rate. + Status ComputeEval(const std::vector& model_outputs, + const Tensor& ground_truth) override; + + // Gets the topK accuracy for images that have been evaluated till now. + const AccuracyStats GetTopKAccuracySoFar() const; + + private: + int GroundTruthIndex(const string& label) const; + std::vector ground_truth_labels_; + const int k_; + std::vector accuracy_counts_; + int num_samples_; +}; +} // namespace metrics +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff332af5c5e56ec2e14b9e4ee509c6344be22c66 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h" +#include + +namespace tensorflow { +namespace metrics { +namespace { + +const int kNumCategories = 1001; + +Tensor CreateStringTensor(const string& value) { + Tensor tensor(DT_STRING, TensorShape({})); + tensor.scalar()() = value; + return tensor; +} + +Tensor CreateOutputTensor() { + Tensor tensor(DT_FLOAT, TensorShape({1, kNumCategories})); + for (int i = 0; i < kNumCategories; i++) { + tensor.flat()(i) = 0; + } + return tensor; +} + +std::vector CreateGroundTruth() { + std::vector ground_truth; + ground_truth.reserve(kNumCategories); + for (int i = 0; i < kNumCategories; i++) { + string category; + strings::StrAppend(&category, i); + ground_truth.push_back(category); + } + return ground_truth; +} + +TEST(ImagenetTopKAccuracy, AllCorrect) { + ImagenetTopKAccuracy acc_top_5(CreateGroundTruth(), 5); + auto accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(0, accuracies.number_of_images); + EXPECT_EQ(5, accuracies.topk_counts.size()); + + for (int i : accuracies.topk_counts) { + EXPECT_EQ(0, i); + } + // First image was correctly identified as "0". + Tensor tensor = CreateOutputTensor(); + tensor.flat()(0) = 0.8; + + TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("0"))); + accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(1, accuracies.number_of_images); + + for (int i : accuracies.topk_counts) { + EXPECT_EQ(1, i); + } + tensor.flat()(1) = 0.9; + TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("1"))); + accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(2, accuracies.number_of_images); + + for (int i : accuracies.topk_counts) { + EXPECT_EQ(2, i); + } +} + +TEST(ImagenetTopKAccuracy, Top5) { + ImagenetTopKAccuracy acc_top_5(CreateGroundTruth(), 5); + auto accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(0, accuracies.number_of_images); + EXPECT_EQ(5, accuracies.topk_counts.size()); + + // For first image, with ground truth "0" probabilities were + // 0.5 for "0", + // "0.6" for 1, + // "0.7" for 2, + // "0.8" for 3, + // "0.9" for 4. + // remaining all zeroes. + + // First image was correctly identified as "0". + Tensor tensor = CreateOutputTensor(); + tensor.flat()(0) = 0.5; + tensor.flat()(1) = 0.6; + tensor.flat()(2) = 0.7; + tensor.flat()(3) = 0.8; + tensor.flat()(4) = 0.9; + + TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("0"))); + accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(1, accuracies.number_of_images); + EXPECT_EQ(1, accuracies.topk_counts[4]); + + for (int i = 0; i < 4; i++) { + EXPECT_EQ(0, accuracies.topk_counts[i]); + } + + // Now for "1" only last two buckets are going to be affected. + TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("1"))); + accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(2, accuracies.number_of_images); + EXPECT_EQ(1, accuracies.topk_counts[3]); + EXPECT_EQ(2, accuracies.topk_counts[4]); + for (int i = 0; i < 3; i++) { + EXPECT_EQ(0, accuracies.topk_counts[i]); + } + + // All buckets will be affected. + TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("4"))); + accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(3, accuracies.number_of_images); + EXPECT_EQ(1, accuracies.topk_counts[0]); + EXPECT_EQ(1, accuracies.topk_counts[1]); + EXPECT_EQ(1, accuracies.topk_counts[2]); + EXPECT_EQ(2, accuracies.topk_counts[3]); + EXPECT_EQ(3, accuracies.topk_counts[4]); + + // No buckets will be affected + TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("10"))); + accuracies = acc_top_5.GetTopKAccuracySoFar(); + EXPECT_EQ(4, accuracies.number_of_images); + EXPECT_EQ(1, accuracies.topk_counts[0]); + EXPECT_EQ(1, accuracies.topk_counts[1]); + EXPECT_EQ(1, accuracies.topk_counts[2]); + EXPECT_EQ(2, accuracies.topk_counts[3]); + EXPECT_EQ(3, accuracies.topk_counts[4]); +} + +} // namespace + +} // namespace metrics +} // namespace tensorflow + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc new file mode 100644 index 0000000000000000000000000000000000000000..7512b39c32f98faed9b41f829666bf1d4d145d82 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h" + +#include + +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace metrics { + +namespace { +void CentralCropImage(const Scope& s, const tensorflow::Output& decoded_image, + double crop_fraction, tensorflow::Output* cropped_image) { + auto image_dims = ops::Slice(s, ops::Shape(s, decoded_image), {0}, {2}); + auto height_width = ops::Cast(s, image_dims, DT_DOUBLE); + auto cropped_begin = ops::Div( + s, ops::Sub(s, height_width, ops::Mul(s, height_width, crop_fraction)), + 2.0); + auto bbox_begin = ops::Cast(s, cropped_begin, DT_INT32); + auto bbox_size = ops::Sub(s, image_dims, ops::Mul(s, bbox_begin, 2)); + auto slice_begin = ops::Concat(s, {bbox_begin, Input({0})}, 0); + auto slice_size = ops::Concat(s, {bbox_size, {-1}}, 0); + *cropped_image = ops::Slice(s, decoded_image, slice_begin, slice_size); +} + +} // namespace + +void InceptionPreprocessingStage::AddToGraph(const Scope& scope, + const Input& input) { + if (!scope.ok()) return; + Scope s = scope.WithOpName(name()); + ops::DecodeJpeg::Attrs attrs; + attrs.channels_ = 3; + auto decoded_jpeg = ops::DecodeJpeg(s, input, attrs); + tensorflow::Output cropped_image; + CentralCropImage(s, decoded_jpeg, params_.cropping_fraction, &cropped_image); + auto dims_expander = ops::ExpandDims(s, cropped_image, 0); + auto resized_image = ops::ResizeBilinear( + s, dims_expander, + ops::Const(s.WithOpName("size"), {image_height_, image_width_})); + if (is_quantized_) { + this->stage_output_ = + ops::Cast(s.WithOpName(output_name()), resized_image, DT_UINT8); + } else { + auto squeezed_image = ops::Squeeze(s, resized_image); + auto normalized_image = + ops::Div(s, + ops::Sub(s, squeezed_image, + {params_.input_means[0], params_.input_means[1], + params_.input_means[2]}), + {params_.scale}); + this->stage_output_ = + ops::ExpandDims(s.WithOpName(output_name()), normalized_image, {0}); + } +} + +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h new file mode 100644 index 0000000000000000000000000000000000000000..15df71981756f6171b8e12bd9ed2a337c4867b64 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h @@ -0,0 +1,75 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_ + +#include + +#include "tensorflow/contrib/lite/tools/accuracy/stage.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace metrics { + +// A stage that does inception preprocessing. +// Inputs: A tensor containing bytes of a JPEG image. +// Outputs: A tensor containing rescaled and preprocessed image that has +// shape {1, image_height, image_width, 3}, where 3 is the number of channels. +class InceptionPreprocessingStage : public Stage { + public: + struct Params { + std::vector input_means; + float scale; + double cropping_fraction; + }; + + static Params DefaultParams() { + return {.input_means = {127.5, 127.5, 127.5}, + .scale = 127.5, + .cropping_fraction = 0.875}; + } + + // Creates a new preprocessing stage object with provided |image_width| + // |image_height| as the size of output image. + // If |is_quantized| is set to true then |params| is ignored since quantized + // images don't go through any preprocessing. + InceptionPreprocessingStage(int image_width, int image_height, + bool is_quantized, + Params params = DefaultParams()) + : image_width_(image_width), + image_height_(image_height), + is_quantized_(is_quantized), + params_(std::move(params)) {} + + string name() const override { return "stage_inception_preprocess"; } + string output_name() const override { + return "stage_inception_preprocess_output"; + } + + void AddToGraph(const Scope& scope, const Input& input) override; + + private: + int image_width_; + int image_height_; + bool is_quantized_; + Params params_; +}; + +} // namespace metrics +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3587878ba3cadd13eb0af4c004f4f98184daf5de --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc @@ -0,0 +1,123 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace { +tensorflow::string* g_test_image_file = nullptr; +} // namespace + +namespace tensorflow { +namespace metrics { + +namespace { + +using tensorflow::Status; +using tensorflow::Tensor; + +Status GetContents(const string& filename, string* output) { + std::ifstream input(filename, std::ios::binary); + const int kBufferSize = 2048; + char buffer[kBufferSize]; + while (true) { + input.read(buffer, kBufferSize); + output->append(buffer, input.gcount()); + if (!input.good()) { + if (input.eof()) return Status::OK(); + return Status(tensorflow::error::ABORTED, "Failed to read file."); + } + } +} + +TEST(InceptionPreprocessingTest, TestImagePreprocessQuantized) { + ASSERT_TRUE(g_test_image_file != nullptr); + string image_contents; + string image_path = *g_test_image_file; + auto status = GetContents(image_path, &image_contents); + ASSERT_TRUE(status.ok()) << status.error_message(); + const int width = 224; + const int height = 224; + const bool is_quantized = true; + InceptionPreprocessingStage preprocess_stage(width, height, is_quantized); + Scope scope = Scope::NewRootScope(); + preprocess_stage.AddToGraph(scope, image_contents); + TF_CHECK_OK(scope.status()); + + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + std::unique_ptr session(NewSession(SessionOptions())); + TF_CHECK_OK(session->Create(graph_def)); + std::vector outputs; + auto run_status = + session->Run({}, /*inputs*/ + {preprocess_stage.output_name()}, {}, /*target node names */ + &outputs); + TF_CHECK_OK(run_status); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(DT_UINT8, outputs[0].dtype()); + EXPECT_TRUE(outputs[0].shape().IsSameSize({1, 224, 224, 3})); +} + +TEST(InceptionPreprocessingTest, TestImagePreprocessFloat) { + ASSERT_TRUE(g_test_image_file != nullptr); + string image_contents; + string image_path = *g_test_image_file; + auto status = GetContents(image_path, &image_contents); + ASSERT_TRUE(status.ok()) << status.error_message(); + const int width = 224; + const int height = 224; + const bool is_quantized = false; + InceptionPreprocessingStage preprocess_stage(width, height, is_quantized); + Scope scope = Scope::NewRootScope(); + preprocess_stage.AddToGraph(scope, image_contents); + TF_CHECK_OK(scope.status()); + + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + std::unique_ptr session(NewSession(SessionOptions())); + TF_CHECK_OK(session->Create(graph_def)); + std::vector outputs; + auto run_status = + session->Run({}, /*inputs*/ + {preprocess_stage.output_name()}, {}, /*target node names */ + &outputs); + TF_CHECK_OK(run_status); + EXPECT_EQ(1, outputs.size()); + EXPECT_EQ(DT_FLOAT, outputs[0].dtype()); + EXPECT_TRUE(outputs[0].shape().IsSameSize({1, 224, 224, 3})); +} + +} // namespace +} // namespace metrics +} // namespace tensorflow + +int main(int argc, char** argv) { + g_test_image_file = new tensorflow::string(); + const std::vector flag_list = { + tensorflow::Flag("test_image", g_test_image_file, + "Path to image file for test."), + }; + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + CHECK(parse_result) << "Required test_model_file"; + ::tensorflow::port::InitMain(argv[0], &argc, &argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/testdata/grace_hopper.jpg b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/testdata/grace_hopper.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d2a427810f679db537236c5430873a81a62ef412 Binary files /dev/null and b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/testdata/grace_hopper.jpg differ diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..da4258f1c131076f564f0002a3cd99b221a18852 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/contrib/lite/tools/accuracy/utils.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +namespace { +Status ValidateInputsMatch(const OpInputList& input_tensors, + const tflite::Interpreter& interpreter) { + std::vector tflite_tensor_indices = interpreter.inputs(); + if (tflite_tensor_indices.size() != input_tensors.size()) { + return errors::InvalidArgument( + "size mismatch, interpreter size: ", tflite_tensor_indices.size(), + " actual: ", input_tensors.size()); + } + + for (int i = 0; i < input_tensors.size(); i++) { + const TfLiteTensor* tflite_tensor = + interpreter.tensor(tflite_tensor_indices[i]); + if (tflite_tensor == nullptr) { + return errors::InvalidArgument("Tensor is null at index: ", i); + } + + const Tensor& tensor = input_tensors[i]; + auto i_type = metrics::utils::GetTFDataType(tflite_tensor->type); + auto i_shape = metrics::utils::GetTFLiteTensorShape(*tflite_tensor); + if (i_type != tensor.dtype()) { + return errors::InvalidArgument("Data types mismatch for tensors: ", i, + " expected: ", i_type, + " got: ", tensor.dtype()); + } + + if (i_shape != tensor.shape()) { + return errors::InvalidArgument("Data shapes mismatch for tensors: ", i, + " expected: ", i_shape, + " got: ", tensor.shape()); + } + } + + return Status::OK(); +} + +} // namespace + +class RunTFLiteModelOp : public OpKernel { + public: + explicit RunTFLiteModelOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string model_file_path; + OP_REQUIRES_OK(ctx, ctx->GetAttr("model_file_path", &model_file_path)); + model_ = tflite::FlatBufferModel::BuildFromFile(model_file_path.data()); + OP_REQUIRES(ctx, model_, + errors::InvalidArgument( + "Model loading failed. Invalid model file path: ", + model_file_path)); + tflite::ops::builtin::BuiltinOpResolver resolver; + + tflite::InterpreterBuilder(*model_, resolver)(&interpreter_); + OP_REQUIRES(ctx, interpreter_, + errors::Internal("Interpreter creation failed.")); + } + + void Compute(OpKernelContext* context) override { + OpInputList input_tensors; + OP_REQUIRES_OK(context, context->input_list("model_input", &input_tensors)); + + OP_REQUIRES_OK(context, ValidateInputsMatch(input_tensors, *interpreter_)); + OpOutputList output_tensors; + OP_REQUIRES_OK(context, + context->output_list("model_output", &output_tensors)); + auto tfl_outputs = interpreter_->outputs(); + OP_REQUIRES(context, output_tensors.size() == tfl_outputs.size(), + errors::InvalidArgument( + "Invalid output size, expected: ", tfl_outputs.size(), + " got: ", output_tensors.size())); + for (int i = 0; i < output_tensors.size(); i++) { + DataType tfl_type = metrics::utils::GetTFDataType( + interpreter_->tensor(tfl_outputs[i])->type); + DataType otype = output_tensors.expected_output_dtype(i); + OP_REQUIRES( + context, tfl_type == otype, + errors::InvalidArgument("Invalid data type for output at index: ", i, + " expected: ", tfl_type, " got: ", otype)); + } + + auto allocation_status = interpreter_->AllocateTensors(); + OP_REQUIRES(context, allocation_status == kTfLiteOk, + errors::Internal("Unable to allocate tensors.")); + for (int i = 0; i < input_tensors.size(); i++) { + const int tfl_index = interpreter_->inputs()[i]; + TfLiteTensor* tflite_tensor = interpreter_->tensor(tfl_index); + auto tensor_bytes = input_tensors[i].tensor_data(); + OP_REQUIRES(context, tflite_tensor->bytes == tensor_bytes.size(), + errors::InvalidArgument( + "Size mismatch, expected: ", tflite_tensor->bytes, + " got: ", tensor_bytes.size())); + std::memcpy(tflite_tensor->data.raw, tensor_bytes.data(), + tensor_bytes.size()); + } + auto invocation_status = interpreter_->Invoke(); + OP_REQUIRES(context, invocation_status == kTfLiteOk, + errors::Internal("Interpreter invocation failed.")); + for (int i = 0; i < output_tensors.size(); i++) { + auto tfl_tensor = interpreter_->tensor(tfl_outputs[i]); + TensorShape shape = metrics::utils::GetTFLiteTensorShape(*tfl_tensor); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, output_tensors.allocate(i, shape, &output)); + auto tensor_bytes = output->tensor_data(); + OP_REQUIRES(context, tensor_bytes.size() == tfl_tensor->bytes, + errors::Internal("Invalid size")); + std::memcpy(const_cast(tensor_bytes.data()), tfl_tensor->data.raw, + tfl_tensor->bytes); + } + } + + private: + std::unique_ptr model_; + std::unique_ptr interpreter_; +}; + +REGISTER_KERNEL_BUILDER(Name("RunTFLiteModel").Device(DEVICE_CPU), + RunTFLiteModelOp); + +REGISTER_OP("RunTFLiteModel") + .Input("model_input: input_type") + .Output("model_output: output_type") + .Attr("model_file_path: string") + .Attr("input_type : list(type)") + .Attr("output_type: list(type)") + .SetShapeFn([](shape_inference::InferenceContext* c) { + // TODO(shashishekhar): Infer the correct shape based on output_type and + // maybe another attribute. + return shape_inference::UnknownShape(c); + }); + +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..88175984a090edfac048455c43757473ffc859ed --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc @@ -0,0 +1,200 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace { +tensorflow::string* g_test_model_file = nullptr; +} + +namespace tensorflow { +namespace { + +TEST(RunTfliteModelOpTest, ModelIsRun) { + ASSERT_TRUE(g_test_model_file != nullptr); + string test_model_file = *g_test_model_file; + ASSERT_FALSE(test_model_file.empty()); + + Scope scope = Scope::NewRootScope(); + TF_CHECK_OK(scope.status()); + // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y + // x = a+b+c, y=b+c+d + + std::vector graph_inputs = { + ops::Const(scope, 1.0f, {1, 8, 8, 3}), // a + ops::Const(scope, 2.1f, {1, 8, 8, 3}), // b + ops::Const(scope, 3.2f, {1, 8, 8, 3}), // c + ops::Const(scope, 4.3f, {1, 8, 8, 3}), // d + }; + + std::vector input_data; + std::transform(graph_inputs.begin(), graph_inputs.end(), + std::back_inserter(input_data), [&scope](Input model_input) { + return ops::AsNodeOut(scope, model_input); + }); + + std::vector model_input_type = {DT_FLOAT, DT_FLOAT, DT_FLOAT, + DT_FLOAT}; + ::tensorflow::Node* ret; + auto builder = ::tensorflow::NodeBuilder("run_model_op", "RunTFLiteModel") + .Input(input_data) + .Attr("model_file_path", test_model_file) + .Attr("input_type", model_input_type) + .Attr("output_type", {DT_FLOAT, DT_FLOAT}); + + scope.UpdateBuilder(&builder); + scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); + TF_CHECK_OK(scope.status()); + + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + + std::unique_ptr session(NewSession(SessionOptions())); + TF_CHECK_OK(session->Create(graph_def)); + + std::vector outputs; + TF_CHECK_OK( + session->Run({}, {"run_model_op:0", "run_model_op:1"}, {}, &outputs)); + EXPECT_EQ(2, outputs.size()); + + for (const auto& tensor : outputs) { + EXPECT_TRUE(tensor.shape().IsSameSize({1, 8, 8, 3})); + } + auto output_x = outputs[0].flat(); + auto output_y = outputs[1].flat(); + EXPECT_EQ(1 * 8 * 8 * 3, output_x.size()); + EXPECT_EQ(1 * 8 * 8 * 3, output_y.size()); + for (int i = 0; i < output_x.size(); i++) { + EXPECT_NEAR(6.3f, output_x(i), 1e-6f); // a+b+c + EXPECT_NEAR(9.6f, output_y(i), 1e-6f); // b+c+d + } +} + +TEST(RunTfliteModelOpTest, NumInputsMismatch) { + ASSERT_TRUE(g_test_model_file != nullptr); + string test_model_file = *g_test_model_file; + ASSERT_FALSE(test_model_file.empty()); + + Scope scope = Scope::NewRootScope(); + TF_CHECK_OK(scope.status()); + // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y + // x = a+b+c, y=b+c+d + // Remove a from input. + + std::vector graph_inputs = { + ops::Const(scope, 2.1f, {1, 8, 8, 3}), // b + ops::Const(scope, 3.2f, {1, 8, 8, 3}), // c + ops::Const(scope, 4.3f, {1, 8, 8, 3}), // d + }; + + std::vector input_data; + std::transform(graph_inputs.begin(), graph_inputs.end(), + std::back_inserter(input_data), [&scope](Input model_input) { + return ops::AsNodeOut(scope, model_input); + }); + + std::vector model_input_type = {DT_FLOAT, DT_FLOAT, DT_FLOAT}; + + ::tensorflow::Node* ret; + auto builder = ::tensorflow::NodeBuilder("run_model_op", "RunTFLiteModel") + .Input(input_data) + .Attr("model_file_path", test_model_file) + .Attr("input_type", model_input_type) + .Attr("output_type", {DT_FLOAT, DT_FLOAT}); + + scope.UpdateBuilder(&builder); + scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); + TF_CHECK_OK(scope.status()); + + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + std::unique_ptr session(NewSession(SessionOptions())); + TF_CHECK_OK(session->Create(graph_def)); + + std::vector outputs; + auto status = + (session->Run({}, {"run_model_op:0", "run_model_op:1"}, {}, &outputs)); + EXPECT_FALSE(status.ok()); +} + +TEST(RunTfliteModelOpTest, InputSizesMismatch) { + ASSERT_TRUE(g_test_model_file != nullptr); + string test_model_file = *g_test_model_file; + ASSERT_FALSE(test_model_file.empty()); + + Scope scope = Scope::NewRootScope(); + TF_CHECK_OK(scope.status()); + // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y + // x = a+b+c, y=b+c+d + // Set a to be invalid size. + std::vector graph_inputs = { + ops::Const(scope, 1.0f, {1, 8, 8, 4}), // a invalid size, + ops::Const(scope, 2.1f, {1, 8, 8, 3}), // b + ops::Const(scope, 3.2f, {1, 8, 8, 3}), // c + ops::Const(scope, 4.3f, {1, 8, 8, 3}), // d + }; + + std::vector input_data; + std::transform(graph_inputs.begin(), graph_inputs.end(), + std::back_inserter(input_data), [&scope](Input model_input) { + return ops::AsNodeOut(scope, model_input); + }); + + std::vector model_input_type = {DT_FLOAT, DT_FLOAT, DT_FLOAT, + DT_FLOAT}; + ::tensorflow::Node* ret; + auto builder = ::tensorflow::NodeBuilder("run_model_op", "RunTFLiteModel") + .Input(input_data) + .Attr("model_file_path", test_model_file) + .Attr("input_type", model_input_type) + .Attr("output_type", {DT_FLOAT, DT_FLOAT}); + + scope.UpdateBuilder(&builder); + scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); + TF_CHECK_OK(scope.status()); + + GraphDef graph_def; + TF_CHECK_OK(scope.ToGraphDef(&graph_def)); + std::unique_ptr session(NewSession(SessionOptions())); + TF_CHECK_OK(session->Create(graph_def)); + + std::vector outputs; + auto status = + (session->Run({}, {"run_model_op:0", "run_model_op:1"}, {}, &outputs)); + EXPECT_FALSE(status.ok()); +} + +} // namespace +} // namespace tensorflow + +int main(int argc, char** argv) { + g_test_model_file = new tensorflow::string(); + const std::vector flag_list = { + tensorflow::Flag("test_model_file", g_test_model_file, + "Path to test tflite model file."), + }; + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + CHECK(parse_result) << "Required test_model_file"; + ::tensorflow::port::InitMain(argv[0], &argc, &argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc new file mode 100644 index 0000000000000000000000000000000000000000..c96795d4994ae3bee88da6ac6d26033c981b8d6a --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h" + +#include + +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" + +namespace tensorflow { +namespace metrics { +void RunTFLiteModelStage::AddToGraph(const Scope& scope, const Input& input) { + if (!scope.ok()) return; + Scope s = scope.WithOpName(name()); + + std::vector _data = {ops::AsNodeOut(s, input)}; + ::tensorflow::Node* ret; + auto builder = NodeBuilder(output_name(), "RunTFLiteModel") + .Input(_data) + .Attr("model_file_path", params_.model_file_path) + .Attr("input_type", params_.input_type) + .Attr("output_type", params_.output_type); + + s.UpdateBuilder(&builder); + s.UpdateStatus(builder.Finalize(s.graph(), &ret)); + if (!s.ok()) return; + s.UpdateStatus(s.DoShapeInference(ret)); + this->stage_output_ = ::tensorflow::Output(ret, 0); +} + +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h new file mode 100644 index 0000000000000000000000000000000000000000..90d12d6f424516859d6ca65c162663de44eeb391 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_ + +#include + +#include "tensorflow/contrib/lite/tools/accuracy/stage.h" + +namespace tensorflow { +namespace metrics { +// Stage that loads and runs a TFLite model. +// Inputs: The input to TFLite model. +// Outputs: The output of running the TFLite model. +class RunTFLiteModelStage : public Stage { + public: + // The parameters for the stage. + struct Params { + string model_file_path; + std::vector output_shape; + std::vector input_type; + std::vector output_type; + }; + + explicit RunTFLiteModelStage(const Params& params) : params_(params) {} + + string name() const override { return "stage_run_tfl_model"; } + // TODO(shashishekhar): This stage can have multiple inputs and + // outputs, perhaps change the definition of stage. + string output_name() const override { return "stage_run_tfl_model_output"; } + + void AddToGraph(const Scope& scope, const Input& input) override; + + private: + Params params_; +}; + +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/stage.h b/tensorflow/contrib/lite/tools/accuracy/stage.h new file mode 100644 index 0000000000000000000000000000000000000000..8292ea2ec735dc6946a4516483b9b97e685e4949 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/stage.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_ + +#include "tensorflow/cc/framework/scope.h" + +namespace tensorflow { +namespace metrics { + +// A stage in an evaluation pipeline. +// Each stage adds a subgraph to the pipeline. Stages can be chained +// together. +class Stage { + public: + Stage() = default; + Stage(const Stage&) = delete; + Stage& operator=(const Stage&) = delete; + + Stage(const Stage&&) = delete; + Stage& operator=(const Stage&&) = delete; + + // Adds a subgraph to given scope that takes in `input` as a parameter. + virtual void AddToGraph(const Scope& scope, const Input& input) = 0; + virtual ~Stage() {} + + // The name of the stage. + // Can be used by derived classes for naming the subscope for the stage + // graph. + virtual string name() const = 0; + + // The name of the output for the stage. + virtual string output_name() const = 0; + + const ::tensorflow::Output& Output() const { return stage_output_; } + + protected: + ::tensorflow::Output stage_output_; +}; +} // namespace metrics +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/utils.cc b/tensorflow/contrib/lite/tools/accuracy/utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5493301fc4d781418cc5c7397bae02ecc155c56 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/utils.cc @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/accuracy/utils.h" + +#include + +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/op_resolver.h" + +namespace tensorflow { +namespace metrics { + +namespace utils { + +DataType GetTFDataType(TfLiteType tflite_type) { + switch (tflite_type) { + case kTfLiteFloat32: + return DT_FLOAT; + case kTfLiteUInt8: + return DT_UINT8; + default: + return DT_INVALID; + } +} + +TensorShape GetTFLiteTensorShape(const TfLiteTensor& tflite_tensor) { + TensorShape shape; + for (int i = 0; i < tflite_tensor.dims->size; i++) { + shape.AddDim(tflite_tensor.dims->data[i]); + } + return shape; +} + +Status ReadFileLines(const string& file_path, + std::vector* lines_output) { + if (!lines_output) { + return errors::InvalidArgument("Invalid output"); + } + std::vector lines; + std::ifstream stream(file_path, std::ios_base::in); + if (!stream) { + return errors::InvalidArgument("Unable to open file: ", file_path); + } + std::string line; + while (std::getline(stream, line)) { + lines_output->push_back(line); + } + return Status::OK(); +} + +Status GetTFliteModelInfo(const string& model_file_path, + ModelInfo* model_info) { + if (model_file_path.empty()) { + return errors::InvalidArgument("Invalid model file."); + } + struct stat stat_buf; + if (stat(model_file_path.c_str(), &stat_buf) != 0) { + int error_num = errno; + return errors::InvalidArgument("Invalid model file: ", model_file_path, + std::strerror(error_num)); + } + + std::unique_ptr model; + std::unique_ptr interpreter; + model = tflite::FlatBufferModel::BuildFromFile(model_file_path.data()); + tflite::ops::builtin::BuiltinOpResolver resolver; + + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + return errors::InvalidArgument("Invalid model", model_file_path); + } + for (int i : interpreter->inputs()) { + TfLiteTensor* tensor = interpreter->tensor(i); + model_info->input_shapes.push_back(utils::GetTFLiteTensorShape(*tensor)); + model_info->input_types.push_back(utils::GetTFDataType(tensor->type)); + } + return Status::OK(); +} + +} // namespace utils +} // namespace metrics +} // namespace tensorflow diff --git a/tensorflow/contrib/lite/tools/accuracy/utils.h b/tensorflow/contrib/lite/tools/accuracy/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..37cbad4d51fd0ddf700b14ead037ae4aeed4d82a --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/utils.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_ + +#include +#include + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace metrics { + +namespace utils { + +struct ModelInfo { + std::vector input_shapes; + std::vector input_types; +}; + +Status GetTFliteModelInfo(const string& model_file_path, ModelInfo* model_info); + +DataType GetTFDataType(TfLiteType tflite_type); + +TensorShape GetTFLiteTensorShape(const TfLiteTensor& tflite_tensor); + +Status ReadFileLines(const string& file_path, + std::vector* lines_output); +} // namespace utils +} // namespace metrics +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_ diff --git a/tensorflow/contrib/lite/tools/accuracy/utils_test.cc b/tensorflow/contrib/lite/tools/accuracy/utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..727eba21b6c6005d367130b23e31bc223508bc60 --- /dev/null +++ b/tensorflow/contrib/lite/tools/accuracy/utils_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/tools/accuracy/utils.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace { +tensorflow::string* g_test_model_file = nullptr; +} + +namespace tensorflow { +namespace metrics { +namespace utils { +namespace { + +TEST(UtilsTest, GetTFLiteModelInfoReturnsCorrectly) { + ASSERT_TRUE(g_test_model_file != nullptr); + string test_model_file = *g_test_model_file; + ASSERT_FALSE(test_model_file.empty()); + // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y + // x = a+b+c, y=b+c+d + // Input and outputs have shape : {1,8,8,3} + ModelInfo model_info; + auto status = GetTFliteModelInfo(test_model_file, &model_info); + TF_CHECK_OK(status); + ASSERT_EQ(4, model_info.input_shapes.size()); + ASSERT_EQ(4, model_info.input_types.size()); + + for (int i = 0; i < 4; i++) { + const TensorShape& shape = model_info.input_shapes[i]; + DataType dataType = model_info.input_types[i]; + EXPECT_TRUE(shape.IsSameSize({1, 8, 8, 3})); + EXPECT_EQ(DT_FLOAT, dataType); + } +} + +TEST(UtilsTest, GetTFliteModelInfoIncorrectFile) { + ModelInfo model_info; + auto status = GetTFliteModelInfo("non_existent_file", &model_info); + EXPECT_FALSE(status.ok()); +} + +} // namespace +} // namespace utils +} // namespace metrics +} // namespace tensorflow + +int main(int argc, char** argv) { + g_test_model_file = new tensorflow::string(); + const std::vector flag_list = { + tensorflow::Flag("test_model_file", g_test_model_file, + "Path to test tflite model file."), + }; + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + CHECK(parse_result) << "Required test_model_file"; + ::tensorflow::port::InitMain(argv[0], &argc, &argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/optimize/BUILD b/tensorflow/contrib/lite/tools/optimize/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..01fbce0ac79e7b3f69543db0a68c0610f3446858 --- /dev/null +++ b/tensorflow/contrib/lite/tools/optimize/BUILD @@ -0,0 +1,11 @@ +# TODO(suharshs): Write quantize_weights tests that use small exportable files. +# Then we can remove this file. +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc new file mode 100644 index 0000000000000000000000000000000000000000..0758514e394734ce2cf67783296684d5f47cadae --- /dev/null +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc @@ -0,0 +1,280 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h" + +#include +#include +#include +#include + +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/core/platform/logging.h" + +namespace tflite { +namespace optimize { + +namespace { + +// The minimum number of elements a weights array must have to be quantized +// by this transformation. +// TODO(suharshs): Make this configurable. +const int kWeightsMinSize = 1024; + +// Nudge min and max so that floating point 0 falls exactly on a quantized +// value, returning the nudges scale and zero_point. +// +// Although this code originates from FakeQuantization in quantized training, +// we may deviate from that implementation as we please since we do not fine +// tune the weights with quantized training. +void GetQuantizationParams(const float min, const float max, + const int quant_min, const int quant_max, + QuantizationParametersT* quantization_params) { + // Adjust the boundaries to guarantee 0 is included. + const float quant_min_float = std::min(static_cast(quant_min), 0.0f); + const float quant_max_float = std::max(static_cast(quant_max), 0.0f); + const float scale = (max - min) / (quant_max_float - quant_min_float); + const float zero_point_from_min = quant_min_float - min / scale; + int64_t zero_point; + if (zero_point_from_min < quant_min_float) { + zero_point = static_cast(quant_min); + } else if (zero_point_from_min > quant_max_float) { + zero_point = static_cast(quant_max); + } else { + zero_point = static_cast(std::round(zero_point_from_min)); + } + quantization_params->scale = {scale}; + quantization_params->zero_point = {zero_point}; +} + +// Returns the number of elements in tensor. +uint64 NumElements(const TensorT* tensor) { + if (tensor->shape.empty()) { + LOG(FATAL) << "Tensor has no shape information."; + } + uint64 num_elements = 1; + for (const uint64 dim : tensor->shape) { + num_elements *= dim; + } + return num_elements; +} + +uint64 CountTensorConsumers(const ModelT* model, const SubGraphT* subgraph, + int32_t tensor_idx) { + uint64 count = 0; + for (int op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) { + const OperatorT* op = subgraph->operators[op_idx].get(); + if (op == nullptr) { + continue; + } + for (int i = 0; i < op->inputs.size(); ++i) { + if (op->inputs[i] == tensor_idx) { + count++; + } + } + } + return count; +} + +// Returns true if the Operator's weight tensor should be quantized. +bool GetQuantizableTensorFromOperator(const ModelT* model, const OperatorT* op, + TensorT** tensor, int32_t* tensor_idx, + int32_t* op_input_index) { + SubGraphT* subgraph = model->subgraphs.at(0).get(); + const BuiltinOperator op_code = + model->operator_codes[op->opcode_index]->builtin_code; + + if (op_code == BuiltinOperator_CONV_2D || + op_code == BuiltinOperator_DEPTHWISE_CONV_2D || + op_code == BuiltinOperator_FULLY_CONNECTED || + op_code == BuiltinOperator_SVDF) { + *op_input_index = 1; + } else if (op_code == BuiltinOperator_LSTM) { + // TODO(suharshs): Add RNN, and sequential/bidi versions. + *op_input_index = 2; + } else { + return false; + } + *tensor_idx = op->inputs[*op_input_index]; + + // TODO(suharshs): Support shared weights, i.e. If two tensors share the + // same weight array, things may break. (i.e. SSD object detection) + if (CountTensorConsumers(model, subgraph, *tensor_idx) != 1) { + LOG(INFO) << "Skipping quantization of tensor that is shared between " + "multiple multiple operations."; + return false; + } + + *tensor = subgraph->tensors[*tensor_idx].get(); + + if ((*tensor)->type != TensorType_FLOAT32) { + LOG(INFO) << "Skipping quantization of tensor that is not type float."; + return false; + } + const uint64 num_elements = NumElements(*tensor); + if (num_elements < kWeightsMinSize) { + LOG(INFO) << "Skipping quantization of tensor because it has fewer than " + << kWeightsMinSize << " elements (" << num_elements << ")."; + return false; + } + + return true; +} + +// Quantizes tensor using asymmetric quantization with the min and max elements +// of the tensor. This is needed to pass to Dequantize operations. +TfLiteStatus AsymmetricQuantizeTensor(ModelT* model, TensorT* tensor) { + BufferT* buffer = model->buffers[tensor->buffer].get(); + float* float_data = reinterpret_cast(buffer->data.data()); + const uint64 num_elements = NumElements(tensor); + LOG(INFO) << "Quantizing tensor with " << num_elements << " elements."; + + // Compute the quantization params. + float min_value = *std::min_element(float_data, float_data + num_elements); + float max_value = *std::max_element(float_data, float_data + num_elements); + GetQuantizationParams(min_value, max_value, 0, 255, + tensor->quantization.get()); + + // Quantize the buffer. + std::vector quantized_buffer; + quantized_buffer.resize(num_elements); + const double inverse_scale = 1. / tensor->quantization->scale[0]; + for (std::size_t i = 0; i < num_elements; i++) { + const float src_val = float_data[i]; + double scaled_val; + if (tensor->quantization->scale[0] == 0) { + scaled_val = tensor->quantization->zero_point[0]; + } else { + scaled_val = + tensor->quantization->zero_point[0] + inverse_scale * src_val; + } + uint8_t integer_val = static_cast(std::round(scaled_val)); + quantized_buffer[i] = integer_val; + } + model->buffers[tensor->buffer]->data = quantized_buffer; + + // Update the tensor type. + tensor->type = TensorType_UINT8; + + return kTfLiteOk; +} + +// Returns the index of the Dequantize op_code. +// If a Dequantize op_code doesn't exist, adds it and returns its index. +int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) { + for (int i = 0; i < model->operator_codes.size(); ++i) { + if (model->operator_codes[i]->builtin_code == BuiltinOperator_DEQUANTIZE) { + return i; + } + } + model->operator_codes.push_back(std::make_unique()); + int op_code_idx = model->operator_codes.size() - 1; + model->operator_codes[op_code_idx]->builtin_code = BuiltinOperator_DEQUANTIZE; + // TODO(suharshs): How should the version be set in this op_code? + + // Return the index of the newly placed OperatorCodeT. + return op_code_idx; +} + +// Creates a Dequantize OperatorT object. +void MakeDequantizeOperator(ModelT* model, std::unique_ptr* op, + int32_t input, int32_t output) { + OperatorT* op_raw = new OperatorT; + op_raw->opcode_index = GetOrInsertDequantizeOpCodeIndex(model); + op_raw->inputs = {input}; + op_raw->outputs = {output}; + + op->reset(op_raw); +} + +// Create a new TensorT object. +void MakeTensor(const string& name, const std::vector& shape, + std::unique_ptr* tensor) { + TensorT* tensor_raw = new TensorT; + tensor_raw->name = name; + tensor_raw->shape = shape; + + tensor->reset(tensor_raw); +} + +} // namespace + +TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model) { + std::unique_ptr model; + model.reset(input_model->UnPack()); + + // TODO(suharshs): When models support multiple subgraphs, add support. + if (model->subgraphs.size() != 1) { + LOG(ERROR) << "Quantize weights tool only supports tflite models with one " + "subgraph."; + return kTfLiteError; + } + + SubGraphT* subgraph = model->subgraphs.at(0).get(); + + std::vector> new_operators; + for (int i = 0; i < subgraph->operators.size(); ++i) { + OperatorT* op = subgraph->operators[i].get(); + + TensorT* tensor; + // The index of the weight tensor in subgraph->tensors. + int32_t tensor_idx; + int32_t op_input_idx; // The index of tensor_idx in the op->inputs. + // TODO(suharshs): Support hybrid ops that require symmetric quantization. + if (GetQuantizableTensorFromOperator(model.get(), op, &tensor, &tensor_idx, + &op_input_idx)) { + // Quantize the tensors. + TF_LITE_ENSURE_STATUS(AsymmetricQuantizeTensor(model.get(), tensor)); + + // Create a new tensor to be the output of the dequantize op. + std::unique_ptr dequantize_output; + MakeTensor(tensor->name + "_dequantize", tensor->shape, + &dequantize_output); + int32_t dequantize_output_idx = subgraph->tensors.size(); + subgraph->tensors.push_back(std::move(dequantize_output)); + + // Create the Dequantize operation. + std::unique_ptr dequantize_op; + MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx, + dequantize_output_idx); + + // Update the op_input of tensor_idx to dequantize_output_idx. + op->inputs[op_input_idx] = dequantize_output_idx; + // Insert the updated op. + new_operators.push_back(std::move(subgraph->operators[i])); + + // Insert the newly created Dequantize operation. + new_operators.push_back(std::move(dequantize_op)); + } else { + // If this tensor wasn't quantizable, just copy the op over as-is. + new_operators.push_back(std::move(subgraph->operators[i])); + } + } + // At this point all unique_ptrs in the original operators are invalid, and + // we need to replace it with the new_operators vector. + subgraph->operators = std::move(new_operators); + + flatbuffers::Offset output_model_location = + Model::Pack(*builder, model.get()); + FinishModelBuffer(*builder, output_model_location); + + return kTfLiteOk; +} + +} // namespace optimize +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h new file mode 100644 index 0000000000000000000000000000000000000000..a408c1662de56ba679cd46b9e3a15a45e5c752c8 --- /dev/null +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h @@ -0,0 +1,38 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_ + +#include +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { +namespace optimize { + +// Quantizes input_model and populates the provided builder with the new model. +// +// A tflite::Model can be obtained from the builder with: +// const uint8_t* buffer = builder->GetBufferPointer(); +// tflite::Model* model = GetModel(buffer); +TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model); + +} // namespace optimize +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_ diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e0676e5ff06802d50d218e7cd7c661768a6052c --- /dev/null +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h" + +#include + +#include "flatbuffers/flexbuffers.h" +#include +#include +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { +namespace optimize { +namespace { + +class QuantizeWeightsTest : public ::testing::Test { + protected: + int GetElementsNum(const TensorT* tensor) { + int tensor_size = 1; + for (const int dim : tensor->shape) { + tensor_size *= dim; + } + return tensor_size; + } + + const OperatorT* GetOpWithOutput(const SubGraphT* subgraph, + int32_t output_tensor_idx) { + for (int i = 0; i < subgraph->operators.size(); ++i) { + OperatorT* op = subgraph->operators[i].get(); + if (std::find(op->outputs.begin(), op->outputs.end(), + output_tensor_idx) != op->outputs.end()) { + return op; + } + } + return nullptr; + } + + void CheckWeights(const Model* model_packed) { + std::unique_ptr model; + model.reset(model_packed->UnPack()); + + SubGraphT* subgraph = model->subgraphs.at(0).get(); + + for (int i = 0; i < subgraph->operators.size(); ++i) { + OperatorT* op = subgraph->operators[i].get(); + const BuiltinOperator op_code = + model->operator_codes[op->opcode_index]->builtin_code; + + // These are the operations that should be quantized. + int32_t tensor_idx; + if (op_code == BuiltinOperator_CONV_2D || + op_code == BuiltinOperator_DEPTHWISE_CONV_2D || + op_code == BuiltinOperator_FULLY_CONNECTED) { + tensor_idx = op->inputs[1]; + } else if (op_code == BuiltinOperator_LSTM) { + // TODO(suharshs): Add tests for LSTMs. + tensor_idx = op->inputs[1]; + } else { + continue; + } + const TensorT* tensor = subgraph->tensors[tensor_idx].get(); + int tensor_size = GetElementsNum(tensor); + // If the tensor_size is less than 1024 we expect the tensor to remain + // unquantized. + if (tensor_size < 1024) { + ASSERT_TRUE(tensor->type == TensorType_FLOAT32) << tensor->name; + const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx); + // The weight tensor should not come from a dequantize op. + ASSERT_TRUE(preceding_op == nullptr); + } else { + // The input to the op should still be float. + ASSERT_TRUE(tensor->type == TensorType_FLOAT32) << tensor->name; + const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx); + ASSERT_TRUE(preceding_op != nullptr); + // The float input should be the dequantize output. + ASSERT_TRUE( + model->operator_codes[preceding_op->opcode_index]->builtin_code == + BuiltinOperator_DEQUANTIZE); + // Finally, ensure that the input to the dequantize operation is + // quantized. + ASSERT_TRUE(subgraph->tensors[preceding_op->inputs[0]]->type == + TensorType_UINT8); + // TODO(suharshs): Add more rigorous testing for the numerical values in + // the tensors. + } + } + } +}; + +TEST_F(QuantizeWeightsTest, SimpleTest) { + string model_path = + "third_party/tensorflow/contrib/lite/tools/optimize/testdata/" + "mobilenet_v1_0.25_128.tflite"; + std::unique_ptr input_fb = + FlatBufferModel::BuildFromFile(model_path.data()); + const Model* input_model = input_fb->GetModel(); + + flatbuffers::FlatBufferBuilder builder; + EXPECT_EQ(QuantizeWeights(&builder, input_model), kTfLiteOk); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + + CheckWeights(output_model); +} + +// TODO(suharshs): Add tests that run the resulting model. + +} // namespace +} // namespace optimize +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: FLAGS_logtostderr = true; + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index a328670526089988c181a8e1146c911309640009..bbf5d3f30c9f7fd0cbe2ad78da15ff3eb34ae2c5 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -2532,7 +2532,8 @@ def sparse_recall_at_top_k(labels, name=name_scope) -def _compute_recall_at_precision(tp, fp, fn, precision, name): +def _compute_recall_at_precision(tp, fp, fn, precision, name, + strict_mode=False): """Helper function to compute recall at a given `precision`. Args: @@ -2541,17 +2542,42 @@ def _compute_recall_at_precision(tp, fp, fn, precision, name): fn: The number of false negatives. precision: The precision for which the recall will be calculated. name: An optional variable_scope name. + strict_mode: If true and there exists a threshold where the precision is + no smaller than the target precision, return the corresponding recall at + the threshold. Otherwise, return 0. If false, find the threshold where the + precision is closest to the target precision and return the recall at the + threshold. Returns: The recall at a given `precision`. """ precisions = math_ops.div(tp, tp + fp + _EPSILON) - tf_index = math_ops.argmin( - math_ops.abs(precisions - precision), 0, output_type=dtypes.int32) + if not strict_mode: + tf_index = math_ops.argmin( + math_ops.abs(precisions - precision), 0, output_type=dtypes.int32) + # Now, we have the implicit threshold, so compute the recall: + return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON, + name) + else: + # We aim to find the threshold where the precision is minimum but no smaller + # than the target precision. + # The rationale: + # 1. Compute the difference between precisions (by different thresholds) and + # the target precision. + # 2. Take the reciprocal of the values by the above step. The intention is + # to make the positive values rank before negative values and also the + # smaller positives rank before larger positives. + tf_index = math_ops.argmax( + math_ops.div(1.0, precisions - precision + _EPSILON), + 0, + output_type=dtypes.int32) + + def _return_good_recall(): + return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON, + name) - # Now, we have the implicit threshold, so compute the recall: - return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON, - name) + return control_flow_ops.cond(precisions[tf_index] >= precision, + _return_good_recall, lambda: .0) def recall_at_precision(labels, @@ -2561,7 +2587,8 @@ def recall_at_precision(labels, num_thresholds=200, metrics_collections=None, updates_collections=None, - name=None): + name=None, + strict_mode=False): """Computes `recall` at `precision`. The `recall_at_precision` function creates four local variables, @@ -2593,6 +2620,11 @@ def recall_at_precision(labels, updates_collections: An optional list of collections that `update_op` should be added to. name: An optional variable_scope name. + strict_mode: If true and there exists a threshold where the precision is + above the target precision, return the corresponding recall at the + threshold. Otherwise, return 0. If false, find the threshold where the + precision is closest to the target precision and return the recall at the + threshold. Returns: recall: A scalar `Tensor` representing the recall at the given @@ -2621,10 +2653,11 @@ def recall_at_precision(labels, predictions, labels, thresholds, weights) recall = _compute_recall_at_precision(values['tp'], values['fp'], - values['fn'], precision, 'value') + values['fn'], precision, 'value', + strict_mode) update_op = _compute_recall_at_precision(update_ops['tp'], update_ops['fp'], update_ops['fn'], precision, - 'update_op') + 'update_op', strict_mode) if metrics_collections: ops.add_to_collections(metrics_collections, recall) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 1c2c17960aa80e454059311a01f72e7d705ed67e..024bd54912b655a7d3213da81b620f23369aef36 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -3467,6 +3467,60 @@ class RecallAtPrecisionTest(test.TestCase): self.assertAlmostEqual(target_recall, sess.run(update_op)) self.assertAlmostEqual(target_recall, recall.eval()) + def _test_strict_mode(self, strict_mode, target_precision, expected_recall): + num_thresholds = 11 + predictions_values = [.2, .3, .5, .6, .7, .8, .9, .9, .9, .1] + labels_values = [1, 1, 0, 0, 0, 0, 0, 0, 0, 1] + # Resulting thresholds and the corresponding precision and recall values at + # each threshold: + # Thresholds [0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9] + # precisions: [0.3 0.2 0.1 0 0 0 0 0 0] + # recalls: [1.0 0.7 0.3 0 0 0 0 0 0] + predictions = constant_op.constant( + predictions_values, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels_values) + recall, update_op = metrics.recall_at_precision( + labels, + predictions, + num_thresholds=num_thresholds, + precision=target_precision, + strict_mode=strict_mode) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(expected_recall, sess.run(update_op)) + self.assertAlmostEqual(expected_recall, recall.eval()) + + def testStrictMode_Off(self): + # strict_mode is turned off and return the recall at the threshold where the + # precision (0.3) is closest to target precision (0.9). The recall + # corresponding to the threshold is 1.0. + self._test_strict_mode( + strict_mode=False, target_precision=0.9, expected_recall=1.0) + + def testStrictMode_OnAndFail(self): + # strict_mode is turned on and we fail to reach the target precision at any + # threshold. + # Target precision: 0.9 + # Diff: [-0.6 -0.7 -0.8 -0.9 -0.9 -0.9 -0.9 -0.9 -0.9] + # Reciprocal: [-1.6 -1.4 -1.3 -1.1 -1.1 -1.1 -1.1 -1.1 -1.1] + # Max index: 3 and corresponding precision is: 0 which is smaller than + # target precsion 0.9. As a result, the expected recall is 0. + self._test_strict_mode( + strict_mode=True, target_precision=0.9, expected_recall=.0) + + def testStrictMode_OnAndSucceed(self): + # strict_mode is on and we can reach the target precision at certain + # threshold. + # Target precision: 0.2 + # Diff: [0.1 0 -0.1 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2] + # Reciprocal: [10 infty -10.0 -5.0 -5.0 -5.0 -5.0 -5.0 -5.0] + # Max index: 1 and corresponding precision is: 0.2 which is no smaller than + # target precsion 0.2. In this case, we return the recall at index 1, which + # is 2.0/3 (0.7). + self._test_strict_mode( + strict_mode=True, target_precision=0.2, expected_recall=2.0 / 3) + class PrecisionAtRecallTest(test.TestCase): diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index a5267fd90482287a65a4c38ae257a0af349523e8..15d95896d96543343fdee2a6423407a1056e1063 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -53,7 +53,7 @@ The pruning library allows for specification of the following hyper parameters: | weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. | | threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds | | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) | -| nbins | integer | 256 | Number of bins to use for histogram computation | +| nbins | integer | 256 | Number of bins to use for histogram computation. Note: When running on TPUs, a large (>1024) value for `nbins` may adversely affect the training time. | | block_height|integer | 1 | Number of rows in a block for block sparse matrices| | block_width |integer | 1 | Number of cols in a block for block sparse matrices| | block_pooling_function| string | AVG | The function to use to pool weight values in a block: average (AVG) or max (MAX)| diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py index b50a372e9d7ebd348b31c6fd183d125a7e1b012f..91b0bb7f6003c047e4dcf342695f433edbc11614 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py @@ -235,19 +235,18 @@ def compute_cdf_from_histogram(values, value_range, **kwargs): def compute_cdf(values, value_range, **kwargs): """Returns the normalized cumulative distribution of the given values tensor. - Uses tf.while_loop to directly compute the cdf of the values. Number of bins - for histogram is fixed at _NBINS=255 + Uses tf.while_loop to directly compute the cdf of the values. Args: values: Numeric `Tensor`. value_range: Shape [2] `Tensor` of same `dtype` as `values` - **kwargs: keyword arguments: name + **kwargs: keyword arguments: nbins, name Returns: A 1-D `Tensor` holding normalized cdf of values. """ - nbins = _NBINS + nbins = kwargs.get('nbins', _NBINS) name = kwargs.get('name', None) with ops.name_scope(name, 'cdf', [values, value_range, nbins]): values = ops.convert_to_tensor(values, name='values') @@ -281,7 +280,7 @@ def compute_cdf(values, value_range, **kwargs): cdf = math_ops.add( cdf, array_ops.one_hot( - loop_count, depth=_NBINS, on_value=temp, off_value=0.0)) + loop_count, depth=nbins, on_value=temp, off_value=0.0)) return [loop_count + 1, cdf] _, cdf = control_flow_ops.while_loop( diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index 781621dba05c1da6e914011b28eed9928a1e094a..ad7d7cfa6e1a4d2cf5795d885a4f7c5d4d3834bf 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -31,6 +31,7 @@ from tensorflow.contrib.opt.python.training.model_average_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * from tensorflow.contrib.opt.python.training.nadam_optimizer import * +from tensorflow.contrib.opt.python.training.reg_adagrad_optimizer import * from tensorflow.contrib.opt.python.training.shampoo import * from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * from tensorflow.contrib.opt.python.training.powersign import * @@ -65,6 +66,7 @@ _allowed_symbols = [ 'ModelAverageCustomGetter', 'GGTOptimizer', 'ShampooOptimizer', + 'RegAdagradOptimizer', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/opt/python/training/adamax_test.py b/tensorflow/contrib/opt/python/training/adamax_test.py index 5790d8a3f1650c791ea12e24d88311c658df8652..61d8b94eca27427754cb2806f33d95e5643c660f 100644 --- a/tensorflow/contrib/opt/python/training/adamax_test.py +++ b/tensorflow/contrib/opt/python/training/adamax_test.py @@ -74,7 +74,7 @@ class AdaMaxOptimizerTest(test.TestCase): def doTestSparse(self, use_resource=False): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): # Initialize variables for numpy implementation. zero_slots = lambda: np.zeros((3), dtype=dtype.as_numpy_dtype) m0, v0, m1, v1 = zero_slots(), zero_slots(), zero_slots(), zero_slots() @@ -142,7 +142,7 @@ class AdaMaxOptimizerTest(test.TestCase): def testSparseRepeatedIndices(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): repeated_index_update_var = variables.Variable( [[1.0], [2.0]], dtype=dtype) aggregated_update_var = variables.Variable( @@ -233,7 +233,7 @@ class AdaMaxOptimizerTest(test.TestCase): opt.get_slot(var=var0, name="m").name) def testBasic(self): - with self.test_session(): + with self.cached_session(): self.doTestBasic(use_resource=False) @test_util.run_in_graph_and_eager_modes(reset_test=True) @@ -242,7 +242,7 @@ class AdaMaxOptimizerTest(test.TestCase): def testTensorLearningRate(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) @@ -278,7 +278,7 @@ class AdaMaxOptimizerTest(test.TestCase): def testSharing(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) diff --git a/tensorflow/contrib/opt/python/training/external_optimizer_test.py b/tensorflow/contrib/opt/python/training/external_optimizer_test.py index 953586ee70cd4137295dd254bfb2d37cab0bcfe4..999710301698406e3167f202a22ddb70f1850e07 100644 --- a/tensorflow/contrib/opt/python/training/external_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/external_optimizer_test.py @@ -85,7 +85,7 @@ class ExternalOptimizerInterfaceTest(TestCase): optimizer = MockOptimizerInterface(loss) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) optimizer.minimize(sess) @@ -107,7 +107,7 @@ class ExternalOptimizerInterfaceTest(TestCase): optimizer = MockOptimizerInterface(loss) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) initial_vector_val = sess.run(vector) @@ -164,7 +164,7 @@ class ScipyOptimizerInterfaceTest(TestCase): optimizer = external_optimizer.ScipyOptimizerInterface( self._objective(x), method=method, options=options) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) optimizer.minimize(sess) @@ -176,7 +176,7 @@ class ScipyOptimizerInterfaceTest(TestCase): x = variables.Variable(array_ops.zeros(dimension)) optimizer = external_optimizer.ScipyOptimizerInterface(self._objective(x)) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) optimizer.minimize(sess) @@ -242,7 +242,7 @@ class ScipyOptimizerInterfaceTest(TestCase): optimizer = external_optimizer.ScipyOptimizerInterface( loss, equalities=equalities, inequalities=inequalities, method='SLSQP') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) optimizer.minimize(sess) self.assertAllClose(np.ones(2), sess.run(vector)) @@ -260,7 +260,7 @@ class ScipyOptimizerInterfaceTest(TestCase): optimizer = external_optimizer.ScipyOptimizerInterface( loss, var_to_bounds=var_to_bounds) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) optimizer.minimize(sess) self.assertAllClose(np.ones(2), sess.run(vector)) @@ -277,7 +277,7 @@ class ScipyOptimizerInterfaceTest(TestCase): optimizer = external_optimizer.ScipyOptimizerInterface( loss, var_to_bounds=var_to_bounds) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) optimizer.minimize(sess) self.assertAllClose([0., 2.], sess.run(vector)) @@ -293,7 +293,7 @@ class ScipyOptimizerInterfaceTest(TestCase): optimizer = external_optimizer.ScipyOptimizerInterface( loss, method='SLSQP') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) optimizer.minimize(sess) method = optimizer.optimizer_kwargs.get('method') @@ -312,7 +312,7 @@ class ScipyOptimizerInterfaceTest(TestCase): optimizer = external_optimizer.ScipyOptimizerInterface(loss, method='SLSQP') - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) initial_vector_val = sess.run(vector) diff --git a/tensorflow/contrib/opt/python/training/ggt_test.py b/tensorflow/contrib/opt/python/training/ggt_test.py index 1d2a79957bb4a2eeed45a5221a8e79f480f72a5a..1775edabb33294d0420d2836c739cff58a78fb5b 100644 --- a/tensorflow/contrib/opt/python/training/ggt_test.py +++ b/tensorflow/contrib/opt/python/training/ggt_test.py @@ -171,7 +171,7 @@ class GGTOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) def testBasic(self): - with self.test_session(): + with self.cached_session(): self.doTestBasic(use_resource=False) @test_util.run_in_graph_and_eager_modes(reset_test=True) diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer_test.py b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py index d94249b994ac8cb4eda604feaafc037474764d8f..b76db763da0a2edbc8fb4703d3b2877e265003f7 100644 --- a/tensorflow/contrib/opt/python/training/lars_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py @@ -31,7 +31,7 @@ class LARSOptimizerTest(test.TestCase): def testLARSGradientOneStep(self): for _ in range(10): for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session() as sess: + with self.cached_session() as sess: shape = [3, 3] var_np = np.ones(shape) grad_np = np.ones(shape) @@ -77,7 +77,7 @@ class LARSOptimizerTest(test.TestCase): def testLARSGradientMultiStep(self): for _ in range(10): for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session() as sess: + with self.cached_session() as sess: shape = [3, 3] var_np = np.ones(shape) grad_np = np.ones(shape) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py index a16857db7d55b7ff95c9e88c655c1be21da1c986..dc4c462ce47bcf4d2f7fb368f0015c50fc169da3 100644 --- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py @@ -53,7 +53,7 @@ class AdamOptimizerTest(test.TestCase): def testSparse(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) @@ -109,7 +109,7 @@ class AdamOptimizerTest(test.TestCase): def testSparseRepeatedIndices(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): repeated_index_update_var = variables.Variable( [[1.0], [2.0]], dtype=dtype) aggregated_update_var = variables.Variable( diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py index d15716f6f612d21b14cf8028833fa3c87c1d1b50..f22e7245285a8b2716645f9789eb5997928a22d2 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py @@ -165,7 +165,7 @@ class MovingAverageOptimizerTest(test.TestCase): self.assertLess(avg_val1[i], orig_val1[i]) def testFailWhenSaverCreatedBeforeInitialized(self): - with self.test_session(): + with self.cached_session(): var = variables.Variable([1.0], name='var', dtype=dtypes.float32) opt = moving_average_optimizer.MovingAverageOptimizer( gradient_descent.GradientDescentOptimizer(learning_rate=2.0)) @@ -187,7 +187,7 @@ class MovingAverageOptimizerTest(test.TestCase): self.apply_gradients_called = True return super(WrapperOptimizer, self).apply_gradients(*args, **kwargs) - with self.test_session() as sess: + with self.cached_session() as sess: var = variables.Variable([1.2], name='var', dtype=dtypes.float32) loss = var ** 2 wrapper_opt = WrapperOptimizer(learning_rate=2.0) diff --git a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py index 618d8eb18d2e9b738d2c2f5b8e563aeffdf82988..904aa9ab13c390349b6fec20a14d455eb2761d5c 100644 --- a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py +++ b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py @@ -34,7 +34,7 @@ class MultitaskOptimizerWrapperTest(test.TestCase): """ def testWrapper(self): - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32) var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32) grads0 = constant_op.constant([0.1, 0.1], dtype=dtypes.float32) @@ -92,7 +92,7 @@ class MultitaskOptimizerWrapperTest(test.TestCase): self.evaluate(slot1)) def testGradientClipping(self): - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32) var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32) var2 = variables.Variable([3.0, 4.0], dtype=dtypes.float32) diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py index 825c08a09a05894df1656a9bb6981f1862195244..85e05ce71cec6ef897cadb7d123e630febb3c064 100644 --- a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py @@ -53,7 +53,7 @@ class NadamOptimizerTest(test.TestCase): def doTestSparse(self, use_resource=False): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) @@ -106,7 +106,7 @@ class NadamOptimizerTest(test.TestCase): def doTestBasic(self, use_resource=False): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): # Initialize variables for numpy implementation. m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) diff --git a/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py b/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py index ea56e1646a0811ab065105cd260a760b5b718354..c09e2ac76d469147dcaaba8ddaf56eff23e25bca 100644 --- a/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py @@ -36,7 +36,7 @@ class RegAdagradOptimizerTest(test.TestCase): def doTestBasic(self, use_locking=False, use_resource=False): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): if use_resource: var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) @@ -73,7 +73,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = resource_variable_ops.ResourceVariable( [[1.0, 2.0], [3.0, 4.0]], dtype=dtype) x = constant_op.constant([[4.0], [5.0]], dtype=dtype) @@ -92,7 +92,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testTensorLearningRate(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -116,7 +116,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testSparseBasic(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([[1.0], [2.0]], dtype=dtype) var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) grads0 = ops.IndexedSlices( @@ -144,7 +144,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testSparseRepeatedIndices(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): repeated_index_update_var = variables.Variable( [[1.0], [2.0]], dtype=dtype) aggregated_update_var = variables.Variable([[1.0], [2.0]], dtype=dtype) @@ -170,7 +170,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testSparseRepeatedIndicesResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var_repeated = resource_variable_ops.ResourceVariable( [1.0, 2.0], dtype=dtype) loss_repeated = math_ops.reduce_sum( @@ -194,7 +194,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testSparseStability(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): shape = [1, 6] var0 = variables.Variable( [[ @@ -230,7 +230,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testSharing(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -263,7 +263,7 @@ class RegAdagradOptimizerTest(test.TestCase): np.array([2.715679168701172, 3.715679168701172]), var1.eval()) def testDynamicShapeVariable_Ok(self): - with self.test_session(): + with self.cached_session(): v = variable_scope.get_variable( "v", initializer=constant_op.constant(1.), validate_shape=False) self.assertFalse(v.shape.is_fully_defined()) @@ -274,7 +274,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testSkipUpdatingSlots(self): iav = 0.130005 # A value that works with float16 for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([1.0, 2.0], dtype=dtype) var1 = variables.Variable([3.0, 4.0], dtype=dtype) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) @@ -306,7 +306,7 @@ class RegAdagradOptimizerTest(test.TestCase): def testSparseSkipUpdatingSlots(self): iav = 0.130005 # A value that works with float16 for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session(): + with self.cached_session(): var0 = variables.Variable([[1.0], [2.0]], dtype=dtype) var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) grads0 = ops.IndexedSlices( diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py index 2e0a202ae293664d85ece884a505096455cde73c..b3688ab1818ca779f3d362af10796542ed8f0e2f 100644 --- a/tensorflow/contrib/opt/python/training/shampoo_test.py +++ b/tensorflow/contrib/opt/python/training/shampoo_test.py @@ -52,7 +52,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase): grad_np = np.random.rand(size) grad_np_2 = np.random.rand(size) - with self.test_session() as sess: + with self.cached_session() as sess: global_step = variables.Variable( 0, dtype=dtypes.int64, use_resource=use_resource_var) var = variables.Variable( @@ -103,7 +103,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase): grad_np = np.random.rand(size[0], size[1]) grad_np_2 = np.random.rand(size[0], size[1]) - with self.test_session() as sess: + with self.cached_session() as sess: global_step = variables.Variable( 0, dtype=dtypes.int64, use_resource=use_resource_var) var = variables.Variable( @@ -162,7 +162,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase): grad_np = np.random.rand(size[0], size[1], size[2]) grad_np_2 = np.random.rand(size[0], size[1], size[2]) - with self.test_session() as sess: + with self.cached_session() as sess: global_step = variables.Variable( 0, dtype=dtypes.int64, use_resource=use_resource_var) var = variables.Variable( @@ -240,7 +240,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase): grad_np = np.random.rand(size) grad_np_2 = np.random.rand(size) - with self.test_session() as sess: + with self.cached_session() as sess: global_step = variables.Variable( 0, dtype=dtypes.int64, use_resource=use_resource_var) var = variables.Variable( @@ -294,7 +294,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase): grad_np = np.random.rand(size[0], size[1]) grad_np_2 = np.random.rand(size[0], size[1]) - with self.test_session() as sess: + with self.cached_session() as sess: global_step = variables.Variable( 0, dtype=dtypes.int64, use_resource=use_resource_var) var = variables.Variable( @@ -365,7 +365,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase): replace=False)) grad_np_2 = np.random.rand(sample_size_2, size[1]) - with self.test_session() as sess: + with self.cached_session() as sess: global_step = variables.Variable( 0, dtype=dtypes.int64, use_resource=use_resource_var) var = variables.Variable( @@ -445,7 +445,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase): replace=False)) grad_np = np.random.rand(sample_size, size[1], size[2]) - with self.test_session() as sess: + with self.cached_session() as sess: global_step = variables.Variable( 0, dtype=dtypes.int64, use_resource=use_resource_var) var = variables.Variable( @@ -512,7 +512,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase): gbar_decay = 0.9 gbar_weight = 0.1 - with self.test_session() as sess: + with self.cached_session() as sess: global_step = variables.Variable( 0, dtype=dtypes.int64, use_resource=use_resource_var) var = variables.Variable( @@ -601,7 +601,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase): mat_g3_a = np.eye(size[2]) mat_g3 = np.zeros_like(mat_g3_a) - with self.test_session() as sess: + with self.cached_session() as sess: global_step = variables.Variable( 0, dtype=dtypes.int64, use_resource=use_resource_var) var = variables.Variable( @@ -672,7 +672,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase): mat_g3_a = np.eye(size[2]) mat_g3 = np.zeros_like(mat_g3_a) - with self.test_session() as sess: + with self.cached_session() as sess: global_step = variables.Variable( 0, dtype=dtypes.int64, use_resource=use_resource_var) var = variables.Variable( diff --git a/tensorflow/contrib/opt/python/training/sign_decay_test.py b/tensorflow/contrib/opt/python/training/sign_decay_test.py index c31cb924eacfc8feea6bbd1f5c9ae903442b04b1..3a84789afd77f5c068501ddcfa96287503e87f60 100644 --- a/tensorflow/contrib/opt/python/training/sign_decay_test.py +++ b/tensorflow/contrib/opt/python/training/sign_decay_test.py @@ -66,7 +66,7 @@ class SignDecaysTest(test.TestCase): linear_decay_fn = sign_decay.get_linear_decay_fn(num_training_steps) for step in range(0, 1000, 100): - with self.test_session(): + with self.cached_session(): tf_decayed = linear_decay_fn(step).eval() py_decayed = py_linear_decay_fn(num_training_steps)(step) self.assertAlmostEqual(tf_decayed, py_decayed, places=4) @@ -78,7 +78,7 @@ class SignDecaysTest(test.TestCase): num_training_steps, num_periods=5, zero_after=2) for step in range(0, 1000, 100): - with self.test_session(): + with self.cached_session(): tf_decayed = cosine_decay_fn(step).eval() py_decayed = py_cosine_decay_fn(num_training_steps)(step) self.assertAlmostEqual(tf_decayed, py_decayed, places=4) @@ -95,7 +95,7 @@ class SignDecaysTest(test.TestCase): num_training_steps, num_periods=5, zero_after=2) for step in range(0, 1000, 100): - with self.test_session(): + with self.cached_session(): tf_decayed = restart_decay_fn(step).eval() py_decayed = py_restart_decay_fn(num_training_steps)(step) self.assertAlmostEqual(tf_decayed, py_decayed, places=4) diff --git a/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py b/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py index fdda86b0b53879d891769747f5b211257f3b3fbd..ff0ea8d766934ed98ec35c89a642a34f794415f3 100644 --- a/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py @@ -158,7 +158,7 @@ class VariableClippingOptimizerTest(test.TestCase): def testDenseLocal(self): for dtype in [dtypes.float32, dtypes.float64, dtypes.half]: - with self.test_session(): + with self.cached_session(): var0, var1, update_op = self._setupDense(False, dtype) self._assertDenseCorrect(var0, var1, update_op) @@ -171,7 +171,7 @@ class VariableClippingOptimizerTest(test.TestCase): def testSparseLocal(self): for dtype in [dtypes.float64, dtypes.float32, dtypes.half]: - with self.test_session(): + with self.cached_session(): var0, var1, update_op = self._setupSparse(False, dtype) self._assertSparseCorrect(var0, var1, update_op) diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py index b9cf40eb7b2d11c98b93c51213145ca4e2670318..29acfc602e7ffdb5fa72b69f9bed0a405ba60693 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -26,6 +26,7 @@ from tensorflow.python.training import adam from tensorflow.python.training import momentum as momentum_opt from tensorflow.python.training import optimizer from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.ops import array_ops class DecoupledWeightDecayExtension(object): @@ -159,8 +160,8 @@ class DecoupledWeightDecayExtension(object): def _decay_weights_sparse_op(self, var, indices, scatter_add): if not self._decay_var_list or var in self._decay_var_list: - return scatter_add(var, indices, -self._weight_decay * var, - self._use_locking) + update = -self._weight_decay * array_ops.gather(var, indices) + return scatter_add(var, indices, update, self._use_locking) return control_flow_ops.no_op() # Here, we overwrite the apply functions that the base optimizer calls. diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py index 631d4f44dfb646541244bfe1d15136dd29f02703..04b1552b61ae45cb8370e94a0b8988913600708d 100644 --- a/tensorflow/contrib/optimizer_v2/adam.py +++ b/tensorflow/contrib/optimizer_v2/adam.py @@ -40,15 +40,14 @@ class AdamOptimizer(optimizer_v2.OptimizerV2): Initialization: - $$m_0 := 0 (Initialize initial 1st moment vector)$$ - $$v_0 := 0 (Initialize initial 2nd moment vector)$$ - $$t := 0 (Initialize timestep)$$ - + $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$ + $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$ + $$t := 0 \text{(Initialize timestep)}$$ The update rule for `variable` with gradient `g` uses an optimization described at the end of section2 of the paper: $$t := t + 1$$ - $$lr_t := \text{learning_rate} * \sqrt{(1 - beta_2^t) / (1 - beta_1^t)}$$ + $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py index 67a8f59c3c03d01a5957a9eff8bd026e70770a45..c3db71359c734d59afc1011d8587a16a82f14b65 100644 --- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py +++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py @@ -178,7 +178,8 @@ def _ApplyLengthsToBatch(sequence_lengths, tf_output): # TODO(drpng): just use Update so that we don't carry over the gradients? """Sets the output to be zero at the end of the sequence.""" # output is batch major. - batch_size, max_time, vector_size = tf_output.shape + shape = array_ops.shape(tf_output) + batch_size, max_time, vector_size = shape[0], shape[1], shape[2] output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size]) output_time = array_ops.reshape(output_time, [batch_size, max_time]) lengths = array_ops.tile( @@ -278,11 +279,16 @@ def functional_rnn(cell, inputs, sequence_length=None, if initial_state is None: initial_state = cell.zero_state(batch_size, dtype) func_cell = _FunctionalRnnCell(cell, inputs, initial_state) + if sequence_length is not None: + max_length = math_ops.reduce_max(sequence_length) + else: + max_length = None extended_acc_state, extended_final_state = recurrent.Recurrent( theta=func_cell.theta, state0=func_cell.extended_initial_state, inputs=inputs, cell_fn=func_cell.cell_step, + max_input_length=max_length, use_tpu=use_tpu) tf_output, tf_state = _PostProcessOutput( extended_acc_state, extended_final_state, func_cell, diff --git a/tensorflow/contrib/session_bundle/session_bundle.cc b/tensorflow/contrib/session_bundle/session_bundle.cc index cf26e3cae7e9247e387ee8294c4c0d5de8781d39..a690d9b129a4d52a540bf41636c8f85497f3551b 100644 --- a/tensorflow/contrib/session_bundle/session_bundle.cc +++ b/tensorflow/contrib/session_bundle/session_bundle.cc @@ -138,10 +138,10 @@ Status RunRestoreOp(const RunOptions& run_options, const StringPiece export_dir, Tensor variables_tensor = CreateStringTensor(GetVariablesFilename(export_dir)); std::vector> inputs = { - {variables_filename_const_op_name.ToString(), variables_tensor}}; + {string(variables_filename_const_op_name), variables_tensor}}; AddAssetsTensorsToInputs(export_dir, asset_files, &inputs); RunMetadata run_metadata; - return session->Run(run_options, inputs, {}, {restore_op_name.ToString()}, + return session->Run(run_options, inputs, {}, {string(restore_op_name)}, nullptr /* outputs */, &run_metadata); } @@ -152,7 +152,7 @@ Status RunInitOp(const RunOptions& run_options, const StringPiece export_dir, std::vector> inputs; AddAssetsTensorsToInputs(export_dir, asset_files, &inputs); RunMetadata run_metadata; - return session->Run(run_options, inputs, {}, {init_op_name.ToString()}, + return session->Run(run_options, inputs, {}, {string(init_op_name)}, nullptr /* outputs */, &run_metadata); } @@ -251,15 +251,14 @@ Status LoadSessionBundleFromPathUsingRunOptions(const SessionOptions& options, auto log_and_count = [&](const string& status_str) { LOG(INFO) << "Loading SessionBundle: " << status_str << ". Took " << load_latency_microsecs << " microseconds."; - load_attempt_count->GetCell(export_dir.ToString(), status_str) - ->IncrementBy(1); + load_attempt_count->GetCell(string(export_dir), status_str)->IncrementBy(1); }; if (status.ok()) { log_and_count(kLoadAttemptSuccess); } else { log_and_count(kLoadAttemptFail); } - load_latency->GetCell(export_dir.ToString()) + load_latency->GetCell(string(export_dir)) ->IncrementBy(load_latency_microsecs); return status; } diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 22d6e499d2b6987204dba23be453e9d944057c5f..652f709fe222d9938742d24d40f633fe156202d8 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -534,10 +534,11 @@ py_library( py_test( name = "random_forest_test", - size = "medium", + size = "large", srcs = ["client/random_forest_test.py"], srcs_version = "PY2AND3", tags = [ + "noasan", "nomac", # b/63258195 "notsan", ], diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc index d43884481afbbbc988d6eb80e01e49663df6914b..99c58003912b56ed0948ea2589dd841c74ad5f5c 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc @@ -130,7 +130,11 @@ void TensorDataSet::RandomSample(int example, num_total_features += num_sparse; } } - int rand_feature = rng_->Uniform(num_total_features); + int rand_feature = 0; + { + mutex_lock lock(mu_); + rand_feature = rng_->Uniform(num_total_features); + } if (rand_feature < available_features_.size()) { // it's dense. *feature_id = available_features_[rand_feature]; *type = input_spec_.GetDenseFeatureType(rand_feature); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h index 95f75b4d7e6a961edf6b3da1dc1712e7ddaacf31..4945b53007e8bd288cfc7aaa31c55c6b88fce646 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h @@ -25,6 +25,7 @@ #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/mutex.h" namespace tensorflow { namespace tensorforest { @@ -120,6 +121,8 @@ class TensorDataSet { int32 split_sampling_random_seed_; std::unique_ptr single_rand_; std::unique_ptr rng_; + // Mutex for using random number generator. + mutable mutex mu_; }; } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index a0fc3e43a9018761181ca92b4935679b5b180f71..122a67a4074199094824f839f638365dfbf3d007 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -279,6 +279,7 @@ tf_cuda_library( "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", + "//tensorflow/core:framework", "//tensorflow/core:framework_lite", "//tensorflow/core:gpu_runtime", "//tensorflow/core:graph", diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 0f5abe68986c77e53d7e05ea03f5cdba63b242bb..c98b07ad8b921e18da85aa90576d0f4aa46cda94 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT #include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD index 71b0d487982f16a6c7c34abdeba067a645bb871d..21c0c30c1982e42f0164dd91e23fa13809c3a19b 100644 --- a/tensorflow/contrib/timeseries/examples/BUILD +++ b/tensorflow/contrib/timeseries/examples/BUILD @@ -32,6 +32,7 @@ py_test( name = "predict_test", timeout = "long", # Moderate but for asan srcs = ["predict_test.py"], + data = ["data/period_trend.csv"], srcs_version = "PY2AND3", tags = [ "no_windows", # TODO: needs investigation on Windows diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly.py b/tensorflow/contrib/timeseries/examples/known_anomaly.py index 71621abc7190fae9973f78522e23f03d43e342c6..1226433625a79baca17f3bb052f79401fa7e7dd9 100644 --- a/tensorflow/contrib/timeseries/examples/known_anomaly.py +++ b/tensorflow/contrib/timeseries/examples/known_anomaly.py @@ -41,7 +41,7 @@ _MODULE_PATH = path.dirname(__file__) _DATA_FILE = path.join(_MODULE_PATH, "data/changepoints.csv") -def state_space_esitmator(exogenous_feature_columns): +def state_space_estimator(exogenous_feature_columns): """Constructs a StructuralEnsembleRegressor.""" def _exogenous_update_condition(times, features): @@ -68,7 +68,7 @@ def state_space_esitmator(exogenous_feature_columns): 4, 64) -def autoregressive_esitmator(exogenous_feature_columns): +def autoregressive_estimator(exogenous_feature_columns): input_window_size = 8 output_window_size = 2 return ( @@ -169,10 +169,10 @@ def main(unused_argv): "Please install matplotlib to generate a plot from this example.") make_plot("Ignoring a known anomaly (state space)", *train_and_evaluate_exogenous( - estimator_fn=state_space_esitmator)) + estimator_fn=state_space_estimator)) make_plot("Ignoring a known anomaly (autoregressive)", *train_and_evaluate_exogenous( - estimator_fn=autoregressive_esitmator, train_steps=3000)) + estimator_fn=autoregressive_estimator, train_steps=3000)) pyplot.show() diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly_test.py b/tensorflow/contrib/timeseries/examples/known_anomaly_test.py index 8c64f2e186a1aab0235f7cfbf1a942b872edd93b..57ccf8f260f41f82d58b43d0cade7af9a26865f5 100644 --- a/tensorflow/contrib/timeseries/examples/known_anomaly_test.py +++ b/tensorflow/contrib/timeseries/examples/known_anomaly_test.py @@ -28,7 +28,7 @@ class KnownAnomalyExampleTest(test.TestCase): def test_shapes_and_variance_structural_ar(self): (times, observed, all_times, mean, upper_limit, lower_limit, anomaly_locations) = known_anomaly.train_and_evaluate_exogenous( - train_steps=1, estimator_fn=known_anomaly.autoregressive_esitmator) + train_steps=1, estimator_fn=known_anomaly.autoregressive_estimator) self.assertAllEqual( anomaly_locations, [25, 50, 75, 100, 125, 150, 175, 249]) @@ -40,7 +40,7 @@ class KnownAnomalyExampleTest(test.TestCase): def test_shapes_and_variance_structural_ssm(self): (times, observed, all_times, mean, upper_limit, lower_limit, anomaly_locations) = known_anomaly.train_and_evaluate_exogenous( - train_steps=50, estimator_fn=known_anomaly.state_space_esitmator) + train_steps=50, estimator_fn=known_anomaly.state_space_estimator) self.assertAllEqual( anomaly_locations, [25, 50, 75, 100, 125, 150, 175, 249]) diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto index 2cc17d6d928370afbb0e3b1e89252f7a687c27d3..bf807af68bc0fd107850477eb0b47a101d77a046 100644 --- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto +++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto @@ -119,7 +119,9 @@ message OptimizationParameters { // Whether to use gradient accumulation (do two passes over the input // gradients: one to accumulate them into a temporary array and another to - // apply them using the actual optimization algorithm). + // apply them using the actual optimization algorithm). This feature is + // experimental -- it has not been fully verified and may cause training + // crashes and/or failures. bool use_gradient_accumulation = 15; // Optimization algorithm parameters; which field is selected determines which diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index a5e8277ba532b3f7c41880df23c0162f80163890..87b900574c36e7d38724fdb5d17fb907a5dc4ba7 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -111,17 +111,24 @@ def reset_tpu_sessions(): # Work-around dependency cycle between DistributionStrategy and TPU lib. -def TPUDistributionStrategy(tpu_cluster_resolver=None): # pylint: disable=invalid-name +def TPUDistributionStrategy(tpu_cluster_resolver=None, num_cores=None): # pylint: disable=invalid-name """Construct a TPUDistributionStrategy.""" from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top - # TODO -- remove this when TPUStrategy API is consistent (b/112705069) + # TODO(b/112705069): Remove this when TPUStrategy API is consistent. + # We are including this for (a) backwards compatibility for open sourced + # releases of TensorFlow and (b) to work around a circular dependency + # where keras_support and tpu_strategy depends on each other. Once we release + # a final version and remove support for the old API, this will be deleted. + # (See bug above for more details) if tpu_cluster_resolver is None: tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('') args, _, _, _ = tf_inspect.getargspec(tpu_strategy.TPUStrategy.__init__) - if len(args) == 3: + if len(args) == 4: logging.info('Detected new TPUStrategy API.') - return tpu_strategy.TPUStrategy(tpu_cluster_resolver, steps_per_run=1) + return tpu_strategy.TPUStrategy(tpu_cluster_resolver, + steps_per_run=1, + num_cores=num_cores) else: logging.info('Detected old TPUStrategy API.') strategy = tpu_strategy.TPUStrategy(num_cores_per_host=8) @@ -612,7 +619,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager): 'currently requires static shapes. The provided ' 'dataset only has a partially defined shape. ' '(Dimension %d of output tensor %d is not statically known ' - 'for output shapes: %s.%s)' % (i, j, dataset.output_shapes, hint)) + 'for output shapes: %s.%s)' % (j, i, dataset.output_shapes, hint)) @property def dummy_x(self): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 7fa06d6d560a4b6ffa6d9a3fd0fa208b4c60ee7f..3c735a0b85db6e26cb5694b2fc822c9d6e0b2dec 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -42,9 +42,9 @@ _BLACKLISTED_OPS = set([ "Placeholder", ]) -# These operations will currently fail to compile, but we should be able to -# support them eventually via CPU offload or extending our operation set. -_NOT_IMPLEMENTED_OPS = set([ +# XLA doesn't currently support reading of intermediate tensors, thus some ops +# are not supported. +_UNSUPPORTED_OPS = set([ "AudioSummary", "AudioSummaryV2", "HistogramSummary", @@ -149,6 +149,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._gradient_colocation_stack = [] self._host_compute_core = [] self._name = name + self._name_as_bytes = compat.as_bytes(name) self._unsupported_ops = [] self._pivot = pivot self._replicated_vars = {} @@ -323,16 +324,13 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): return self._host_compute_core def AddOp(self, op): - self._AddOpInternal(op) - - def _AddOpInternal(self, op): # pylint: disable=protected-access if op.type in _BLACKLISTED_OPS: logging.error("Operation of type %s (%s) is not supported on the TPU. " "Execution will fail if this op is used in the graph. " % (op.type, op.name)) - if op.type in _NOT_IMPLEMENTED_OPS: + if op.type in _UNSUPPORTED_OPS: self._unsupported_ops.append(op) if any(x.dtype._is_ref_dtype for x in op.inputs): @@ -342,7 +340,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): if _TPU_REPLICATE_ATTR in op.node_def.attr: raise ValueError("TPU computations cannot be nested") op._set_attr(_TPU_REPLICATE_ATTR, - attr_value_pb2.AttrValue(s=compat.as_bytes(self._name))) + attr_value_pb2.AttrValue(s=self._name_as_bytes)) if self._outside_compilation_cluster: op._set_attr( _OUTSIDE_COMPILATION_ATTR, @@ -356,11 +354,12 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): # Remove any control edges from outer control flow contexts. These may cause # mismatched frame errors. - control_inputs, external_inputs = self._RemoveExternalControlEdges(op) + (internal_control_inputs, + external_control_inputs) = self._RemoveExternalControlEdges(op) if not op.inputs: # Add a control edge from the control pivot to this op. - if not control_inputs: + if not internal_control_inputs: # pylint: disable=protected-access op._add_control_input(self.GetControlPivot()) # pylint: enable=protected-access @@ -371,19 +370,19 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): if real_x != x: op._update_input(index, real_x) # pylint: disable=protected-access - if external_inputs: + if external_control_inputs: # Use an identity to pull control inputs as data inputs. Note that we # ignore ops which don't have outputs. TODO(phawkins): fix that. with ops.control_dependencies(None): self.Enter() - external_inputs = [ + external_control_inputs = [ array_ops.identity(x.outputs[0]).op - for x in external_inputs + for x in external_control_inputs if x.outputs ] self.Exit() # pylint: disable=protected-access - op._add_control_inputs(external_inputs) + op._add_control_inputs(external_control_inputs) # pylint: enable=protected-access # Mark op's outputs as seen by this context and any outer contexts. @@ -399,6 +398,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._outer_context.AddInnerOp(op) def AddValue(self, val): + """Add `val` to the current context and its outer context recursively.""" if val.name in self._values: # Use the real value if it comes from outer context. result = self._external_values.get(val.name) @@ -415,7 +415,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): return result def AddInnerOp(self, op): - self._AddOpInternal(op) + self.AddOp(op) if self._outer_context: self._outer_context.AddInnerOp(op) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 2e4050bd997aee86bab568a1834d0c6891986f5e..1ff04f5c2661d2b9ec1236ec517e700d9e55e976 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -804,11 +804,14 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( per_host_sharded_inputs.append(flattened_inputs) if inputs_structure_recorder.flattened_input_dims: + input_partition_dims = inputs_structure_recorder.flattened_input_dims + if signals: + input_partition_dims += [None] * len(signals) # pylint: disable=protected-access infeed_queue = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=len(per_host_sharded_inputs[0]), host_id=host_id, - input_partition_dims=inputs_structure_recorder.flattened_input_dims, + input_partition_dims=input_partition_dims, device_assignment=ctx.device_assignment) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( per_host_sharded_inputs) @@ -2821,8 +2824,6 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - num_cores = ctx.num_cores - (single_tpu_predict_step, host_calls, captured_scaffold_fn, captured_predict_hooks ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn) @@ -2841,7 +2842,7 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): (dummy_predict_op,) = tpu.shard( multi_tpu_predict_steps_on_single_shard, inputs=[], - num_shards=num_cores, + num_shards=ctx.num_replicas, outputs_from_all_shards=False, device_assignment=ctx.device_assignment) diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py index 81278ea82cc17a4df86b181aa4da794d67973eea..afeef978f31627ba8f925efc14106ce9a0c3b561 100644 --- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py +++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py @@ -108,7 +108,7 @@ class BatchSequencesWithStatesTest(test.TestCase): expected_seq4_batch1, expected_seq4_batch2, key=None, make_keys_unique=False): - with self.test_session() as sess: + with self.cached_session() as sess: next_batch = sqss.batch_sequences_with_states( input_key=key if key is not None else self.key, input_sequences=self.sequences, @@ -332,7 +332,7 @@ class BatchSequencesWithStatesTest(test.TestCase): "seq4": self.sequences["seq4"], } - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, ".*should be a multiple of: 3, but saw " "value: 4. Consider setting pad=True."): diff --git a/tensorflow/contrib/training/python/training/bucket_ops_test.py b/tensorflow/contrib/training/python/training/bucket_ops_test.py index 504f1fcd417f99a8aaa72504f1852e523da1a4c9..b259e0ee83f9f4231111e25caea0e60437930994 100644 --- a/tensorflow/contrib/training/python/training/bucket_ops_test.py +++ b/tensorflow/contrib/training/python/training/bucket_ops_test.py @@ -112,7 +112,7 @@ class BucketTest(test.TestCase): self.assertAllEqual( [[32], [32, None], [32, 3], [None, None]], [out.get_shape().as_list() for out in bucketed_dynamic[1]]) - with self.test_session() as sess: + with self.cached_session() as sess: for v in range(32): self.enqueue_inputs(sess, { self.scalar_int_feed: v, @@ -162,7 +162,7 @@ class BucketTest(test.TestCase): self.assertAllEqual( [[None], [None, None], [None, 3], [None, None]], [out.get_shape().as_list() for out in bucketed_dynamic[1]]) - with self.test_session() as sess: + with self.cached_session() as sess: for v in range(15): self.enqueue_inputs(sess, { self.scalar_int_feed: v, @@ -204,7 +204,7 @@ class BucketTest(test.TestCase): self.assertAllEqual( [[32], [32, None], [32, 3], [None, None]], [out.get_shape().as_list() for out in bucketed_dynamic[1]]) - with self.test_session() as sess: + with self.cached_session() as sess: for v in range(64): self.enqueue_inputs(sess, { self.scalar_int_feed: v, @@ -286,7 +286,7 @@ class BucketTest(test.TestCase): self.assertAllEqual( [[32], [32, None], [32, 3]], [out.get_shape().as_list() for out in bucketed_dynamic[1]]) - with self.test_session() as sess: + with self.cached_session() as sess: for v in range(128): self.enqueue_inputs(sess, { self.scalar_int_feed: v, @@ -405,7 +405,7 @@ class BucketBySequenceLengthTest(test.TestCase): num_pairs_to_enqueue - (batch_size - 1) * num_buckets, num_pairs_dequeued) - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() # Feed the inputs, then close the input thread. diff --git a/tensorflow/contrib/training/python/training/evaluation_test.py b/tensorflow/contrib/training/python/training/evaluation_test.py index c36d00e8425ccbfe9338b50fc492dc1334d59731..ec47fe5d97e4709904581193842e028ea2e1a629 100644 --- a/tensorflow/contrib/training/python/training/evaluation_test.py +++ b/tensorflow/contrib/training/python/training/evaluation_test.py @@ -67,7 +67,7 @@ class CheckpointIteratorTest(test.TestCase): global_step = variables.get_or_create_global_step() saver = saver_lib.Saver() # Saves the global step. - with self.test_session() as session: + with self.cached_session() as session: session.run(variables_lib.global_variables_initializer()) save_path = os.path.join(checkpoint_dir, 'model.ckpt') saver.save(session, save_path, global_step=global_step) diff --git a/tensorflow/contrib/training/python/training/resample_test.py b/tensorflow/contrib/training/python/training/resample_test.py index 774241a816452cf56dbd609c814d4ee57da3ac11..8665a24883b718314450b5dc53be471b435681d0 100644 --- a/tensorflow/contrib/training/python/training/resample_test.py +++ b/tensorflow/contrib/training/python/training/resample_test.py @@ -44,7 +44,7 @@ class ResampleTest(test.TestCase): ([3], [0, 0, 0]), ([0, 1, 2, 3], [1, 2, 2, 3, 3, 3]), ] - with self.test_session() as sess: + with self.cached_session() as sess: for inputs, expected in cases: array_inputs = numpy.array(inputs, dtype=numpy.int32) actual = sess.run(resample._repeat_range(array_inputs)) @@ -65,7 +65,7 @@ class ResampleTest(test.TestCase): init = control_flow_ops.group(variables.local_variables_initializer(), variables.global_variables_initializer()) - with self.test_session() as s: + with self.cached_session() as s: s.run(init) # initialize # outputs @@ -112,7 +112,7 @@ class ResampleTest(test.TestCase): init = control_flow_ops.group(variables.local_variables_initializer(), variables.global_variables_initializer()) expected_sum_op = math_ops.reduce_sum(vals) - with self.test_session() as s: + with self.cached_session() as s: s.run(init) expected_sum = n * s.run(expected_sum_op) @@ -147,7 +147,7 @@ class ResampleTest(test.TestCase): resampled = resample.resample_at_rate([vals], rates) - with self.test_session() as s: + with self.cached_session() as s: rs, = s.run(resampled, { vals: list(range(count)), rates: numpy.zeros( diff --git a/tensorflow/contrib/training/python/training/sampling_ops_test.py b/tensorflow/contrib/training/python/training/sampling_ops_test.py index bf7fb4fd48574d3db0d3e3de1161cbb244580b63..1aeff7dc80d21bcaadf9ca096eaea147ec2380ac 100644 --- a/tensorflow/contrib/training/python/training/sampling_ops_test.py +++ b/tensorflow/contrib/training/python/training/sampling_ops_test.py @@ -146,7 +146,7 @@ class StratifiedSampleTest(test.TestCase): for illegal_label in illegal_labels: # Run session that should fail. - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors_impl.InvalidArgumentError): sess.run([val_tf, lbl_tf], feed_dict={label_ph: illegal_label, @@ -154,7 +154,7 @@ class StratifiedSampleTest(test.TestCase): for illegal_prob in illegal_probs: # Run session that should fail. - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaises(errors_impl.InvalidArgumentError): sess.run([prob_tf], feed_dict={label_ph: valid_labels, @@ -172,7 +172,7 @@ class StratifiedSampleTest(test.TestCase): summary_op = logging_ops.merge_summary( ops.get_collection(ops.GraphKeys.SUMMARIES)) - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) @@ -197,7 +197,7 @@ class StratifiedSampleTest(test.TestCase): batch_size, init_probs=[0, .3, 0, .7, 0], enqueue_many=True) - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) @@ -228,7 +228,7 @@ class StratifiedSampleTest(test.TestCase): # Run graph to make sure there are no shape-related runtime errors. for vals, labels in legal_input_pairs: - with self.test_session() as sess: + with self.cached_session() as sess: sess.run([val_tf, labels_tf], feed_dict={vals_ph: vals, labels_ph: labels}) @@ -253,7 +253,7 @@ class StratifiedSampleTest(test.TestCase): self.assertEqual(len(val_list), len(val_input_batch)) self.assertTrue(isinstance(lbls, ops.Tensor)) - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) @@ -283,7 +283,7 @@ class StratifiedSampleTest(test.TestCase): # Run session and keep track of how frequently the labels and values appear. data_l = [] label_l = [] - with self.test_session() as sess: + with self.cached_session() as sess: # Need to initialize variables that keep running total of classes seen. variables.global_variables_initializer().run() @@ -374,7 +374,7 @@ class RejectionSampleTest(test.TestCase): 'rejection_sample/prob_with_checks:0') # Run session that should fail. - with self.test_session() as sess: + with self.cached_session() as sess: for illegal_prob in [-0.1, 1.1]: with self.assertRaises(errors_impl.InvalidArgumentError): sess.run(prob_tensor, feed_dict={prob_ph: illegal_prob}) @@ -393,7 +393,7 @@ class RejectionSampleTest(test.TestCase): sample = sampling_ops.rejection_sample(tensor_list, accept_prob_fn, batch_size) - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(coord=coord) diff --git a/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py b/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py index ca78c0029ee18692445980f599eefa781126d3aa..73ad859ab34fda38b5e8bcc7076be6c8e5672886 100644 --- a/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py +++ b/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py @@ -59,7 +59,7 @@ class SamplingOpsThreadingTest(test.TestCase): out_tensor = queue.dequeue() # Run the multi-threaded session. - with self.test_session() as sess: + with self.cached_session() as sess: # Need to initialize variables that keep running total of classes seen. variables.global_variables_initializer().run() diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py index 7aebd9d9fe94f3f668a95ed0303703e7f2558cb8..8932b905c91df918d53de9495f7a05410b7e5405 100644 --- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py +++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py @@ -36,7 +36,7 @@ from tensorflow.python.platform import test class SequenceQueueingStateSaverTest(test.TestCase): def testSequenceInputWrapper(self): - with self.test_session(): + with self.cached_session(): length = 3 key = "key" padded_length = 4 @@ -54,7 +54,7 @@ class SequenceQueueingStateSaverTest(test.TestCase): self.assertTrue(isinstance(input_wrapper.context["context1"], ops.Tensor)) def testStateSaverWithTwoSimpleSteps(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size_value = 2 batch_size = constant_op.constant(batch_size_value) num_unroll = 2 @@ -159,7 +159,7 @@ class SequenceQueueingStateSaverTest(test.TestCase): self.assertEqual(0, state_saver.barrier.ready_size().eval()) def testStateSaverFailsIfPaddedLengthIsNotMultipleOfNumUnroll(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = constant_op.constant(32) num_unroll = 17 bad_padded_length = 3 @@ -194,7 +194,7 @@ class SequenceQueueingStateSaverTest(test.TestCase): }) def _testStateSaverFailsIfCapacityTooSmall(self, batch_size): - with self.test_session() as sess: + with self.cached_session() as sess: num_unroll = 2 length = array_ops.placeholder(dtypes.int32) key = array_ops.placeholder(dtypes.string) @@ -243,7 +243,7 @@ class SequenceQueueingStateSaverTest(test.TestCase): self._testStateSaverFailsIfCapacityTooSmall(batch_size) def testStateSaverFailsIfInconsistentPaddedLength(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = constant_op.constant(32) num_unroll = 17 length = array_ops.placeholder(dtypes.int32) @@ -282,7 +282,7 @@ class SequenceQueueingStateSaverTest(test.TestCase): def testStateSaverFailsIfInconsistentWriteState(self): # TODO(b/26910386): Identify why this infrequently causes timeouts. - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = constant_op.constant(1) num_unroll = 17 length = array_ops.placeholder(dtypes.int32) @@ -326,7 +326,7 @@ class SequenceQueueingStateSaverTest(test.TestCase): def testStateSaverWithManyInputsReadWriteThread(self): batch_size_value = 32 num_proc_threads = 100 - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = constant_op.constant(batch_size_value) num_unroll = 17 length = array_ops.placeholder(dtypes.int32) @@ -490,7 +490,7 @@ class SequenceQueueingStateSaverTest(test.TestCase): self.assertGreater(processed_count[0], 2 * 20 * batch_size_value) def testStateSaverProcessesExamplesInOrder(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size_value = 32 batch_size = constant_op.constant(batch_size_value) num_unroll = 17 @@ -563,7 +563,7 @@ class SequenceQueueingStateSaverTest(test.TestCase): self.assertEqual(get_ready_size.eval(), 0) def testStateSaverCanHandleVariableBatchsize(self): - with self.test_session() as sess: + with self.cached_session() as sess: batch_size = array_ops.placeholder(dtypes.int32) num_unroll = 17 length = array_ops.placeholder(dtypes.int32) diff --git a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py index 4a46e9a49ef203384e36698f81d6cbe3a3881ef8..3269d5fef2080ce23f07b17cdc69ae878de9837e 100644 --- a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py +++ b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py @@ -62,7 +62,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase): def get_sgdr_values(self, lr, initial_period_steps, t_mul, iters): """Get an array with learning rate values from the consecutive steps using current tensorflow implementation.""" - with self.test_session(): + with self.cached_session(): step = placeholder(dtypes.int32) decay = sgdr_decay(lr, step, initial_period_steps, t_mul) @@ -76,7 +76,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase): """Compare values generated by tensorflow implementation to the values generated by the original implementation (https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py).""" - with self.test_session(): + with self.cached_session(): lr = 10.0 init_steps = 2 t_mul = 3 @@ -92,7 +92,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase): def testMDecay(self): """Test m_mul argument. Check values for learning rate at the beginning of the first, second, third and fourth period. """ - with self.test_session(): + with self.cached_session(): step = placeholder(dtypes.int32) lr = 0.1 @@ -121,7 +121,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase): def testCos(self): """Check learning rate values at the beginning, in the middle and at the end of the period.""" - with self.test_session(): + with self.cached_session(): step = placeholder(dtypes.int32) lr = 0.2 t_e = 1000 diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py index df0a186f4f6963d7e874bb4ab74a8db7e10a52ee..d9b0511a98fea909079ea53e4b95c2082f015f39 100644 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py @@ -79,7 +79,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() queue_handle, value = iterator.get_next() enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([[0, 0, 0]], sess.run(value)) value_1, _ = sess.run([value, enqueue_negative]) self.assertAllEqual([[1, 0, 0]], value_1) @@ -101,7 +101,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() queue_handle, value = iterator.get_next() enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual([0], sess.run(value)) value_1, _ = sess.run([value, enqueue_negative]) self.assertEqual([1], value_1) @@ -126,7 +126,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase): enqueue_zeroth = tqd.enqueue_in_queue_dataset([queue_handle[0]], array_ops.expand_dims( value[0], axis=0)) - with self.test_session() as sess: + with self.cached_session() as sess: value_0, _ = sess.run([value, enqueue_negative]) self.assertAllEqual([0, 1], value_0) value_1, _ = sess.run([value, enqueue_zeroth]) @@ -147,7 +147,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase): tqd.enqueue_in_queue_dataset(queue_handle, value + 100 + i) for i in range(1000) ] - with self.test_session() as sess: + with self.cached_session() as sess: value_0, _ = sess.run((value, enqueue_many_more)) self.assertEqual([0], value_0) rest = [] @@ -174,7 +174,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() queue_handle, value = iterator.get_next() enqueue = tqd.enqueue_in_queue_dataset(queue_handle, value + 1) - with self.test_session() as sess: + with self.cached_session() as sess: i = 0 while i < 4: received, _ = sess.run((value, enqueue)) @@ -199,7 +199,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase): batch_size=1, padded_shapes=[2])) iterator = dataset.make_one_shot_iterator() _, value = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesOpError( r"Incompatible input shapes at component 0 between " r"input dataset this dataset: \[3\] vs. \[2\]"): @@ -224,7 +224,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase): np.array( [[1]], dtype=np.int32)) - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesOpError( "mismatched number of tensors. Queue expects 1 tensors but " "tried to insert 2"): @@ -274,7 +274,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase): with ops.control_dependencies([enqueue_rest_op]): calc = array_ops.identity(value_head) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([[0, 0], [2, 2], [4, 4]], sess.run(calc)) self.assertAllEqual([[4, 4], [6, 6]], sess.run(calc)) self.assertAllEqual([[6, 6]], sess.run(calc)) @@ -304,7 +304,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase): iterator = dataset.make_one_shot_iterator() _, (unused_count, padded_value) = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: self.assertAllEqual([[-1, -1, -1, -1], [2, 2, -1, -1], [4, 4, 4, 4]], sess.run(padded_value)) self.assertAllEqual([[6] * 6], sess.run(padded_value)) diff --git a/tensorflow/contrib/training/python/training/training_test.py b/tensorflow/contrib/training/python/training/training_test.py index 94cf7788b2bd3bc3fe87eefd599ce88de03042af..3b524ac8c76ebc566eb3cf3e75448037f45e4b66 100644 --- a/tensorflow/contrib/training/python/training/training_test.py +++ b/tensorflow/contrib/training/python/training/training_test.py @@ -62,7 +62,7 @@ class ClipGradsTest(test.TestCase): clipped_gradients_to_variables = training.clip_gradient_norms( gradients_to_variables, 3.0) - with self.test_session() as session: + with self.cached_session() as session: session.run(variables_lib2.global_variables_initializer()) self.assertAlmostEqual(4.0, gradients_to_variables[0][0].eval()) self.assertAlmostEqual(3.0, clipped_gradients_to_variables[0][0].eval()) @@ -75,7 +75,7 @@ class ClipGradsTest(test.TestCase): clipped_gradients_to_variables = training.clip_gradient_norms_fn(3.0)( gradients_to_variables) - with self.test_session() as session: + with self.cached_session() as session: session.run(variables_lib2.global_variables_initializer()) self.assertAlmostEqual(4.0, gradients_to_variables[0][0].eval()) self.assertAlmostEqual(3.0, clipped_gradients_to_variables[0][0].eval()) @@ -122,7 +122,7 @@ class CreateTrainOpTest(test.TestCase): moving_variance = variables_lib.get_variables_by_name('moving_variance')[ 0] - with self.test_session() as session: + with self.cached_session() as session: # Initialize all variables session.run(variables_lib2.global_variables_initializer()) mean, variance = session.run([moving_mean, moving_variance]) @@ -155,7 +155,7 @@ class CreateTrainOpTest(test.TestCase): moving_variance = variables_lib.get_variables_by_name('moving_variance')[ 0] - with self.test_session() as session: + with self.cached_session() as session: # Initialize all variables session.run(variables_lib2.global_variables_initializer()) mean, variance = session.run([moving_mean, moving_variance]) @@ -186,7 +186,7 @@ class CreateTrainOpTest(test.TestCase): global_step = variables_lib.get_or_create_global_step() - with self.test_session() as session: + with self.cached_session() as session: # Initialize all variables session.run(variables_lib2.global_variables_initializer()) @@ -209,7 +209,7 @@ class CreateTrainOpTest(test.TestCase): global_step = variables_lib.get_or_create_global_step() - with self.test_session() as session: + with self.cached_session() as session: # Initialize all variables session.run(variables_lib2.global_variables_initializer()) @@ -535,7 +535,7 @@ class TrainTest(test.TestCase): train_biases = training.create_train_op( total_loss, optimizer, variables_to_train=[biases]) - with self.test_session() as session: + with self.cached_session() as session: # Initialize the variables. session.run(variables_lib2.global_variables_initializer()) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 60830b7d609648f81d7be88c0f394eb8196f8203..51225f34bcd62dc20fb83caca3347f9ca66ebabf 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -375,6 +375,7 @@ cc_library( ":lib_platform", ":platform_base", "//tensorflow/core/platform/default/build_config:port", + "@com_google_absl//absl/base", "@snappy", ], ) @@ -2267,6 +2268,8 @@ cc_library( srcs = if_android([ "lib/gif/gif_io.cc", "platform/gif.h", + "lib/strings/strcat.h", + "lib/strings/numbers.h", ]), hdrs = [ "lib/bfloat16/bfloat16.h", @@ -2704,12 +2707,13 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/allocator_retry.h", "common_runtime/base_collective_executor.h", "common_runtime/bfc_allocator.h", - "common_runtime/broadcaster.h", + "common_runtime/hierarchical_tree_broadcaster.h", "common_runtime/buf_rendezvous.h", "common_runtime/build_graph_options.h", "common_runtime/collective_executor_mgr.h", "common_runtime/collective_param_resolver_local.h", "common_runtime/collective_rma_local.h", + "common_runtime/collective_util.h", "common_runtime/constant_folding.h", "common_runtime/copy_tensor.h", "common_runtime/costmodel_manager.h", @@ -2740,6 +2744,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/stats_publisher_interface.h", "common_runtime/step_stats_collector.h", "common_runtime/threadpool_device.h", + "common_runtime/tracing_device.h", "common_runtime/visitable_allocator.h", "common_runtime/process_state.h", "common_runtime/pool_allocator.h", @@ -2754,12 +2759,12 @@ tf_cuda_library( "common_runtime/allocator_retry.cc", "common_runtime/base_collective_executor.cc", "common_runtime/bfc_allocator.cc", - "common_runtime/broadcaster.cc", "common_runtime/buf_rendezvous.cc", "common_runtime/build_graph_options.cc", "common_runtime/collective_executor_mgr.cc", "common_runtime/collective_param_resolver_local.cc", "common_runtime/collective_rma_local.cc", + "common_runtime/collective_util.cc", "common_runtime/constant_folding.cc", "common_runtime/copy_tensor.cc", "common_runtime/costmodel_manager.cc", @@ -2774,6 +2779,7 @@ tf_cuda_library( "common_runtime/function.cc", "common_runtime/graph_optimizer.cc", "common_runtime/graph_runner.cc", + "common_runtime/hierarchical_tree_broadcaster.cc", "common_runtime/local_device.cc", "common_runtime/lower_if_op.cc", "common_runtime/memory_types.cc", @@ -3660,10 +3666,10 @@ tf_cc_tests_gpu( ) tf_cc_tests_gpu( - name = "broadcaster_test", + name = "hierarchical_tree_broadcaster_test", size = "small", srcs = [ - "common_runtime/broadcaster_test.cc", + "common_runtime/hierarchical_tree_broadcaster_test.cc", ], linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags(), diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt index b90f5473c89cbe3afe38f0283025e7273817d0e4..6341eeda3266651f17360be692e89c9dd33cd9d9 100644 --- a/tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt @@ -82,7 +82,7 @@ END } summary: "Update \'*var\' according to the Adam algorithm." description: <