diff --git a/RELEASE.md b/RELEASE.md index 1613a37be3379b18281dafa04857687d7899097c..3d497dbaa965d2cf239cab8360109bf5804b6f6e 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -9,6 +9,9 @@ for LSTMs and stacked LSTMs. This bug fix follows recommendations from published literature, but is a behavioral change. State dropout behavior may be customized via the new `dropout_state_filter_visitor` argument. +* Removed `tf.contrib.training.python_input`. The same behavior, in a more + flexible and reproducible package, is available via the new + `tf.contrib.data.Dataset.from_generator` method! # Release 1.3.0 diff --git a/WORKSPACE b/WORKSPACE index 5e9b991fccaa8d11e9233c6c2db08d4798168796..a0fe67bf3189c1156c524aced5210e466e1d8f12 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,11 +2,11 @@ workspace(name = "org_tensorflow") http_archive( name = "io_bazel_rules_closure", - sha256 = "bc41b80486413aaa551860fc37471dbc0666e1dbb5236fb6177cb83b0c105846", - strip_prefix = "rules_closure-dec425a4ff3faf09a56c85d082e4eed05d8ce38f", + sha256 = "25f5399f18d8bf9ce435f85c6bbf671ec4820bc4396b3022cc5dc4bc66303609", + strip_prefix = "rules_closure-0.4.2", urls = [ - "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dec425a4ff3faf09a56c85d082e4eed05d8ce38f.tar.gz", # 2017-06-02 - "https://github.com/bazelbuild/rules_closure/archive/dec425a4ff3faf09a56c85d082e4eed05d8ce38f.tar.gz", + "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/0.4.2.tar.gz", # 2017-08-29 + "https://github.com/bazelbuild/rules_closure/archive/0.4.2.tar.gz", ], ) diff --git a/configure.py b/configure.py index 1a0f71ed94358ee8ea81a13b192d4131d6133f66..ef5051d275e5b1a7c1d06e6cfec3dc00fc1305e2 100644 --- a/configure.py +++ b/configure.py @@ -143,7 +143,7 @@ def run_shell(cmd, allow_non_zero=False): def cygpath(path): """Convert path from posix to windows.""" - return run_shell(['cygpath', '-m', path]) + return os.path.abspath(path).replace('\\', '/') def get_python_path(environ_cp, python_bin_path): @@ -196,7 +196,7 @@ def setup_python(environ_cp, bazel_version): environ_cp['PYTHON_BIN_PATH'] = '' # Convert python path to Windows style before checking lib and version - if is_cygwin(): + if is_windows() or is_cygwin(): python_bin_path = cygpath(python_bin_path) # Get PYTHON_LIB_PATH @@ -219,7 +219,7 @@ def setup_python(environ_cp, bazel_version): python_major_version = get_python_major_version(python_bin_path) # Convert python path to Windows style before writing into bazel.rc - if is_cygwin(): + if is_windows() or is_cygwin(): python_lib_path = cygpath(python_lib_path) # Set-up env variables used by python_configure.bzl @@ -600,7 +600,7 @@ def set_tf_cuda_version(environ_cp): # Find out where the CUDA toolkit is installed default_cuda_path = _DEFAULT_CUDA_PATH - if is_cygwin(): + if is_windows() or is_cygwin(): default_cuda_path = cygpath( environ_cp.get('CUDA_PATH', _DEFAULT_CUDA_PATH_WIN)) elif is_linux(): @@ -660,7 +660,7 @@ def set_tf_cunn_version(environ_cp): # unusable. Going through one more level of expansion to handle that. cudnn_install_path = os.path.realpath( os.path.expanduser(cudnn_install_path)) - if is_cygwin(): + if is_windows() or is_cygwin(): cudnn_install_path = cygpath(cudnn_install_path) if is_windows(): @@ -685,10 +685,13 @@ def set_tf_cunn_version(environ_cp): ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' cudnn_path_from_ldconfig = run_shell([ldconfig_bin, '-p']) cudnn_path_from_ldconfig = re.search('.*libcudnn.so .* => (.*)', - cudnn_path_from_ldconfig).group(1) - if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)): - cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig) - break + cudnn_path_from_ldconfig) + if cudnn_path_from_ldconfig: + cudnn_path_from_ldconfig = cudnn_path_from_ldconfig.group(1) + if os.path.exists('%s.%s' % (cudnn_path_from_ldconfig, + tf_cudnn_version)): + cudnn_install_path = os.path.dirname(cudnn_path_from_ldconfig) + break # Reset and Retry print( diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 4dc84889137bed48a0145cf34757d1970708a5a7..20a6c7910757f9b11fe4de601de97e4d1202e1ee 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -9,6 +9,9 @@ licenses(["notice"]) # Apache 2.0 exports_files([ "LICENSE", "ACKNOWLEDGMENTS", + # The leakr files are used by //third_party/cloud_tpu. + "leakr_badwords.dic", + "leakr_badfiles.dic", ]) # Config setting for determining if we are building for Android. @@ -287,12 +290,14 @@ filegroup( "//tensorflow/contrib/decision_trees/proto:all_files", "//tensorflow/contrib/distributions:all_files", "//tensorflow/contrib/eager/python:all_files", + "//tensorflow/contrib/estimator:all_files", "//tensorflow/contrib/factorization:all_files", "//tensorflow/contrib/factorization/kernels:all_files", "//tensorflow/contrib/ffmpeg:all_files", "//tensorflow/contrib/ffmpeg/default:all_files", "//tensorflow/contrib/framework:all_files", "//tensorflow/contrib/fused_conv:all_files", + "//tensorflow/contrib/gan:all_files", "//tensorflow/contrib/graph_editor:all_files", "//tensorflow/contrib/grid_rnn:all_files", "//tensorflow/contrib/hooks:all_files", @@ -320,6 +325,7 @@ filegroup( "//tensorflow/contrib/nn:all_files", "//tensorflow/contrib/opt:all_files", "//tensorflow/contrib/predictor:all_files", + "//tensorflow/contrib/receptive_field:all_files", "//tensorflow/contrib/reduce_slice_ops:all_files", "//tensorflow/contrib/remote_fused_graph/pylib:all_files", "//tensorflow/contrib/resampler:all_files", @@ -339,6 +345,7 @@ filegroup( "//tensorflow/contrib/staging:all_files", "//tensorflow/contrib/stat_summarizer:all_files", "//tensorflow/contrib/stateless:all_files", + "//tensorflow/contrib/summary:all_files", "//tensorflow/contrib/tensor_forest:all_files", "//tensorflow/contrib/tensor_forest/hybrid:all_files", "//tensorflow/contrib/tensor_forest/kernels/v4:all_files", @@ -401,6 +408,7 @@ filegroup( "//tensorflow/python/eager:all_files", "//tensorflow/python/estimator:all_files", "//tensorflow/python/feature_column:all_files", + "//tensorflow/python/keras:all_files", "//tensorflow/python/kernel_tests:all_files", "//tensorflow/python/kernel_tests/distributions:all_files", "//tensorflow/python/ops/distributions:all_files", @@ -458,6 +466,7 @@ cc_binary( "//tensorflow:darwin": [ "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file "//tensorflow/c:exported_symbols.lds", + "-Wl,-install_name,@rpath/libtensorflow.so", ], "//tensorflow:windows": [], "//tensorflow:windows_msvc": [], diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index df2ae229db531b33c067fc64ad81f8b9f501fdb3..1822e235eba3f9919f2d3e19c628fc7160dd1977 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -45,8 +45,13 @@ tf_cuda_library( tf_cuda_library( name = "c_api", - srcs = ["c_api.cc"], - hdrs = ["c_api.h"], + srcs = [ + "c_api.cc", + "c_api_function.cc", + ], + hdrs = [ + "c_api.h", + ], copts = tf_copts(), visibility = ["//visibility:public"], deps = select({ @@ -61,6 +66,7 @@ tf_cuda_library( "//tensorflow/cc:ops", "//tensorflow/cc:grad_ops", "//tensorflow/cc:scope_internal", + "//tensorflow/cc:while_loop", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -156,6 +162,21 @@ tf_cc_test( ], ) +tf_cc_test( + name = "c_api_function_test", + size = "small", + srcs = ["c_api_function_test.cc"], + deps = [ + ":c_api", + ":c_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "while_loop_test", size = "small", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 1ea70f0598124db5cff8db2b9543d06387548fa5..334f867e47800507760eaa71dce91186f646f72d 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/while_loop.h" #include "tensorflow/cc/saved_model/loader.h" #endif #include "tensorflow/c/c_api_internal.h" @@ -164,22 +165,6 @@ void deallocate_buffer(void* data, size_t len, void* arg) { tensorflow::cpu_allocator()->DeallocateRaw(data); } -Status MessageToBuffer(const tensorflow::protobuf::Message& in, - TF_Buffer* out) { - if (out->data != nullptr) { - return InvalidArgument("Passing non-empty TF_Buffer is invalid."); - } - const auto proto_size = in.ByteSizeLong(); - void* buf = tensorflow::port::Malloc(proto_size); - in.SerializeToArray(buf, proto_size); - out->data = buf; - out->length = proto_size; - out->data_deallocator = [](void* data, size_t length) { - tensorflow::port::Free(data); - }; - return Status::OK(); -} - } // namespace TF_Tensor::~TF_Tensor() { buffer->Unref(); } @@ -389,6 +374,65 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, status->status = Reset(opt->options, container_names); } +// This traverses the specified nodes in topological order to verify there are +// no cycles. Starting with inputless nodes, it visits nodes whose inputs have +// all been visited, and counts the total number of visited nodes. If there is a +// cycle, nodes in the cycle will never be visited, and the visited count will +// be less than the total node count. +Status ValidateNoCycles(const Graph& g) { + // TODO(nolivia): check this on a subset of the graph instead of all of it. + int total_num_nodes = g.num_node_ids(); + // A node is ready when all of its inputs have been visited. + std::vector ready; + std::vector pending_count(total_num_nodes, 0); + + for (int i = 0; i < total_num_nodes; ++i) { + const Node* n = g.FindNodeId(i); + if (n == nullptr) continue; + pending_count[i] = n->in_edges().size(); + if (n->IsMerge()) { + // While-loop cycles are legal cycles so we manually adjust the + // pending_count to make sure that the loop is visited. + for (const Edge* e : n->in_edges()) { + if (!e->IsControlEdge() && e->src()->IsNextIteration()) { + pending_count[i]--; + } + } + } + if (pending_count[i] == 0) { + ready.push_back(n); + } + } + + int processed = 0; + while (!ready.empty()) { + const Node* node = ready.back(); + ready.pop_back(); + ++processed; + + for (const Edge* out : node->out_edges()) { + const int output_id = out->dst()->id(); + pending_count[output_id]--; + if (pending_count[output_id] == 0) { + ready.push_back(out->dst()); + } + } + } + + if (processed < total_num_nodes) { + std::vector nodes_in_cycle; + for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; + ++i) { + if (pending_count[i] != 0) { + nodes_in_cycle.push_back(g.FindNodeId(i)->name()); + } + } + return errors::InvalidArgument( + "Graph is invalid, contains a cycle with ", total_num_nodes - processed, + " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); + } + return Status::OK(); +} } // namespace } // namespace tensorflow @@ -558,6 +602,27 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, dimvec.size(), base, size, DeleteArray, base); } +Status MessageToBuffer(const tensorflow::protobuf::Message& in, + TF_Buffer* out) { + if (out->data != nullptr) { + return InvalidArgument("Passing non-empty TF_Buffer is invalid."); + } + const size_t proto_size = in.ByteSizeLong(); + void* buf = tensorflow::port::Malloc(proto_size); + if (buf == nullptr) { + return tensorflow::errors::ResourceExhausted( + "Failed to allocate memory to serialize message of type '", + in.GetTypeName(), "' and size ", proto_size); + } + in.SerializeToArray(buf, proto_size); + out->data = buf; + out->length = proto_size; + out->data_deallocator = [](void* data, size_t length) { + tensorflow::port::Free(data); + }; + return Status::OK(); +} + // Helpers for loading a TensorFlow plugin (a .so file). Status LoadLibrary(const char* library_filename, void** result, const void** buf, size_t* len); @@ -831,6 +896,30 @@ const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, return attr; } +TensorId ToTensorId(const TF_Output& output) { + return TensorId(output.oper->node.name(), output.index); +} + +#ifndef __ANDROID__ +std::vector OutputsFromTFOutputs(TF_Output* tf_outputs, + int n) { + std::vector outputs(n); + for (int i = 0; i < n; ++i) { + outputs[i] = + tensorflow::Output(&tf_outputs[i].oper->node, tf_outputs[i].index); + } + return outputs; +} + +void TFOutputsFromOutputs(const std::vector& outputs, + TF_Output* tf_outputs) { + for (int i = 0; i < outputs.size(); i++) { + tf_outputs[i].oper = ToOperation(outputs[i].node()); + tf_outputs[i].index = outputs[i].index(); + } +} +#endif // __ANDROID__ + } // namespace // Shape functions ----------------------------------------------------------- @@ -1721,14 +1810,6 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, opts->opts.prefix = prefix; } -namespace { - -TensorId ToTensorId(const TF_Output& output) { - return TensorId(output.oper->node.name(), output.index); -} - -} // namespace - void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst) { @@ -1812,6 +1893,11 @@ void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, // While loop functions ------------------------------------------------------- namespace { + +#ifndef __ANDROID__ + +// Creates a placeholder representing an input to the cond or body graph. +// TODO(skyewm): remove these from final graph bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name, TF_Output* input, TF_Status* status) { TF_OperationDescription* desc = TF_NewOperation(g, "Placeholder", name); @@ -1823,130 +1909,50 @@ bool CreateInput(const TF_Output& parent_input, TF_Graph* g, const char* name, return true; } -bool CreateEnter(TF_Graph* g, const char* node_name, const char* frame_name, - const TF_Output& input, TF_Output* enter, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(g->mu) { - TF_OperationDescription* desc = TF_NewOperationLocked(g, "Enter", node_name); - TF_AddInput(desc, input); - TF_SetAttrString(desc, "frame_name", frame_name, strlen(frame_name)); - TF_Operation* oper = TF_FinishOperationLocked(desc, status); - if (!status->status.ok()) return false; - *enter = {oper, 0}; - return true; -} - -bool CreateMerge(TF_Graph* g, const char* name, const TF_Output& input, - const char* backedge_name, int backedge_index, - TF_Output* merge, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(g->mu) { - TF_OperationDescription* desc = TF_NewOperationLocked(g, "Merge", name); - - // The merge nodes accept the while loop's back edges as an input. Use the - // underlying NodeBuilder API directly to create an input to the - // not-yet-created back edge. - std::vector input_list; - input_list.push_back(NodeBuilder::NodeOut(&input.oper->node, input.index)); - // All merge inputs must have same type - DataType type = input.oper->node.output_type(input.index); - input_list.push_back( - NodeBuilder::NodeOut(backedge_name, backedge_index, type)); - - desc->node_builder.Input(input_list); - - TF_Operation* oper = TF_FinishOperationLocked(desc, status); - if (!status->status.ok()) return false; - *merge = {oper, 0}; - return true; -} - -bool CreateSwitch(TF_Graph* g, const char* name, const TF_Output& input, - const TF_Output& predicate, TF_Output* switch_true, - TF_Output* switch_false, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(g->mu) { - TF_OperationDescription* desc = TF_NewOperationLocked(g, "Switch", name); - TF_AddInput(desc, input); - TF_AddInput(desc, predicate); - TF_Operation* oper = TF_FinishOperationLocked(desc, status); - if (!status->status.ok()) return false; - *switch_false = {oper, 0}; - *switch_true = {oper, 1}; - return true; -} - -bool CreateNext(TF_Graph* g, const char* name, const TF_Output& input, - TF_Output* next, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(g->mu) { - TF_OperationDescription* desc = - TF_NewOperationLocked(g, "NextIteration", name); - TF_AddInput(desc, input); - TF_Operation* oper = TF_FinishOperationLocked(desc, status); - if (!status->status.ok()) return false; - *next = {oper, 0}; - return true; -} - -bool CreateExit(TF_Graph* g, const char* name, const TF_Output& input, - TF_Output* exit, TF_Status* status) - EXCLUSIVE_LOCKS_REQUIRED(g->mu) { - TF_OperationDescription* desc = TF_NewOperationLocked(g, "Exit", name); - TF_AddInput(desc, input); - TF_Operation* oper = TF_FinishOperationLocked(desc, status); - if (!status->status.ok()) return false; - *exit = {oper, 0}; - return true; -} - -class ScopedImportGraphDefOptions { - public: - ScopedImportGraphDefOptions() { opts_ = TF_NewImportGraphDefOptions(); } - ~ScopedImportGraphDefOptions() { TF_DeleteImportGraphDefOptions(opts_); } - - TF_ImportGraphDefOptions* get() const { return opts_; } - - private: - TF_ImportGraphDefOptions* opts_; - - TF_DISALLOW_COPY_AND_ASSIGN(ScopedImportGraphDefOptions); -}; - // Copies `src_graph` into `dst_graph`. Any node in `src_graph` with input -// `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. -// `prefix` will be prepended to copied node names. `return_nodes` are nodes -// in `src_graph`, and the new corresponding nodes in `dst_graph` will be -// returned. `return_nodes` should be preallocated to size `nreturn_nodes`. -bool CopyGraph(TF_Graph* src_graph, TF_Graph* dst_graph, - const TF_Output* src_inputs, - const std::vector& dst_inputs, const char* prefix, - const TF_Output* nodes_to_return, int nreturn_nodes, - TF_Output* return_nodes, TF_Status* s) - EXCLUSIVE_LOCKS_REQUIRED(dst_graph->mu) { +// `src_inputs[i]` will have that input replaced with `dst_inputs[i]`. `prefix` +// will be prepended to copied node names. `control_deps` are nodes in +// `dst_graph` that the copied `src_graph` nodes will have control dependencies +// on. `return_nodes` are nodes in `src_graph`, and the new corresponding nodes +// in `dst_graph` will be returned. `return_nodes` must be non-null. +Status CopyGraph(Graph* src_graph, Graph* dst_graph, + tensorflow::ShapeRefiner* dst_refiner, + const TF_Output* src_inputs, + const std::vector& dst_inputs, + const tensorflow::string& prefix, + const std::vector& control_deps, + const TF_Output* nodes_to_return, int nreturn_nodes, + std::vector* return_nodes) { + DCHECK(return_nodes != nullptr); GraphDef gdef; - src_graph->graph.ToGraphDef(&gdef); + src_graph->ToGraphDef(&gdef); - ScopedImportGraphDefOptions opts; - TF_ImportGraphDefOptionsSetPrefix(opts.get(), prefix); + tensorflow::ImportGraphDefOptions opts; + opts.prefix = prefix; for (int i = 0; i < dst_inputs.size(); ++i) { - TensorId src = ToTensorId(src_inputs[i]); - TF_ImportGraphDefOptionsAddInputMapping(opts.get(), src.first.data(), - src.second, dst_inputs[i]); + opts.input_map[ToTensorId(src_inputs[i])] = + TensorId(dst_inputs[i].node()->name(), dst_inputs[i].index()); } - opts.get()->opts.skip_mapped_nodes = true; + opts.skip_mapped_nodes = true; - // We use the pivot node to control constants in `src_graph` - TF_Operation* pivot = dst_inputs[0].oper; - TF_ImportGraphDefOptionsAddControlDependency(opts.get(), pivot); + for (const tensorflow::Operation& op : control_deps) { + opts.control_dependencies.push_back(op.node()->name()); + } for (int i = 0; i < nreturn_nodes; ++i) { - TF_ImportGraphDefOptionsAddReturnOutput( - opts.get(), nodes_to_return[i].oper->node.name().c_str(), - nodes_to_return[i].index); + opts.return_tensors.push_back(ToTensorId(nodes_to_return[i])); } - GraphImportGraphDefLocked(dst_graph, gdef, opts.get(), return_nodes, - nreturn_nodes, s); - if (TF_GetCode(s) != TF_OK) return false; - return true; + // TOOD(skyewm): change to OutputTensor + std::vector> return_tensors; + TF_RETURN_IF_ERROR( + ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &return_tensors)); + + for (const auto& pair : return_tensors) { + return_nodes->emplace_back(pair.first, pair.second); + } + return Status::OK(); } bool ValidateConstWhileParams(const TF_WhileParams& params, TF_Status* s) { @@ -1982,6 +1988,8 @@ bool ValidateInputWhileParams(const TF_WhileParams& params, TF_Status* s) { return true; } +#endif // __ANDROID__ + void FreeWhileResources(const TF_WhileParams* params) { TF_DeleteGraph(params->cond_graph); TF_DeleteGraph(params->body_graph); @@ -1999,6 +2007,13 @@ TF_WhileParams EmptyWhileParams() { TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Creating while loops is not supported in Android. File a bug at " + "https://github.com/tensorflow/tensorflow/issues if this feature is " + "important to you"); + return EmptyWhileParams(); +#else if (ninputs == 0) { status->status = InvalidArgument("TF_NewWhile() must be passed at least one input"); @@ -2039,8 +2054,10 @@ TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, int ninputs, return EmptyWhileParams(); } return params; +#endif // __ANDROID__ } +#ifndef __ANDROID__ namespace { // TODO(skyewm): make nodes in while loop unfetchable like in Python version @@ -2050,113 +2067,90 @@ void TF_FinishWhileHelper(const TF_WhileParams* params, TF_Status* status, TF_Graph* parent = params->cond_graph->parent; TF_Output* parent_inputs = params->cond_graph->parent_inputs; - int n = params->ninputs; + int num_loop_vars = params->ninputs; mutex_lock l(parent->mu); - // Create Enter nodes - std::vector enter_nodes(n); - for (int i = 0; i < n; ++i) { - if (!CreateEnter(parent, StrCat(params->name, "/enter", i).c_str(), - params->name, parent_inputs[i], &enter_nodes[i], status)) { - return; - } - } - - // Create Merge nodes - std::vector merge_nodes(n); - for (int i = 0; i < n; ++i) { - if (!CreateMerge(parent, StrCat(params->name, "/merge", i).c_str(), - enter_nodes[i], StrCat(params->name, "/next", i).c_str(), - 0, &merge_nodes[i], status)) { - return; - } - } - - // Copy cond_graph to parent and replace input placeholders with merge node - // outputs, and get handle to new cond output - tensorflow::string cond_prefix = StrCat(params->name, "/cond"); - TF_Output cond_output; - if (!CopyGraph(params->cond_graph, parent, params->cond_inputs, merge_nodes, - cond_prefix.c_str(), ¶ms->cond_output, 1, &cond_output, - status)) { - return; - } - - // Create Switch nodes - std::vector switch_trues(n); - std::vector switch_falses(n); - for (int i = 0; i < n; ++i) { - if (!CreateSwitch(parent, StrCat(params->name, "/switch", i).c_str(), - merge_nodes[i], cond_output, &switch_trues[i], - &switch_falses[i], status)) { - return; - } - } - - // Copy body_graph to parent, replace input placeholders with switch node - // true outputs, and get handles to new body outputs - tensorflow::string body_prefix = StrCat(params->name, "/body"); - std::vector body_outputs(n); - if (!CopyGraph(params->body_graph, parent, params->body_inputs, switch_trues, - body_prefix.c_str(), params->body_outputs, n, - body_outputs.data(), status)) { - return; - } - - // Create Next nodes - std::vector next_nodes(n); - for (int i = 0; i < n; ++i) { - if (!CreateNext(parent, StrCat(params->name, "/next", i).c_str(), - body_outputs[i], &next_nodes[i], status)) { - return; - } - } - - // Create Exit nodes (which are the outputs of the while loop) - for (int i = 0; i < n; ++i) { - if (!CreateExit(parent, StrCat(params->name, "/exit", i).c_str(), - switch_falses[i], &outputs[i], status)) { - return; - } + // 'cond_fn' copies the cond graph into the parent graph. + tensorflow::ops::CondGraphBuilderFn cond_fn = + [params, parent](const tensorflow::Scope& scope, + const std::vector& inputs, + tensorflow::Output* output) { + DCHECK_EQ(scope.graph(), &parent->graph); + std::vector cond_output; + TF_RETURN_IF_ERROR(CopyGraph( + ¶ms->cond_graph->graph, &parent->graph, &parent->refiner, + params->cond_inputs, inputs, scope.impl()->name(), + scope.impl()->control_deps(), ¶ms->cond_output, + /* nreturn_nodes */ 1, &cond_output)); + *output = cond_output[0]; + return Status::OK(); + }; + + // 'body_fn' copies the body graph into the parent graph. + tensorflow::ops::BodyGraphBuilderFn body_fn = + [params, parent, num_loop_vars]( + const tensorflow::Scope& scope, + const std::vector& inputs, + std::vector* outputs) { + DCHECK_EQ(scope.graph(), &parent->graph); + TF_RETURN_IF_ERROR( + CopyGraph(¶ms->body_graph->graph, &parent->graph, + &parent->refiner, params->body_inputs, inputs, + scope.impl()->name(), scope.impl()->control_deps(), + params->body_outputs, num_loop_vars, outputs)); + return Status::OK(); + }; + + // Create the while loop using an internal scope. + tensorflow::Scope scope = + NewInternalScope(&parent->graph, &status->status, &parent->refiner) + .NewSubScope(params->name); + + const int first_new_node_id = parent->graph.num_node_ids(); + + tensorflow::OutputList loop_outputs; + status->status = tensorflow::ops::BuildWhileLoop( + scope, OutputsFromTFOutputs(parent_inputs, num_loop_vars), cond_fn, + body_fn, params->name, &loop_outputs); + + // Update name_map with newly-created ops. + // TODO(skyewm): right now BuildWhileLoop() may alter the graph if it returns + // a bad status. Once we fix this, we may want to return early instead of + // executing the following code. + for (int i = first_new_node_id; i < parent->graph.num_node_ids(); ++i) { + Node* new_node = parent->graph.FindNodeId(i); + if (new_node == nullptr) continue; + parent->name_map[new_node->name()] = new_node; + } + + // Populate 'outputs'. + DCHECK_LE(loop_outputs.size(), num_loop_vars); + for (int i = 0; i < loop_outputs.size(); ++i) { + outputs[i] = {ToOperation(loop_outputs[i].node()), loop_outputs[i].index()}; } } } // namespace +#endif // __ANDROID__ void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status, TF_Output* outputs) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "Creating while loops is not supported in Android. File a bug at " + "https://github.com/tensorflow/tensorflow/issues if this feature is " + "important to you"); +#else // If it appears the caller created or modified `params`, don't free resources if (!ValidateConstWhileParams(*params, status)) return; TF_FinishWhileHelper(params, status, outputs); FreeWhileResources(params); +#endif // __ANDROID__ } void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); } -#ifndef __ANDROID__ -namespace { - -void OutputsFromTFOutputs(TF_Output* tf_outputs, int n, TF_Status* status, - std::vector* outputs) { - outputs->resize(n); - for (int i = 0; i < n; i++) { - const TF_Output& tf_output = tf_outputs[i]; - (*outputs)[i] = tensorflow::Output(&tf_output.oper->node, tf_output.index); - } -} - -void TFOutputsFromOutputs(const std::vector& outputs, - TF_Output* tf_outputs) { - for (int i = 0; i < outputs.size(); i++) { - tf_outputs[i].oper = ToOperation(outputs[i].node()); - tf_outputs[i].index = outputs[i].index(); - } -} - -} // namespace -#endif // __ANDROID__ - void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy) { #ifdef __ANDROID__ @@ -2165,25 +2159,22 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, "https://github.com/tensorflow/tensorflow/issues if this feature is " "important to you"); #else - std::vector y_arg; - std::vector x_arg; + std::vector y_arg = OutputsFromTFOutputs(y, ny); + std::vector x_arg = OutputsFromTFOutputs(x, nx); std::vector dy_arg; - OutputsFromTFOutputs(y, ny, status, &y_arg); - OutputsFromTFOutputs(x, nx, status, &x_arg); { // We need to hold on to the lock while we have a scope that uses TF_Graph. mutex_lock graph_lock(g->mu); - const int max_node_id_before = g->graph.num_node_ids(); + const int first_new_node_id = g->graph.num_node_ids(); tensorflow::Scope scope = NewInternalScope(&g->graph, &status->status, &g->refiner) .NewSubScope("gradients"); if (dx != nullptr) { - std::vector dx_arg; - OutputsFromTFOutputs(dx, ny, status, &dx_arg); + std::vector dx_arg = OutputsFromTFOutputs(dx, ny); status->status = AddSymbolicGradients(scope, y_arg, x_arg, dx_arg, &dy_arg); } else { @@ -2192,7 +2183,7 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, // Update g->name_map with the name_map from the scope, which will contain // the new gradient ops. - for (int i = max_node_id_before; i < g->graph.num_node_ids(); ++i) { + for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) { Node* n = g->graph.FindNodeId(i); if (n == nullptr) continue; g->name_map[n->name()] = n; @@ -2319,6 +2310,12 @@ static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { const Graph& graph = session->graph->graph; const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { + status->status = tensorflow::ValidateNoCycles(session->graph->graph); + if (!status->status.ok()) { + session->graph->mu.unlock(); + return false; + } + GraphDef graph_def; *graph_def.mutable_versions() = graph.versions(); // Fill graph_def with nodes with ids in the range diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 43b5078013731ae8099c2bcc9f1ee41ccd99f035..ee110d88cea50614515b3b3f42af1db1aaee9012 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -357,6 +357,14 @@ typedef struct TF_Output { int index; // The index of the output within oper. } TF_Output; +// TF_Function is a grouping of operations with defined inputs and outputs. +// Once created and added to graphs, functions can be invoked by creating an +// operation whose operation type matches the function name. +typedef struct TF_Function TF_Function; + +// Function definition options. TODO(iga): Define and implement +typedef struct TF_FunctionOptions TF_FunctionOptions; + // Sets the shape of the Tensor referenced by `output` in `graph` to // the shape described by `dims` and `num_dims`. // @@ -914,6 +922,15 @@ TF_CAPI_EXPORT extern void TF_GraphImportGraphDef( TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status); +// Add `function` to graph `g`. Once `function` is added to `g`, +// it can be called by creating an operation using the function's name. +// +// If successful, status is set to OK and function is added to g +// Otherwise, status is set to the encountered error and g is unmodified +TF_CAPI_EXPORT extern void TF_GraphAddFunction(TF_Graph* g, + const TF_Function* function, + TF_Status* status); + // Note: The following function may fail on very large protos in the future. TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper, @@ -1001,6 +1018,105 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy); +// Create a TF_Function from a TF_Graph +// +// Params: +// fn_body - the graph whose operations (or subset of whose operations) will be +// converted to TF_Function. +// fn_name - the name of the new TF_Function. Should match the operation +// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]* and be distinct +// from other operation names (at least those registered in graphs +// where this function will be used). +// TODO(iga): Allow null in here and have C API come up with +// a unique name with high probability (similarly to +// _create_hash_str in function.py) +// num_opers - `num_opers` contains the number of elements in the `opers` array +// or a special value of -1 meaning that no array is given. +// The distinction between an empty array of operations and no +// array of operations is necessary to distinguish the case of +// creating a function with no body (e.g. identity or permutation) +// and the case of creating a function whose body contains all +// the nodes in the graph (except for the automatic skipping, see +// below). +// opers - Array of operations to become the body of the function or null. +// - If no array is given (`num_opers` = -1), all the +// operations in `fn_body` will become part of the function +// except operations referenced in `inputs`. These operations +// must have a single output (these operations are typically +// placeholders created for the sole purpose of representing +// an input. We can relax this constraint if there are +// compelling use cases). +// - If an array is given (`num_opers` >= 0), all operations +// in it will become part of the function. In particular, no +// automatic skipping of dummy input operations is performed. +// ninputs - number of elements in `inputs` array +// inputs - array of TF_Outputs that specify the inputs to the function. +// If `ninputs` is zero (the function takes no inputs), `inputs` +// can be null. The names used for function inputs are normalized +// names of the operations (usually placeholders) pointed to by +// `inputs`. These operation names should start with a letter. +// Normalization will convert all letters to lowercase and +// non-alphanumeric characters to '_' to make resulting names match +// the "[a-z][a-z0-9_]*" pattern for operation argument names. +// `inputs` cannot contain the same tensor twice. +// noutputs - number of elements in `outputs` array +// outputs - array of TF_Outputs that specify the outputs of the function. +// If `noutputs` is zero (the function returns no outputs), `outputs` +// can be null. `outputs` can contain the same tensor more than once. +// output_names - The names of the function's outputs. `output_names` array +// must either have the same length as `outputs` +// (i.e. `noutputs`) or be null. In the former case, +// the names should match the regular expression for ArgDef +// names - "[a-z][a-z0-9_]*". In the latter case, +// names for outputs will be generated automatically. +// opts - various options for the function, e.g. XLA's inlining control. +// status - Set to OK on success and an appropriate error on failure. +// +// Note that when the same TF_Output is listed as both an input and an output, +// the corresponding function's output will equal to this input, +// instead of the original node's output. +// +// Callers must also satisfy the following constraints: +// - `inputs` cannot refer to TF_Outputs within a control flow context. For +// example, one cannot use the output of "switch" node as input. +// - No TF_Output of a function (inside any of `inputs`, `outputs`, `fn_body`) +// is allowed to have a reference type. Reference types are not exposed +// through C API and are being deprecated. +// - Every node in the function's body must have all of its inputs (including +// control inputs). In other words, for every node in the body, each input +// must be either listed in `inputs` or must come from another node in +// the body. In particular, it is an error to have a control edge going from +// a node outside of the body into a node in the body. This applies to control +// edges going from nodes referenced in `inputs` to nodes in the body when +// the former nodes are not in the body (automatically skipped or not +// included in explicitly specified body). +// +// Returns: +// On successful, a newly created TF_Function instance. It must be deleted by +// calling TF_DeleteFunction. +// +// On failure, null. +// +// TODO(iga): Add input_names argument and get output_names working (they are +// currently ignored) +TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( + const TF_Graph* fn_body, const char* fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + const TF_FunctionOptions* opts, TF_Status* status); + +// Write out a serialized representation of `func` (as a FunctionDef protocol +// message) to `output_func_def` (allocated by TF_NewBuffer()). +// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer() +// is called. +// +// May fail on very large graphs in the future. +TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func, + TF_Buffer* output_func_def, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function*); + // TODO(josh11b): Register OpDef, available to all operations added // to this graph. diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4c6397d0b4d34b4745f0f5115426b166354f570 --- /dev/null +++ b/tensorflow/c/c_api_function.cc @@ -0,0 +1,496 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api_internal.h" + +#include +#include +#include + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace { + +// Class that maintains a one-to-one original node name -> new node name +// mapping. We normalize the names used as input and output arguments to match +// regexp "[a-z][a-z0-9_]*" specified in definition of ArgDef.name. +// Once we rename them, we risk creating a name collision with the other +// node names, so if necessary we add a suffix to make +// names unique. If we have an input named "A" and a node in the function +// body named "a", they will be renamed to "a" and "a_0". +class NodeNameMapping { + public: + NodeNameMapping() = default; + + // Normalize the input/output name and make it unique. + string GetIOName(const string& name); + + // Make the node name unique. + string Uniquify(const string& name); + + // Look up how a node name was previously normalized/uniquified. + // Returns empty if name was never seen. + string Lookup(const string& name) const; + + private: + string UniquifyHelper(const string& name) const; + static string Normalize(string name); + + // The normalized/uniquified names already used as + // input names (in signature), output names (in signature), and node names + // (in node_def). + // This is a superset of values in name_mapping_. + std::unordered_set used_names_; + // Mapping from original node name from the graph to the normalized + // and uniqified version of it. + std::unordered_map name_mapping_; +}; + +string NodeNameMapping::Normalize(string name) { + // Convert letters to lowercase and non-alphanumeric characters to '_'. + if (name.empty()) return "unknown"; + const int n = name.size(); + for (int i = 0; i < n; ++i) { + char c = name[i]; + if (isalnum(c)) { + if (isupper(c)) { + name[i] = tolower(c); + } + } else { + name[i] = '_'; + } + } + + // Find the first letter and start with it. + int i = 0; + for (; i < n; ++i) { + if (isalpha(name[i])) break; + } + + // Return "unknown" if none of the name's chars were letters. + return i == n ? "unknown" : name.substr(i); +} + +string NodeNameMapping::UniquifyHelper(const string& name) const { + // If the name hasn't been used yet, use it as-is. + if (used_names_.find(name) == used_names_.end()) return name; + // Add a suffix to name to make it unique. + for (int i = 0;; ++i) { + const string candidate = strings::StrCat(name, "_", i); + if (used_names_.find(candidate) == used_names_.end()) return candidate; + } +} + +string NodeNameMapping::GetIOName(const string& name) { + const string& input_name = UniquifyHelper(Normalize(name)); + // Record that we used this name, but don't add it to name_mapping_ + // since this name is not for a node. + used_names_.insert(input_name); + return input_name; +} + +string NodeNameMapping::Uniquify(const string& name) { + const string uniqued = UniquifyHelper(name); + name_mapping_[name] = uniqued; + used_names_.insert(uniqued); + return uniqued; +} + +string NodeNameMapping::Lookup(const string& name) const { + const auto iter = name_mapping_.find(name); + if (iter == name_mapping_.end()) return string(); + return iter->second; +} + +Status ValidateNoRefOutputs(const Node* node) { + for (int i = 0; i < node->num_outputs(); ++i) { + const DataType& dt = node->output_type(i); + if (IsRefType(dt)) { + return errors::InvalidArgument("Output ", i, " of node '", node->name(), + "' has a reference " + "type ", + DataTypeString(dt)); + } + } + return Status::OK(); +} + +Status FillFunctionBody( + const string& fn_name, const NodeNameMapping& node_names, + const std::vector& body_nodes, + const std::unordered_map& tensor_renaming, + FunctionDef* fdef) { + std::vector in_edges; + std::vector control_edges; + for (const Node* node : body_nodes) { + NodeDef* node_def = fdef->add_node_def(); + // First, copy the node_def as is. We will patch it next. + *node_def = node->def(); + if (!node->assigned_device_name().empty()) { + node_def->set_device(node->assigned_device_name()); + } + node_def->set_name(node_names.Lookup(node->name())); + + // Input names must be set based on nested names in tensor_renaming. + // Clear the flat input names we got from the original node_def + // from the graph. + node_def->clear_input(); + + // Collect regular and control inputs. Regular inputs are indexed + // by the index at which they come into the `node`. Control inputs + // don't follow any order. + in_edges.clear(); + in_edges.resize(node->num_inputs(), nullptr); + control_edges.clear(); + for (const Edge* edge : node->in_edges()) { + if (edge->src()->IsSource()) continue; + if (edge->IsControlEdge()) { + control_edges.push_back(edge); + } else { + in_edges[edge->dst_input()] = edge; + } + } + + // Add regular inputs. + for (size_t i = 0; i < in_edges.size(); ++i) { + const Edge* edge = in_edges[i]; + string original_input_name; + if (edge == nullptr) { + // A backedge might not appear as a regular Edge, but be only present + // in the node_def. Such edges are referred to as requested_inputs(). + if (i >= node->requested_inputs().size()) { + return errors::InvalidArgument( + "Graph to be converted to function appears to be malformed. ", + "Node ", node->name(), " is missing input edge ", i); + } + original_input_name = + ParseTensorName(node->requested_inputs()[i]).ToString(); + } else { + original_input_name = + strings::StrCat(edge->src()->name(), ":", edge->src_output()); + } + + const auto iter = tensor_renaming.find(original_input_name); + if (iter == tensor_renaming.end()) { + return errors::InvalidArgument( + "Input ", i, ", '", original_input_name, "', of node '", + node->name(), "' in function '", fn_name, + "' is not available. You might need to include it in inputs " + "or include its source node in the body"); + } + node_def->add_input(iter->second); + } + + // Add control inputs. + for (const Edge* edge : control_edges) { + // Add this control input only if the src node is in the body. + const string normalized = node_names.Lookup(edge->src()->name()); + // If we did not find a name for the source of control edge, this + // source must be outside of the body. Raise an error. + if (normalized.empty()) { + return errors::InvalidArgument( + "The source of control edge ", edge->DebugString(), + " is not in the body. Encountered while creating function '", + fn_name, "'"); + } + node_def->add_input(strings::StrCat("^", normalized)); + } + } + return Status::OK(); +} + +// Graph to FunctionDef conversion. This code is closely modeled on the Python +// code in third_party/tensorflow/python/framework/function.py. +Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, + const std::vector& body_nodes, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& output_names, + FunctionDef* fdef) { + fdef->mutable_signature()->set_name(fn_name); + + // Keep track of names we used and how we normalized them. + NodeNameMapping node_names; + + // Mapping from original names of tensors (i.e. ":") to the + // name we used in the function: + // - For input tensors: + // {flat_tensor_name -> normalized_name_of_src_node} + // e.g. {In:3 -> in} + // - For tensors produced by nodes in function's body: + // {flat_tensor_name -> nested_tensor_name} + // e.g. {Add:3 -> add_0:z:1} + std::unordered_map tensor_renaming; + + // Fill inputs in function's signature. + for (size_t i = 0; i < inputs.size(); ++i) { + const Node* node = inputs[i].node; + int idx = inputs[i].index; + OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg(); + argdef->set_type(node->output_type(idx)); + const string& input_name = node_names.GetIOName(node->name()); + argdef->set_name(input_name); + tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name; + } + + // Fill outputs in function's signature. + for (size_t i = 0; i < outputs.size(); ++i) { + const Node* node = outputs[i].node; + int idx = outputs[i].index; + OpDef::ArgDef* argdef = fdef->mutable_signature()->add_output_arg(); + argdef->set_type(node->output_type(idx)); + argdef->set_name(node_names.GetIOName(node->name())); + } + + // Populate tensor_renaming and node_names. + // Generate the new output names for every node in the function. + // The NodeDefs in FunctionDefs use a different naming scheme for + // their inputs than the NodeDefs in a graph (see the comment for + // FunctionDef.node_def in function.proto). We do the + // graph tensor name -> function tensor name conversion for every + // possible input (i.e. every node's outputs) and store the result + // in tensor_renaming. + for (const Node* node : body_nodes) { + // Make sure node_name does not collide with an input or output name. + const string& node_name = node_names.Uniquify(node->name()); + // For each output_arg in the op_def, the output_ranges + // map will have [start, end] range of indices that this arg produces + // among all the output tensors of this op. + NameRangeMap output_ranges; + TF_RETURN_IF_ERROR( + NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges)); + for (const auto& output : output_ranges) { + const string& output_name = output.first; + int index_start = output.second.first; + int index_end = output.second.second; + for (int i = index_start; i < index_end; ++i) { + const string& original_name = strings::StrCat(node->name(), ":", i); + const string& new_name = + strings::StrCat(node_name, ":", output_name, ":", i - index_start); + // Record the mapping if this tensor is not already mapped. + // Tensor can be already mapped if it is used as an input. + if (tensor_renaming.find(original_name) == tensor_renaming.end()) { + tensor_renaming[original_name] = new_name; + } + } + } + } + + TF_RETURN_IF_ERROR( + FillFunctionBody(fn_name, node_names, body_nodes, tensor_renaming, fdef)); + + // Remap return values. + for (int r = 0; r < fdef->signature().output_arg_size(); ++r) { + const string& ret_name = fdef->signature().output_arg(r).name(); + + // We convert this flat tensor name to the nested value + // (e.g. `add:z:1`) that we stored in tensor_renaming. + const string& return_value = + strings::StrCat(outputs[r].node->name(), ":", outputs[r].index); + const auto iter = tensor_renaming.find(return_value); + if (iter == tensor_renaming.end()) { + return errors::InvalidArgument( + "TF_Output ", return_value, " is neither in the function body ", + "nor among function inputs. Encountered while creating function '", + fn_name, "'"); + } + (*fdef->mutable_ret())[ret_name] = iter->second; + } + + return Status::OK(); +} + +// Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and +// does various checks while doing so. `input_nodes` will contain the same +// information as input_tensors just in a different structure to make +// following processing easier. TODO(iga): Simplify this nested structure. +Status ProcessInputs( + const TF_Graph* fn_body, const char* fn_name, int ninputs, + const TF_Output* inputs, std::vector* input_tensors, + std::unordered_map>* input_nodes) + EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { + input_tensors->reserve(ninputs); + for (int i = 0; i < ninputs; ++i) { + const Node& node = inputs[i].oper->node; + int idx = inputs[i].index; + + TF_RETURN_WITH_CONTEXT_IF_ERROR( + fn_body->graph.IsValidOutputTensor(&node, idx), + "Encountered while processing input ", i, " into function '", fn_name, + "'"); + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(&node), + "Encountered while processing input ", i, + " into function '", fn_name, "'"); + + input_tensors->emplace_back(&node, idx); + + const auto& iter = input_nodes->find(&node); + if (iter == input_nodes->end()) { + input_nodes->insert({&node, {idx}}); + } else { + auto& indices = iter->second; + if (std::find(indices.begin(), indices.end(), idx) != indices.end()) { + return errors::InvalidArgument( + "TF_Output ", node.name(), ":", idx, + " appears more than once in the input list"); + } + indices.push_back(idx); + } + } + return Status::OK(); +} + +// Converts `noutputs` and `outputs` into `outputs_tensors` and does various +// checks while doing so. +Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name, + int noutputs, const TF_Output* outputs, + std::vector* output_tensors) + EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { + output_tensors->reserve(noutputs); + for (int i = 0; i < noutputs; ++i) { + const Node& node = outputs[i].oper->node; + int idx = outputs[i].index; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + fn_body->graph.IsValidOutputTensor(&node, idx), + "Encountered while processing output ", i, " from function '", fn_name, + "'"); + output_tensors->emplace_back(&node, idx); + } + return Status::OK(); +} + +// Populates `body_nodes` with the nodes that will become function's body. +// Performs various checks. +Status ComputeBodyNodes( + const TF_Graph* fn_body, const char* fn_name, int num_opers, + const TF_Operation* const* opers, + const std::unordered_map>& input_nodes, + std::vector* body_nodes) + EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { + if (num_opers == -1) { + for (const Node* node : fn_body->graph.op_nodes()) { + const auto& iter = input_nodes.find(node); + if (iter == input_nodes.end()) { + // This node is not referenced in inputs. Add it to the body. + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node), + "Encountered while creating function '", + fn_name, "'"); + body_nodes->push_back(node); + } else { + // This node is referenced in inputs. Currently, we place an + // artificial restriction and require that when num_opers=-1, such + // nodes must have a single output. + if (node->num_outputs() != 1) { + return errors::InvalidArgument( + "When `num_opers` is set to -1, nodes referenced in `inputs` " + "must have a single output. Node ", + node->name(), " has ", node->num_outputs(), + " outputs. Encountered while creating function '", fn_name, "'"); + } + } + } + } else { + body_nodes->reserve(num_opers); + for (int i = 0; i < num_opers; ++i) { + const Node* node = &opers[i]->node; + TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNoRefOutputs(node), + "Encountered while creating function '", + fn_name, "'"); + body_nodes->push_back(node); + } + } + return Status::OK(); +} + +} // anonymous namespace +} // namespace tensorflow + +using tensorflow::Node; +using tensorflow::string; + +TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, + int num_opers, const TF_Operation* const* opers, + int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, + const char* const* output_names, + const TF_FunctionOptions* opts, + TF_Status* status) { + tensorflow::mutex_lock l(*const_cast(&fn_body->mu)); + + // Process inputs. + std::vector input_tensors; + std::unordered_map> input_nodes; + status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs, + &input_tensors, &input_nodes); + if (!status->status.ok()) return nullptr; + + // Process outputs. + std::vector output_tensors; + status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs, + outputs, &output_tensors); + if (!status->status.ok()) return nullptr; + + // Process output names. + std::vector output_names_vec; + if (output_names) { + output_names_vec.reserve(noutputs); + for (int i = 0; i < noutputs; ++i) { + output_names_vec.push_back(string(output_names[i])); + } + } + + // Compute body nodes. + std::vector body_nodes; + status->status = tensorflow::ComputeBodyNodes( + fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes); + if (!status->status.ok()) return nullptr; + + // Do the actual function creation. + TF_Function* tf_function = new TF_Function(); + status->status = tensorflow::GraphToFunctionDef( + fn_body->graph, fn_name, body_nodes, input_tensors, output_tensors, + output_names_vec, tf_function->fdef_lib.add_function()); + if (!status->status.ok()) { + TF_DeleteFunction(tf_function); + return nullptr; + } + return tf_function; +} + +void TF_GraphAddFunction(TF_Graph* g, const TF_Function* function, + TF_Status* status) { + tensorflow::mutex_lock l(g->mu); + + // At the moment, we have only one function and no gradients in fdef_lib. + // This makes the following operation atomic. + // TODO(iga): Add an atomic version of AddFunctionLibrary when we support + // gradients + status->status = g->graph.AddFunctionLibrary(function->fdef_lib); +} + +void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def, + TF_Status* status) { + DCHECK_EQ(1, func->fdef_lib.function_size()); + status->status = MessageToBuffer(func->fdef_lib.function(0), output_func_def); +} + +void TF_DeleteFunction(TF_Function* function) { delete function; } diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c9dd38ea15fa49f1fec5f86f9dd2353b1b8398ba --- /dev/null +++ b/tensorflow/c/c_api_function_test.cc @@ -0,0 +1,1039 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api.h" + +#include "tensorflow/c/c_test_util.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +// Specification for expected input/output and its type. +// DataType value of DT_INVALID signifies that we don't want to +// check the data type. +typedef std::pair IOSpec; + +std::vector M(const std::initializer_list& names) { + std::vector v; + for (const string& name : names) { + v.push_back(IOSpec(name, DT_INVALID)); + } + return v; +} + +// Specification for an expected edge. +// src is either: +// - input name (as it appears in FunctionDef) +// - name of output tensor (in nested "add:z:0" format) +// dst is either: +// - output name (as it appears in FunctionDef) +// - : (this looks the same as +// output tensor naming, but it the index is actually an input index) +struct EdgeSpec : public std::pair { + typedef std::pair Base; + + // Inherit the set of constructors + using Base::pair; + + string ToString() const { return strings::StrCat(first, "->", second); } +}; + +class CApiFunctionTest : public ::testing::Test { + protected: + CApiFunctionTest() + : s_(TF_NewStatus()), + func_graph_(TF_NewGraph()), + host_graph_(TF_NewGraph()), + func_(nullptr) {} + + void SetUp() override {} + + ~CApiFunctionTest() override { + TF_DeleteFunction(func_); + TF_DeleteGraph(host_graph_); + TF_DeleteGraph(func_graph_); + TF_DeleteStatus(s_); + } + + void Run(const std::vector>& inputs, + TF_Operation* output, int32_t expected_result) { + Run(inputs, {{output, 0}}, {expected_result}); + } + + // Run the host graph, which now contains a function and check that + // outputs are as expected. + // 'T' stands for 'tensor' since the outputs are tensors, not scalars. + void RunT(const std::vector>& inputs, + std::initializer_list outputs, + const std::vector>& expected_results) { + // Create a session for this graph + CSession csession(host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Run + csession.SetInputs(inputs); + csession.SetOutputs(outputs); + csession.Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Check results + for (int i = 0; i < expected_results.size(); ++i) { + TF_Tensor* out = csession.output_tensor(i); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(1, TF_NumDims(out)); + CompareInt32Tensor(expected_results[i], out); + } + } + + // Run the host graph, which now contains a function and check that + // outputs are as expected. + void Run(const std::vector>& inputs, + std::initializer_list outputs, + const std::vector& expected_results) { + // Create a session for this graph. + CSession csession(host_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + csession.SetInputs(inputs); + csession.SetOutputs(outputs); + csession.Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + for (int i = 0; i < expected_results.size(); ++i) { + TF_Tensor* out = csession.output_tensor(i); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out)); + int32_t* output_contents = static_cast(TF_TensorData(out)); + EXPECT_EQ(expected_results[i], *output_contents); + } + } + + void CompareInt32Tensor(const std::vector& expected, TF_Tensor* t) { + int32_t* data = static_cast(TF_TensorData(t)); + size_t size = TF_TensorByteSize(t); + ASSERT_EQ(expected.size() * sizeof(int32_t), size); + for (int i = 0; i < expected.size(); ++i) { + ASSERT_EQ(expected[i], data[i]) << "Different data at index " << i; + } + } + + std::vector ToOutput(const std::vector ops) { + std::vector out; + for (auto op : ops) { + out.push_back({op, 0}); + } + return out; + } + + void Define(int num_opers, const std::vector& opers, + const std::vector& inputs, + const std::vector& outputs, + const char** output_names, bool expect_failure = false) { + DefineT(num_opers, opers, ToOutput(inputs), ToOutput(outputs), output_names, + expect_failure); + } + + // An explicit `num_opers` is needed so that we can distinguish between the + // case of no operations specified (-1) and the case of an empty set of + // operations specified (0). + void DefineT(int num_opers, const std::vector& opers, + const std::vector& inputs, + const std::vector& outputs, const char** output_names, + bool expect_failure = false) { + ASSERT_EQ(func_, nullptr); + func_ = TF_GraphToFunction(func_graph_, func_name_, num_opers, + num_opers == -1 ? nullptr : opers.data(), + inputs.size(), inputs.data(), outputs.size(), + outputs.data(), output_names, + /*opts=*/nullptr, s_); + if (expect_failure) { + ASSERT_EQ(func_, nullptr); + return; + } + + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + ASSERT_NE(func_, nullptr); + TF_GraphAddFunction(host_graph_, func_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + TF_Operation* Use(const std::vector& inputs) { + return UseT(ToOutput(inputs)); + } + + TF_Operation* UseT(const std::vector& inputs) { + TF_Operation* op; + UseHelper(inputs, &op); + return op; + } + + // All the *Helper methods are used as a workaround for the restrictions that + // one cannot call ASSERT_* methods in non-void-returning functions (when + // exceptions are disabled during compilation) + void UseHelper(const std::vector& inputs, TF_Operation** op) { + TF_OperationDescription* desc = + TF_NewOperation(host_graph_, func_name_, func_node_name_); + for (auto input : inputs) { + TF_AddInput(desc, input); + } + // Set device to CPU because some ops inside the function might not be + // available on GPU. + TF_SetDevice(desc, "/cpu:0"); + *op = TF_FinishOperation(desc, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + ASSERT_NE(*op, nullptr); + } + + FunctionDef fdef() { + tensorflow::FunctionDef fdef; + EXPECT_TRUE(GetFunctionDef(func_, &fdef)); + return fdef; + } + + // logging utility + template + string ToString(const Container& v) { + std::stringstream ss; + ss << "{"; + size_t i = 0; + for (const auto& e : v) { + if (i != 0) { + ss << ", "; + } + ss << e.ToString(); + ++i; + } + ss << "}"; + return ss.str(); + } + + void VerifyFDefNodes(const tensorflow::FunctionDef& fdef, + const std::unordered_set& nodes) { + ASSERT_EQ(nodes.size(), fdef.node_def_size()) + << "Got unexpected number of nodes. Expected: [" + << str_util::Join(nodes, ", ") + << "] Actual nodes in fdef: " << fdef.DebugString(); + for (const NodeDef& node_def : fdef.node_def()) { + ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end()) + << "Got unexpected node: " << node_def.name() + << " in fdef: " << fdef.DebugString(); + } + } + + void VerifyFDefInputs(const tensorflow::FunctionDef& fdef, + const std::vector& inputs) { + const OpDef& signature = fdef.signature(); + ASSERT_EQ(inputs.size(), signature.input_arg_size()); + for (int i = 0; i < inputs.size(); ++i) { + const OpDef::ArgDef& arg = signature.input_arg(i); + const IOSpec& in = inputs[i]; + if (in.second != DT_INVALID) { + ASSERT_EQ(arg.type(), in.second) + << "Got unexpected type for input " << i + << ". fdef: " << fdef.DebugString(); + } + ASSERT_EQ(arg.name(), in.first) << "Got unexpected name for input " << i + << ". fdef: " << fdef.DebugString(); + } + } + + void VerifyFDefOutputs(const tensorflow::FunctionDef& fdef, + const std::vector& outputs) { + const OpDef& signature = fdef.signature(); + ASSERT_EQ(outputs.size(), signature.output_arg_size()); + for (int i = 0; i < outputs.size(); ++i) { + const OpDef::ArgDef& arg = signature.output_arg(i); + const IOSpec& out = outputs[i]; + if (out.second != DT_INVALID) { + ASSERT_EQ(arg.type(), out.second) + << "Got unexpected type for output " << i + << ". fdef: " << fdef.DebugString(); + } + ASSERT_EQ(arg.name(), out.first) << "Got unexpected name for output " << i + << ". fdef: " << fdef.DebugString(); + } + } + + void VerifyFDefEdges( + const tensorflow::FunctionDef& fdef, + const std::vector& e_edges, // expected edges + const std::vector& c_edges, // expected ctrl edges + bool is_exact_edges = true) { + // Build a set of edges from fdef + std::set a_edges; // actual edges + // Get edges from inputs to body nodes and between body nodes + for (const NodeDef& node_def : fdef.node_def()) { + for (int i = 0; i < node_def.input_size(); ++i) { + const string& in = node_def.input(i); + const auto& v = + a_edges.insert({in, strings::StrCat(node_def.name(), ":", i)}); + ASSERT_TRUE(v.second) << "Duplicate edge " << in << " -> " + << strings::StrCat(node_def.name(), ":", i) + << ". fdef: " << fdef.DebugString(); + } + } + // Get edges from body nodes to outputs and from inputs to outputs + for (const OpDef::ArgDef& arg : fdef.signature().output_arg()) { + const auto& iter = fdef.ret().find(arg.name()); + if (iter != fdef.ret().end()) { + const auto& v = a_edges.insert({iter->second, arg.name()}); + ASSERT_TRUE(v.second) << "Duplicate edge " << iter->second << " -> " + << arg.name() << ". fdef: " << fdef.DebugString(); + } else { + const auto& v = a_edges.insert({arg.name(), arg.name()}); + ASSERT_TRUE(v.second) << "Duplicate edge " << arg.name() << " -> " + << arg.name() << ". fdef: " << fdef.DebugString(); + } + } + + // Verify edges + for (const EdgeSpec& e : e_edges) { + ASSERT_TRUE(a_edges.find(e) != a_edges.end()) + << "Failed to find expected edge " << e.ToString() + << " in fdef: " << fdef.DebugString(); + } + + // If caller specified all edges, check that we have seen all + if (is_exact_edges) { + ASSERT_EQ(e_edges.size() + c_edges.size(), a_edges.size()) + << "Expected edges: " << ToString(e_edges) + << " Expected Control edges: " << ToString(c_edges) + << " Actual edges: " << ToString(a_edges) + << " in fdef: " << fdef.DebugString(); + } + } + + void VerifyFDef(const std::unordered_set& nodes, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& e_edges, // expected edges + const std::vector& c_edges, // expected ctrl edges + bool is_exact_edges = true) { + tensorflow::FunctionDef fdef; + ASSERT_TRUE(GetFunctionDef(func_, &fdef)); + VerifyFDefNodes(fdef, nodes); + VerifyFDefInputs(fdef, inputs); + VerifyFDefOutputs(fdef, outputs); + VerifyFDefEdges(fdef, e_edges, c_edges, is_exact_edges); + } + + const char* func_name_ = "MyFunc"; + const char* func_node_name_ = "MyFunc_0"; + TF_Status* s_; + TF_Graph* func_graph_; + TF_Graph* host_graph_; + TF_Function* func_; + + // Workaround for not being able to initialize empty map using {} + std::unordered_set empty_; +}; + +TEST_F(CApiFunctionTest, OneOp_ZeroInputs_OneOutput) { + /* + * constant + * | + * v + */ + // Define + TF_Operation* c = ScalarConst(10, func_graph_, s_, "scalar10"); + Define(-1, {}, {}, {c}, nullptr); + + // Use, run, and verify + TF_Operation* func_op = Use({}); + Run({}, func_op, 10); + VerifyFDef({"scalar10_0"}, {}, {{"scalar10", DT_INT32}}, + {{"scalar10_0:output:0", "scalar10"}}, {}); +} + +TEST_F(CApiFunctionTest, OneOp_OneInput_OneOutput) { + /* + * | + * v + * negate + * | + * v + */ + // Define + TF_Operation* feed = Placeholder(func_graph_, s_); + TF_Operation* neg = Neg(feed, func_graph_, s_); + Define(-1, {}, {feed}, {neg}, nullptr); + + // Use, run, and verify + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, -3); + VerifyFDef({"neg_0"}, {{"feed", DT_INT32}}, {{"neg", DT_INT32}}, + {{"feed", "neg_0:0"}, {"neg_0:y:0", "neg"}}, {}); +} + +TEST_F(CApiFunctionTest, ZeroOps_Identity) { + /* + * | + * | + * | + * v + */ + // Define + TF_Operation* feed = Placeholder(func_graph_, s_); + Define(-1, {}, {feed}, {feed}, nullptr); + + // Use, run, and verify + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 3); + VerifyFDef(empty_, {{"feed", DT_INT32}}, {{"feed_0", DT_INT32}}, + {{"feed", "feed_0"}}, {}); +} + +TEST_F(CApiFunctionTest, ZeroOps_Permutation) { + /* + * | | + * \ / + * \/ + * x + * /\ + * / \ + * | | + * v v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + Define(-1, {}, {feed1, feed2}, {feed2, feed1}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {3, 2}); + VerifyFDef(empty_, M({{"feed1"}, {"feed2"}}), M({{"feed2_0"}, {"feed1_0"}}), + {{"feed1", "feed1_0"}, {"feed2", "feed2_0"}}, {}); +} + +TEST_F(CApiFunctionTest, OneOp_TwoInputs_OneOutput) { + /* + * | | + * v v + * add + * | + * v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* add = Add(feed1, feed2, func_graph_, s_); + Define(-1, {}, {feed1, feed2}, {add}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); + VerifyFDef( + {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}), + {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, {}); +} + +TEST_F(CApiFunctionTest, OneOp_TwoInputs_ZeroOutputs) { + /* + * | | + * v v + * add + * + * (output ignored) + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + Add(feed1, feed2, func_graph_, s_); + Define(-1, {}, {feed1, feed2}, {}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + Use({two, func_feed}); + VerifyFDef({"add"}, M({{"feed1"}, {"feed2"}}), {}, + {{"feed1", "add:0"}, {"feed2", "add:1"}}, {}); +} + +TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_OneOutput) { + /* + * | | | + * v v / + * add1 / + * | | + * v v + * add2 + * | + * v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3"); + TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1"); + TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2"); + Define(-1, {}, {feed1, feed2, feed3}, {add2}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_, "two"); + TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten"); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, ten, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 10 + 3); + VerifyFDef({"add1", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}), + M({{"add2"}}), + {{"feed1", "add1:0"}, + {"feed2", "add1:1"}, + {"add1:sum:0", "add2_0:0"}, + {"feed3", "add2_0:1"}, + {"add2_0:sum:0", "add2"}}, + {}); +} + +TEST_F(CApiFunctionTest, OneOp_TwoInputs_TwoDuplicateOutputs) { + /* + * | | + * v v + * add + * | + * +-+-+ + * | | + * v v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* add = Add(feed1, feed2, func_graph_, s_); + Define(-1, {}, {feed1, feed2}, {add, add}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {5, 5}); + VerifyFDef({"add_1"}, M({{"feed1"}, {"feed2"}}), M({{"add"}, {"add_0"}}), + {{"feed1", "add_1:0"}, + {"feed2", "add_1:1"}, + {"add_1:sum:0", "add"}, + {"add_1:sum:0", "add_0"}}, + {}); +} + +TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_TwoOutputs) { + /* + * | | | + * v v / + * add / + * | | + * +-+ | + * | | | + * | v v + * | add + * | | + * v v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3"); + TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1"); + TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2"); + Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_, "two"); + TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten"); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, ten, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15}); + VerifyFDef({"add1_0", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}), + M({{"add1"}, {"add2"}}), + {{"feed1", "add1_0:0"}, + {"feed2", "add1_0:1"}, + {"add1_0:sum:0", "add2_0:0"}, + {"feed3", "add2_0:1"}, + {"add1_0:sum:0", "add1"}, + {"add2_0:sum:0", "add2"}}, + {}); +} + +TEST_F(CApiFunctionTest, FromSubsetOfOps) { + /* + * | | | + * v v / + * add / + * | | + * +---+--+---+ + * Ops used | | | | + * for func | v v | + * | | add | + * +-------> | | | + * | v | + * | | + * +----------+ + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3"); + TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1"); + TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2"); + Define(1, {add2}, {add1, feed3}, {add2}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_, "two"); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); + VerifyFDef( + {"add2_0"}, M({{"add1"}, {"feed3"}}), M({{"add2"}}), + {{"add1", "add2_0:0"}, {"feed3", "add2_0:1"}, {"add2_0:sum:0", "add2"}}, + {}); +} + +TEST_F(CApiFunctionTest, UsingOneOutputOfSplit) { + /* + * feed + * | + * +---------+---+ + * | const0 | | + * | | | | + * | v / | + * | split | + * | | | | | + * | v | v | + * | | | + * +------+------+ + * | + * v + * + * Only the second output from split is used as function output + */ + // Define + TF_Operation* feed = Placeholder(func_graph_, s_); + TF_Operation* split = Split3(feed, func_graph_, s_); + DefineT(-1, {}, {{feed, 0}}, {{split, 1}}, nullptr); + + // Use, run, and verify + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({func_feed}); + RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}}, {{func_op, 0}}, + {{3, 4}}); + VerifyFDef({"split3_const0", "split3_0"}, M({{"feed"}}), M({{"split3"}}), + {{"split3_const0:output:0", "split3_0:0"}, + {"feed", "split3_0:1"}, + {"split3_0:output:1", "split3"}}, + {}); +} + +TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplit) { + /* + * feed + * | + * +---------+---+ + * | const0 | | + * | | | | + * | v / | + * | split | + * | | | | | + * | | v | | + * | | | | + * +---+-----+---+ + * | | + * v v + * + * Second output from split is not used as function output + */ + // Define + TF_Operation* feed = Placeholder(func_graph_, s_); + TF_Operation* split = Split3(feed, func_graph_, s_); + DefineT(-1, {}, {{feed, 0}}, {{split, 0}, {split, 2}}, nullptr); + + // Use, run, and verify + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({func_feed}); + RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}}, + {{func_op, 0}, {func_op, 1}}, {{1, 2}, {5, 6}}); + VerifyFDef({"split3_const0", "split3_1"}, M({{"feed"}}), + M({{"split3"}, {"split3_0"}}), + {{"split3_const0:output:0", "split3_1:0"}, + {"feed", "split3_1:1"}, + {"split3_1:output:0", "split3"}, + {"split3_1:output:2", "split3_0"}}, + {}); +} + +TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplitAsInputs) { + /* + * | + * v + * split + * | | | + * | v | + * | | + * +---+-----+---+ + * | | | | + * | v v | + * | add | + * | | | + * | | | + * +------+------+ + * | + * v + */ + // Define + TF_Operation* feed = Placeholder(func_graph_, s_); + TF_Operation* split = Split3(feed, func_graph_, s_); + TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + DefineT(1, {add}, {{split, 0}, {split, 2}}, {{add, 0}}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_, "two"); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); + VerifyFDef( + {"add_0"}, M({{"split3"}, {"split3_0"}}), M({{"add"}}), + {{"split3", "add_0:0"}, {"split3_0", "add_0:1"}, {"add_0:sum:0", "add"}}, + {}); +} + +TEST_F(CApiFunctionTest, NodesUsedInInputsMustHaveSingleOutput) { + /* + * | + * v + * split + * | | | + * | v | + * | | + * input --->| |<--- input + * | | + * v v + * add + * | + * | + * v + */ + // Define + TF_Tensor* tensor_123 = Int32Tensor({1, 2, 3}); + TF_Operation* c = Const(tensor_123, func_graph_, s_, "const_array"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* split = Split3(c, func_graph_, s_); + TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + DefineT(-1, {}, {{split, 0}, {split, 2}}, {{add, 0}}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("When `num_opers` is set to -1, nodes referenced in " + "`inputs` must have a single output. Node split3 has " + "3 outputs. Encountered while creating function 'MyFunc'"), + string(TF_Message(s_))); + + TF_DeleteTensor(tensor_123); +} + +TEST_F(CApiFunctionTest, FunctionWithWhileLoop) { + // Inputs to the while loop and the function as a whole + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + + // Outputs of the while loop corresponding to the two inputs above + // The first one will the function's output + std::vector outputs; + + // Add while loop to func_graph_ + { + // The inputs to the while loop + std::vector inputs = {{feed1, 0}, {feed2, 0}}; + std::unique_ptr params(new TF_WhileParams( + TF_NewWhile(func_graph_, &inputs[0], inputs.size(), s_))); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params->name = "test_loop"; + + // Initialize outputs so we can easily detect errors/bugs + outputs.resize(2, {nullptr, -1}); + + // Create loop: while (input1 < input2) input1 += input2 + 1 + TF_Operation* less_than = LessThan( + params->cond_inputs[0], params->cond_inputs[1], params->cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params->cond_output = {less_than, 0}; + + TF_Operation* add1 = Add(params->body_inputs[0], params->body_inputs[1], + params->body_graph, s_, "add1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* one = ScalarConst(1, params->body_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* add2 = Add(add1, one, params->body_graph, s_, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params->body_outputs[0] = {add2, 0}; + params->body_outputs[1] = params->body_inputs[1]; + + // Finalize while loop + TF_FinishWhile(params.get(), s_, &outputs[0]); + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + // Define function, use it in graph, and run + DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {outputs[0]}, nullptr); + TF_Operation* five = ScalarConst(5, host_graph_, s_, "five"); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({func_feed, five}); + Run({{func_feed, Int32Tensor(2)}}, func_op, 2 /*+=*/ + 5 + 1); + + // Verify input, output, and subset of edges in fdef. + // The subset of edges we verify is a chain between feed1 and output to + // make sure that the correct output is picked. + tensorflow::FunctionDef fdef; + ASSERT_TRUE(GetFunctionDef(func_, &fdef)); + VerifyFDefInputs(fdef, M({{"feed1"}, {"feed2"}})); + VerifyFDefOutputs(fdef, M({{"test_loop_exit"}})); + VerifyFDefEdges(fdef, + {{"feed1", "test_loop/Enter:0"}, + {"test_loop/Enter:output:0", "test_loop/Merge:0"}, + {"test_loop/Merge:output:0", "test_loop/Switch:0"}, + {"test_loop/Switch:output_false:0", "test_loop/Exit:0"}, + {"test_loop/Exit:output:0", "test_loop_exit"}}, + {}, false); +} + +TEST_F(CApiFunctionTest, ControlDependency) { + /* + * | | scalar + * | | . + * v v . <---- control dependency + * add < - + * | + * v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* five = ScalarConst(5, func_graph_, s_); + TF_Operation* add = + AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_); + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + Define(-1, {}, {feed1, feed2}, {add}, nullptr); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); + VerifyFDef( + {"add_0", "scalar"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}), + {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, + {{"scalar", "add_0"}}); +} + +TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody) { + /* + * | | scalar + * | | . + * v v . <---- control dependency + * add < - + * | + * v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* five = ScalarConst(5, func_graph_, s_); + TF_Operation* add = + AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_); + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + Define(1, {add}, {feed1, feed2}, {add}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("The source of control edge [id=3 scalar:-1 -> add:-1] " + "is not in the body. Encountered while creating " + "function 'MyFunc'"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody_FromInputNode) { + /* + * | |. + * | | . + * | | . + * v v . <---- control dependency + * add < - + * | + * v + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* add = + AddWithCtrlDependency(feed1, feed2, func_graph_, feed1, s_); + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + Define(-1, {}, {feed1, feed2}, {add}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("The source of control edge [id=3 feed1:-1 -> add:-1] " + "is not in the body. Encountered while creating " + "function 'MyFunc'"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, DuplicateInputsAreNotAllowed) { + /* + * feed + * | + * +++ + * | | + * +---+-+---+ + * | | | | + * | v v | + * | add | + * | | | + * | | | + * +----+----+ + * | + * v + */ + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* add = Add(feed1, feed1, func_graph_, s_); + Define(-1, {}, {feed1, feed1}, {add}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ( + string("TF_Output feed1:0 appears more than once in the input list"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, InvalidInputTensor_HighIndex) { + /* + * | | + * v v + * add + * | + * v + */ + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* add = Add(feed1, feed2, func_graph_, s_); + DefineT(-1, {}, {{feed1, 0}, {feed2, 2}}, {{add, 0}}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Node 'feed2' (type: 'Placeholder', num of outputs: 1) does " + "not have output 2\n\tEncountered while processing " + "input 1 into function 'MyFunc'"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, InvalidInputTensor_BadNodePtr) { + /* + * | | + * v v + * add + * | + * v + */ + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* add = Add(feed1, feed2, func_graph_, s_); + DefineT(-1, {}, {{feed1, 0}, {nullptr, 0}}, {{add, 0}}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Node is null\n\tEncountered while processing input 1 " + "into function 'MyFunc'"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, InvalidOutputTensor_HighIndex) { + /* + * | | + * v v + * add + * | + * v + */ + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* add = Add(feed1, feed2, func_graph_, s_); + DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{add, 3}}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Node 'add' (type: 'AddN', num of outputs: 1) does " + "not have output 3\n\tEncountered while processing " + "output 0 from function 'MyFunc'"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, InvalidOutputTensor_BadNodePtr) { + /* + * | | + * v v + * add + * | + * v + */ + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + Add(feed1, feed2, func_graph_, s_); + DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{nullptr, 3}}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Node is null\n\tEncountered while processing output 0 " + "from function 'MyFunc'"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, NodeMissingInput) { + /* + * input---> | | <----missing input + * v v + * body----> add + * | + * v + */ + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* add = Add(feed1, feed2, func_graph_, s_); + DefineT(1, {add}, {{feed1, 0}}, {{add, 0}}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("Input 1, 'feed2:0', of node 'add' in function 'MyFunc' " + "is not available. You might need to include it in inputs " + "or include its source node in the body"), + string(TF_Message(s_))); +} + +TEST_F(CApiFunctionTest, OutputOpNotInBody) { + /* + * | | + * v v + * add scalar (scalar not included in body) + * | | + * v v (function has two outputs) + */ + // Define + TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1"); + TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); + TF_Operation* scalar = ScalarConst(2, func_graph_, s_); + TF_Operation* add = Add(feed1, feed2, func_graph_, s_); + Define(1, {add}, {feed1, feed2}, {add, scalar}, nullptr, true); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(string("TF_Output scalar:0 is neither in the function body nor " + "among function inputs. Encountered while creating " + "function 'MyFunc'"), + string(TF_Message(s_))); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index f7d25dce8f5573d257daf0ecc281bab1f9eca016..68c324f2b992df144db79fb392eb8262a283d250 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -130,6 +130,11 @@ struct TF_DeviceList { std::vector response; }; +struct TF_Function { + // Currently contains a single function and no gradients + tensorflow::FunctionDefLibrary fdef_lib; +}; + namespace tensorflow { class TensorCApi { @@ -141,7 +146,12 @@ class TensorCApi { } }; +Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); + TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); + +Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out); + } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 0aa60fb45dda14fcf535c75e677da428698dfb3c..c4420290099ee10c89792210dad2604328296515 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -829,7 +829,7 @@ TEST(CAPI, ShapeInferenceError) { TF_Operation* vec3 = Const(vec3_tensor.get(), graph, status, "vec3"); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_Operation* add = Add(vec2, vec3, graph, status); + TF_Operation* add = AddNoCheck(vec2, vec3, graph, status); ASSERT_NE(TF_OK, TF_GetCode(status)); ASSERT_TRUE(add == nullptr); diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 21603c1a07caf9e9fdcd53561a94fdf7756ec84d..9cd978c97eada2123950da6271886ee20f918d5f 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/c/c_test_util.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" using tensorflow::GraphDef; @@ -36,6 +38,23 @@ TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { return t; } +TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims, + const int32_t* values) { + int64_t num_values = 1; + for (int i = 0; i < num_dims; ++i) { + num_values *= dims[i]; + } + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, dims, num_dims, sizeof(int32_t) * num_values); + memcpy(TF_TensorData(t), values, sizeof(int32_t) * num_values); + return t; +} + +TF_Tensor* Int32Tensor(const std::vector& values) { + int64_t dims = values.size(); + return Int32Tensor(&dims, 1, values.data()); +} + TF_Tensor* Int32Tensor(int32_t v) { const int num_bytes = sizeof(int32_t); int32_t* values = new int32_t[1]; @@ -44,19 +63,40 @@ TF_Tensor* Int32Tensor(int32_t v) { &Int32Deallocator, nullptr); } -TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) { +// All the *Helper methods are used as a workaround for the restrictions that +// one cannot call ASSERT_* methods in non-void-returning functions (when +// exceptions are disabled during compilation) +void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name, + TF_Operation** op) { TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); TF_SetAttrType(desc, "dtype", TF_INT32); - return TF_FinishOperation(desc, s); + *op = TF_FinishOperation(desc, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); } -TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, - const char* name) { +TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name) { + TF_Operation* op; + PlaceholderHelper(graph, s, name, &op); + return op; +} + +void ConstHelper(TF_Tensor* t, TF_Graph* graph, TF_Status* s, const char* name, + TF_Operation** op) { TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name); TF_SetAttrTensor(desc, "value", t, s); - if (TF_GetCode(s) != TF_OK) return nullptr; + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_SetAttrType(desc, "dtype", TF_TensorType(t)); - return TF_FinishOperation(desc, s); + *op = TF_FinishOperation(desc, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); +} + +TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, + const char* name) { + TF_Operation* op; + ConstHelper(t, graph, s, name, &op); + return op; } TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, @@ -65,11 +105,39 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, return Const(tensor.get(), graph, s, name); } +void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, + const char* name, TF_Operation** op, bool check) { + TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); + TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; + TF_AddInputList(desc, add_inputs, 2); + *op = TF_FinishOperation(desc, s); + if (check) { + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); + } +} + TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name) { + TF_Operation* op; + AddHelper(l, r, graph, s, name, &op, true); + return op; +} + +TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name) { + TF_Operation* op; + AddHelper(l, r, graph, s, name, &op, false); + return op; +} + +TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r, + TF_Graph* graph, TF_Operation* ctrl_op, + TF_Status* s, const char* name) { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; TF_AddInputList(desc, add_inputs, 2); + TF_AddControlInput(desc, ctrl_op); return TF_FinishOperation(desc, s); } @@ -81,11 +149,20 @@ TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, return TF_FinishOperation(desc, s); } -TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) { +void NegHelper(TF_Operation* n, TF_Graph* graph, TF_Status* s, + TF_Operation** op) { TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg"); TF_Output neg_input = {n, 0}; TF_AddInput(desc, neg_input); - return TF_FinishOperation(desc, s); + *op = TF_FinishOperation(desc, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); +} + +TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) { + TF_Operation* op; + NegHelper(n, graph, s, &op); + return op; } TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, @@ -96,6 +173,32 @@ TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, return TF_FinishOperation(desc, s); } +void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s, + const char* name, TF_Operation** op) { + TF_Operation* zero = ScalarConst( + 0, graph, s, ::tensorflow::strings::StrCat(name, "_const0").c_str()); + TF_OperationDescription* desc = TF_NewOperation(graph, "Split", name); + TF_AddInput(desc, {zero, 0}); + TF_AddInput(desc, {input, 0}); + TF_SetAttrInt(desc, "num_split", 3); + TF_SetAttrType(desc, "T", TF_INT32); + // Set device to CPU since there is no version of split for int32 on GPU + // TODO(iga): Convert all these helpers and tests to use floats because + // they are usually available on GPUs. After doing this, remove TF_SetDevice + // call in c_api_function_test.cc + TF_SetDevice(desc, "/cpu:0"); + *op = TF_FinishOperation(desc, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); +} + +TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s, + const char* name) { + TF_Operation* op; + Split3Helper(input, graph, s, name, &op); + return op; +} + bool IsPlaceholder(const tensorflow::NodeDef& node_def) { if (node_def.op() != "Placeholder" || node_def.name() != "feed") { return false; @@ -196,6 +299,18 @@ bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def) { return ret; } +bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def) { + TF_Status* s = TF_NewStatus(); + TF_Buffer* buffer = TF_NewBuffer(); + TF_FunctionToFunctionDef(func, buffer, s); + bool ret = TF_GetCode(s) == TF_OK; + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + if (ret) ret = func_def->ParseFromArray(buffer->data, buffer->length); + TF_DeleteBuffer(buffer); + TF_DeleteStatus(s); + return ret; +} + bool GetAttrValue(TF_Operation* oper, const char* attr_name, tensorflow::AttrValue* attr_value, TF_Status* s) { TF_Buffer* buffer = TF_NewBuffer(); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 0c0ba667bd0c3014efc6f0bd48ad0e63ccf4ee6e..a927739d462edfd25c9652cedbd0ab506991af45 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -33,6 +33,13 @@ typedef std::unique_ptr // Create a tensor with values of type TF_INT8 provided by `values`. TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values); +// Create a tensor with values of type TF_INT32 provided by `values`. +TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims, + const int32_t* values); + +// Create 1 dimensional tensor with values from `values` +TF_Tensor* Int32Tensor(const std::vector& values); + TF_Tensor* Int32Tensor(int32_t v); TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, @@ -47,6 +54,13 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name = "add"); +TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name = "add"); + +TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r, + TF_Graph* graph, TF_Operation* ctrl_op, + TF_Status* s, const char* name = "add"); + TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, const char* name = "add"); @@ -54,6 +68,10 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s); TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s); +// Split `input` along the first dimention into 3 tensors +TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s, + const char* name = "split3"); + bool IsPlaceholder(const tensorflow::NodeDef& node_def); bool IsScalarConst(const tensorflow::NodeDef& node_def, int v); @@ -66,6 +84,8 @@ bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def); bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def); +bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def); + bool GetAttrValue(TF_Operation* oper, const char* attr_name, tensorflow::AttrValue* attr_value, TF_Status* s); diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index b1baa5ce12527bb2faafd7f364c2e601f9faf565..e70539ceefa1e9b3b70be0ac2dd8acb431ed8caa 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" @@ -150,10 +151,11 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { return TF_SessionListDevices(ctx->session, status); } -TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t) { - return new TFE_TensorHandle( - tensorflow::TensorCApi::MakeTensor(t->dtype, t->shape, t->buffer), - nullptr); +TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { + tensorflow::Tensor tensor; + status->status = tensorflow::TF_TensorToTensor(t, &tensor); + if (!status->status.ok()) return nullptr; + return new TFE_TensorHandle(tensor, nullptr); } void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; } @@ -451,8 +453,9 @@ tensorflow::Status ValidateInputTypeAndPlacement( return tensorflow::errors::InvalidArgument( "cannot compute ", op->name, " as input #", i, " was expected to be a ", - tensorflow::DataType_Name(kernel->input_type(i)), " tensor but is a ", - tensorflow::DataType_Name(op->inputs[i].dtype()), " tensor"); + tensorflow::DataTypeString(kernel->input_type(i)), + " tensor but is a ", + tensorflow::DataTypeString(op->inputs[i].dtype()), " tensor"); } } return tensorflow::Status::OK(); diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 476c9288f895bd02b0d830c8ab339bf2e2d09731..a54d206a3076c35e473239c7e4c310d977afc882 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -20,6 +20,25 @@ limitations under the License. #include "tensorflow/c/c_api.h" +// Macro to control visibility of exported symbols in the shared library (.so, +// .dylib, .dll). +// This duplicates the TF_EXPORT macro definition in +// tensorflow/core/platform/macros.h in order to keep this .h file independent +// of any other includes.$a +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(COMPILER_MSVC) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // COMPILER_MSVC +#endif // SWIG + #ifdef __cplusplus extern "C" { #endif @@ -30,11 +49,11 @@ extern "C" { // TODO(ashankar): Merge with TF_Session? typedef struct TFE_Context TFE_Context; -extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, - TF_Status* status); -extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status); -extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, - TF_Status* status); +TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, + TF_Status* status); +TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status); +TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, + TF_Status* status); // A handle to a tensor on a device. // @@ -43,14 +62,15 @@ extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, // placed in memory of different devices or remote address spaces. typedef struct TFE_TensorHandle TFE_TensorHandle; -extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t); -extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); -extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); -extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h); -extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index); -extern const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h); -extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, - TF_Status* status); +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, + TF_Status* status); +TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); +TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); +TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h); +TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index); +TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h); +TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, + TF_Status* status); // Create a new TFE_TensorHandle with the same contents as 'h' but placed // in the memory of the device name 'device_name'. @@ -58,10 +78,10 @@ extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, // that shares the underlying buffer. Otherwise, it currently requires at least // one of the source or destination devices to be CPU (i.e., for the source or // destination tensor to be placed in host memory). -extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, - TFE_Context* ctx, - const char* device_name, - TF_Status* status); +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, + TFE_Context* ctx, + const char* device_name, + TF_Status* status); // Description of the TensorFlow op to execute. // @@ -76,49 +96,49 @@ extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, // the additional sanity checks there seem unnecessary; typedef struct TFE_Op TFE_Op; -extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, - TF_Status* status); -extern void TFE_DeleteOp(TFE_Op* op); +TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, + TF_Status* status); +TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op); // TODO(ashankar): TFE_OpSetDevice and TFE_Execute should not have a TFE_Context // parameter. Instead, the TFE_Context should be captured when creating the // TFE_Op. -extern void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx, - const char* device_name, TF_Status* status); - -extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status); - -extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, - unsigned char* is_list, TF_Status* status); - -extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, - const char* value); -extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value); -extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value); -extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, - unsigned char value); -extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, - TF_DataType value); +TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, TFE_Context* ctx, + const char* device_name, TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status); + +TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, + unsigned char* is_list, TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, + const char* value); +TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value); +TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value); +TF_CAPI_EXPORT extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, + unsigned char value); +TF_CAPI_EXPORT extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, + TF_DataType value); // If the number of dimensions is unknown, `num_dims` must be set to // -1 and `dims` can be null. If a dimension is unknown, the // corresponding entry in the `dims` array must be -1. -extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, - const int64_t* dims, const int num_dims, - TF_Status* out_status); - -extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, - const char** value, int num_values); -extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, - const int64_t* values, int num_values); -extern void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, - const float* values, int num_values); -extern void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, - const unsigned char* values, int num_values); -extern void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, - const TF_DataType* values, int num_values); -extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, - const int64_t** dims, const int* num_dims, - int num_values, TF_Status* out_status); +TF_CAPI_EXPORT extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, + const int64_t* dims, const int num_dims, + TF_Status* out_status); + +TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, + const char** value, int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, + const int64_t* values, int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, + const float* values, int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, + const unsigned char* values, int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, + const TF_DataType* values, int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, + const int64_t** dims, const int* num_dims, + int num_values, TF_Status* out_status); // Execute the operation defined by 'op' and return handles to computed // tensors in 'retvals'. @@ -128,14 +148,14 @@ extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, // // On return, 'num_retvals' will be set to the actual number of outputs // returned by the operation. -extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, - int* num_retvals, TF_Status* status); +TF_CAPI_EXPORT extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, + int* num_retvals, TF_Status* status); // Add a function (serialized FunctionDef protocol buffer) to ctx so // that it can be invoked using TFE_Execute. -extern void TFE_ContextAddFunctionDef(TFE_Context* ctx, - const char* serialized_function_def, - size_t size, TF_Status* status); +TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef(TFE_Context* ctx, + const char* serialized_function_def, + size_t size, TF_Status* status); #ifdef __cplusplus } /* end extern "C" */ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 6f5c21c947247477e2b2dafd452d4828b604e570..72e0fe8a1565a9a717c01aed83044cab2dd2dfbc 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -34,8 +34,11 @@ TFE_TensorHandle* TestMatrixTensorHandle() { TF_Tensor* t = TF_AllocateTensor( TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TFE_TensorHandle* th = TFE_NewTensorHandle(t); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteTensor(t); + TF_DeleteStatus(status); return th; } @@ -383,7 +386,9 @@ TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, memcpy(TF_TensorData(t.get()), &value, TF_TensorByteSize(t.get())); std::unique_ptr - value_handle(TFE_NewTensorHandle(t.get()), TFE_DeleteTensorHandle); + value_handle(TFE_NewTensorHandle(t.get(), status), + TFE_DeleteTensorHandle); + if (TF_GetCode(status) != TF_OK) return nullptr; TFE_OpAddInput(op, value_handle.get(), status); if (TF_GetCode(status) != TF_OK) return nullptr; diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index adca6c762526a85f015560efb22d3de185e2ae6c..b8d36b894722304e2b5e97332cabd5bab3c6dbd4 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -20,7 +20,6 @@ limitations under the License. namespace tensorflow { void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { - // TODO(skyewm): make sure cycles are prevented mutex_lock l(graph->mu); graph->graph.AddControlEdge(&input->node, &op->node); } diff --git a/tensorflow/c/version_script.lds b/tensorflow/c/version_script.lds index 455bd7362bb36d30af421a17f0e2f8e9ba66e02b..c352a1440d145c5ea62bbadb5bb5defb40bff537 100644 --- a/tensorflow/c/version_script.lds +++ b/tensorflow/c/version_script.lds @@ -1,7 +1,8 @@ VERS_1.0 { # Export symbols in c_api.h. global: - TF_*; + *TF_*; + *TFE_*; # Hide everything else. local: diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 369adee086d88903631c8455f5664c8fd8bb087f..d9071ba6e460b01746026db3215ff21ee52ac1b1 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -228,6 +228,26 @@ tf_cc_test( ], ) +cc_library_with_android_deps( + name = "while_loop", + srcs = ["ops/while_loop.cc"], + hdrs = ["ops/while_loop.h"], + android_deps = [ + "//tensorflow/core:android_tensorflow_lib", + ], + common_deps = [ + ":cc_ops", + ":cc_ops_internal", + ":ops", + ":scope", + ":scope_internal", + ], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + ], +) + cc_library( name = "grad_op_registry", srcs = ["framework/grad_op_registry.cc"], @@ -276,6 +296,7 @@ cc_library( ":cc_ops", ":cc_ops_internal", ":grad_op_registry", + ":gradients", ], alwayslink = 1, ) diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index 66a943410e2757ea5a5c55351c1fc20d5a5e3154..82469261e5ba43bfbdc7a0343ce6b651da46ccc1 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -78,6 +78,10 @@ class SymbolicGradientBuilder { const std::vector& grad_inputs, std::vector* grad_outputs); + // Returns a list mapping whether each node in the graph is reachable + // from outputs_. Keyed by node id. + std::vector GetReachableNodes(); + const Scope& scope_; const ops::GradOpRegistry* registry_; const std::vector& outputs_; @@ -143,11 +147,36 @@ Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad, return Status::OK(); } +std::vector SymbolicGradientBuilder::GetReachableNodes() { + std::vector reachable_nodes(scope_.graph()->num_node_ids(), false); + std::deque queue; + for (const Output& out : outputs_) { + if (!reachable_nodes[out.node()->id()]) { + queue.push_back(out.node()); + reachable_nodes[out.node()->id()] = true; + } + } + + while (!queue.empty()) { + Node* n = queue.front(); + queue.pop_front(); + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) continue; + queue.push_back(e->src()); + reachable_nodes[e->src()->id()] = true; + } + } + return reachable_nodes; +} + Status SymbolicGradientBuilder::Initialize() { if (outputs_.size() != grad_inputs_.size()) { return errors::InvalidArgument( "Must specify a gradient input for each output."); } + std::vector reachable_nodes = GetReachableNodes(); + // TODO(theflofly) Check that inputs_ are reachable from + // outputs_ using reachable_nodes grad_outputs_->clear(); grad_outputs_->resize(inputs_.size()); // Populate `output_nodes_` from node ids in `outputs_`. @@ -188,12 +217,15 @@ Status SymbolicGradientBuilder::Initialize() { if (output_nodes_.find(n->id()) == output_nodes_.end()) { // Internal node: continue BFS along connected outputs. for (const Edge* e : n->out_edges()) { - if (e->IsControlEdge()) continue; - ++num_expected_backprops; + // If a node is not reachable from outputs_, + // we don't expect it to receive a backpropagated gradient. + // It will not be counted in num_expected_backprops. + if (e->IsControlEdge() || !reachable_nodes[e->dst()->id()]) continue; if (visited.find(e->dst()) == visited.end()) { queue.push_back(e->dst()); visited.insert(e->dst()); } + ++num_expected_backprops; } } else { // Output node: stop BFS and update `num_expected_backprops` for diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index 24af7d567b267332610eba2c8c8c57681fa0559b..032ab936235acfe58ffa711c0dd75a0eede7eb62 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -364,6 +364,73 @@ TEST_F(GradientsTest, MultipleNodeOutputGrads) { test::AsTensor({60, 61, 62, 63, 66, 66, 66, 67}, {4, 2})); } +TEST_F(GradientsTest, UnreachableEdgeGradOneOutput) { + auto x = Variable(scope_test_, {2, 3}, DT_DOUBLE); + auto x_const = Const(scope_test_, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}); + auto x_assign = Assign(scope_test_, x, x_const); + + auto y = Variable(scope_test_, {3, 1}, DT_DOUBLE); + auto y_const = Const(scope_test_, {{1.0}, {2.0}, {3.0}}); + auto y_assign = Assign(scope_test_, y, y_const); + + auto m1 = MatMul(scope_test_, x, y); + + auto z = Variable(scope_test_, {1, 3}, DT_DOUBLE); + auto z_const = Const(scope_test_, {{9.0, 10.0, 11.0}}); + auto z_assign = Assign(scope_test_, z, z_const); + + auto m2 = MatMul(scope_test_, y, z); + + auto dm1 = Const(scope_test_, {{0.5}, {0.5}}); + + std::vector grad_outputs; + TF_ASSERT_OK( + AddSymbolicGradients(scope_test_, {m1}, {y}, {dm1}, &grad_outputs)); + + std::vector outputs; + test::GetTensors(scope_test_, {x_assign, y_assign, z_assign}, + {grad_outputs[0]}, &outputs); + // dz/dy = xT * dm1 + test::ExpectTensorNear( + outputs[0], test::AsTensor({2.5, 3.5, 4.5}, {3, 1}), 1e-5); +} + +TEST_F(GradientsTest, UnreachableEdgeGradTwoOutputs) { + auto x = Variable(scope_test_, {2, 3}, DT_DOUBLE); + auto x_const = Const(scope_test_, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}); + auto x_assign = Assign(scope_test_, x, x_const); + + auto y = Variable(scope_test_, {3, 1}, DT_DOUBLE); + auto y_const = Const(scope_test_, {{1.0}, {2.0}, {3.0}}); + auto y_assign = Assign(scope_test_, y, y_const); + + auto m1 = MatMul(scope_test_, x, y); + + auto z = Variable(scope_test_, {1, 3}, DT_DOUBLE); + auto z_const = Const(scope_test_, {{9.0, 10.0, 11.0}}); + auto z_assign = Assign(scope_test_, z, z_const); + + auto m2 = MatMul(scope_test_, y, z); + + auto dm1 = Const(scope_test_, {{0.5}, {0.5}}); + auto dm2 = + Const(scope_test_, {{0.5, 0.5, 0.5}, {0.6, 0.7, 0.8}, {0.6, 0.7, 0.9}}); + + std::vector grad_outputs; + TF_ASSERT_OK(AddSymbolicGradients(scope_test_, {m1, m2}, {y}, {dm1, dm2}, + &grad_outputs)); + + std::vector outputs; + test::GetTensors(scope_test_, {x_assign, y_assign, z_assign}, + {grad_outputs[0]}, &outputs); + + // the gradients from m1 and m2 will be summed to compute the gradient + // w.r.t y + // dz/dy = xT * dm1 + dm2 * zT + test::ExpectTensorNear( + outputs[0], test::AsTensor({17.5, 24.7, 26.8}, {3, 1}), 1e-5); +} + // StopGradientSingleOutputMultiEdgeTest tests combinations of valid and // 'NoGradient' (induced by StopGradient op) returned along multiple edges from // a single nodes output. diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 5cae5c64ad8620d2d552616aedadad14708aa8ec..0335f6357d0cb4bf0d586a17856bbf46f23d34d9 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -167,7 +167,8 @@ class Scope { // START_SKIP_DOXYGEN - /// Update the builder with properties accumulated in this scope. + /// Update the builder with properties accumulated in this scope. Does not set + /// status(). // TODO(skyewm): NodeBuilder is not part of public API void UpdateBuilder(NodeBuilder* builder) const; // END_SKIP_DOXYGEN @@ -215,12 +216,15 @@ class Scope { const std::vector& control_deps() const; - private: - friend class InternalScope; + // START_SKIP_DOXYGEN class Impl; - std::unique_ptr impl_; Impl* impl() { return impl_.get(); } const Impl* impl() const { return impl_.get(); } + // END_SKIP_DOXYGEN + + private: + friend class InternalScope; + std::unique_ptr impl_; explicit Scope(Impl*); }; diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index e2cc22af5d180eca0bdda0e06e6010995c917012..968c366550ef6f46557cd9b5662d9d0719b31531 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -43,6 +43,9 @@ class Scope::Impl { const std::shared_ptr& name_map, const std::shared_ptr& refiner); + const string& name() const { return name_; } + const std::vector& control_deps() const { return control_deps_; } + private: friend class Scope; @@ -98,6 +101,8 @@ class Scope::Impl { const std::vector control_deps_; + // The fully-qualified name of this scope (i.e. includes any parent scope + // names). const string name_ = ""; const string op_name_ = ""; const bool exit_on_error_ = false; diff --git a/tensorflow/cc/framework/testutil.cc b/tensorflow/cc/framework/testutil.cc index ca78f31db513f043d02594e100e549cb16e92795..57d573e3c5ad3a5068a945ec9391705b025fa7b6 100644 --- a/tensorflow/cc/framework/testutil.cc +++ b/tensorflow/cc/framework/testutil.cc @@ -36,5 +36,19 @@ void GetTensor(const Scope& scope, Output tensor, Tensor* out) { *out = outputs[0]; } +void GetTensors(const Scope& scope, const std::vector& assign_vars, + const OutputList& tensors, std::vector* out) { + ClientSession session(scope); + TF_CHECK_OK(session.Run(assign_vars, nullptr)); + TF_CHECK_OK(session.Run(tensors, out)); +} + +void GetTensor(const Scope& scope, const std::vector& assign_vars, + Output tensor, Tensor* out) { + std::vector outputs; + GetTensors(scope, assign_vars, {std::move(tensor)}, &outputs); + *out = outputs[0]; +} + } // end namespace test } // end namespace tensorflow diff --git a/tensorflow/cc/framework/testutil.h b/tensorflow/cc/framework/testutil.h index d027ad3744db895ef9e203f6f50fb5fe41687cb7..a3e19870ec847bcd4f0e0bf0e71dda724024d5d2 100644 --- a/tensorflow/cc/framework/testutil.h +++ b/tensorflow/cc/framework/testutil.h @@ -26,9 +26,21 @@ namespace test { void GetTensors(const Scope& scope, OutputList tensors, std::vector* out); +// Computes the outputs listed in 'tensors', returns the tensors in 'out'. +// assign_vars are extra outputs that should be run +// e.g. to assign values to variables. +void GetTensors(const Scope& scope, const std::vector& assign_vars, + const OutputList& tensors, std::vector* out); + /// Computes the output 'tensor', returning the resulting tensor in 'out'. void GetTensor(const Scope& scope, Output tensor, Tensor* out); +// Computes the output 'tensor', returning the resulting tensor in 'out'. +// assign_vars are extra outputs that should be run +// e.g. to assign values to variables. +void GetTensor(const Scope& scope, const std::vector& assign_vars, + Output tensor, Tensor* out); + } // namespace test } // namespace tensorflow diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 77a82fc8ca2853ebb16175dfc3861ce19a9690e4..d90654f2e9a89da56ef45d82b875c123d80f4633 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradients.h" namespace tensorflow { namespace ops { @@ -549,6 +550,209 @@ Status ConjGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Conj", ConjGrad); +// Integer division x / y, assuming x and y >=0, but treats x/0 = x +Output SafeDivHelper(const Scope& scope, const Output& x, const Output& y) { + return Div(scope, x, Maximum(scope, y, Const(scope, 1))); +} + +// Helper function for reduction ops. +// +// input_shape: 1-D Tensor, the shape of the Tensor being reduced. +// axes: 1-D Tensor, the reduction axes. +// Note that the reduction indices are in the range +// -rank(input_shape), rank(input_shape) +// returns a 1-D Tensor, the output shape as if keep_dims were set to True. +Output ReducedShapeHelper(const Scope& scope, const Output& input_shape, + const Output& reduction_axes) { + auto zero = Const(scope, 0); + auto one = Const(scope, 1); + + // Running example in comments + // input_shape = [2, 3, 5, 7] + // axes = [1, 2] + // The result (a shape after a reduction with keep_dims=True) + // [2, 1, 1, 7] + // + // We can treat each entry in axes as an index into input_shape that + // should be replaced by 1. + // We use DynamicStitch to do this. + + // input_rank = 4 + auto input_rank = Size(scope, input_shape); + + // Normalize any negative indices in the reduction_axes to positive + // values. + auto axes = Mod(scope, Add(scope, reduction_axes, input_rank), input_rank); + + // This [0..input_rank) range of integers is used in DynamicStitch to + // first copy input_shape to the result. + // input_rank_range = [0, 1, 2, 3] + auto input_rank_range = Range(scope, zero, input_rank, one); + + // A 1-filled tensor with the same shape as axes. DynamicStitch will + // merge these 1s (using axes for indices) to the correct + // position in the result. + // axes_ones = [1, 1] + auto axes_ones = OnesLike(scope, axes); + + // using DynamicStitch: + // indices = { input_rank_range, axes } + // = { [0, 1, 2, 3], [1, 2] } + // data = { input_shape, axes_ones } + // = { [2, 3, 5, 7], [1, 1] } + // The input_rank_range entry in indices first replicates the + // input_shape to the result. + // The axes entry in indices then moves a 1 to each of its entries, + // resulting in + // [2, 1, 1, 7] + std::vector indices = {input_rank_range, axes}; + std::vector data = {input_shape, axes_ones}; + return DynamicStitch(scope, indices, data); +} + +// SumGradHelper returns the gradient for the Sum operator, and is used +// by SumGrad and MeanGrad. +Output SumGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs) { + // The partial derivative for any input along a "reduced" dimension + // is just 1, so we only need replicate the output gradient on such a + // dimension to its "expanded" shape. + // Running example: + // input is + // [[a, b, c], + // [d, e, f]] + // reduction_indices = [1] + // Sum = [a + b + c, d + e + f] + // if the gradient is [g1, g2] + // We want the propagated gradient to be + // [[g1, g1, g1], + // [g2, g2, g2]] + + // input_shape = [2, 3] + auto input_shape = Shape(scope, op.input(0)); + + // output_shape_kept_dims = [2, 1] + auto output_shape_kept_dims = + ReducedShapeHelper(scope, input_shape, op.input(1)); + + // This step "flips" any 1s with values from the input_shape, and + // replaces remaining entries with 1. This creates a shape that + // shows how much each dimension in the incoming gradient should be + // replicated. + // tile_scaling = [1, 3] + auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims); + + // grad = [[g1], [g2]] + auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); + + // tile(grad, tile_scaling) = [[g1, g1, g1], [g2, g2, g2]] + return Tile(scope, grad, tile_scaling); +} + +Status SumGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + grad_outputs->push_back(SumGradHelper(scope, op, grad_inputs)); + + // Stop propagation along reduction_indices + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("Sum", SumGrad); + +Status MeanGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // The Mean gradient is just like the Sum gradient, except that + // all gradients are also divided by the size of reduced groups. + auto sum_grad = SumGradHelper(scope, op, grad_inputs); + + // The product of all entries in a tensor's shape is the total + // number of entries in the tensor. This step calculates + // n_input_entries/n_output_entries + // = group_size + auto input_shape = Shape(scope, op.input(0)); + auto output_shape = Shape(scope, op.output(0)); + auto zero = Const(scope, 0); + auto group_size = SafeDivHelper(scope, Prod(scope, input_shape, zero), + Prod(scope, output_shape, zero)); + + // propagate sum_grad/group_size + grad_outputs->push_back( + Div(scope, sum_grad, Cast(scope, group_size, sum_grad.type()))); + + // Stop propagation along reduction_indices + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("Mean", MeanGrad); + +Status MinOrMaxGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + // The partial derivative for any input along a "reduced" dimension + // is 1 when it is the min (or max) and 0 everywhere else. So the + // gradient calculation is identical for both operators. + // + // There's a special case for propagating gradients when there are + // multiple minima (or maxima) - we choose to divide the gradient + // equally among all matching inputs. + // + // Please note this comment + // https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063 + // for details. + + // Running example: + // input: [[5, 5, 5], + // [1, 2, -3]] + // reduction_indices: [1] + auto input = op.input(0); + auto reduction_indices = op.input(1); + + // [2, 3] + auto input_shape = Shape(scope, input); + + // [2, 1] + auto output_shape_kept_dims = + ReducedShapeHelper(scope, input_shape, reduction_indices); + + // for op=min (say) + // output = [5, -3] + // y = [[5], + // [-3]] + auto y = Reshape(scope, op.output(0), output_shape_kept_dims); + + // reshape([g1, g2], [2, 1]) = [[g1], + // [g2]] + auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); + + // indicators = equal(y, input) + // = equal([[5], [[5, 5, 5], + // [-3]], [1, 2, -3]]) + // = [[1, 1, 1], + // [0, 0, 1]] + auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type()); + + // [[3], + // [1]] + auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices), + output_shape_kept_dims); + + // [[1/3, 1/3, 1/3], + // [0, 0, 1]] + auto scale = Div(scope, indicators, num_selected); + + // [[g1/3, g1/3, g1/3], + // [0, 0, g2]] + grad_outputs->push_back(Mul(scope, scale, grad)); + + // Stop propagation along reduction_indices + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("Min", MinOrMaxGrad); +REGISTER_GRADIENT_OP("Max", MinOrMaxGrad); + // MatMulGrad helper function used to compute two MatMul operations // based on input matrix transposition combinations. Status MatMulGradHelper(const Scope& scope, const bool is_batch, diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 45060c33f3077ec54120ce284ef88ee631f44905..5b1558dd820862b18e486b347254c3a249bd016c 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -937,6 +937,73 @@ class NaryGradTest : public ::testing::Test { Scope scope_; }; +TEST_F(NaryGradTest, Sum) { + TensorShape x_shape({2, 3, 5, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = Sum(scope_, x, {1, -1}); + // y's shape is the result of reducing x along axes 1 and -1 (= 3) + TensorShape y_shape({2, 5}); + RunTest({x}, {x_shape}, {y}, {y_shape}); +} + +TEST_F(NaryGradTest, Mean) { + TensorShape x_shape({2, 3, 5, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = Mean(scope_, x, {1, -1}); + // y's shape is the result of reducing x along axes 1 and -1 (= 3) + TensorShape y_shape({2, 5}); + RunTest({x}, {x_shape}, {y}, {y_shape}); +} + +TEST_F(NaryGradTest, Min) { + TensorShape x_shape({2, 3}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = Min(scope_, x, {-1}); + // y's shape is the result of reducing x along axes -1 (= 1) + TensorShape y_shape({2}); + Tensor x_init_value = + test::AsTensor({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape); + RunTest(x, x_init_value, y, y_shape); +} + +TEST_F(NaryGradTest, Max) { + TensorShape x_shape({2, 3}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = Max(scope_, x, {-1}); + // y's shape is the result of reducing x along axes -1 (= 1) + TensorShape y_shape({2}); + Tensor x_init_value = + test::AsTensor({0.5f, 0.7f, 0.2f, 1.0f, 1.5f, -2.8f}, x_shape); + RunTest(x, x_init_value, y, y_shape); +} + +TEST_F(NaryGradTest, MinMulti) { + // Test gradient when there are multiple minima. + // Note that we cannot directly use a test Tensor with multiple + // minima, as the numeric estimator will calculate incorrect + // gradients when perturbing each entry in the Tensor (which then + // changes how many minima exist.) + // Instead, we use a single input that broadcast-multiplies a larger + // tensor with equal values, and apply reduce_min to the multiplied + // result. + TensorShape x_shape({1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x); + auto y = Min(scope_, all_same, {0}); + // y is a [3] shaped tensor reduced along dimension 0, so it is [1] shaped + TensorShape y_shape({1}); + RunTest({x}, {x_shape}, {y}, {y_shape}); +} + +TEST_F(NaryGradTest, MaxMulti) { + TensorShape x_shape({1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto all_same = Mul(scope_, Const(scope_, {1.f, 1.f, 1.f}), x); + auto y = Max(scope_, all_same, {0}); + TensorShape y_shape({1}); + RunTest({x}, {x_shape}, {y}, {y_shape}); +} + TEST_F(NaryGradTest, AddN) { TensorShape shape({3, 2, 5}); std::vector xs; diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index 31ae7ca770abc634be8fb66447203c74c44134f1..affc1e1dbe6526bd468e07bc6803cbf9b7b54db2 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -36,7 +36,7 @@ class NNGradTest : public ::testing::Test { float max_error; TF_ASSERT_OK(ComputeGradientError(scope_, {x}, {x_shape}, {y}, {y_shape}, &max_error)); - EXPECT_LT(max_error, 1e-4); + EXPECT_LT(max_error, 2e-4); } void RunTest(const Output& x, const Tensor& x_init_value, const Output& y, @@ -44,7 +44,7 @@ class NNGradTest : public ::testing::Test { float max_error; TF_ASSERT_OK( ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error)); - EXPECT_LT(max_error, 1e-4); + EXPECT_LT(max_error, 2e-4); } void RunTest(const OutputList& xs, const std::vector& x_shapes, @@ -53,7 +53,7 @@ class NNGradTest : public ::testing::Test { float max_error; TF_ASSERT_OK( ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error)); - EXPECT_LT(max_error, 1e-4); + EXPECT_LT(max_error, 2e-4); } Scope scope_; diff --git a/tensorflow/cc/ops/while_loop.cc b/tensorflow/cc/ops/while_loop.cc new file mode 100644 index 0000000000000000000000000000000000000000..27da77bbe068fd4be0eec40590a204fe6dedd235 --- /dev/null +++ b/tensorflow/cc/ops/while_loop.cc @@ -0,0 +1,223 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/while_loop.h" + +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/graph/node_builder.h" + +namespace tensorflow { +namespace ops { + +namespace { + +// Utility function for converting to internal C++ datatypes. +OutputTensor ToOutputTensor(const Output& output) { + return OutputTensor(output.node(), output.index()); +} + +// Utility function for converting to internal C++ datatypes. +std::vector ToOutputTensors(const std::vector& outputs) { + std::vector result(outputs.size()); + for (int i = 0; i < outputs.size(); ++i) { + result[i] = ToOutputTensor(outputs[i]); + } + return result; +} + +// Utility function for converting to internal C++ datatypes. +std::vector ToNodes(const std::vector& outputs) { + std::vector result(outputs.size()); + for (int i = 0; i < outputs.size(); ++i) { + result[i] = outputs[i].node(); + } + return result; +} + +// Manually generates the name of the `loop_var_idx`-th NextIteration node of a +// loop being constructed with `scope`. This is used to define the backedge +// before the NextIteration node is created. +string NextIterationName(const Scope& scope, int loop_var_idx) { + string result; + const string& prefix = scope.impl()->name(); + if (!prefix.empty()) strings::StrAppend(&result, prefix, "/"); + strings::StrAppend(&result, "NextIteration"); + if (loop_var_idx > 0) strings::StrAppend(&result, "_", loop_var_idx); + return result; +} + +// Creates the `loop_var_idx`-th Merge node of a loop being constructed with +// `scope`. `enter_output` is the `loop_var_idx`-th Enter node's output. +Status CreateMerge(const Scope& scope, int loop_var_idx, + const Output& enter_output, Output* merge_output) { + // The merge nodes accept the while loop's back edges as an input (i.e. the + // not-yet-created next iteration nodes). Use the underlying NodeBuilder API + // directly to create the back edge. + NodeBuilder::NodeOut enter_input(enter_output.node(), enter_output.index()); + + const int next_output_index = 0; + DataType dtype = enter_output.node()->output_type(0); + NodeBuilder::NodeOut next_input(NextIterationName(scope, loop_var_idx), + next_output_index, dtype); + + std::vector input_list({enter_input, next_input}); + const string unique_name = scope.GetUniqueNameForOp("Merge"); + NodeBuilder builder = NodeBuilder(unique_name, "Merge").Input(input_list); + scope.UpdateBuilder(&builder); + + Node* merge_node; + TF_RETURN_IF_ERROR(builder.Finalize(scope.graph(), &merge_node)); + TF_RETURN_IF_ERROR(scope.DoShapeInference(merge_node)); + *merge_output = Output(merge_node, 0); + return Status::OK(); +} + +// Creates the condition subgraph defined by `cond`. +Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond, + const std::vector& inputs, Output* output) { + // The control dependency is for constants in the cond graph, and other ops + // that do not depend on the loop variables. This ensures that these ops are + // in the while loop frame (since they will indirectly depend on an Enter node + // defining the frame) and that they are executed once per loop iteration. + // + // TODO(skyewm): the control dep will be added to all nodes in the cond graph. + // This is at best unnecessary, and at worst may prevent different parts of + // different loop iterations from executing in parallel. + Scope cond_scope = + scope.NewSubScope("cond").WithControlDependencies(inputs[0]); + Output raw_cond_out; + TF_RETURN_IF_ERROR(cond(cond_scope, inputs, &raw_cond_out)); + if (raw_cond_out.type() != DT_BOOL) { + return errors::InvalidArgument( + "BuildWhileLoop: 'cond' argument must return a boolean output, got ", + DataTypeString(raw_cond_out.type())); + } + *output = LoopCond(scope, raw_cond_out).output; + return Status::OK(); +} + +// Create the bdoy subgraph defined by `body`. `outputs` must be non-null and +// empty. +Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body, + const std::vector& inputs, + std::vector* outputs) { + DCHECK(outputs != nullptr); + DCHECK(outputs->empty()); + + // The control dependency is analogous to that in CreateCond(). + Scope body_scope = + scope.NewSubScope("body").WithControlDependencies(inputs[0]); + TF_RETURN_IF_ERROR(body(body_scope, inputs, outputs)); + const size_t num_loop_vars = inputs.size(); + if (outputs->size() != num_loop_vars) { + return errors::InvalidArgument( + "BuildWhileLoop: 'body' argument expected to return ", num_loop_vars, + "outputs, got ", outputs->size()); + } + // TODO(skyewm): check output types/shapes + return Status::OK(); +} + +} // namespace + +// A while loop with a single loop variable looks like this: +// +// (output) +// ^ +---------------+ +// | | body subgraph +-------------+ +// Exit +---------------+ | +// ^ ^ | +// | | | +// Switch<--------+ v +// ^ | NextIteration +// | +------+--------+ | +// +---->| cond subgraph | | +// | +---------------+ | +// Merge<---------------------------+ +// ^ +// | +// Enter +// ^ +// | +// (input) +// +// If there are multiple loop variables, each of the control flow ops is +// duplicated for each loop variable. +// TODO(skyewm): link to public version of design doc +Status BuildWhileLoop(const Scope& scope, const std::vector& inputs, + const CondGraphBuilderFn& cond, + const BodyGraphBuilderFn& body, const string& frame_name, + OutputList* outputs) { + DCHECK(!inputs.empty()); + DCHECK(outputs != nullptr); + DCHECK(outputs->empty()); + + TF_RETURN_IF_ERROR(scope.status()); + const size_t num_loop_vars = inputs.size(); + + std::vector enter_outputs(num_loop_vars); + for (int i = 0; i < num_loop_vars; ++i) { + enter_outputs[i] = internal::Enter(scope, inputs[i], frame_name); + } + TF_RETURN_IF_ERROR(scope.status()); + + std::vector merge_outputs(num_loop_vars); + for (int i = 0; i < num_loop_vars; ++i) { + TF_RETURN_IF_ERROR( + CreateMerge(scope, i, enter_outputs[i], &merge_outputs[i])); + } + + Output cond_out; + TF_RETURN_IF_ERROR(CreateCond(scope, cond, merge_outputs, &cond_out)); + + std::vector switch_trues(num_loop_vars); + std::vector switch_falses(num_loop_vars); + for (int i = 0; i < num_loop_vars; ++i) { + auto switch_i = Switch(scope, merge_outputs[i], cond_out); + switch_trues[i] = switch_i.output_true; + switch_falses[i] = switch_i.output_false; + } + TF_RETURN_IF_ERROR(scope.status()); + + std::vector body_outputs; + TF_RETURN_IF_ERROR(CreateBody(scope, body, switch_trues, &body_outputs)); + + std::vector next_outputs(num_loop_vars); + for (int i = 0; i < num_loop_vars; ++i) { + next_outputs[i] = NextIteration(scope, body_outputs[i]); + DCHECK_EQ(next_outputs[i].node()->name(), NextIterationName(scope, i)); + } + TF_RETURN_IF_ERROR(scope.status()); + + // Create the backedges from the NextIteration nodes to the Merge nodes. + for (int i = 0; i < num_loop_vars; ++i) { + const int merge_backedge_output_index = 1; + scope.graph()->AddEdge(next_outputs[i].node(), next_outputs[i].index(), + merge_outputs[i].node(), + merge_backedge_output_index); + } + + outputs->resize(num_loop_vars); + for (int i = 0; i < num_loop_vars; ++i) { + (*outputs)[i] = internal::Exit(scope, switch_falses[i]); + } + return scope.status(); +} + +} // namespace ops +} // namespace tensorflow diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h new file mode 100644 index 0000000000000000000000000000000000000000..253d5d8935cf1632f06c1f3ce728a68fc85391bf --- /dev/null +++ b/tensorflow/cc/ops/while_loop.h @@ -0,0 +1,64 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_ +#define THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_ + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" + +namespace tensorflow { +namespace ops { + +// Function that takes cond graph inputs and returns cond graph boolean output. +// 'output' need not be set if an error is returned. +typedef std::function& inputs, + Output* output)> + CondGraphBuilderFn; + +// Function that takes body graph inputs and returns body graph outputs. +// 'outputs' need not be populated if an error is returned. +typedef std::function& inputs, + std::vector* outputs)> + BodyGraphBuilderFn; + +// Constructs a while loop. +// +// Arguments: +// * scope: used to construct the while loop. +// * inputs: the initial values of the loop variables. Must be non-empty. +// * cond: a function that builds the condition graph of the loop. Takes the +// current loop variables as inputs and returns a scalar boolean Output +// indicating whether the loop should continue. +// * body: a function that builds the body graph of the loop. Takes the current +// loop variables as inputs and returns the updated loop variables. +// * frame_name: the frame name to use for this while loop. This should be a +// unique name. This will be used as a prefix for created operations. +// * outputs: output param that returns final loop variable outputs in non-error +// case. Must be non-null and empty. +// +// Returns an error if the while loop could not be fully constructed. +// +// TODO(skyewm): clean up partially-constructed loop in error case +// TODO(skyewm): create public interface to this method +Status BuildWhileLoop(const Scope& scope, const std::vector& inputs, + const CondGraphBuilderFn& cond, + const BodyGraphBuilderFn& body, const string& frame_name, + OutputList* outputs); + +} // namespace ops +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_ diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index cf7ffc5af19f4247f71f03dbd090b4808b42a08f..e6862f0d9dd7ec05b4e0c4ba26ab5f16a7aa9ad7 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -52,6 +52,12 @@ class BinaryOpsTest(XLATestCase): def testFloatOps(self): for dtype in self.float_types: + self._testBinary( + lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001), + np.array([[[[-1, 2.00009999], [-3, 4.01]]]], dtype=dtype), + np.array([[[[-1.001, 2], [-3.00009, 4]]]], dtype=dtype), + expected=np.array([[[[False, True], [True, False]]]], dtype=dtype)) + self._testBinary( gen_math_ops._real_div, np.array([3, 3, -1.5, -8, 44], dtype=dtype), @@ -82,6 +88,12 @@ class BinaryOpsTest(XLATestCase): dtype(4), expected=np.array([[16], [81]], dtype=dtype)) + self._testBinary( + gen_math_ops._reciprocal_grad, + np.array([4, -3, -2, 1], dtype=dtype), + np.array([5, -6, 7, -8], dtype=dtype), + expected=np.array([-80, 54, -28, 8], dtype=dtype)) + self._testBinary( gen_math_ops._sigmoid_grad, np.array([4, 3, 2, 1], dtype=dtype), @@ -94,6 +106,12 @@ class BinaryOpsTest(XLATestCase): np.array([5, 6, 7, 8], dtype=dtype), expected=np.array([-160, -81, -28, -4], dtype=dtype)) + self._testBinary( + gen_math_ops._sqrt_grad, + np.array([4, 3, 2, 1], dtype=dtype), + np.array([5, 6, 7, 8], dtype=dtype), + expected=np.array([0.625, 1, 1.75, 4], dtype=dtype)) + self._testBinary( gen_nn_ops._softplus_grad, np.array([4, 3, 2, 1], dtype=dtype), @@ -101,6 +119,13 @@ class BinaryOpsTest(XLATestCase): expected=np.array( [3.97322869, 2.99258232, 1.99817801, 0.99966466], dtype=dtype)) + self._testBinary( + gen_nn_ops._softsign_grad, + np.array([4, 3, 2, 1], dtype=dtype), + np.array([5, 6, 7, 8], dtype=dtype), + expected=np.array( + [0.11111111, 0.06122449, 0.03125, 0.01234568], dtype=dtype)) + self._testBinary( gen_math_ops._tanh_grad, np.array([4, 3, 2, 1], dtype=dtype), diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 6a0bdf1ed15097df3ba565d61d4d8c234aef1214..49c1699b6edc9d16bbba4578fdadd86a14d8c56c 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -829,6 +829,13 @@ TEST_F(OpTest, Abs) { }); } +TEST_F(OpTest, Acosh) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Acosh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, Add) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); @@ -881,6 +888,30 @@ TEST_F(OpTest, Any) { }); } +TEST_F(OpTest, ApproximateEqual) { + Repeatedly([this]() { + auto dims = RandomDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ApproximateEqual") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Asinh) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Asinh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Atanh) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Atanh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, AvgPool) { Repeatedly([this]() { std::uniform_int_distribution random_int(1, 5); @@ -1372,6 +1403,20 @@ TEST_F(OpTest, DepthwiseConv2DBackpropFilter) { }); } +TEST_F(OpTest, Cos) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Cos").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Cosh) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Cosh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, DepthwiseConv2DBackpropInput) { Repeatedly([this]() { WindowedSpatialDims d = ChooseWindowedSpatialDims(2); @@ -1459,11 +1504,11 @@ TEST_F(OpTest, DynamicStitch) { } while (size == 0); // Shuffle the range of indices that cover the output. - // TODO(phawkins): The documentation for DynamicStitch doesn't require that - // the indices cover all positions of the output. The XLA implementation - // does so require. However, the native TF implementation leaves undefined - // values if we don't cover everything, so we can't really test that case - // anyway. + // TODO(phawkins): The documentation for DynamicStitch doesn't require + // that the indices cover all positions of the output. The XLA + // implementation does so require. However, the native TF implementation + // leaves undefined values if we don't cover everything, so we can't + // really test that case anyway. std::vector indices(size); std::iota(indices.begin(), indices.end(), 0); std::shuffle(indices.begin(), indices.end(), generator()); @@ -1540,6 +1585,13 @@ TEST_F(OpTest, Exp) { }); } +TEST_F(OpTest, Expm1) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Expm1").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, ExpandDims) { Repeatedly([this]() { DataType type = Choose(kAllXlaTypes); @@ -1620,11 +1672,9 @@ TEST_F(OpTest, GreaterEqual) { TEST_F(OpTest, L2Loss) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); - // TODO(b/31644876): scalars currently crash. - return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("L2Loss") - .RandomInput(type, RandomDims(1)) - .Attr("T", type)); + DataType type = DT_FLOAT; + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("L2Loss").RandomInput(type).Attr("T", type)); }); } @@ -1675,6 +1725,13 @@ TEST_F(OpTest, Log) { }); } +TEST_F(OpTest, Log1p) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Log1p").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, LogicalAnd) { Repeatedly([this]() { auto dims = BroadcastableDims(); @@ -2116,6 +2173,15 @@ TEST_F(OpTest, Reciprocal) { }); } +TEST_F(OpTest, ReciprocalGrad) { + Repeatedly([this]() { + std::vector dims = RandomDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReciprocalGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); + }); +} TEST_F(OpTest, Relu) { Repeatedly([this]() { return ExpectTfAndXlaOutputsAreClose( @@ -2201,6 +2267,13 @@ TEST_F(OpTest, ReverseV2) { }); } +TEST_F(OpTest, Rint) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Rint").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, Round) { Repeatedly([this]() { return ExpectTfAndXlaOutputsAreClose( @@ -2272,6 +2345,20 @@ TEST_F(OpTest, Sign) { }); } +TEST_F(OpTest, Sin) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Sin").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, Sinh) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Sinh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, Size) { Repeatedly([this]() { DataType type = Choose({DT_INT32, DT_FLOAT}); @@ -2339,6 +2426,23 @@ TEST_F(OpTest, SoftplusGrad) { }); } +TEST_F(OpTest, Softsign) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Softsign").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + +TEST_F(OpTest, SoftsignGrad) { + Repeatedly([this]() { + std::vector dims = RandomDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftsignGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, SpaceToBatch) { Repeatedly([this]() { std::vector block_dims = RandomDims(4, 4, 0, 5); @@ -2496,6 +2600,16 @@ TEST_F(OpTest, Sqrt) { }); } +TEST_F(OpTest, SqrtGrad) { + Repeatedly([this]() { + auto dims = RandomDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SqrtGrad") + .RandomInput(DT_FLOAT, dims) + .RandomInput(DT_FLOAT, dims) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, SquaredDifference) { Repeatedly([this]() { auto dims = BroadcastableDims(); @@ -2655,6 +2769,13 @@ TEST_F(OpTest, StridedSliceGrad) { }); } +TEST_F(OpTest, Tan) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Tan").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, Tanh) { Repeatedly([this]() { return ExpectTfAndXlaOutputsAreClose( diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index cfc2a0c85ecd591a8739291a8b5e6bdf50073ac0..b21f1998a5d351d4a86438236441be541eef42b0 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import unittest + import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin @@ -117,27 +119,61 @@ class UnaryOpsTest(XLATestCase): def testFloatOps(self): for dtype in self.float_types: + self._assertOpOutputMatchesExpected( + math_ops.acosh, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array([0, 1.3169579, 1.76274717, 2.06343707], + dtype=dtype)) + + self._assertOpOutputMatchesExpected( + math_ops.asinh, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array([0.88137359, 1.44363548, 1.81844646, 2.09471255], + dtype=dtype)) + + self._assertOpOutputMatchesExpected( + math_ops.atanh, + np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype), + expected=np.array([0.10033535, 0.20273255, 0.3095196, 0.42364893], + dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.ceil, np.array([[-1.7, 1.2]], dtype=dtype), expected=np.array([[-1, 2]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.cosh, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array([1.54308063, 3.76219569, 10.067662, 27.30823284], + dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.exp, np.array([[-1, 1]], dtype=dtype), expected=np.array([[0.36787945, 2.7182817]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.expm1, + np.array([[-1, 1]], dtype=dtype), + expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.floor, np.array([[-1.7, 1.2]], dtype=dtype), expected=np.array([[-2, 1]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.is_finite, + np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], + dtype=dtype), + expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool)) + # Tests for tf.nn ops. self._assertOpOutputMatchesExpected( nn_ops.l2_loss, np.array([[[]]], dtype=dtype), expected=dtype(0)) - # TODO(b/31644876): enable this test case when fixed. - # self._assertOpOutputMatchesExpected(tf.nn.l2_loss, dtype(4), dtype(10)) + self._assertOpOutputMatchesExpected(nn_ops.l2_loss, dtype(4), dtype(8)) self._assertOpOutputMatchesExpected( nn_ops.l2_loss, np.array([[-2, 4]], dtype=dtype), expected=dtype(10)) @@ -169,6 +205,12 @@ class UnaryOpsTest(XLATestCase): np.array([[1e-14, 1e-15, 0.6]], dtype=dtype), expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.rint, + np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], + [0.5, 1.5, 2.5, 3.5]], dtype=dtype), + expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], + dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.round, np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5], @@ -197,11 +239,23 @@ class UnaryOpsTest(XLATestCase): np.array([-300, -150, 0, 150, 300], dtype=dtype), expected=np.array([0, 0, 0.5, 1, 1], dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.sinh, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array([1.17520119, 3.62686041, 10.01787493, 27.2899172], + dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.sqrt, np.array([[4, 9]], dtype=dtype), expected=np.array([[2, 3]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + math_ops.tan, + np.array([1, 2, 3, 4], dtype=dtype), + expected=np.array([1.55740772, -2.18503986, -0.14254654, 1.15782128], + dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.tanh, np.array( @@ -260,6 +314,12 @@ class UnaryOpsTest(XLATestCase): np.array([[-2, 0, 8]], dtype=dtype), expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + nn_ops.softsign, + np.array([[-2, -1, 0, 1, 2]], dtype=dtype), + expected=np.array([[-0.66666669, -0.5, 0, 0.5, 0.66666669]], + dtype=dtype)) + self._assertOpOutputMatchesExpected( math_ops.is_finite, np.array( @@ -294,6 +354,23 @@ class UnaryOpsTest(XLATestCase): np.array([[4, 3], [2, 1]], dtype=dtype), expected=np.array([[1, 1], [1, 1]], dtype=dtype)) + # TODO(phawkins): these tests fail unless fastmath optimizations + # are disabled. Use more robust IsInf/IsNaN detection and enable these + # tests. + @unittest.skip("test case fails in fast-math mode") + def testIsInfAndIsNan(self): + for dtype in self.float_types: + self._assertOpOutputMatchesExpected( + math_ops.is_inf, + np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], + dtype=dtype), + expected=np.array([[1, 0, 0, 0, 0, 0, 0, 1, 0]], dtype=np.bool)) + self._assertOpOutputMatchesExpected( + math_ops.is_nan, + np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], + dtype=dtype), + expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool)) + def testLogicalOps(self): self._assertOpOutputMatchesExpected( math_ops.logical_not, diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index af5753c2600f0b064b6aea4eba556054c38d8d9c..ddd912b87315f7943915153b5bf73531107af54d 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -38,7 +38,8 @@ string MakeUniquePath(string name) { // Remove illegal characters from `name`. for (int i = 0; i < name.size(); ++i) { - if (name[i] == '/') { + char ch = name[i]; + if (ch == '/' || ch == '[' || ch == ']' || ch == '*' || ch == '?') { name[i] = '_'; } } diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index f5c3947009490c6a1c45e407cf834ff72173fd91..6e6c5dc17f5364bf6623dd07b57cf4797442bc3b 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -31,7 +31,6 @@ tf_kernel_library( "function_ops.cc", "gather_op.cc", "identity_op.cc", - "is_finite_op.cc", "l2loss_op.cc", "lrn_ops.cc", "matmul_op.cc", @@ -145,31 +144,35 @@ tf_kernel_library( ], ) -tf_kernel_library( +cc_library( name = "gather_op_kernel_float_int32", srcs = ["gather_op_kernel_float_int32.cc"], visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/core:framework_lite", - "//tensorflow/core/kernels:gather_functor", + "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/kernels:gather_functor_hdr", "//third_party/eigen3", ], + alwayslink = 1, ) -tf_kernel_library( +cc_library( name = "gather_op_kernel_float_int64", srcs = ["gather_op_kernel_float_int64.cc"], visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/core:framework_lite", - "//tensorflow/core/kernels:gather_functor", + "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/kernels:gather_functor_hdr", "//third_party/eigen3", ], + alwayslink = 1, ) -tf_kernel_library( +cc_library( name = "index_ops_kernel_argmax_float_1d", srcs = ["index_ops_kernel_argmax_float_1d.cc"], visibility = ["//visibility:public"], @@ -177,9 +180,10 @@ tf_kernel_library( "//tensorflow/core:framework_lite", "//third_party/eigen3", ], + alwayslink = 1, ) -tf_kernel_library( +cc_library( name = "index_ops_kernel_argmax_float_2d", srcs = ["index_ops_kernel_argmax_float_2d.cc"], visibility = ["//visibility:public"], @@ -187,6 +191,7 @@ tf_kernel_library( "//tensorflow/core:framework_lite", "//third_party/eigen3", ], + alwayslink = 1, ) # ----------------------------------------------------------------------------- diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index ded20a9a3cefd12c1f1fee0593c82165a8129f40..58538b45137b26ed5aa296eb6c1077e88aea72b9 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -102,11 +102,16 @@ XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(RealDiv, b->Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(ReciprocalGrad, b->Neg(b->Mul(rhs, b->Mul(lhs, lhs)))); XLA_MAKE_BINARY( RsqrtGrad, b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)), b->Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)), extend_dimensions)); +XLA_MAKE_BINARY(SqrtGrad, + b->Div(b->Mul(rhs, + XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), + lhs, extend_dimensions)); static xla::ComputationDataHandle Square(xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) { @@ -136,6 +141,11 @@ XLA_MAKE_BINARY(SoftplusGrad, b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)), XlaHelpers::One(b, input_type(1))))); +// softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2 +XLA_MAKE_BINARY(SoftsignGrad, + b->Div(lhs, Square(b, b->Add(XlaHelpers::One(b, input_type(0)), + b->Abs(rhs))))); + XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)), b->Mul(lhs, lhs)))); @@ -143,5 +153,24 @@ XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions)); #undef XLA_MAKE_BINARY +class ApproximateEqualOp : public XlaOpKernel { + public: + explicit ApproximateEqualOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("tolerance", &tolerance_)); + } + + // Computes the max of the scalar input x and 0. + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + auto result = b->Lt(b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))), + XlaHelpers::FloatLiteral(b, input_type(0), tolerance_)); + ctx->SetOutput(0, result); + } + + private: + float tolerance_; +}; +REGISTER_XLA_OP(Name("ApproximateEqual"), ApproximateEqualOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/is_finite_op.cc b/tensorflow/compiler/tf2xla/kernels/is_finite_op.cc deleted file mode 100644 index 788dcee54438ded815bf244d2c1f8cda1e902cf6..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/kernels/is_finite_op.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/bcast.h" - -namespace tensorflow { -namespace { - -class IsFiniteOp : public XlaOpKernel { - public: - explicit IsFiniteOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - - void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle input = ctx->Input(0); - ctx->SetOutput(0, ctx->builder()->IsFinite(input)); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(IsFiniteOp); -}; - -REGISTER_XLA_OP(Name("IsFinite"), IsFiniteOp); - -} // anonymous namespace -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 626ddd17d394d4a2e1c014c3a280949a415dce94..6b8f5ec7b33cd448a7b06c5dfe4aac288e53e9c9 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -43,13 +43,42 @@ namespace { // Return x if x>0, otherwise -x. XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); + +// acosh(x) = log(x + sqrt(x^2 - 1)) +XLAJIT_MAKE_UNARY( + Acosh, + b->Log(b->Add(x, b->Pow(b->Sub(b->Mul(x, x), + XlaHelpers::One(b, input_type(0))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); +// asinh(x) = log(x + sqrt(x^2 + 1)) +XLAJIT_MAKE_UNARY( + Asinh, + b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x), + XlaHelpers::One(b, input_type(0))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); +// atanh(x) = 0.5 * log((1 + x) / (1 - x)) +XLAJIT_MAKE_UNARY( + Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x), + b->Sub(XlaHelpers::One(b, input_type(0)), x))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x)); XLAJIT_MAKE_UNARY(Cos, b->Cos(x)); +XLAJIT_MAKE_UNARY(Cosh, + b->Mul(b->Add(b->Exp(x), b->Exp(b->Neg(x))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); XLAJIT_MAKE_UNARY(Sin, b->Sin(x)); XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); + +// TODO(b/34703906): use a more accurate implementation of expm1. +XLAJIT_MAKE_UNARY(Expm1, b->Sub(b->Exp(x), XlaHelpers::One(b, input_type(0)))); + XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); -// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. -XLAJIT_MAKE_UNARY(Sign, b->Sign(x)); +XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x)); +XLAJIT_MAKE_UNARY(IsInf, b->Eq(b->Abs(x), + XlaHelpers::FloatLiteral( + b, input_type(0), + std::numeric_limits::infinity()))); +XLAJIT_MAKE_UNARY(IsNan, b->Ne(x, x)); // Return 1/x XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x)); XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x)); @@ -80,6 +109,12 @@ static xla::ComputationDataHandle Round(xla::ComputationBuilder* b, b->Add(round_val, one), round_val); } +XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x)); +XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x)); + +XLAJIT_MAKE_UNARY(Rsqrt, + b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5))); + // Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2. static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b, DataType dtype, @@ -87,16 +122,23 @@ static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b, auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x)))); } - -XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x)); -XLAJIT_MAKE_UNARY(Rsqrt, - b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5))); XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x)); + +// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. +XLAJIT_MAKE_UNARY(Sign, b->Sign(x)); +XLAJIT_MAKE_UNARY(Sinh, + b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); XLAJIT_MAKE_UNARY(Softplus, b->Log(b->Add(b->Exp(x), XlaHelpers::One(b, input_type(0))))); +// softsign(x) = x / (abs(x) + 1) +XLAJIT_MAKE_UNARY(Softsign, + b->Div(x, + b->Add(b->Abs(x), XlaHelpers::One(b, input_type(0))))); XLAJIT_MAKE_UNARY(Sqrt, b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); XLAJIT_MAKE_UNARY(Square, b->Mul(x, x)); +XLAJIT_MAKE_UNARY(Tan, b->Div(b->Sin(x), b->Cos(x))); XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x)); #undef XLAJIT_MAKE_UNARY diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD index b47c82f075a1b71dd355bd86ae7200360ab0f388..7977df81f5c93ebc3a53556844584ccd758e0f95 100644 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -1,11 +1,11 @@ -# Legacy command line flags for the XLA libraries. +# Legacy command-line flags for the XLA libraries. # Please do not add more flags to this package. -# The XLA libraries were written in an environment that allowed command - line +# The XLA libraries were written in an environment that allowed command-line # flags to be scattered freely throughout the libraries. This model, while -# initially convenient, leads to a proliferation in unused commnd line flags in -# tests and binaries, and serious problems in servers, where one might wish +# initially convenient, leads to a proliferation in unused command-line flags +# in tests and binaries, and serious problems in servers, where one might wish # parameters to be different in independent RPC calls to the same routine. # # Please don't add more flags. If you're a library author, pass options and @@ -43,17 +43,38 @@ cc_test( cc_library( name = "debug_options_flags", - srcs = ["debug_options_flags.cc"], + srcs = [ + "debug_options_flags.cc", + "debug_options_parsers.h", + ], hdrs = ["debug_options_flags.h"], deps = [ ":parse_flags_from_env", "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], ) +cc_test( + name = "debug_options_parsers_test", + size = "small", + srcs = [ + "debug_options_parsers.h", + "debug_options_parsers_test.cc", + ], + deps = + [ + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 913b8650d246941db37ef5b32c4e5272e4942430..8892bfbe929d168c602af24cfbb507256dc05328 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -17,6 +17,7 @@ limitations under the License. #include // NOLINT(build/c++11): only using std::call_once, not mutex. #include +#include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h" #include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -35,6 +36,7 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { flags->set_xla_llvm_enable_alias_scope_metadata(true); flags->set_xla_llvm_enable_noalias_metadata(true); flags->set_xla_llvm_enable_invariant_load_metadata(true); + flags->set_xla_llvm_disable_expensive_passes(false); flags->set_xla_backend_optimization_level(3); flags->set_xla_cpu_multi_thread_eigen(true); flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); @@ -66,7 +68,7 @@ void AllocateFlags() { }; }; - // Returns a lambda that is a custom "sub-parser" for xla_disable_hlo_passes. + // Custom "sub-parser" lambda for xla_disable_hlo_passes. auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) { std::vector disabled_passes = tensorflow::str_util::Split(comma_separated_values, ','); @@ -76,33 +78,25 @@ void AllocateFlags() { return true; }; - // Returns a lambda that is a custom "sub-parser" for - // xla_backend_extra_options. + // Custom "sub-parser" lambda for xla_backend_extra_options. auto setter_for_xla_backend_extra_options = [](string comma_separated_values) { - std::vector extra_options_parts = - tensorflow::str_util::Split(comma_separated_values, ','); auto* extra_options_map = flag_values->mutable_xla_backend_extra_options(); - - // The flag contains a comma-separated list of options; some options - // have arguments following "=", some don't. - for (const auto& part : extra_options_parts) { - size_t eq_pos = part.find_first_of('='); - if (eq_pos == string::npos) { - (*extra_options_map)[part] = ""; - } else { - string value = ""; - if (eq_pos + 1 < part.size()) { - value = part.substr(eq_pos + 1); - } - (*extra_options_map)[part.substr(0, eq_pos)] = value; - } - } - + impl::parse_xla_backend_extra_options(extra_options_map, + comma_separated_values); return true; }; + // Custom "sub-parser" lambda for xla_reduce_precision. + auto setter_for_xla_reduce_precision = + [](string reduce_precision_option_value) { + HloReducePrecisionOptions* option_proto = + flag_values->add_hlo_reduce_precision_options(); + return impl::parse_xla_reduce_precision_option( + option_proto, reduce_precision_option_value); + }; + flag_objects = new std::vector( {tensorflow::Flag( "xla_generate_hlo_graph", @@ -157,6 +151,13 @@ void AllocateFlags() { "In LLVM-based backends, enable the emission of " "!invariant.load metadata in " "the generated IR."), + tensorflow::Flag( + "xla_llvm_disable_expensive_passes", + bool_setter_for( + &DebugOptions::set_xla_llvm_disable_expensive_passes), + flag_values->xla_llvm_disable_expensive_passes(), + "In LLVM-based backends, disable a custom set of " + "expensive optimization passes."), tensorflow::Flag( "xla_backend_optimization_level", int32_setter_for(&DebugOptions::set_xla_backend_optimization_level), @@ -242,7 +243,20 @@ void AllocateFlags() { setter_for_xla_backend_extra_options, "", "Extra options to pass to a backend; " "comma-separated list of 'key=val' strings (=val " - "may be omitted); no whitespace around commas.")}); + "may be omitted); no whitespace around commas."), + tensorflow::Flag("xla_reduce_precision", setter_for_xla_reduce_precision, + "", + "Directions for adding reduce-precision operations. " + "Format is 'LOCATION=E,M:OPS;NAMES' where LOCATION is " + "the class of locations in which to insert the " + "operations (e.g., 'OP_OUTPUTS'), E and M are the " + "exponent and matissa bit counts respectively, and " + "OPS and NAMES are comma-separated (no spaces) lists " + "of the operation types and names to which to attach " + "the reduce-precision operations. The NAMES string " + "and its preceding ';' may be omitted. This option " + "may be repeated to define multiple sets of added " + "reduce-precision operations.")}); ParseFlagsFromEnv(*flag_objects); } diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h new file mode 100644 index 0000000000000000000000000000000000000000..0c238e6a5decffb0339f428e4ea676944479cf1b --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h @@ -0,0 +1,151 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ + +#include +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace xla { +namespace legacy_flags { +namespace impl { + +template +void parse_xla_backend_extra_options(T* extra_options_map, + string comma_separated_values) { + std::vector extra_options_parts = + tensorflow::str_util::Split(comma_separated_values, ','); + + // The flag contains a comma-separated list of options; some options + // have arguments following "=", some don't. + for (const auto& part : extra_options_parts) { + size_t eq_pos = part.find_first_of('='); + if (eq_pos == string::npos) { + (*extra_options_map)[part] = ""; + } else { + string value = ""; + if (eq_pos + 1 < part.size()) { + value = part.substr(eq_pos + 1); + } + (*extra_options_map)[part.substr(0, eq_pos)] = value; + } + } +} + +// The --xla_reduce_precision option has the format "LOCATION=E,M:OPS;NAME", +// where LOCATION is an HloReducePrecisionOptions::location, E and M are +// integers for the exponent and matissa bit counts respectively, and OPS and +// NAMES are comma-separated of the operation types and names to which to +// attach the reduce-precision operations. The OPS values are matches to the +// strings produced by HloOpcodeString, while the NAME values are arbitrary +// strings subject to the requirements that they not contain any of "=,:;". +// The NAME string (with its preceding semicolon) is optional. +inline bool parse_xla_reduce_precision_option( + HloReducePrecisionOptions* options, string option_string) { + // Split off "LOCATION" from remainder of string. + std::vector eq_split = + tensorflow::str_util::Split(option_string, '='); + if (eq_split.size() != 2) { + return false; + } + string& location = eq_split[0]; + if (location == "OP_INPUTS") { + options->set_location(HloReducePrecisionOptions::OP_INPUTS); + } else if (location == "OP_OUTPUTS") { + options->set_location(HloReducePrecisionOptions::OP_OUTPUTS); + } else if (location == "UNFUSED_OP_OUTPUTS") { + options->set_location(HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS); + } else if (location == "FUSION_INPUTS_BY_CONTENT") { + options->set_location(HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT); + } else if (location == "FUSION_OUTPUTS_BY_CONTENT") { + options->set_location(HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT); + } else { + return false; + } + + // Split off "E,M" from remainder of string. + std::vector colon_split = + tensorflow::str_util::Split(eq_split[1], ':'); + if (colon_split.size() != 2) { + return false; + } + + // Split E and M, and parse. + std::vector bitsizes; + if (!tensorflow::str_util::SplitAndParseAsInts(colon_split[0], ',', + &bitsizes) || + bitsizes.size() != 2) { + return false; + } + options->set_exponent_bits(bitsizes[0]); + options->set_mantissa_bits(bitsizes[1]); + + // Split off OPS comma-separated list from remainder of string, if the + // remainder exists. + std::vector semicolon_split = + tensorflow::str_util::Split(colon_split[1], ';'); + if (semicolon_split.size() > 2) { + return false; + } + // The opcode values are either 'all' (meaning all opcodes), or matches to + // the strings returned by HloOpcodeString. An empty string is also + // interpreted as 'all', for convenience. Note that 'all' may not be part + // of a comma-separated list; it must stand alone. + string& opcode_string = semicolon_split[0]; + if (opcode_string == "" || opcode_string == "all") { + for (int i = 0; i < HloOpcodeCount(); i++) { + options->add_opcodes_to_suffix(i); + } + } else { + std::vector opcodes = + tensorflow::str_util::Split(opcode_string, ','); + for (const string& opcode : opcodes) { + bool found = false; + for (int i = 0; i < HloOpcodeCount(); i++) { + if (opcode == HloOpcodeString(static_cast(i))) { + options->add_opcodes_to_suffix(i); + found = true; + break; + } + } + if (!found) { + return false; + } + } + } + + // Process the NAMES string, if it exists. + if (semicolon_split.size() == 2) { + std::vector opnames = + tensorflow::str_util::Split(semicolon_split[1], ','); + for (const string& opname : opnames) { + if (opname.length() > 0) { + options->add_opname_substrings_to_suffix(opname); + } + } + } + + return true; +} + +} // namespace impl +} // namespace legacy_flags +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ed788a9676fe9b1bd06fb3ceabf627c108a2c70 --- /dev/null +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers_test.cc @@ -0,0 +1,106 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Test for parse_flags_from_env.cc + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h" + +#include +#include + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace legacy_flags { + +// Test that the xla_backend_extra_options flag is parsed correctly. +TEST(DebugOptionsFlags, ParseXlaBackendExtraOptions) { + std::unordered_map test_map; + string test_string = "aa=bb,cc,dd=,ee=ff=gg"; + impl::parse_xla_backend_extra_options(&test_map, test_string); + EXPECT_EQ(test_map.size(), 4); + EXPECT_EQ(test_map.at("aa"), "bb"); + EXPECT_EQ(test_map.at("cc"), ""); + EXPECT_EQ(test_map.at("dd"), ""); + EXPECT_EQ(test_map.at("ee"), "ff=gg"); +} + +// Test that the xla_reduce_precision flag is parsed correctly. +TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStrings) { + HloReducePrecisionOptions proto; + string test_string = "OP_OUTPUTS=5,10:add,dot"; + EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS); + EXPECT_EQ(proto.exponent_bits(), 5); + EXPECT_EQ(proto.mantissa_bits(), 10); + EXPECT_EQ(proto.opcodes_to_suffix_size(), 2); + EXPECT_EQ(static_cast(proto.opcodes_to_suffix(0)), + HloOpcode::kAdd); + EXPECT_EQ(static_cast(proto.opcodes_to_suffix(1)), + HloOpcode::kDot); + EXPECT_EQ(proto.opname_substrings_to_suffix_size(), 0); +} + +TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoStringsSemicolon) { + HloReducePrecisionOptions proto; + string test_string = "OP_OUTPUTS=5,10:add,dot;"; + EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_EQ(proto.location(), HloReducePrecisionOptions::OP_OUTPUTS); + EXPECT_EQ(proto.exponent_bits(), 5); + EXPECT_EQ(proto.mantissa_bits(), 10); + EXPECT_EQ(proto.opcodes_to_suffix_size(), 2); + EXPECT_EQ(static_cast(proto.opcodes_to_suffix(0)), + HloOpcode::kAdd); + EXPECT_EQ(static_cast(proto.opcodes_to_suffix(1)), + HloOpcode::kDot); + EXPECT_EQ(proto.opname_substrings_to_suffix_size(), 0); +} + +TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionNoOpcodes) { + HloReducePrecisionOptions proto; + string test_string = "UNFUSED_OP_OUTPUTS=5,10:;foo,bar/baz"; + EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS); + EXPECT_EQ(proto.exponent_bits(), 5); + EXPECT_EQ(proto.mantissa_bits(), 10); + EXPECT_EQ(proto.opcodes_to_suffix_size(), HloOpcodeCount()); + EXPECT_EQ(proto.opname_substrings_to_suffix_size(), 2); + EXPECT_EQ(proto.opname_substrings_to_suffix(0), "foo"); + EXPECT_EQ(proto.opname_substrings_to_suffix(1), "bar/baz"); +} + +TEST(DebugOptionsFlags, ParseXlaReducePrecisionOptionBoth) { + HloReducePrecisionOptions proto; + string test_string = "UNFUSED_OP_OUTPUTS=5,10:subtract;foo,bar/baz"; + EXPECT_TRUE(impl::parse_xla_reduce_precision_option(&proto, test_string)); + EXPECT_EQ(proto.location(), HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS); + EXPECT_EQ(proto.exponent_bits(), 5); + EXPECT_EQ(proto.mantissa_bits(), 10); + EXPECT_EQ(proto.opcodes_to_suffix_size(), 1); + EXPECT_EQ(static_cast(proto.opcodes_to_suffix(0)), + HloOpcode::kSubtract); + EXPECT_EQ(proto.opname_substrings_to_suffix_size(), 2); + EXPECT_EQ(proto.opname_substrings_to_suffix(0), "foo"); + EXPECT_EQ(proto.opname_substrings_to_suffix(1), "bar/baz"); +} + +} // namespace legacy_flags +} // namespace xla + +int main(int argc, char* argv[]) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc index 07bbcd802fef1d89f428717a4bd7d669d9f119c2..a3b4286f4c12bf39a44c63dd6e7d303a46a418c3 100644 --- a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc +++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc @@ -86,14 +86,14 @@ static const char kTestFlagString[] = "--single_quoted='single quoted \\\\ \n \"' " "--double_quoted=\"double quoted \\\\ \n '\\\"\" "; -// Test that the environent variable is parserd correctly. +// Test that the environent variable is parsed correctly. TEST(ParseFlagsFromEnv, Basic) { // Prepare environment. setenv("TF_XLA_FLAGS", kTestFlagString, true /*overwrite*/); TestParseFlagsFromEnv("(flags in environment variable)"); } -// Test that a file named by the environent variable is parserd correctly. +// Test that a file named by the environent variable is parsed correctly. TEST(ParseFlagsFromEnv, File) { // environment variables where tmp dir may be specified. static const char* kTempVars[] = {"TEST_TMPDIR", "TMP"}; diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 71995b2307e3a34ac5d1f3307ccea42b4cf230a5..6190bd624db65343154adfaae45c18c4b50c3cd7 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -94,14 +94,17 @@ Status Literal::CopyRange(const Literal& src_literal, TF_RET_CHECK(ShapeUtil::Rank(src_shape) == src_base.size()); TF_RET_CHECK(ShapeUtil::Rank(dest_shape) == dest_base.size()); + if (ShapeUtil::Rank(src_shape) == 0 || ShapeUtil::Rank(dest_shape) == 0) { // If any of the two shapes are scalars, we can just call the StridedCopy() // directly, and we know we will be copying only one value. TF_RET_CHECK(copy_size.empty()); StridedCopy(dest_data, LinearIndex(dest_base), 0, src_data, src_literal.LinearIndex(src_base), 0, 1); - } else if (!ShapeUtil::HasZeroElements(dest_shape)) { - TF_RET_CHECK(!ShapeUtil::HasZeroElements(src_shape)); + } else if (!ShapeUtil::HasZeroElements(dest_shape) && + !ShapeUtil::HasZeroElements(src_shape)) { + // Perform copy if neither src literal nor dest literal has dimensions with + // zero element, otherwise it's a no-op. TF_RET_CHECK(src_base.size() == dest_base.size()); TF_RET_CHECK(src_base.size() == copy_size.size()); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 447c494bfca01cd268bd3f4a335984a33bafb045..64513459186b6b8a12381fd353e68e2198066e80 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -237,6 +237,9 @@ class Literal { // The src_literal and this literal must have the same primitive type, // src_base+copy_size must fit the source literal dimensions, as well as // dest_base+copy_size must fit the destination literal dimensions. + // Note: if either src_literal or this literal contains dimensions with zero + // element, then copy_size must be 0 in these dimensions while the + // corresponding base indices being 0. Status Copy(const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index a33c0fe09dd95734c7ff98ff4f7955c0124415b9..61ceac4f9a60e0bfaddd48ba7aa7a8e736a6dc14 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -698,7 +698,7 @@ TEST_F(LiteralUtilTest, Copy) { for (const auto& layout : layouts) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), dimensions, layout); - auto blank = Literal::CreateFromShape(shape); + auto source = Literal::CreateFromShape(shape); const int64 zero_base[] = {0, 0, 0, 0}; const int64 step[] = {1, 1, 1, 1}; @@ -707,15 +707,15 @@ TEST_F(LiteralUtilTest, Copy) { source->Set(indexes, ++seqnr); return true; }; - ShapeUtil::ForEachIndex(source->shape(), zero_base, dimensions, step, init_proc); + auto blank = Literal::CreateFromShape(shape); const int64 src_base[] = {3, 1, 5, 7}; const int64 dest_base[] = {6, 4, 12, 2}; const int64 copy_size[] = {7, 8, 11, 9}; - TF_EXPECT_OK(blank->Copy(*source, src_base, dest_base, copy_size)); + std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); bool matched = true; @@ -730,6 +730,7 @@ TEST_F(LiteralUtilTest, Copy) { matched = (bval != 0 && bval == source->Get(source_indexes)); return matched; }; + ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step, check_proc); EXPECT_TRUE(matched); @@ -749,6 +750,30 @@ TEST_F(LiteralUtilTest, CopyScalars) { EXPECT_EQ(vect->Get({4}), 17); } +TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { + const Shape empty_r1_shape = ShapeUtil::MakeShape(F32, {0}); + const auto const_nine = Literal::CreateR1({9}); + const auto const_empty = Literal::CreateFromShape(empty_r1_shape); + + { + // Source contains dimension with zero elements. + const auto empty = Literal::CreateFromShape(empty_r1_shape); + auto nine = Literal::CreateR1({9}); + + TF_EXPECT_OK(nine->Copy(*empty, {0}, {0}, {0})); + EXPECT_TRUE(nine->Equal(*const_nine)); + } + + { + // Copy 0 element to destination with zero elements. + const auto empty = Literal::CreateFromShape(empty_r1_shape); + auto nine = Literal::CreateR1({9}); + + TF_EXPECT_OK(empty->Copy(*nine, {0}, {0}, {0})); + EXPECT_TRUE(empty->Equal(*const_empty)); + } +} + TEST_F(LiteralUtilTest, F16) { // Verify that the internal data views are consistent and that they // are in little endian format diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 4b7bfa4eda733dbb795161127c6daebef9860444..98cc3401c14f93cc2f209806baa0d97f41582f4b 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -41,6 +41,7 @@ cc_library( srcs = ["shape_inference.cc"], hdrs = ["shape_inference.h"], deps = [ + ":hlo", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -829,6 +830,8 @@ cc_library( ":call_graph", ":hlo", ":hlo_proto", + ":hlo_value", + ":liveness_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -844,6 +847,7 @@ cc_test( srcs = ["hlo_ordering_test.cc"], deps = [ ":hlo", + ":hlo_dataflow_analysis", ":hlo_ordering", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", @@ -1600,6 +1604,8 @@ cc_library( hdrs = ["hlo_verifier.h"], deps = [ ":hlo_pass", + ":shape_inference", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 59cf08be476fda2ea348c7a391d8b917e9bee2f2..74f8e3143d718e46e09e76bf9439b4f96b012226 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -120,6 +120,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleAdd(HloInstruction* add, HloInstruction* lhs, HloInstruction* rhs) override; + Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleBroadcast(HloInstruction* broadcast) override; Status HandleConcatenate( @@ -239,9 +241,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status ReplaceWithNewInstruction( HloInstruction* old_instruction, std::unique_ptr new_instruction) { - VLOG(4) << "Replacing instruction:"; - VLOG(4) << " old: " << old_instruction->ToString(); - VLOG(4) << " new: " << new_instruction->ToString(); + VLOG(3) << "Replacing instruction:"; + VLOG(3) << " old: " << old_instruction->ToString(); + VLOG(3) << " new: " << new_instruction->ToString(); TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( old_instruction, std::move(new_instruction))); changed_ = true; @@ -253,9 +255,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Returns the Status representing the result of the replace operation. Status ReplaceInstruction(HloInstruction* old_instruction, HloInstruction* new_instruction) { - VLOG(4) << "Replacing instruction:"; - VLOG(4) << " old: " << old_instruction->ToString(); - VLOG(4) << " new: " << new_instruction->ToString(); + VLOG(3) << "Replacing instruction:"; + VLOG(3) << " old: " << old_instruction->ToString(); + VLOG(3) << " new: " << new_instruction->ToString(); TF_RETURN_IF_ERROR( computation_->ReplaceInstruction(old_instruction, new_instruction)); changed_ = true; @@ -339,6 +341,20 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add, return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { + // If a bitcast feeds a bitcast, make it a single bitcast. + if (bitcast->operand(0)->opcode() == HloOpcode::kBitcast) { + return ReplaceWithNewInstruction( + bitcast, HloInstruction::CreateUnary( + bitcast->shape(), HloOpcode::kBitcast, + bitcast->mutable_operand(0)->mutable_operand(0))); + } + // All bitcasts can be eliminated (assuming layout constraints are + // satisified). + ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0)); + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { // If a copy feeds a copy, make it a single copy. if (copy->operand(0)->opcode() == HloOpcode::kCopy) { @@ -514,11 +530,19 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) if (lhs->opcode() == HloOpcode::kDivide && rhs->opcode() == HloOpcode::kDivide) { + TF_ASSIGN_OR_RETURN( + const Shape a_times_d_shape, + ShapeInference::InferBinaryOpShape(HloOpcode::kMultiply, + lhs->operand(0), rhs->operand(1))); auto a_times_d = computation_->AddInstruction(HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs->mutable_operand(0), + a_times_d_shape, HloOpcode::kMultiply, lhs->mutable_operand(0), rhs->mutable_operand(1))); + TF_ASSIGN_OR_RETURN( + const Shape b_times_c_shape, + ShapeInference::InferBinaryOpShape(HloOpcode::kMultiply, + lhs->operand(1), rhs->operand(0))); auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs->mutable_operand(1), + b_times_c_shape, HloOpcode::kMultiply, lhs->mutable_operand(1), rhs->mutable_operand(0))); return ReplaceWithNewInstruction( divide, HloInstruction::CreateBinary( @@ -527,8 +551,11 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, // (A / B) / C => A / (B * C) if (lhs->opcode() == HloOpcode::kDivide) { + TF_ASSIGN_OR_RETURN(const Shape b_times_c_shape, + ShapeInference::InferBinaryOpShape( + HloOpcode::kMultiply, lhs->operand(1), rhs)); auto b_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); + b_times_c_shape, HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); return ReplaceWithNewInstruction( divide, HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, @@ -537,8 +564,11 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, // A / (B / C) => (A*C) / B if (rhs->opcode() == HloOpcode::kDivide) { + TF_ASSIGN_OR_RETURN(const Shape a_times_c_shape, + ShapeInference::InferBinaryOpShape( + HloOpcode::kMultiply, lhs, rhs->operand(1))); auto a_times_c = computation_->AddInstruction(HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs, rhs->mutable_operand(1))); + a_times_c_shape, HloOpcode::kMultiply, lhs, rhs->mutable_operand(1))); return ReplaceWithNewInstruction( divide, HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index be71e03e985a285abafc2adf7219b6aca2a775b6..c442e2d0bc962a5e4aae9a563099e9584b41f201 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -199,21 +199,22 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { // Test that (A/B)/(C/D) is simplified to (A*D)/(B*C). TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123}); HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* param1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction::CreateParameter(1, r2f32, "param1")); HloInstruction* param2 = builder.AddInstruction( - HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction::CreateParameter(2, r2f32, "param2")); HloInstruction* param3 = builder.AddInstruction( HloInstruction::CreateParameter(3, r0f32, "param3")); HloInstruction* div0 = builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, param1)); + HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1)); HloInstruction* div1 = builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param2, param3)); + HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param2, param3)); builder.AddInstruction( - HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div0, div1)); + HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -229,6 +230,8 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { EXPECT_THAT( computation->root_instruction(), op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2))); + EXPECT_TRUE( + ShapeUtil::Compatible(computation->root_instruction()->shape(), r2f32)); } // Test that A/exp(B) is simplified to A*exp(-B). @@ -995,7 +998,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { HloInstruction* zero = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR1({1., 2., 3.}))); builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {}), HloOpcode::kMaximum, reshape, zero)); + ShapeUtil::MakeShape(F32, {3}), HloOpcode::kMaximum, reshape, zero)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index b450e0c40074344778109ed2ba8b2238cff7940e..c0f3bcdc2218199288eaa3d0010ee70632c8f959 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -313,6 +313,49 @@ Status CallGraph::VisitNodes(const VisitorFunction& visitor_func, return Status::OK(); } +bool CallGraph::IsFlattened() const { + for (const CallGraphNode& node : nodes_) { + if (node.context() == CallContext::kBoth) { + return false; + } + if (node.context() == CallContext::kSequential && + node.caller_callsites().size() > 1) { + return false; + } + } + return true; +} + +std::pair +CallGraph::NearestAncestorsInSameComputation(HloInstruction* a, + HloInstruction* b) const { + // Lambda which returns the next instruction in the callee->caller chain in + // the call graph. This is the unique instruction which calls the computation + // containing 'instruction'. If more than one instruction calls the + // computation containing 'instruction' or no instructions call the + // computation then nullptr is returned. + auto next_caller = [this](HloInstruction* instruction) -> HloInstruction* { + const CallGraphNode& node = GetNode(instruction->parent()); + if (node.caller_callsites().size() != 1) { + return nullptr; + } + return node.caller_callsites()[0].instruction(); + }; + + // Iterate through the callee->caller chains and find the earliest common + // element. + for (HloInstruction* a_ancestor = a; a_ancestor != nullptr; + a_ancestor = next_caller(a_ancestor)) { + for (HloInstruction* b_ancestor = b; b_ancestor != nullptr; + b_ancestor = next_caller(b_ancestor)) { + if (a_ancestor->parent() == b_ancestor->parent()) { + return {a_ancestor, b_ancestor}; + } + } + } + return {nullptr, nullptr}; +} + string CallGraph::ToString() const { string out; Appendf(&out, "Call graph for module %s:\n", module_->name().c_str()); diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index a3297ff534f429279fd4674517db545f289af627..688c4085dfb4f47d3e08a4abee5e7b645f595b11 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -203,6 +203,39 @@ class CallGraph { return Dominates(computation, instruction->parent()); } + // Returns the nearest call graph ancestors of instructions 'a' and 'b' for + // which the ancestors are in the same computation. An instruction is an call + // graph ancestor of 'a' if the instruction calls the computation containing + // 'a' either directly or transitively. Degeneratively an instruction is an + // ancestor of itself. nullptr is returned if there is no common ancestor or + // if the caller chain of 'a' or 'b' diverges (has multiple callers) before + // the nearest common ancestor. + // + // Example: + // + // Entry computation: + // %x = Call(A, {Constant(42.0)}) + // %y = Call(B, {%x}) + // + // Computation A: + // %a = Negate(Param()) + // + // Computation B: + // %b = Exp(Param()); + // + // If called with %a and %b, this function would return (%x, %y). %x is an + // ancestor of %a, and %y is an ancestor of %b, and %x and %y are in the same + // computation. + std::pair NearestAncestorsInSameComputation( + HloInstruction* a, HloInstruction* b) const; + + // Returns whether the call graph is flattened. A call graph is flattened if + // every computation called in a sequential context (eg, kWhile or kCall) has + // zero or one callsite, and no computation is called from both a parallel and + // sequential context. The call graph of a module can be flattened with + // FlattenCallGraph. + bool IsFlattened() const; + string ToString() const; private: diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 3c22871b3bff193c27ee2eb639fe72306d532b97..4243d37a77e10dce950d421f87a16d56e4829e4c 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -97,6 +97,8 @@ TEST_F(CallGraphTest, SingletonComputation) { module->AddEntryComputation(MakeScalarComputation()); std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(1, call_graph->nodes().size()); + EXPECT_TRUE(call_graph->IsFlattened()); + const CallGraphNode& node = call_graph->GetNode(computation); EXPECT_EQ(computation, node.computation()); EXPECT_TRUE(node.callsites().empty()); @@ -169,6 +171,10 @@ TEST_F(CallGraphTest, SequentialComputations) { std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); + // The called computation is only called from one other computation, but there + // are multiple callsites. + EXPECT_FALSE(call_graph->IsFlattened()); + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); EXPECT_EQ(entry_computation, entry_node.computation()); EXPECT_EQ(CallContext::kSequential, entry_node.context()); @@ -206,6 +212,8 @@ TEST_F(CallGraphTest, ContextBothComputations) { std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(2, call_graph->nodes().size()); + EXPECT_FALSE(call_graph->IsFlattened()); + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); EXPECT_EQ(entry_computation, entry_node.computation()); EXPECT_EQ(2, entry_node.callsites().size()); @@ -273,6 +281,7 @@ TEST_F(CallGraphTest, ComplexGraph) { std::unique_ptr call_graph = CallGraph::Build(module.get()); EXPECT_EQ(5, call_graph->nodes().size()); + EXPECT_FALSE(call_graph->IsFlattened()); // Entry computation has one while instruction calling two computations // (cond_computation and a_computation). @@ -347,6 +356,78 @@ TEST_F(CallGraphTest, ComplexGraph) { EXPECT_TRUE(call_graph->Dominates(cond_computation, cond_computation)); } +TEST_F(CallGraphTest, ComplexGraphNearestAncestors) { + // Test NearestAncestorsInSameComputation on a call graph of a module with + // several computation called in various contexts. The call graph looks like: + // + // entry + // / | + // a | + // / | \ | + // b | cond + // \ | + // c + // + // Calls are made via kCall, kWhile, and kMap instructions. + auto module = CreateNewModule(); + HloComputation* cond_computation = + module->AddEmbeddedComputation(MakeConditionComputation()); + HloComputation* c_computation = + module->AddEmbeddedComputation(MakeScalarComputation()); + HloComputation* b_computation = module->AddEmbeddedComputation( + MakeMappingComputation(c_computation, /*callsites=*/1)); + HloInstruction* b_map = b_computation->root_instruction(); + + HloComputation* a_computation; + HloInstruction* a_call; + HloInstruction* a_while; + { + HloComputation::Builder builder(TestName() + ".a"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + a_call = builder.AddInstruction( + HloInstruction::CreateCall(kScalarShape, {param0}, c_computation)); + a_while = builder.AddInstruction(HloInstruction::CreateWhile( + kScalarShape, cond_computation, b_computation, a_call)); + a_computation = module->AddEmbeddedComputation(builder.Build()); + } + + HloComputation* entry_computation; + HloInstruction* entry_while; + { + HloComputation::Builder builder(TestName() + ".entry"); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, kScalarShape, "param0")); + entry_while = builder.AddInstruction(HloInstruction::CreateWhile( + kScalarShape, cond_computation, a_computation, param0)); + entry_computation = module->AddEntryComputation(builder.Build()); + } + + std::unique_ptr call_graph = CallGraph::Build(module.get()); + EXPECT_EQ(5, call_graph->nodes().size()); + + // Verify NearestAncestorsInSameComputation for various instructions in the + // module. + EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_call, a_call), + std::make_pair(a_call, a_call)); + + // c_computation is called from more than one site, so + // NearestAncestorsInSameComputation bails and returns nullptrs. + std::pair null_pair = {nullptr, nullptr}; + EXPECT_EQ(call_graph->NearestAncestorsInSameComputation( + b_map, c_computation->root_instruction()), + null_pair); + + EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, entry_while), + std::make_pair(entry_while, entry_while)); + EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(b_map, a_call), + std::make_pair(a_while, a_call)); + EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, a_call), + std::make_pair(a_while, a_call)); + EXPECT_EQ(call_graph->NearestAncestorsInSameComputation(a_while, b_map), + std::make_pair(a_while, a_while)); +} + TEST_F(CallGraphTest, VisitSingletonComputation) { // Test the call graph visitor with a call graph with a single node. auto module = CreateNewModule(); diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 08eabd66828fb755471d1e3ce25627eef3edc5bc..141582c0690474d27cd6917dd6031d33004be5d0 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -59,10 +59,63 @@ CompilerFunctor::AllIntrinsics() { return intrinsics; } +/* Create filtered versions of the LLVM Pass Managers to filter out some +of the expensive passes. +Profiling: + learning/brain/google/xla/benchmarks:inception_cpu_benchmark + learning/brain/google/xla/benchmarks:cifarnet +pointed to LICM and IndVarSimplify as the hottest passes. +LICM is known to exhibit O(n^2) time in the number of instructions. +IndVarSimplify is slow due to SCEV. If loops are emitted in canonical form, +this pass is not necessary. +Disabling these as a starting point. +*/ +// TODO(b/64227304) Creating a custom pass pipeline will replace this. + +class FilteredFunctionPassManager : public llvm::legacy::FunctionPassManager { + public: + FilteredFunctionPassManager(llvm::Module* m, bool disable_expensive_passes) + : llvm::legacy::FunctionPassManager(m), + disable_expensive_passes_(disable_expensive_passes) {} + void add(llvm::Pass* p) override { + if (disable_expensive_passes_) { + llvm::StringRef PassName = p->getPassName(); + if (PassName.contains("LICM") || PassName.contains("IndVarSimplify") || + PassName.contains("LoopUnroll")) { + return; + } + } + llvm::legacy::FunctionPassManager::add(p); + } + + private: + bool disable_expensive_passes_; +}; + +class FilteredPassManager : public llvm::legacy::PassManager { + public: + explicit FilteredPassManager(bool disable_expensive_passes) + : disable_expensive_passes_(disable_expensive_passes) {} + void add(llvm::Pass* p) override { + if (disable_expensive_passes_) { + llvm::StringRef PassName = p->getPassName(); + if (PassName.contains("LICM") || PassName.contains("IndVarSimplify") || + PassName.contains("LoopUnroll")) { + return; + } + } + llvm::legacy::PassManager::add(p); + } + + private: + bool disable_expensive_passes_; +}; + llvm::object::OwningBinary CompilerFunctor:: operator()(llvm::Module& module) const { - llvm::legacy::PassManager module_passes; - llvm::legacy::FunctionPassManager function_passes(&module); + FilteredPassManager module_passes(disable_expensive_passes_); + FilteredFunctionPassManager function_passes(&module, + disable_expensive_passes_); VLOG(2) << "IR before optimizations"; XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module)); diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.h b/tensorflow/compiler/xla/service/cpu/compiler_functor.h index 7187f14c96a8d0d064bf2673110d99b3b32fb7ce..8cdd049e7b773bdc455db627ff1749997d621ee4 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.h +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.h @@ -44,6 +44,7 @@ class CompilerFunctor { explicit CompilerFunctor( llvm::TargetMachine* target_machine, const Disassembler* disassembler, int opt_level, bool optimize_for_size, bool enable_fast_math, + bool disable_expensive_passes, const VectorIntrinsics& available_intrinsics, LLVMCompiler::ModuleHook pre_optimization_hook = nullptr, LLVMCompiler::ModuleHook post_optimization_hook = nullptr) @@ -52,6 +53,7 @@ class CompilerFunctor { opt_level_(opt_level), optimize_for_size_(optimize_for_size), enable_fast_math_(enable_fast_math), + disable_expensive_passes_(disable_expensive_passes), available_intrinsics_(available_intrinsics), pre_optimization_hook_(pre_optimization_hook), post_optimization_hook_(post_optimization_hook) {} @@ -75,6 +77,7 @@ class CompilerFunctor { const unsigned opt_level_; const bool optimize_for_size_; const bool enable_fast_math_; + const bool disable_expensive_passes_; const VectorIntrinsics available_intrinsics_; LLVMCompiler::ModuleHook pre_optimization_hook_; LLVMCompiler::ModuleHook post_optimization_hook_; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 154223727904ebd3dc1c3aa5465e57310a48ec0b..8c7c2aa70eeb7f05bcc8931b4b3d1a25cb57ed80 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -198,8 +198,8 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { std::unordered_map hlo_to_profile_idx; CollectProfileCandidates profile_candidates_for_computation( &hlo_to_profile_idx); - TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( - &profile_candidates_for_computation)); + TF_RETURN_IF_ERROR( + computation->Accept(&profile_candidates_for_computation)); return hlo_to_profile_idx; } @@ -253,11 +253,11 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { Status CpuCompiler::RunHloPasses(HloModule* module) { // Optimization pipeline. HloPassPipeline pipeline("CPU"); - pipeline.AddInvariantChecker(); + pipeline.AddInvariantChecker(ShapeSizeBytesFunction()); ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), - HloReducePrecisionOptions::BEFORE_OP_FUSION); + ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); // TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding // where we will take this pass in future. @@ -292,10 +292,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module) { ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), - HloReducePrecisionOptions::AFTER_OP_FUSION); - ReducePrecisionInsertion::AddPasses( - &pipeline, module->config().debug_options(), - HloReducePrecisionOptions::FUSION_BY_CONTENT); + ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( module->mutable_entry_computation_layout()); @@ -436,6 +433,10 @@ Status InitializeModuleHooks( StatusOr> CpuCompiler::Compile( std::unique_ptr module, se::StreamExecutor* stream_exec) { + const string timer_message = + "Compiling [" + module->name() + "] for CPU using JIT"; + ScopedLoggingTimer compiling_timer(timer_message, 1); + VLOG(1) << "Compiling: " << module->name(); TF_RET_CHECK(stream_exec != nullptr); std::call_once(llvm_command_line_options_initialized, @@ -451,11 +452,13 @@ StatusOr> CpuCompiler::Compile( auto llvm_context = MakeUnique(); auto llvm_module = MakeUnique("__compute_module", *llvm_context); + auto jit = MakeUnique( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), module->config().debug_options().xla_enable_fast_math(), + module->config().debug_options().xla_llvm_disable_expensive_passes(), pre_optimization_ir_hook, post_optimization_ir_hook); llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); @@ -809,6 +812,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, target_machine.get(), &disassembler, opt_level, options::OptimizeForSizeRequested(module->config()), module->config().debug_options().xla_enable_fast_math(), + module->config().debug_options().xla_llvm_disable_expensive_passes(), CompilerFunctor::AllIntrinsics(), pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook); llvm::object::OwningBinary object_file = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc index 4d0e0f744ac4b02f7c4a74c5a341d6b9ce937967..20ee4f12e53a16b76d39d0151bc5b8ca4475f7ab 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc @@ -70,7 +70,7 @@ StatusOr ParallelizationPreparation::Run(HloModule* module) { while (CanOutlineWithUser(outline_candidate)) { HloInstruction* prior_candidate = outline_candidate; outline_candidate = *outline_candidate->users().begin(); - all_bitcasts |= outline_candidate->opcode() == HloOpcode::kBitcast; + all_bitcasts &= outline_candidate->opcode() == HloOpcode::kBitcast; if (std::any_of(outline_candidate->operands().begin(), outline_candidate->operands().end(), [&](const HloInstruction* operand) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index c5275ede651bb5e4a35a4e14a9baf966cc036040..06c94e19de2ea5244bd437128b2beb25d72d98c3 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -240,6 +240,13 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { compute_function_->addFnAttr(llvm::Attribute::OptimizeForSize); } + if (hlo_module_config_.debug_options().xla_enable_fast_math()) { + compute_function_->addFnAttr("unsafe-fp-math", "true"); + compute_function_->addFnAttr("no-infs-fp-math", "true"); + compute_function_->addFnAttr("no-nans-fp-math", "true"); + compute_function_->addFnAttr("no-signed-zeros-fp-math", "true"); + } + ir_builder_.SetInsertPoint(llvm::BasicBlock::Create( /*Context=*/module_->getContext(), /*Name=*/"entry", diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index bef4ecd480d1fb0f593d95bac25f2184805f33cd..40fa3a67bdec3953003ba8f98f2a19a9082a82c5 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -241,7 +241,7 @@ Status Executor::Run() { completion_queue_.pop_front(); break; } - } while (1); + } while (true); TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, assignment_->GetUniqueTopLevelSlice(instruction)); void* result_buffer = diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 573647534faa5c49d1a846a848d7b6457005a506..c3c11df090e88c3c24104b66d28b3b16f03baa80 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -176,6 +176,7 @@ CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() { SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, bool enable_fast_math, + bool disable_expensive_passes, LLVMCompiler::ModuleHook pre_optimization_hook, LLVMCompiler::ModuleHook post_optimization_hook) : target_machine_( @@ -190,12 +191,13 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, data_layout_(target_machine_->createDataLayout()), object_layer_( [] { return std::make_shared(); }), - compile_layer_(object_layer_, - CompilerFunctor(target_machine_.get(), &disassembler_, - opt_level, optimize_for_size, - enable_fast_math, GetAvailableIntrinsics(), - std::move(pre_optimization_hook), - std::move(post_optimization_hook))) { + compile_layer_( + object_layer_, + CompilerFunctor(target_machine_.get(), &disassembler_, opt_level, + optimize_for_size, enable_fast_math, + disable_expensive_passes, GetAvailableIntrinsics(), + std::move(pre_optimization_hook), + std::move(post_optimization_hook))) { VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() << " features: " << target_machine_->getTargetFeatureString().str(); } diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 331e18bc8b3304a366b309d01697f7a2f53a0b6e..e476c0e3812cc0fb2a2d633832374b3165ca072a 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -57,13 +57,15 @@ class SimpleOrcJIT { // generator. // The |optimize_for_size| parameter specifies that the code generator should // optimize to reduce code size, potentially at the cost of performance. + // The |disable_expensive_passes| parameter will disable certain optimization + // passes // The |pre_optimization_hook| is invoked on the module before any IR // level optimizations are applied. // The |post_optimization_hook| is invoked on the module after all IR // level optimizations are applied. SimpleOrcJIT(const llvm::TargetOptions& target_options, llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, - bool enable_fast_math, + bool enable_fast_math, bool disable_expensive_passes, LLVMCompiler::ModuleHook pre_optimization_hook, LLVMCompiler::ModuleHook post_optimization_hook); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc index 669ebb55beca32bdb6739d120fb4ac1585b32cee..6efd0bcee58d19b355b6c2afa6d9497f75ef4b3c 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc @@ -24,16 +24,14 @@ limitations under the License. namespace xla { -Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo, - HloOpcode opcode) { +Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s", - HloOpcodeString(opcode).c_str()); + HloOpcodeString(hlo->opcode()).c_str()); } -Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo, - HloOpcode opcode) { +Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s", - HloOpcodeString(opcode).c_str()); + HloOpcodeString(hlo->opcode()).c_str()); } DfsHloVisitor::VisitState DfsHloVisitor::GetVisitState( diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index a1a3a882c7a570a850212ce21285371ecb6e3ed7..2f21043a1d341aecd14c0476fb61a8ff511656ea 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -63,37 +63,37 @@ class DfsHloVisitor { // These routines are self-descriptive, see class comment for usage // information. - virtual Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode); - virtual Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode); + virtual Status HandleElementwiseUnary(HloInstruction* hlo); + virtual Status HandleElementwiseBinary(HloInstruction* hlo); virtual Status HandleClamp(HloInstruction* clamp, HloInstruction* min, HloInstruction* arg, HloInstruction* max) = 0; virtual Status HandleSelect(HloInstruction* select, HloInstruction* pred, HloInstruction* on_true, HloInstruction* on_false) = 0; virtual Status HandleMaximum(HloInstruction* maximum) { - return HandleElementwiseBinary(maximum, HloOpcode::kMaximum); + return HandleElementwiseBinary(maximum); } virtual Status HandleMinimum(HloInstruction* minimum) { - return HandleElementwiseBinary(minimum, HloOpcode::kMinimum); + return HandleElementwiseBinary(minimum); } virtual Status HandleConcatenate( HloInstruction* concatenate, tensorflow::gtl::ArraySlice operands) = 0; virtual Status HandleConvert(HloInstruction* convert) { - return HandleElementwiseUnary(convert, HloOpcode::kConvert); + return HandleElementwiseUnary(convert); } virtual Status HandleCopy(HloInstruction* copy) { - return HandleElementwiseUnary(copy, HloOpcode::kCopy); + return HandleElementwiseUnary(copy); } virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(multiply, HloOpcode::kMultiply); + return HandleElementwiseBinary(multiply); } virtual Status HandleDot(HloInstruction* dot, HloInstruction* lhs, HloInstruction* rhs) = 0; virtual Status HandlePower(HloInstruction* power, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(power, HloOpcode::kPower); + return HandleElementwiseBinary(power); } virtual Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, @@ -101,73 +101,72 @@ class DfsHloVisitor { virtual Status HandleCrossReplicaSum(HloInstruction* crs) = 0; virtual Status HandleCompare(HloInstruction* compare, HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(compare, opcode); + return HandleElementwiseBinary(compare); } virtual Status HandleAdd(HloInstruction* add, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(add, HloOpcode::kAdd); + return HandleElementwiseBinary(add); } virtual Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(divide, HloOpcode::kDivide); + return HandleElementwiseBinary(divide); } virtual Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(remainder, HloOpcode::kRemainder); + return HandleElementwiseBinary(remainder); } virtual Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(subtract, HloOpcode::kSubtract); + return HandleElementwiseBinary(subtract); } virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { - return HandleElementwiseUnary(abs, HloOpcode::kAbs); + return HandleElementwiseUnary(abs); } virtual Status HandleSign(HloInstruction* sign, HloInstruction* operand) { - return HandleElementwiseUnary(sign, HloOpcode::kSign); + return HandleElementwiseUnary(sign); } virtual Status HandleNegate(HloInstruction* negate, HloInstruction* operand) { - return HandleElementwiseUnary(negate, HloOpcode::kNegate); + return HandleElementwiseUnary(negate); } virtual Status HandleExp(HloInstruction* exp, HloInstruction* operand) { - return HandleElementwiseUnary(exp, HloOpcode::kExp); + return HandleElementwiseUnary(exp); } virtual Status HandleFloor(HloInstruction* floor, HloInstruction* operand) { - return HandleElementwiseUnary(floor, HloOpcode::kFloor); + return HandleElementwiseUnary(floor); } virtual Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) { - return HandleElementwiseUnary(ceil, HloOpcode::kCeil); + return HandleElementwiseUnary(ceil); } virtual Status HandleLog(HloInstruction* log, HloInstruction* operand) { - return HandleElementwiseUnary(log, HloOpcode::kLog); + return HandleElementwiseUnary(log); } virtual Status HandleCos(HloInstruction* cos, HloInstruction* operand) { - return HandleElementwiseUnary(cos, HloOpcode::kCos); + return HandleElementwiseUnary(cos); } virtual Status HandleSin(HloInstruction* sin, HloInstruction* operand) { - return HandleElementwiseUnary(sin, HloOpcode::kSin); + return HandleElementwiseUnary(sin); } virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) { - return HandleElementwiseUnary(tanh, HloOpcode::kTanh); + return HandleElementwiseUnary(tanh); } virtual Status HandleIsFinite(HloInstruction* is_finite, HloInstruction* operand) { - return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite); + return HandleElementwiseUnary(is_finite); } virtual Status HandleLogicalAnd(HloInstruction* logical_and, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd); + return HandleElementwiseBinary(logical_and); } virtual Status HandleLogicalNot(HloInstruction* logical_not, HloInstruction* operand) { - return HandleElementwiseUnary(logical_not, HloOpcode::kLogicalNot); + return HandleElementwiseUnary(logical_not); } virtual Status HandleLogicalOr(HloInstruction* logical_or, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr); + return HandleElementwiseBinary(logical_or); } virtual Status HandleReducePrecision(HloInstruction* reduce_precision) { - return HandleElementwiseUnary(reduce_precision, - HloOpcode::kReducePrecision); + return HandleElementwiseUnary(reduce_precision); } virtual Status HandleInfeed(HloInstruction* infeed) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 10f8ae9b044bda60beec6a19a5080217bfd2ffbb..a5fe120598416235dff2af9d8a5c0ae64ac9edcc 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -41,12 +41,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { // Default action performed on HloInstruction. virtual Status DefaultAction(HloInstruction* hlo_instruction) = 0; - Status HandleElementwiseUnary(HloInstruction* hlo, - HloOpcode opcode) override { + Status HandleElementwiseUnary(HloInstruction* hlo) override { return DefaultAction(hlo); } - Status HandleElementwiseBinary(HloInstruction* hlo, - HloOpcode opcode) override { + Status HandleElementwiseBinary(HloInstruction* hlo) override { return DefaultAction(hlo); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 84bdd5acac37773dd24f78a1ca696f5819558c50..350dbc321fb2234912d2143adfe70b75b48d0e27 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -616,7 +616,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( auto random_value = [hlo]() { const HloModule* module = - hlo->IsFused() ? hlo->fusion_instruction()->parent()->parent() + hlo->IsFused() ? hlo->parent()->FusionInstruction()->parent()->parent() : hlo->parent()->parent(); return module->RandomNew64(); }; @@ -709,7 +709,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( } else { auto r = ir_builder_->CreateSub(q, p); auto leading_zeros = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(1)}, + llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(true)}, {param_ir_type}, ir_builder_); auto in_block = ir_builder_->GetInsertBlock(); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index d044462f9a710e01fdb9caf4f5cece6f92af53f6..5edaaba3ebe482126c800059968d0e430076f950 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -334,7 +334,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder_); IrArray::Index input_index(index.size()); - llvm::Value* in_bounds = ir_builder_->getInt1(1); + llvm::Value* in_bounds = ir_builder_->getInt1(true); for (size_t i = 0; i < index.size(); ++i) { llvm::Value* stridden_index = ir_builder_->CreateNSWMul( index[i], ir_builder_->getInt64(window.dimensions(i).stride())); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index dd51c11d5e7b2a1dd32c2e27bb75ea3cfb89c42e..7f5be602beb1c6b337ee7ef86ec19271d1f73cb5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -120,17 +120,20 @@ string GetLibdeviceDir(const HloModuleConfig& config) { } // Runs optimization passes on the given HLO module. -tensorflow::Status OptimizeHloModule(HloModule* hlo_module, - const se::DeviceDescription& device_desc) { +tensorflow::Status OptimizeHloModule( + HloModule* hlo_module, const se::DeviceDescription& device_desc, + const HloCostAnalysis::ShapeSizeFunction& shape_size_function) { { HloPassPipeline pipeline("optimization"); - pipeline.AddInvariantChecker(); + pipeline.AddInvariantChecker(shape_size_function); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), - HloReducePrecisionOptions::BEFORE_OP_FUSION); + ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); { auto& pass = pipeline.AddPass>("simplification"); + pass.AddInvariantChecker(shape_size_function); + // TODO(b/62764704): Do not rewrite on GPU, use cuDNN's BatchNorm APIs // instead. pass.AddPass( @@ -158,18 +161,17 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, } { HloPassFix fusion("fusion"); + fusion.AddInvariantChecker(shape_size_function); fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); + reduce_pipeline.AddInvariantChecker(shape_size_function); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), - HloReducePrecisionOptions::AFTER_OP_FUSION); - ReducePrecisionInsertion::AddPasses( - &reduce_pipeline, hlo_module->config().debug_options(), - HloReducePrecisionOptions::FUSION_BY_CONTENT); + ReducePrecisionInsertion::PassTiming::AFTER_FUSION); StatusOr reduce_result = reduce_pipeline.Run(hlo_module); TF_RETURN_IF_ERROR(reduce_result.status()); @@ -184,14 +186,16 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { +tensorflow::Status PrepareHloModuleForIrEmitting( + HloModule* hlo_module, + const HloCostAnalysis::ShapeSizeFunction& shape_size_function) { // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker(); + pipeline.AddInvariantChecker(shape_size_function); pipeline.AddPass(); pipeline.AddPass( hlo_module->mutable_entry_computation_layout()); @@ -260,9 +264,11 @@ StatusOr> GpuCompiler::Compile( std::unique_ptr module, se::StreamExecutor* stream_exec) { TF_RET_CHECK(stream_exec != nullptr); + TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), + stream_exec->GetDeviceDescription(), + ShapeSizeBytesFunction())); TF_RETURN_IF_ERROR( - OptimizeHloModule(module.get(), stream_exec->GetDeviceDescription())); - TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); + PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction())); llvm::LLVMContext llvm_context; std::string buffer; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 202a0171dbe742d205f86afc79b663cfabe1c706..a40eb6afc2ffe80ae17980d2ea935b60eb6107c0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -87,6 +87,9 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( } } + // TODO(b/65380986): Investigate if adding fast math flags for generated + // kernels makes sense. + llvm::BasicBlock* entry_bb = llvm::BasicBlock::Create(function->getContext(), "entry", function); // Emit a "return void" at entry_bb's end, and sets the insert point before diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index e7d32a4ae1448a375a2625707da5e82dc8bb5a2b..b84284046b0e61267d4427b3c6a0f9e5215552f5 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -201,6 +201,9 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( } kernel->addAttribute(temp_buffer_arg_no + 1, llvm::Attribute::NoAlias); + // TODO(b/65380986): Investigate if adding fast math flags for generated + // kernels makes sense. + // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX // treats it as a CUDA kernel. llvm::NamedMDNode* nvvm_annotations_node = @@ -894,7 +897,7 @@ Status IrEmitterUnnested::EmitColumnReduction( llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &ir_builder_); const HloInstruction* output = - reduce->IsFused() ? reduce->fusion_instruction() : reduce; + reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress( llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), &ir_builder_, "output_element_address"); @@ -1142,7 +1145,7 @@ Status IrEmitterUnnested::EmitRowReduction( } const HloInstruction* output = - reduce->IsFused() ? reduce->fusion_instruction() : reduce; + reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; // Emit an atomic operation that accumulates the partial reduction result of // lane 0 (which holds the partially accumulated result for its warp) to the @@ -1913,10 +1916,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator()); - // const HloInstruction* root = hlo.fused_expression_root(); - llvm_ir::EmitTuple( - GetIrArray(*hlo.fused_expression_root()->fusion_instruction()), - tuple_operand_ptrs, &ir_builder_); + llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 2a999f52f012499b14630078612eff2d4f312a3d..2e7765c4c61a18f482fcc659dc1de8408a9d37b8 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -389,7 +389,7 @@ StatusOr CompileModuleToPtx(llvm::Module* module, // Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA // again after the standard optimization passes [http://b/13329423]. - // TODO(jingyue): SROA may further expose more optimization opportunities, such + // TODO(jingyue): SROA may further expose more optimization opportunities such // as more precise alias analysis and more function inlining (SROA may change // the inlining cost of a function). For now, running SROA already emits good // enough code for the evaluated benchmarks. We may want to run more diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index cecbb01ff88e9b63d208467485c6d25008277325..ccdd1717593e4fa7c1d1deb3f0f9ebfab1bf7209 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -308,7 +308,7 @@ class WhileConditionComputationMatcher : public MatcherBase { GetTaggedInstruction("gte.fusion_param.param0", tagged_instructions)); CHECK_EQ(HloOpcode::kParameter, gte_fusion_param0->opcode()); CHECK(gte_fusion_param0->IsFused()); - if (gte_fusion_param0->fusion_instruction()->operand( + if (gte_fusion_param0->parent()->FusionInstruction()->operand( gte_fusion_param0->parameter_number()) != computation_->parameter_instruction(0)) { return InvalidArgument("Could not match fusion param: %s", @@ -469,7 +469,8 @@ class WhileBodyComputationMatcher : public MatcherBase { // Fusion parameter: lookup and compare with associated fusion operand. CHECK_EQ(HloOpcode::kParameter, inst->opcode()); CHECK(inst->IsFused()); - if (inst->fusion_instruction()->operand(inst->parameter_number()) != + if (inst->parent()->FusionInstruction()->operand( + inst->parameter_number()) != computation_->parameter_instruction(0)) { return InvalidArgument("Could not match fusion param: %s", inst->name().c_str()); diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 0beea423798af731a6510ffedf5176bf94a1ff91..3dd8ac6dc5fa46b80328e080e6d1b4e8c402e8b0 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -37,6 +37,230 @@ namespace xla { using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; +// Data structure used to construct the alias analysis. Thrown away after alias +// analysis is complete. This data structure keeps track of which sets of +// HloValues must be in the same HloBuffer. This is maintained as a map from a +// buffer identifier (BufferNumber) to set of HLoValues. +// +// Initially each value is its own buffer. In MergeAliasedBuffers, sets of +// values which must share the same buffer are merged together. The end result +// is a partitioning of all HloValues into sets where each set needs its own +// HloBuffer. By performing this analysis without constructing HloBuffers on the +// fly, we can after-the-fact construct a vector of contiguously numbered +// HloBuffers after the buffer requirement has been determined. +class BufferValueMap { + public: + // A unique identifier for a set of colocated values which must share the same + // buffer. This is not necessarily the same as the HloBuffer::Id which will + // ultimately contain the values. The reason is that HloBuffer::Id's are + // contiguous, while BufferNumbers may not be. BufferNumbers may not be + // dense because buffers may be created and destroyed during the analysis + // construction process. + using BufferNumber = int64; + + explicit BufferValueMap(const HloDataflowAnalysis& dataflow) + : dataflow_(dataflow) { + buffers_.reserve(dataflow_.values().size()); + value_to_buffer_number_.reserve(dataflow_.values().size()); + for (const HloValue* value : dataflow_.values()) { + BufferNumber buffer_number = next_buffer_number_++; + buffers_[buffer_number].insert(value); + value_to_buffer_number_[value] = buffer_number; + } + } + + // Merge together sets of HloValues which must be in the same HloBuffer + // because of aliasing rules (eg, in-place kWhile instruction). + void MergeAliasedBuffers() { + for (const HloValue* value : dataflow_.values()) { + VLOG(3) << "Merging colocated values, value: " << value->ToShortString(); + + // Gather the set of buffers with aliasing rules (eg, kWhile) which this + // value must be contained in. + std::vector aliased_buffers = ComputeAliasedBuffers(*value); + + BufferNumber current_buffer = value_to_buffer_number_.at(value); + if (aliased_buffers.empty()) { + // The buffer containing 'value' aliases no other buffers. If the buffer + // containing 'value' already only contains 'value', then no change is + // necessary. If the buffer containing 'value' does contain other + // values, then remove 'value' from the buffer and create a new buffer + // containing only 'value' + if (buffers_.at(current_buffer).size() == 1) { + CHECK_EQ(*buffers_.at(current_buffer).begin(), value); + } else { + MoveValueToNewBuffer(*value); + } + } else { + // If multiple buffers are aliased merge these buffers together into a + // single buffer (arbitrarily chosen as the first buffer in the vector). + if (aliased_buffers.size() > 1) { + for (int64 i = 1; i < aliased_buffers.size(); ++i) { + MergeBuffers(/*from=*/aliased_buffers[i], + /*to=*/aliased_buffers[0]); + } + } + BufferNumber new_buffer = aliased_buffers[0]; + if (current_buffer != new_buffer) { + MoveValueToBuffer(*value, new_buffer); + } + } + } + } + + // Compute and return a sorted vector of all BufferNumbers. Can be used to + // iterate through all buffers stabily. + std::vector ComputeSortedBufferNumbers() const { + std::vector buffer_numbers; + for (const auto& pair : buffers_) { + buffer_numbers.push_back(pair.first); + } + std::sort(buffer_numbers.begin(), buffer_numbers.end()); + return buffer_numbers; + } + + // Return a set of all the values in the given buffer. + const tensorflow::gtl::FlatSet& GetValuesInBuffer( + BufferNumber buffer_number) const { + return buffers_.at(buffer_number); + } + + private: + // Create a new buffer. + void NewBuffer(const HloValue& value) { + BufferNumber buffer_number = next_buffer_number_++; + buffers_[buffer_number].insert(&value); + value_to_buffer_number_[&value] = buffer_number; + } + + // Move the given value into a new buffer containing only the value. + void MoveValueToNewBuffer(const HloValue& value) { + BufferNumber new_buffer_number = next_buffer_number_++; + buffers_[new_buffer_number]; + MoveValueToBuffer(value, new_buffer_number); + } + + // Move the given value into the given buffer. + void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) { + BufferNumber old_buffer_number = value_to_buffer_number_.at(&value); + buffers_.at(old_buffer_number).erase(&value); + if (buffers_.at(old_buffer_number).empty()) { + buffers_.erase(old_buffer_number); + } + + buffers_.at(buffer_number).insert(&value); + value_to_buffer_number_.at(&value) = buffer_number; + } + + // Merge the buffer 'from' into the buffer 'to'. + void MergeBuffers(BufferNumber from, BufferNumber to) { + auto& from_value_set = buffers_.at(from); + buffers_.at(to).insert(from_value_set.begin(), from_value_set.end()); + // NOTE: using a union-find algorithm to hold the colocated values might be + // faster. + for (const HloValue* value : from_value_set) { + value_to_buffer_number_.at(value) = to; + } + buffers_.erase(from); + } + + BufferNumber GetBufferForValue(const HloValue& value) { + return value_to_buffer_number_.at(&value); + } + + // Compute and return a vector of buffers that the given value must be + // contained in due to HLO aliasing rules. + std::vector ComputeAliasedBuffers(const HloValue& value) { + // Value is init of a while (use is while). + std::vector aliased_buffers; + for (const HloUse& use : value.uses()) { + VLOG(1) << "use of value " << value.ToShortString() << ": " << use; + if (use.instruction->opcode() == HloOpcode::kWhile) { + // Determine the while value that this shares a buffer with. + const HloValue& while_value = + dataflow_.GetUniqueValueAt(use.instruction, use.operand_index); + aliased_buffers.push_back(GetBufferForValue(while_value)); + VLOG(3) << " value is init value to a while; must share buffer with " + "while value " + << while_value.ToShortString(); + } + } + + // Value is a parameter of a while body/condition. + if (value.defining_instruction()->opcode() == HloOpcode::kParameter) { + const HloComputation* computation = + value.defining_instruction()->parent(); + const CallGraphNode& call_graph_node = + dataflow_.call_graph().GetNode(computation); + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + // Call graph must have been flattened. + CHECK_EQ(call_graph_node.caller_callsites().size(), 1); + + const HloValue& while_value = dataflow_.GetUniqueValueAt( + callsite.instruction(), value.defining_index()); + VLOG(3) << " value is parameter value of the body or condition of a " + "while; must share buffer with while value " + << while_value.ToShortString(); + aliased_buffers.push_back(GetBufferForValue(while_value)); + } + } + } + + // Value is the root of a while body. + for (const HloPosition& position : value.positions()) { + const HloComputation* computation = position.instruction->parent(); + const CallGraphNode& call_graph_node = + dataflow_.call_graph().GetNode(computation); + if (position.instruction == computation->root_instruction()) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kWhile && + callsite.instruction()->while_body() == computation) { + // Call graph must have been flattened. + CHECK_EQ(call_graph_node.caller_callsites().size(), 1); + + const HloValue& while_value = dataflow_.GetUniqueValueAt( + callsite.instruction(), position.index); + VLOG(3) << " value is root the body computation of a while; must " + "share buffer with while value " + << while_value.ToShortString(); + aliased_buffers.push_back(GetBufferForValue(while_value)); + } + } + } + } + + // Value is the output of the while instruction itself. + if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { + VLOG(3) << " value is output of a while instruction"; + aliased_buffers.push_back(GetBufferForValue(value)); + } + + // Uniquify aliased buffers. + std::sort(aliased_buffers.begin(), aliased_buffers.end()); + aliased_buffers.erase( + std::unique(aliased_buffers.begin(), aliased_buffers.end()), + aliased_buffers.end()); + + return aliased_buffers; + } + + // Dataflow analysis used to construct the buffer map. + const HloDataflowAnalysis& dataflow_; + + // A map containing the set of values contained in each buffer. + tensorflow::gtl::FlatMap> + buffers_; + + // A map indicating which buffer each value is contained in. + tensorflow::gtl::FlatMap + value_to_buffer_number_; + + // The buffer number of the next buffer to be created. + BufferNumber next_buffer_number_ = 0; +}; + HloAliasAnalysis::HloAliasAnalysis(HloModule* module) : module_(module) {} const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt( @@ -99,10 +323,11 @@ bool HloAliasAnalysis::InstructionBuffersAreDistinct( } } else { // It's possible for multiple values at this index to have the same - // HloBuffer. This does not result in non-distictness. To account for this - // case, add all of the buffers at this index after checking whether each - // buffer exists at an earlier index. This is a corner case, however, as - // the number of values at an index is almost always one. + // HloBuffer. This does not result in non-distictness. To account for + // this case, add all of the buffers at this index after checking + // whether each buffer exists at an earlier index. This is a corner + // case, however, as the number of values at an index is almost always + // one. std::vector buffers_at_this_index; for (const HloValue* value : value_set.values()) { const HloBuffer* buffer = &GetBufferContainingValue(*value); @@ -118,15 +343,6 @@ bool HloAliasAnalysis::InstructionBuffersAreDistinct( return true; } -void HloAliasAnalysis::InitializeBufferSets() { - // Initially define a buffer for every HloValue in the module. - for (const HloValue& value : dataflow_analysis_->values()) { - HloBuffer& buffer = NewHloBuffer(); - buffer.AddValue(value); - value_to_buffer_[&value] = &buffer; - } -} - Status HloAliasAnalysis::Verify() const { // Verify consistency between the value_to_buffer_ map and // HloBuffer::values(). @@ -137,9 +353,8 @@ Status HloAliasAnalysis::Verify() const { value) != buffer.values().end()); } - for (const auto& pair : buffers_) { - const HloBuffer::Id id = pair.first; - const HloBuffer& buffer = pair.second; + for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) { + const HloBuffer& buffer = buffers_[id]; TF_RET_CHECK(buffer.id() == id); HloValue::Id last_value_id = -1; @@ -152,116 +367,9 @@ Status HloAliasAnalysis::Verify() const { } } - if (!buffers_vector_.empty()) { - // buffers_vector_ should be a vector of all HloBuffers sorted by id. - std::vector buffers; - for (const auto& id_buffer : buffers_) { - buffers.push_back(&id_buffer.second); - } - std::sort(buffers.begin(), buffers.end(), HloBuffer::IdLessThan); - TF_RET_CHECK(buffers_vector_ == buffers); - } - - return Status::OK(); -} - -Status HloAliasAnalysis::VerifyAgainstReference() const { - TF_RETURN_IF_ERROR(Verify()); - - TF_ASSIGN_OR_RETURN(std::unique_ptr reference, - Run(module_)); - TF_RETURN_IF_ERROR(reference->Verify()); - - VLOG(2) << "This analysis:"; - XLA_VLOG_LINES(2, ToString()); - VLOG(2) << "Reference:"; - XLA_VLOG_LINES(2, reference->ToString()); - - // Create map from HloValue in the reference analysis to HloValue in this - // analysis and vice versa. - tensorflow::gtl::FlatMap reference_to_this; - tensorflow::gtl::FlatMap this_to_reference; - for (const HloValue& value : dataflow_analysis().values()) { - const HloValue& reference_value = - reference->dataflow_analysis().GetValueDefinedAt( - value.defining_instruction(), value.defining_index()); - reference_to_this[&reference_value] = &value; - this_to_reference[&value] = &reference_value; - } - - TF_RET_CHECK(buffers_.size() == reference->buffers_.size()) - << "Different number of buffers (" << buffers_.size() - << " != " << reference->buffers_.size() << ")"; - for (const auto& pair : reference->buffers_) { - const HloBuffer& reference_buffer = pair.second; - - // Find the corresponding buffer in the reference by taking the first value - // in the buffer, finding the corresponding value in the reference, and then - // finding the buffer holding that value. - TF_RET_CHECK(!reference_buffer.values().empty()); - const HloValue* reference_value = reference_buffer.values()[0]; - const HloValue* value = reference_to_this.at(reference_value); - const HloBuffer& buffer = GetBufferContainingValue(*value); - - // The buffer and the reference should have the exact same values. To make - // comparison easy, sort the values in the reference buffer identically to - // the values in the non-reference buffer (ie, by the corresponding id of - // the non-reference value). - std::vector reference_values = reference_buffer.values(); - std::sort(reference_values.begin(), reference_values.end(), - [&reference_to_this](const HloValue* a, const HloValue* b) { - return reference_to_this.at(a)->id() < - reference_to_this.at(b)->id(); - }); - TF_RET_CHECK(reference_values.size() == buffer.values().size()); - for (int i = 0; i < buffer.values().size(); ++i) { - TF_RET_CHECK(*reference_values[i] == *buffer.values()[i]) - << "Buffer:\n " << buffer - << "\ndoes not have the same values as reference buffer:\n " - << reference_buffer; - } - } - return Status::OK(); } -HloBuffer& HloAliasAnalysis::NewHloBuffer() { - HloBuffer::Id buffer_id = next_buffer_id_++; - auto emplaced = buffers_.emplace(std::piecewise_construct, - std::forward_as_tuple(buffer_id), - std::forward_as_tuple(buffer_id)); - CHECK(emplaced.second); - - buffers_vector_.clear(); - - return emplaced.first->second; -} - -void HloAliasAnalysis::MoveValueToNewBuffer(const HloValue& value) { - HloBuffer& new_buffer = NewHloBuffer(); - MoveValueToBuffer(value, &new_buffer); - - VLOG(3) << "Moved value " << value.ToShortString() << " into new buffer " - << new_buffer.id(); -} - -void HloAliasAnalysis::MoveValueToBuffer(const HloValue& value, - HloBuffer* buffer) { - HloBuffer& old_buffer = GetBufferContainingValue(value); - CHECK_NE(buffer, &old_buffer); - VLOG(3) << "Moved value " << value.ToShortString() << " from buffer " - << old_buffer.id() << " into buffer " << buffer->id(); - old_buffer.RemoveValue(value); - if (old_buffer.values().empty()) { - VLOG(3) << "Buffer " << old_buffer.id() << " now empty. Removing."; - buffers_.erase(old_buffer.id()); - buffers_vector_.clear(); - } - - buffer->AddValue(value); - value_to_buffer_[&value] = buffer; -} - string HloAliasAnalysis::ToString() const { string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n"); StrAppend(&out, " Buffers at each position:\n"); @@ -290,10 +398,10 @@ string HloAliasAnalysis::ToString() const { } StrAppend(&out, " Buffers:\n"); - for (const HloBuffer* buffer : buffers()) { - StrAppend(&out, " ", buffer->ToString(), "\n"); + for (const HloBuffer& buffer : buffers()) { + StrAppend(&out, " ", buffer.ToString(), "\n"); StrAppend(&out, " positions:\n"); - for (const HloPosition& position : buffer->ComputePositions()) { + for (const HloPosition& position : buffer.ComputePositions()) { StrAppend(&out, " ", position.ToString(), "\n"); } } @@ -301,217 +409,6 @@ string HloAliasAnalysis::ToString() const { return out; } -const std::vector& HloAliasAnalysis::buffers() const { - if (buffers_vector_.empty()) { - // Lazily construct vector of buffers. - buffers_vector_.reserve(buffers_.size()); - for (auto& pair : buffers_) { - buffers_vector_.push_back(&pair.second); - } - std::sort(buffers_vector_.begin(), buffers_vector_.end(), - HloBuffer::IdLessThan); - } else { - CHECK_EQ(buffers_vector_.size(), buffers_.size()); - for (const HloBuffer* buffer : buffers_vector_) { - DCHECK(ContainsKey(buffers_, buffer->id())); - DCHECK(&GetBuffer(buffer->id()) == buffer); - } - } - return buffers_vector_; -} - -void HloAliasAnalysis::UpdateAtInstructions( - tensorflow::gtl::ArraySlice instructions) { - VLOG(4) << "Updated HLO module:"; - XLA_VLOG_LINES(4, module_->ToString()); - - VLOG(3) << "Before update:"; - XLA_VLOG_LINES(3, ToString()); - - std::vector values_to_update; - for (const HloInstruction* instruction : instructions) { - for (auto& pair : dataflow_analysis().GetInstructionValueSet(instruction)) { - for (const HloValue* value : pair.second.values()) { - values_to_update.push_back(value); - } - } - } - - UpdateBuffersForValues(values_to_update); - - VLOG(3) << "After update:"; - XLA_VLOG_LINES(3, ToString()); -} - -void HloAliasAnalysis::UpdateAfterChangingOperand(HloInstruction* instruction, - HloInstruction* old_operand, - HloInstruction* new_operand) { - VLOG(1) << "UpdateAfterChangingOperand(" << instruction->name() << ", " - << old_operand->name() << " => " << new_operand->name() << ")"; - - dataflow_analysis_->UpdateAfterChangingOperand(instruction, old_operand, - new_operand); - TF_DCHECK_OK(dataflow_analysis_->VerifyAgainstReference()); - - VLOG(4) << "Updated dataflow:"; - XLA_VLOG_LINES(4, dataflow_analysis_->ToString()); - - UpdateAtInstructions({instruction, old_operand, new_operand}); -} - -void HloAliasAnalysis::UpdateAfterChangingRoot(HloInstruction* old_root, - HloInstruction* new_root) { - VLOG(1) << "UpdateAfterChangingRoot(" << old_root->name() << " => " - << new_root->name() << ")"; - - dataflow_analysis_->UpdateAfterChangingRoot(old_root, new_root); - TF_DCHECK_OK(dataflow_analysis_->VerifyAgainstReference()); - - VLOG(4) << "Updated dataflow:"; - XLA_VLOG_LINES(4, dataflow_analysis_->ToString()); - - UpdateAtInstructions({old_root, new_root}); -} - -std::vector HloAliasAnalysis::ComputeAliasedBuffers( - const HloValue& value) { - std::vector aliased_buffers; - - // Value is init of a while (use is while). - for (const HloUse& use : value.uses()) { - VLOG(1) << "use of value " << value.ToShortString() << ": " << use; - if (use.instruction->opcode() == HloOpcode::kWhile) { - // Determine the while value that this shares a buffer with. - const HloValue& while_value = dataflow_analysis().GetUniqueValueAt( - use.instruction, use.operand_index); - aliased_buffers.push_back(&GetBufferContainingValue(while_value)); - VLOG(3) << " value is init value to a while; must share buffer with " - "while value " - << while_value.ToShortString(); - } - } - - // Value is a parameter of a while body/condition. - if (value.defining_instruction()->opcode() == HloOpcode::kParameter) { - const HloComputation* computation = value.defining_instruction()->parent(); - const CallGraphNode& call_graph_node = - dataflow_analysis().call_graph().GetNode(computation); - for (const CallSite& callsite : call_graph_node.caller_callsites()) { - if (callsite.instruction()->opcode() == HloOpcode::kWhile) { - // Call graph must have been flattened. - CHECK_EQ(call_graph_node.caller_callsites().size(), 1); - - const HloValue& while_value = dataflow_analysis().GetUniqueValueAt( - callsite.instruction(), value.defining_index()); - VLOG(3) << " value is parameter value of the body or condition of a " - "while; must share buffer with while value " - << while_value.ToShortString(); - aliased_buffers.push_back(&GetBufferContainingValue(while_value)); - } - } - } - - // Value is the root of a while body. - for (const HloPosition& position : value.positions()) { - const HloComputation* computation = position.instruction->parent(); - const CallGraphNode& call_graph_node = - dataflow_analysis().call_graph().GetNode(computation); - if (position.instruction == computation->root_instruction()) { - for (const CallSite& callsite : call_graph_node.caller_callsites()) { - if (callsite.instruction()->opcode() == HloOpcode::kWhile && - callsite.instruction()->while_body() == computation) { - // Call graph must have been flattened. - CHECK_EQ(call_graph_node.caller_callsites().size(), 1); - - // If the value appears in the root of a while body, then - // necessarily the value is defined in the body as well. - CHECK_EQ(value.defining_instruction()->parent(), computation); - - const HloValue& while_value = dataflow_analysis().GetUniqueValueAt( - callsite.instruction(), position.index); - VLOG(3) << " value is root the body computation of a while; must " - "share buffer with while value " - << while_value.ToShortString(); - aliased_buffers.push_back(&GetBufferContainingValue(while_value)); - } - } - } - } - - // Value is in the while instruction itself. - if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { - VLOG(3) << " value is output of a while instruction"; - aliased_buffers.push_back(&GetUniqueBufferAt(value.defining_instruction(), - value.defining_index())); - } - - // Uniquify aliased buffers. - std::sort(aliased_buffers.begin(), aliased_buffers.end(), - HloBuffer::IdLessThan); - aliased_buffers.erase( - std::unique(aliased_buffers.begin(), aliased_buffers.end()), - aliased_buffers.end()); - - return aliased_buffers; -} - -// This method recomputes the HloBuffer for each of the given HloValues. The -// method does not necessarily update the HloBuffer of values which share a -// buffer with the given values, but are not explicitly passed in -// 'values'. Therefore, the caller must pass in all values which may require an -// update according to the kind of HLO graph change which occurred: operand -// changed (UpdateAfterChangingOperand), or root of computation changed -// (UpdateAfterChangingRoot). -void HloAliasAnalysis::UpdateBuffersForValues( - tensorflow::gtl::ArraySlice values) { - for (const HloValue* value : values) { - VLOG(3) << "Updating buffer for value: " << value->ToShortString(); - - // Gather the set of buffer with aliasing rules (eg, kWhile) which this - // value must be contained in due. - std::vector aliased_buffers = ComputeAliasedBuffers(*value); - - HloBuffer& current_buffer = GetBufferContainingValue(*value); - if (aliased_buffers.empty()) { - // The buffer containing 'value' aliases no other buffers. If the buffer - // containing 'value' already only contains 'value', then no change is - // necessary. If the buffer containing 'value' does contain other values, - // then remove 'value' from the buffer and create a new buffer containing - // only 'value' - if (current_buffer.values().size() == 1) { - CHECK_EQ(current_buffer.values()[0], value); - } else { - MoveValueToNewBuffer(*value); - } - } else { - // If multiple buffers are aliased merge these buffers together into a - // single buffer (arbitrarily chosen as the first buffer in the vector). - if (aliased_buffers.size() > 1) { - for (int64 i = 1; i < aliased_buffers.size(); ++i) { - // Make copy of values vector because MoveValueToBuffer invalidates - // the values iterator. The could be done more efficiently by moving - // all values and once. - std::vector values = aliased_buffers[i]->values(); - for (const HloValue* value : values) { - MoveValueToBuffer(*value, aliased_buffers[0]); - } - } - aliased_buffers.resize(1); - } - - CHECK_EQ(aliased_buffers.size(), 1); - HloBuffer* new_buffer = aliased_buffers[0]; - - if (¤t_buffer != new_buffer) { - MoveValueToBuffer(*value, new_buffer); - } - } - - VLOG(4) << "Analysis after update:"; - XLA_VLOG_LINES(4, ToString()); - } -} - /* static */ StatusOr> HloAliasAnalysis::Run( HloModule* module) { @@ -524,18 +421,28 @@ StatusOr> HloAliasAnalysis::Run( HloDataflowAnalysis::Run(module, /*ssa_form=*/true, /*bitcast_defines_value=*/false)); - alias_analysis->InitializeBufferSets(); - - VLOG(3) << "After initialization:"; - XLA_VLOG_LINES(3, alias_analysis->ToString()); - - std::vector all_values; - for (const HloValue& value : alias_analysis->dataflow_analysis().values()) { - all_values.push_back(&value); + BufferValueMap buffer_map(alias_analysis->dataflow_analysis()); + buffer_map.MergeAliasedBuffers(); + + // Create a vector of HloBuffers, one for each set of values in the + // BufferValueMap. Create the HloBuffers as a vector of contiguously numbered + // buffers. + std::vector sorted_buffer_numbers = + buffer_map.ComputeSortedBufferNumbers(); + alias_analysis->buffers_.reserve(sorted_buffer_numbers.size()); + HloBuffer::Id next_id = 0; + for (BufferValueMap::BufferNumber buffer_number : sorted_buffer_numbers) { + auto& value_set = buffer_map.GetValuesInBuffer(buffer_number); + std::vector sorted_values(value_set.begin(), + value_set.end()); + std::sort(sorted_values.begin(), sorted_values.end(), HloValue::IdLessThan); + alias_analysis->buffers_.emplace_back(next_id++, sorted_values); + for (const HloValue* value : sorted_values) { + alias_analysis->value_to_buffer_[value] = + &alias_analysis->buffers_.back(); + } } - alias_analysis->UpdateBuffersForValues(all_values); - TF_DCHECK_OK(alias_analysis->Verify()); XLA_VLOG_LINES(1, alias_analysis->ToString()); diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index 1b538f6d1cfc21397aaa526b9b21a7a4b7e90940..39554e466488007bfca666b5453ebaa555f598bf 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -74,7 +74,7 @@ class HloAliasAnalysis { // Return a vector of all HloBuffers stabily sorted by HloBuffer::Id. This // vector is lazily computed. Mutating operations on HloAliasAnalysis may // invalidate the underlying vector requiring recomputation. - const std::vector& buffers() const; + const std::vector& buffers() const { return buffers_; } // Returns the underlying dataflow analysis used by this alias analysis. const HloDataflowAnalysis& dataflow_analysis() const { @@ -90,50 +90,13 @@ class HloAliasAnalysis { // output of the given instruction. bool InstructionBuffersAreDistinct(const HloInstruction* instruction) const; - // Updates the analysis after the operands of 'instruction' have changed or if - // 'instruction' has been made the root of a computation. Analysis update is - // not possible if instructions have been added or removed from the graph. - void UpdateAfterChangingOperand(HloInstruction* instruction, - HloInstruction* old_operand, - HloInstruction* new_operand); - void UpdateAfterChangingRoot(HloInstruction* old_root, - HloInstruction* new_root); - // Compare the dataflow analysis against a clean recomputation of the // analysis. Returns an error status if there is a mismatch. Useful for // verifying the correctness after updates to the analysis. Status VerifyAgainstReference() const; protected: - HloAliasAnalysis(HloModule* module); - - // Create a new empty HloBuffer. - HloBuffer& NewHloBuffer(); - - // Move the given value to the given buffer. The value is removed from it's - // current buffer. - void MoveValueToBuffer(const HloValue& value, HloBuffer* buffer); - - // Move the given value to a newly created buffer. The value is removed from - // it's current buffer. - void MoveValueToNewBuffer(const HloValue& value); - - // Construct the initial set of buffer sets where an HloBuffer is created for - // each HloValue in the module. - void InitializeBufferSets(); - - // Compute and return the buffers with aliasing rules (eg, kWhile) which the - // given value must be contained in. - std::vector ComputeAliasedBuffers(const HloValue& value); - - // Recompute the HloBuffers for the given values. - void UpdateBuffersForValues( - tensorflow::gtl::ArraySlice values); - - // Recompute the HloBuffers for all the values which appear in the output of - // the given instructions. - void UpdateAtInstructions( - tensorflow::gtl::ArraySlice instructions); + explicit HloAliasAnalysis(HloModule* module); // Verify various invariants of the alias analysis. Status Verify() const; @@ -143,20 +106,12 @@ class HloAliasAnalysis { // The underlying dataflow analysis used by this alias analysis. std::unique_ptr dataflow_analysis_; - // The map of all HloBuffers in the module. We pass around pointers to the - // mapped HloBuffers, so the underlying container must keep them valid despite - // mutations touching other map entries. - std::unordered_map buffers_; - // A map indicating which buffer a value is contained in. tensorflow::gtl::FlatMap value_to_buffer_; // A lazily constructed vector containing all HloBuffers sorted by // HloBuffer::Id. - mutable std::vector buffers_vector_; - - // The Id to use for the next HloBuffer. - int64 next_buffer_id_ = 0; + std::vector buffers_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index e2815d6e648346ef382bb3d93bc9cad708a216a2..6e311e25fb92f32ae8266bab0c3daad43d2349a3 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -87,14 +87,13 @@ class HloAliasAnalysisTest : public HloTestBase { // constructed. bool AnyValuesInSameBufferInterfere() { DependencyHloOrdering ordering(module_.get()); - for (const HloBuffer* buffer : analysis_->buffers()) { - for (const HloValue* value_a : buffer->values()) { - for (const HloValue* value_b : buffer->values()) { + for (const HloBuffer& buffer : analysis_->buffers()) { + for (const HloValue* value_a : buffer.values()) { + for (const HloValue* value_b : buffer.values()) { if (*value_a != *value_b && - analysis_->dataflow_analysis().MayInterfere(*value_a, *value_b, - ordering)) { + ordering.MayInterfere(*value_a, *value_b)) { VLOG(1) << *value_a << " interferes with " << *value_b - << " in buffer: " << *buffer; + << " in buffer: " << buffer; return true; } } @@ -384,10 +383,7 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) { EXPECT_THAT( GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})), - UnorderedElementsAre(GetValueDefinedAt(xla_while, /*index=*/{0}), - GetValueDefinedAt(body_param, /*index=*/{0}), - GetValueDefinedAt(cond_param, /*index=*/{0}), - GetValueDefinedAt(constant1))); + UnorderedElementsAre(GetValueDefinedAt(constant1))); EXPECT_THAT( GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})), UnorderedElementsAre(GetValueDefinedAt(constant2), @@ -631,9 +627,9 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { // HloBuffers. EXPECT_THAT( analysis.buffers(), - UnorderedElementsAre(&analysis.GetUniqueBufferAt(constant1), - &analysis.GetUniqueBufferAt(tuple, /*index=*/{}), - &analysis.GetUniqueBufferAt(cond_constant))); + UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1), + analysis.GetUniqueBufferAt(tuple, /*index=*/{}), + analysis.GetUniqueBufferAt(cond_constant))); // The tuple elements of the while and the three constant inputs should all be // smooshed into the same buffer. @@ -820,127 +816,5 @@ TEST_F(HloAliasAnalysisTest, Bitcast) { analysis.GetUniqueBufferAt(bitcast)); } -TEST_F(HloAliasAnalysisTest, UpdateAnalysisForWhile) { - // Test updating alias analysis after modifying a module with an array shaped - // while: - // - // body(F32[] %param): - // %negate = Negate(%param) - // - // condition(F32[] %param): - // return Constant(false) - // - // entry: - // %constant = Constant(1.0) - // %exp = Exp(%constant) - // return While(%exp, body, condition) - // - auto body_builder = HloComputation::Builder("body"); - auto body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); - auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( - scalar_shape_, HloOpcode::kNegate, body_param)); - HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); - - // Condition computation trivially returns a constant "false". - auto cond_builder = HloComputation::Builder("condition"); - auto cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); - cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); - HloComputation* condition = - module_->AddEmbeddedComputation(cond_builder.Build()); - - auto builder = HloComputation::Builder(TestName()); - auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant)); - auto xla_while = builder.AddInstruction( - HloInstruction::CreateWhile(scalar_shape_, condition, body, exp)); - module_->AddEntryComputation(builder.Build()); - - HloAliasAnalysis& analysis = RunAnalysis(); - - // Sanity check some alias information. - EXPECT_EQ(analysis.GetUniqueBufferAt(exp), - analysis.GetUniqueBufferAt(body_param)); - EXPECT_EQ(analysis.GetUniqueBufferAt(exp), - analysis.GetUniqueBufferAt(cond_param)); - EXPECT_EQ(analysis.GetUniqueBufferAt(exp), - analysis.GetUniqueBufferAt(negate)); - EXPECT_EQ(analysis.GetUniqueBufferAt(exp), - analysis.GetUniqueBufferAt(xla_while)); - - // Set the body root to the body_param. Previously it was Negate(body_param). - body->set_root_instruction(body_param); - - // Prior to updating, verify that the analysis is no longer valid. - Status verify_status = analysis.VerifyAgainstReference(); - EXPECT_FALSE(verify_status.ok()); - - analysis.UpdateAfterChangingRoot(/*old_root=*/negate, - /*new_root*/ body_param); - - // Analysis should be valid after the update. - TF_ASSERT_OK(analysis.VerifyAgainstReference()); - - // The exponential should now pass through the body transparently. - EXPECT_EQ(analysis.GetUniqueBufferAt(exp), - analysis.GetUniqueBufferAt(body_param)); - EXPECT_EQ(analysis.GetUniqueBufferAt(exp), - analysis.GetUniqueBufferAt(cond_param)); - EXPECT_NE(analysis.GetUniqueBufferAt(exp), - analysis.GetUniqueBufferAt(negate)); - EXPECT_EQ(analysis.GetUniqueBufferAt(exp), - analysis.GetUniqueBufferAt(xla_while)); - - // Now replace the operand of the while with %constant (was %exp). - TF_ASSERT_OK(exp->ReplaceUseWith(xla_while, constant)); - analysis.UpdateAfterChangingOperand(xla_while, /*old_operand=*/exp, - /*new_operand=*/constant); - - // Analysis should be valid after the update. - TF_ASSERT_OK(analysis.VerifyAgainstReference()); - - EXPECT_EQ(analysis.GetUniqueBufferAt(constant), - analysis.GetUniqueBufferAt(body_param)); - EXPECT_EQ(analysis.GetUniqueBufferAt(constant), - analysis.GetUniqueBufferAt(cond_param)); - EXPECT_EQ(analysis.GetUniqueBufferAt(constant), - analysis.GetUniqueBufferAt(xla_while)); - EXPECT_NE(analysis.GetUniqueBufferAt(constant), - analysis.GetUniqueBufferAt(exp)); - EXPECT_NE(analysis.GetUniqueBufferAt(constant), - analysis.GetUniqueBufferAt(negate)); - - // And finally make the negate the root of the body again. - body->set_root_instruction(negate); - analysis.UpdateAfterChangingRoot(/*old_root=*/body_param, - /*new_root*/ negate); - - // Analysis should be valid after the update. - TF_ASSERT_OK(analysis.VerifyAgainstReference()); - - EXPECT_EQ(analysis.GetUniqueBufferAt(negate), - analysis.GetUniqueBufferAt(body_param)); - EXPECT_EQ(analysis.GetUniqueBufferAt(negate), - analysis.GetUniqueBufferAt(cond_param)); - EXPECT_EQ(analysis.GetUniqueBufferAt(negate), - analysis.GetUniqueBufferAt(xla_while)); - EXPECT_EQ(analysis.GetUniqueBufferAt(constant), - analysis.GetUniqueBufferAt(negate)); - - auto value_of = [&analysis](const HloInstruction* instruction) { - return &analysis.dataflow_analysis().GetValueDefinedAt(instruction); - }; - EXPECT_THAT(analysis.GetUniqueBufferAt(negate).values(), - UnorderedElementsAre(value_of(body_param), value_of(cond_param), - value_of(negate), value_of(constant), - value_of(xla_while))); -} - -// Test update tuple element. - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc index 2bfdd9156adf30ac39f7736a3f4e8103e9b8dc47..e16413f361fb0216792b47c3c67ef3c1357c2221 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.cc +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -36,22 +36,6 @@ namespace xla { using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrCat; -void HloBuffer::AddValue(const HloValue& value) { - values_.push_back(&value); - // Sort vector and remove duplicates. - std::sort(values_.begin(), values_.end(), HloValue::IdLessThan); - values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual), - values_.end()); -} - -void HloBuffer::RemoveValue(const HloValue& value) { - // The values are sorted, so finding the value could be done in log(n) time - // with a binary search. - auto it = std::find(values_.begin(), values_.end(), &value); - CHECK(it != values_.end()); - values_.erase(it); -} - bool HloBuffer::operator==(const HloBuffer& other) const { bool equal = id() == other.id(); if (equal) { diff --git a/tensorflow/compiler/xla/service/hlo_buffer.h b/tensorflow/compiler/xla/service/hlo_buffer.h index cb961e1601c904a65afa5fc250d1116a96dc4384..4873463b2ea4fee3ee39dff31fc3429a4998142f 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.h +++ b/tensorflow/compiler/xla/service/hlo_buffer.h @@ -84,22 +84,15 @@ class HloBuffer { return a->id() == b->id(); } - HloBuffer(Id id) : id_(id) {} + HloBuffer(Id id, tensorflow::gtl::ArraySlice values) + : id_(id), values_(values.begin(), values.end()) {} // Return the unique identifier for this HloBuffer. Id id() const { return id_; } - // Add a value to the set of values held by this buffer. Also adds the - // HloPositions of the value to the positions vector of the buffer. If the - // buffer already contains this value, then this method is a nop. - void AddValue(const HloValue& value); - void RemoveValue(const HloValue& value); - // Return all values contained in this buffer. const std::vector& values() const { return values_; } - std::vector ComputePositions() const; - // Return the unique HLO value in the buffer. CHECK fails if the buffer does // not contain exactly one value. const HloValue& GetUniqueValue() const { @@ -107,6 +100,8 @@ class HloBuffer { return *values_[0]; } + std::vector ComputePositions() const; + string ToString() const; bool operator==(const HloBuffer& other) const; @@ -118,7 +113,7 @@ class HloBuffer { // The set of values contained in this buffer. Vector contains no duplicates // and is sorted stably by HloValue::Id. - std::vector values_; + const std::vector values_; }; std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer); diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index b8133cda30e4859bc13682d6e7ac4dbca23458a7..2d077846196bdaf5183f6ee43ab582ede4ef4f52 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -58,16 +58,16 @@ std::unique_ptr HloComputation::Builder::Build( CHECK_NE(nullptr, root); return WrapUnique(new HloComputation(name_, parameter_count, &instructions_, - root, is_fusion_computation_)); + root, fusion_instruction_)); } HloComputation::HloComputation( const string& name, int parameter_count, std::vector>* instructions, - HloInstruction* root_instruction, bool is_fusion_computation) + HloInstruction* root_instruction, HloInstruction* fusion_instruction) : name_(name), root_instruction_(root_instruction), - is_fusion_computation_(is_fusion_computation) { + fusion_instruction_(fusion_instruction) { param_instructions_.resize(parameter_count, nullptr); bool root_found = false; for (auto& instruction : *instructions) { @@ -112,11 +112,8 @@ HloInstruction* HloComputation::AddInstructionInternal( HloInstruction* HloComputation::AddParameter( std::unique_ptr instruction) { CHECK(instruction->opcode() == HloOpcode::kParameter); - CHECK(is_fusion_computation_); - CHECK(root_instruction_->fusion_instruction() != nullptr); - instruction->SetParentFusion(root_instruction_->fusion_instruction()); - CHECK(root_instruction_->fusion_instruction()->operand_count() == - param_instructions_.size()); + CHECK(IsFusionComputation()); + CHECK(fusion_instruction_->operand_count() == param_instructions_.size()); instruction->set_parent(this); param_instructions_.push_back(instruction.get()); AddInstructionInternal(std::move(instruction)); @@ -126,8 +123,7 @@ HloInstruction* HloComputation::AddParameter( Status HloComputation::RemoveParameter(int64 param_no) { CHECK_GE(param_no, 0); CHECK_LT(param_no, param_instructions_.size()); - CHECK(is_fusion_computation_); - CHECK(root_instruction_->fusion_instruction() != nullptr); + CHECK(IsFusionComputation()); HloInstruction* param_instruction = param_instructions_[param_no]; auto param_instruction_iterator = param_instructions_.begin() + param_no; param_instructions_.erase(param_instruction_iterator); @@ -155,7 +151,6 @@ Status HloComputation::RemoveParameter(int64 param_no) { AddInstructionInternal(HloInstruction::CreateParameter( param_no, param_instruction->shape(), param_name)); TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); - new_instr->SetParentFusion(root_instruction_->fusion_instruction()); param_instructions_[param_no] = new_instr; TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); param_no++; @@ -178,7 +173,7 @@ bool HloComputation::IsRemovable(const HloInstruction* instruction) { } if (instruction->opcode() == HloOpcode::kParameter && - !is_fusion_computation_) { + !IsFusionComputation()) { return false; } @@ -263,7 +258,7 @@ void HloComputation::set_root_instruction( HloInstruction* new_root_instruction) { // The shape of the root (ignoring layout) is an invariant of the computation // for non-fusion cases. - if (!is_fusion_computation_) { + if (!IsFusionComputation()) { CHECK(ShapeUtil::Compatible(new_root_instruction->shape(), root_instruction_->shape())) << new_root_instruction->shape().ShortDebugString() diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index f383a17fb85a22b0cb20efa605bb213859862f6b..576c44a9f344160fd6184bf2bd590044676a27d6 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -56,10 +56,11 @@ class HloComputation { // Builder class for HloComputation. class Builder { public: - explicit Builder(const string& name, bool is_fusion_computation = false) + explicit Builder(const string& name, + HloInstruction* fusion_instruction = nullptr) : name_(name), last_added_instruction_(nullptr), - is_fusion_computation_(is_fusion_computation) {} + fusion_instruction_(fusion_instruction) {} // Build and return an HloComputation. The parameter root_instruction // specifies the already-added instruction to use as the root. If @@ -78,7 +79,7 @@ class HloComputation { private: const string name_; HloInstruction* last_added_instruction_; - bool is_fusion_computation_; + HloInstruction* fusion_instruction_; std::vector> instructions_; }; @@ -274,13 +275,18 @@ class HloComputation { bool HasSideEffect() const; // Returns if this computation is a fusion computation. - bool IsFusionComputation() const { return is_fusion_computation_; } + bool IsFusionComputation() const { return fusion_instruction_ != nullptr; } + + // Returns the owning fusion instruction, or nullptr if this is not a fusion + // computation. + HloInstruction* FusionInstruction() const { return fusion_instruction_; } private: explicit HloComputation( const string& name, int parameter_count, std::vector>* instructions, - HloInstruction* root_instruction, bool is_fusion_computation = false); + HloInstruction* root_instruction, + HloInstruction* fusion_instruction = nullptr); // Internal helper for adding instructions. HloInstruction* AddInstructionInternal( @@ -309,8 +315,9 @@ class HloComputation { string name_; HloInstruction* root_instruction_; - // A tag shows if this is a fusion computation. - bool is_fusion_computation_; + // If this computation is a fusion computation, this field points to the + // corresponding fusion instruction. Otherwise, this is null. + HloInstruction* fusion_instruction_; // Module containing this computation. HloModule* parent_ = nullptr; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 9dbde0ec243e81b1aa9eb1ed64ab16b88047f7fa..f6b764732b495a1b60bd7dac114ee99bc70bd1b6 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -118,13 +118,11 @@ Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) { } } -Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo, - HloOpcode opcode) { +Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo) { return HandleElementwiseOp(hlo); } -Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo, - HloOpcode opcode) { +Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo) { return HandleElementwiseOp(hlo); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 6d8fdfa64b5846edbea168816dff4e463c2a1027..eeb3d4edd1be3bb0204d37e3e6591058a687712e 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -49,9 +49,8 @@ class HloCostAnalysis : public DfsHloVisitor { using ShapeSizeFunction = std::function; explicit HloCostAnalysis(const ShapeSizeFunction& shape_size); - Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode) override; - Status HandleElementwiseBinary(HloInstruction* hlo, - HloOpcode opcode) override; + Status HandleElementwiseUnary(HloInstruction* hlo) override; + Status HandleElementwiseBinary(HloInstruction* hlo) override; Status HandleConstant(HloInstruction* constant, const Literal& literal) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element, diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index ea8b239e100fec0ad93c4d76cc4ce44c4b7fda03..2be1645f1b05dc5824faf7f485c3619716726d77 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -67,6 +67,22 @@ HloValue& HloDataflowAnalysis::GetValueDefinedAt( return GetUniqueValueAt(instruction, index); } +HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction, + const ShapeIndex& index, + bool is_phi) { + const int64 value_id = next_value_id_++; + auto emplaced = values_.emplace( + std::piecewise_construct, std::forward_as_tuple(value_id), + std::forward_as_tuple(value_id, instruction, index, is_phi)); + CHECK(emplaced.second); + + return &emplaced.first->second; +} + +void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) { + values_.erase(value_id); +} + string HloDataflowAnalysis::ToString() const { string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n"); StrAppend(&out, " Instruction value sets:\n"); @@ -99,20 +115,96 @@ string HloDataflowAnalysis::ToString() const { } } StrAppend(&out, " HloValues:\n"); - for (const HloValue& value : values()) { - StrAppend(&out, value.ToString(/*indent=*/4)); + for (const HloValue* value : values()) { + StrAppend(&out, value->ToString(/*indent=*/4)); + } + return out; +} + +bool HloDataflowAnalysis::Phi( + HloInstruction* instruction, + tensorflow::gtl::ArraySlice inputs) { + CHECK(ssa_form_); + + for (const InstructionValueSet* input : inputs) { + DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape())); } - StrAppend(&out, " Phi resolutions:\n"); - for (const HloValue& value : values()) { - if (value.is_phi()) { - const HloValue* resolved_value = ResolvePhi(value); - StrAppend(&out, " ", value.ToShortString(), " => ", - resolved_value == nullptr ? "UNKNOWN" - : resolved_value->ToShortString(), - "\n"); + + bool changed = false; + for (auto& pair : GetInstructionValueSet(instruction)) { + const ShapeIndex& index = pair.first; + HloValueSet& value_set = pair.second; + + // Positions with phi values should never have more than one value in the + // value set. + CHECK_LE(value_set.values().size(), 1); + const HloValue* current_value = + value_set.values().size() == 1 ? value_set.values()[0] : nullptr; + + // Construct a vector of unique value IDs of the inputs. + std::vector input_value_ids; + for (const InstructionValueSet* input : inputs) { + for (const HloValue* value : input->element(index).values()) { + input_value_ids.push_back(value->id()); + } + } + std::sort(input_value_ids.begin(), input_value_ids.end()); + input_value_ids.erase( + std::unique(input_value_ids.begin(), input_value_ids.end()), + input_value_ids.end()); + + // Remove the existing phi value (if it exists). The phi can be its own + // input, for example, in while body parameters where the body passes + // through the parameter value. + bool current_value_defined_here = + (current_value != nullptr && + current_value->defining_instruction() == instruction && + current_value->defining_index() == index); + if (current_value_defined_here) { + CHECK(current_value->is_phi()); + auto it = std::find(input_value_ids.begin(), input_value_ids.end(), + current_value->id()); + if (it != input_value_ids.end()) { + input_value_ids.erase(it); + } + } + + if (input_value_ids.empty()) { + // A value set which has at least one element should never have its value + // set reduced to zero elements. During dataflow value sets only can go + // from empty to non-empty, not the reverse. + CHECK_EQ(value_set.values().size(), 0) + << "Instruction " << instruction->name() << " at index " << index + << " previously had non-empty value set. Value set: " << value_set; + } else if (input_value_ids.size() == 1) { + // Only a single value reaches this point. There should be no phi, and + // this value set should contain this single value. + const HloValue& new_value = GetValue(input_value_ids[0]); + if (current_value == nullptr) { + value_set.Clear(); + value_set.AddValue(&new_value); + changed = true; + } else if (current_value != &new_value) { + if (current_value_defined_here) { + // Remove the existing phi. + DeleteHloValue(current_value->id()); + } + value_set.Clear(); + value_set.AddValue(&new_value); + changed = true; + } + } else { + // Multiple distinct values reach this point. A phi value is + // necessary. + CHECK_GT(input_value_ids.size(), 1); + if (current_value == nullptr || !current_value->is_phi()) { + value_set.Clear(); + value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true)); + changed = true; + } } } - return out; + return changed; } const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const { @@ -142,129 +234,6 @@ HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) { return GetValueSet(position.instruction, position.index); } -void HloDataflowAnalysis::UpdateAfterChangingOperand( - HloInstruction* instruction, HloInstruction* old_operand, - HloInstruction* new_operand) { - CHECK(std::find(instruction->operands().begin(), - instruction->operands().end(), - new_operand) != instruction->operands().end()); - VLOG(1) << "UpdateAfterChangingOperand(" << instruction->name() << ", " - << old_operand->name() << " => " << new_operand->name() << ")"; - - std::vector to_update = {instruction}; - - // If the instruction calls any computations then add the parameters of called - // computation to capture any changes to the dataflow into the subcomputation - // introduced by the new operand. - for (HloComputation* computation : instruction->called_computations()) { - to_update.insert(to_update.end(), - computation->parameter_instructions().begin(), - computation->parameter_instructions().end()); - } - - UpdateInstructionsAndPropagate(to_update); - - // The uses of the values in the old and new operand may have changed. Uses of - // other HloValues are updated in UpdateInstructionsAndPropagate. - for (auto& pair : GetInstructionValueSet(old_operand)) { - for (const HloValue* value : pair.second.values()) { - GetValue(value->id()).RecomputeUses(); - } - } - for (auto& pair : GetInstructionValueSet(new_operand)) { - for (const HloValue* value : pair.second.values()) { - GetValue(value->id()).RecomputeUses(); - } - } - - TF_DCHECK_OK(VerifyAgainstReference()); -} - -void HloDataflowAnalysis::UpdateAfterChangingRoot(HloInstruction* old_root, - HloInstruction* new_root) { - VLOG(1) << "UpdateAfterChangingRoot(" << old_root->name() << " => " - << new_root->name() << ")"; - - CHECK_EQ(new_root, new_root->parent()->root_instruction()); - CHECK_EQ(new_root->parent(), old_root->parent()); - - std::vector to_update = {old_root, new_root}; - - const CallGraphNode& call_graph_node = - call_graph_->GetNode(new_root->parent()); - for (const CallSite& callsite : call_graph_node.caller_callsites()) { - if (callsite.instruction()->opcode() == HloOpcode::kCall) { - to_update.push_back(callsite.instruction()); - } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { - // Add the while itself, and the body and condition parameters. - to_update.push_back(callsite.instruction()); - to_update.push_back( - callsite.instruction()->while_body()->parameter_instruction(0)); - to_update.push_back( - callsite.instruction()->while_condition()->parameter_instruction(0)); - } - } - - UpdateInstructionsAndPropagate(to_update); - - TF_DCHECK_OK(VerifyAgainstReference()); -} - -const HloValue* HloDataflowAnalysis::ResolvePhi(const HloValue& phi) const { - CHECK(phi.is_phi()); - - tensorflow::gtl::FlatSet visited; - std::queue worklist; - auto add_to_worklist = [&worklist, &visited](const HloValue* v) { - if (visited.insert(v).second) { - // 'v' was not previously in visited. - worklist.push(v); - } - }; - add_to_worklist(&phi); - - const HloValue* resolved_value = nullptr; - while (!worklist.empty()) { - const HloValue* value = worklist.front(); - worklist.pop(); - - if (!value->is_phi()) { - if (resolved_value == nullptr) { - resolved_value = value; - } else if (resolved_value != value) { - return nullptr; - } - } else { - for (const HloValue* input : phi_inputs_.at(value)) { - add_to_worklist(input); - } - } - } - return resolved_value; -} - -void HloDataflowAnalysis::UpdatePhiInputs( - const HloInstruction* instruction, - tensorflow::gtl::ArraySlice inputs) { - CHECK(ssa_form_); - for (auto& pair : GetInstructionValueSet(instruction)) { - const ShapeIndex& index = pair.first; - const HloValue& phi_value = GetUniqueValueAt(instruction, index); - auto& phi_inputs = phi_inputs_.at(&phi_value); - phi_inputs.clear(); - for (const InstructionValueSet* input : inputs) { - for (const HloValue* value : input->element(index).values()) { - // The number of phi inputs is typically 2, and virtually always very - // small. - if (std::find(phi_inputs.begin(), phi_inputs.end(), value) == - phi_inputs.end()) { - phi_inputs.push_back(value); - } - } - } - } -} - bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) { CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast); const InstructionValueSet& operand_set = @@ -380,8 +349,7 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { } if (ssa_form_ && called_from_while) { - UpdatePhiInputs(parameter, inputs); - return false; + return Phi(parameter, inputs); } else { return GetInstructionValueSet(parameter).AssignUnionOf(inputs); } @@ -439,8 +407,7 @@ bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) { &GetInstructionValueSet(xla_while->while_body()->root_instruction()), &GetInstructionValueSet(xla_while->operand(0))}; if (ssa_form_) { - UpdatePhiInputs(xla_while, inputs); - return false; + return Phi(xla_while, inputs); } else { return GetInstructionValueSet(xla_while).AssignUnionOf(inputs); } @@ -487,38 +454,7 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate( VLOG(3) << "Worklist top: " << instruction->name(); VLOG(3) << ToString(); - // The updating of the instruction value set below in - // UpdateInstructionValueSet does not update HloValue::positions(). To - // perform the positions() update remove all positions in 'instruction' from - // the HloValues in 'instruction's value set prior to the update, then after - // the update add the new positions back in. There is likely a more - // efficient way of doing this. - for (auto& pair : GetInstructionValueSet(instruction)) { - const ShapeIndex& index = pair.first; - HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (value->defining_instruction() != instruction) { - // Use GetValue for a non-const HloValue reference. - GetValue(value->id()).RemovePosition(instruction, index); - } - } - } - - bool changed = UpdateInstructionValueSet(instruction); - - // Add the positions back in. - for (auto& pair : GetInstructionValueSet(instruction)) { - const ShapeIndex& index = pair.first; - HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (value->defining_instruction() != instruction) { - // Use GetValue for a non-const HloValue reference. - GetValue(value->id()).AddPosition(instruction, index); - } - } - } - - if (!changed) { + if (!UpdateInstructionValueSet(instruction)) { // No change to the instruction's value set. VLOG(4) << "No change."; continue; @@ -531,12 +467,16 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate( for (HloInstruction* user : instruction->users()) { worklist.push(user); - // If user calls a computation, then the respective parameter(s) of the - // computation need to be updated. + // If user sequentially calls a computation, then the respective + // parameter(s) of the computation need to be updated. for (HloComputation* called_computation : user->called_computations()) { - for (int64 operand_number : user->OperandIndices(instruction)) { - worklist.push( - called_computation->parameter_instruction(operand_number)); + const CallGraphNode& call_graph_node = + call_graph_->GetNode(called_computation); + if (call_graph_node.context() == CallContext::kSequential) { + for (int64 operand_number : user->OperandIndices(instruction)) { + worklist.push( + called_computation->parameter_instruction(operand_number)); + } } } } @@ -574,25 +514,10 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet( } Status HloDataflowAnalysis::InitializeInstructionValueSets() { - // Gather the values to create before creating them. This is done because we - // want to allocate the vector of values only once so references to elements - // are stable. - struct ValueToCreate { - HloInstruction* instruction; - ShapeIndex index; - bool is_phi; - }; - std::vector values_to_create; - for (const std::unique_ptr& computation : module_->computations()) { const CallGraphNode& call_graph_node = call_graph_->GetNode(computation.get()); - bool called_from_while = std::any_of( - call_graph_node.caller_callsites().begin(), - call_graph_node.caller_callsites().end(), [](const CallSite& cs) { - return cs.instruction()->opcode() == HloOpcode::kWhile; - }); for (const std::unique_ptr& instruction : computation->instructions()) { @@ -603,20 +528,22 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { // Lambda to set the value set to define all values in the output of the // instruction. - auto define_all_values = [this, &instruction, - &values_to_create](bool is_phi = false) { + auto define_all_values = [this, &instruction](bool is_phi = false) { for (auto& pair : GetInstructionValueSet(instruction.get())) { const ShapeIndex& index = pair.first; - values_to_create.push_back({instruction.get(), index, is_phi}); + HloValue* value = + NewHloValue(instruction.get(), index, /*is_phi=*/false); + GetValueSet(instruction.get(), index).AddValue(value); } }; // Lambda to set the value set to define only the top-level buffer in the // output of the instruction. Any other values flow from the operands of // the instruction (or from cross-computation dataflow). - auto define_top_level_only = [this, &instruction, &values_to_create]() { - values_to_create.push_back( - {instruction.get(), /*index=*/{}, /*is_phi=*/false}); + auto define_top_level_only = [this, &instruction]() { + HloValue* value = + NewHloValue(instruction.get(), /*index=*/{}, /*is_phi=*/false); + GetValueSet(instruction.get(), /*index=*/{}).AddValue(value); }; switch (instruction->opcode()) { @@ -626,10 +553,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { } break; case HloOpcode::kWhile: - if (ssa_form_) { - define_all_values(/*is_phi=*/true); - } - break; case HloOpcode::kCall: case HloOpcode::kGetTupleElement: // These instructions define no values. The values in their output @@ -654,10 +577,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { // values in their output. Otherwise the values of the parameter // come from the caller (eg, operands to the kCall instruction). define_all_values(); - } else if (call_graph_node.context() == CallContext::kSequential && - called_from_while && ssa_form_) { - // Parameters of while bodies and conditions are phis. - define_all_values(/*is_phi=*/true); } break; case HloOpcode::kCopy: @@ -674,164 +593,9 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { } } - // Reserve the vector ahead of time so references to elements are stable. - values_.reserve(values_to_create.size()); - for (int64 i = 0; i < values_to_create.size(); ++i) { - const ValueToCreate& to_create = values_to_create[i]; - values_.emplace_back(/*id=*/i, to_create.instruction, to_create.index, - to_create.is_phi); - const HloValue& value = values_.back(); - GetValueSet(to_create.instruction, to_create.index).AddValue(&value); - if (value.is_phi()) { - phi_inputs_[&value] = {}; - } - } return Status::OK(); } -bool HloDataflowAnalysis::IsDefinedBefore(const HloValue& a, const HloValue& b, - const HloOrdering& ordering) const { - // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b' - // is live into the module. - if (b.defining_instruction()->parent() == module_->entry_computation() && - b.defining_instruction()->opcode() == HloOpcode::kParameter) { - return false; - } - - // Phi values require special handling. Because XLA does not have a phi - // instruction, the definition instruction of the phis values are - // placeholders: either the subcomputation parameter (body or condition) or - // the while instruction. However, the program point where these values are - // logically defined does not necessarily coincide exactly with program point - // of these place-holder instructions. So we explicitly define the following - // order for phi values: - // - // body/condition parameter phi: - // Defined before all values defined in its computation excepting other - // phis. - // - // while phi: - // defined after all values defined in the condition or body. - // - auto is_body_or_condition_phi = [](const HloValue& v) { - return v.is_phi() && - v.defining_instruction()->opcode() == HloOpcode::kParameter; - }; - if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) && - call_graph_->InstructionIsNestedIn(b.defining_instruction(), - a.defining_instruction()->parent())) { - return true; - } - if (is_body_or_condition_phi(b) && - call_graph_->InstructionIsNestedIn(a.defining_instruction(), - b.defining_instruction()->parent())) { - return false; - } - - // If 'b' is a while phi and 'a' is in the body or condition, then 'a' - // executes before 'b'. - if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile && - (call_graph_->InstructionIsNestedIn( - a.defining_instruction(), b.defining_instruction()->while_body()) || - call_graph_->InstructionIsNestedIn( - a.defining_instruction(), - b.defining_instruction()->while_condition()))) { - return true; - } - - return ordering.ExecutesBefore(a.defining_instruction(), - b.defining_instruction()); -} - -bool HloDataflowAnalysis::UseIsBeforeValueDefinition( - const HloUse& use, const HloValue& value, - const HloOrdering& ordering) const { - if (ordering.ExecutesBefore(use.instruction, value.defining_instruction())) { - return true; - } - - // If the use is at the instruction where the value is defined, then the use - // is before the def if the instruction allows buffer sharing (in place - // computation). - if (use.instruction == value.defining_instruction() && - CanShareOperandBufferWithUser( - use.instruction->mutable_operand(use.operand_number), - use.operand_index, value.defining_instruction(), - value.defining_index())) { - return true; - } - - // The use at a while is an input to a phi, and logically occurs before values - // are defined in the body or condition computations. - if (use.instruction->opcode() == HloOpcode::kWhile) { - const HloInstruction* xla_while = use.instruction; - if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), - xla_while->while_body()) || - call_graph_->InstructionIsNestedIn(value.defining_instruction(), - xla_while->while_condition())) { - return true; - } - } - - // Similarly if the value is defined at a while, it logically occurs after any - // uses in the body or condition computations. - if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { - CHECK(ssa_form_); - const HloInstruction* xla_while = value.defining_instruction(); - if (call_graph_->InstructionIsNestedIn(use.instruction, - xla_while->while_body()) || - call_graph_->InstructionIsNestedIn(use.instruction, - xla_while->while_condition())) { - return true; - } - } - return false; -} - -bool HloDataflowAnalysis::LiveRangeStrictlyBefore( - const HloValue& a, const HloValue& b, const HloOrdering& ordering) const { - VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString() - << ", b = " << b.ToShortString() << ")"; - if (!IsDefinedBefore(a, b, ordering)) { - VLOG(4) << "a not defined before b"; - return false; - } - - // Live-out values from the module can never have ranges strictly before any - // other value. - if (a.live_out_of_module()) { - VLOG(4) << "a is live out of module"; - return false; - } - - // Live-out values of computations can never have ranges strictly before any - // other value in the computation (including values nested in - // subcomputations). - if (a.live_out_of_computation() && - call_graph_->InstructionIsNestedIn(b.defining_instruction(), - a.defining_instruction()->parent())) { - VLOG(4) << "a is live out of computation containing b"; - return false; - } - - // All uses of 'a' must be before 'b' is defined. - for (const HloUse& use : a.uses()) { - if (!UseIsBeforeValueDefinition(use, b, ordering)) { - VLOG(4) << "use of a (" << use << ") not before b is defined"; - return false; - } - } - - return true; -} - -bool HloDataflowAnalysis::MayInterfere(const HloValue& a, const HloValue& b, - const HloOrdering& ordering) const { - // Buffers without disjoint liveness may interfere. - return !LiveRangeStrictlyBefore(a, b, ordering) && - !LiveRangeStrictlyBefore(b, a, ordering); -} - /* static */ StatusOr> HloDataflowAnalysis::Run( HloModule* module, bool ssa_form, bool bitcast_defines_value) { @@ -855,6 +619,33 @@ StatusOr> HloDataflowAnalysis::Run( } dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions); + // Add in positions to all values. + for (const std::unique_ptr& computation : + module->computations()) { + for (const std::unique_ptr& instruction : + computation->instructions()) { + for (const auto& pair : + dataflow_analysis->GetInstructionValueSet(instruction.get())) { + const ShapeIndex& index = pair.first; + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + if (value->defining_instruction() != instruction.get()) { + dataflow_analysis->GetValue(value->id()) + .AddPosition(instruction.get(), index); + } + } + } + } + } + + // Construct vector of values. + dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size()); + for (auto& pair : dataflow_analysis->values_) { + dataflow_analysis->values_vector_.push_back(&pair.second); + } + std::sort(dataflow_analysis->values_vector_.begin(), + dataflow_analysis->values_vector_.end(), HloValue::IdLessThan); + TF_DCHECK_OK(dataflow_analysis->Verify()); XLA_VLOG_LINES(1, dataflow_analysis->ToString()); @@ -865,14 +656,14 @@ StatusOr> HloDataflowAnalysis::Run( Status HloDataflowAnalysis::Verify() const { // Verify each HloValue appears in the value sets that the value's positions() // indicate. - for (const HloValue& value : values()) { - for (const HloPosition& position : value.positions()) { + for (const HloValue* value : values()) { + for (const HloPosition& position : value->positions()) { const HloValueSet& value_set = GetValueSet(position); TF_RET_CHECK(std::find(value_set.values().begin(), value_set.values().end(), - &value) != value_set.values().end()) + value) != value_set.values().end()) << "Value set at position " << position << " does not contain value " - << value.ToShortString(); + << value->ToShortString(); } } @@ -898,75 +689,4 @@ Status HloDataflowAnalysis::Verify() const { return Status::OK(); } -Status HloDataflowAnalysis::VerifyAgainstReference() const { - TF_RETURN_IF_ERROR(Verify()); - - TF_ASSIGN_OR_RETURN(std::unique_ptr reference, - Run(module_, ssa_form_, bitcast_defines_value_)); - TF_RETURN_IF_ERROR(reference->Verify()); - - VLOG(2) << "This analysis:"; - XLA_VLOG_LINES(2, ToString()); - VLOG(2) << "Reference:"; - XLA_VLOG_LINES(2, reference->ToString()); - - // Verify value sets in each position are identical. - for (const auto& computation : module_->computations()) { - for (const auto& instruction : computation->instructions()) { - for (const auto& pair : GetInstructionValueSet(instruction.get())) { - const ShapeIndex& index = pair.first; - const HloValueSet& value_set = pair.second; - const HloValueSet& reference_value_set = - reference->GetValueSet(instruction.get(), index); - - auto value_in_set = [](const HloValue& v, const HloValueSet& vset) { - return std::find_if(vset.values().begin(), vset.values().end(), - [&v](const HloValue* w) { return *w == v; }) != - vset.values().end(); - }; - - for (const HloValue* value : value_set.values()) { - TF_RET_CHECK(value_in_set(*value, reference_value_set)) - << "Value " << value->ToShortString() - << " does not exist in reference"; - } - for (const HloValue* reference_value : reference_value_set.values()) { - TF_RET_CHECK(value_in_set(*reference_value, value_set)) - << "Value " << reference_value->ToShortString() - << " only exists in reference"; - } - } - } - } - - // Verify all phis resolve identically and uses are identical. - for (const HloValue& value : values()) { - const HloValue& reference_value = reference->GetValueDefinedAt( - value.defining_instruction(), value.defining_index()); - TF_RET_CHECK(value.is_phi() == reference_value.is_phi()); - if (value.is_phi()) { - const HloValue* resolved_value = ResolvePhi(value); - const HloValue* reference_resolved_value = - reference->ResolvePhi(reference_value); - if (resolved_value == nullptr) { - TF_RET_CHECK(reference_resolved_value == nullptr); - } else { - TF_RET_CHECK(reference_resolved_value != nullptr); - TF_RET_CHECK(*reference_resolved_value == *resolved_value); - } - } - - for (const HloUse& use : value.uses()) { - TF_RET_CHECK(std::find(reference_value.uses().begin(), - reference_value.uses().end(), - use) != reference_value.uses().end()); - } - for (const HloUse& reference_use : reference_value.uses()) { - TF_RET_CHECK(std::find(value.uses().begin(), value.uses().end(), - reference_use) != value.uses().end()); - } - } - return Status::OK(); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 7781cc58a3a3daeb3cd95f28c5de43baa5803089..aae257dd09e8ee37e040b8c7b673059355615ed4 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -88,10 +88,10 @@ class HloDataflowAnalysis { // given position. const HloValueSet& GetValueSet(const HloInstruction* instruction, const ShapeIndex& index = {}) const; - HloValueSet& GetValueSet(const HloInstruction* instruction, - const ShapeIndex& index = {}); const HloValueSet& GetValueSet(const HloPosition& position) const; HloValueSet& GetValueSet(const HloPosition& position); + HloValueSet& GetValueSet(const HloInstruction* instruction, + const ShapeIndex& index = {}); // Return the unique value in the HloValueSet at the given instruction and // shape index. CHECKs if the value set does not contain a exactly one value. @@ -108,49 +108,11 @@ class HloDataflowAnalysis { const HloValue& GetValue(HloValue::Id value_id) const; HloValue& GetValue(HloValue::Id value_id); - // Returns whether the given values interfere assuming the given HLO - // ordering. Two values interfere if they may both be simultaneously live. - bool MayInterfere(const HloValue& a, const HloValue& b, - const HloOrdering& ordering) const; - - // Overload which takes HloValue:Ids. - bool MayInterfere(HloValue::Id a, HloValue::Id b, - const HloOrdering& ordering) const { - return MayInterfere(GetValue(a), GetValue(b), ordering); - } - // Return the total number of HloValues. int64 value_count() const { return values_.size(); } - // Return a vector of all HloValues. - const std::vector& values() const { return values_; } - - // Updates the dataflow after the changing an operand of - // 'instruction'. Dataflow update is not possible if instructions have been - // added or removed from the graph. - void UpdateAfterChangingOperand(HloInstruction* instruction, - HloInstruction* old_operand, - HloInstruction* new_operand); - - // Updates the dataflow after the changing the root of a computation from - // 'old_root' to 'new_root'. - void UpdateAfterChangingRoot(HloInstruction* old_root, - HloInstruction* new_root); - - // Returns the non-phi HloValue that is the unique (transitive) input to the - // given phi. If no such HloValue exists (there are multiple inputs to the - // phi) then nullptr is returned. This is computed by all walking the inputs - // of the given phi value until non-phi HloValue(s) are encountered. - const HloValue* ResolvePhi(const HloValue& phi) const; - const HloValue* ResolvePhi(const HloInstruction* instruction, - const ShapeIndex& index = {}) const { - return ResolvePhi(GetValueDefinedAt(instruction, index)); - } - - // Compare the dataflow analysis against a clean recomputation of the - // analysis. Returns an error status if there is a mismatch. Useful for - // verifying the correctness after updates to the analysis. - Status VerifyAgainstReference() const; + // Return a vector of all HloValues stabily sorted by HloValue::Id. + const std::vector& values() const { return values_vector_; } // Return the call graph used for computing the dataflow. const CallGraph& call_graph() const { return *call_graph_; } @@ -161,6 +123,13 @@ class HloDataflowAnalysis { HloDataflowAnalysis(HloModule* module, bool ssa_form, bool bitcast_defines_value = false); + // Returns a new HloValue defined at the given instruction and shape index. + HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, + bool is_phi = false); + + // Delete the HloValue with the given ID. + void DeleteHloValue(HloValue::Id value_id); + // Constructs and initializes the InstructionValueSets of all instructions to // contain exactly the HloValues defined by each instruction. These values can // then propagated throughout the HLO graph by calling @@ -187,10 +156,11 @@ class HloDataflowAnalysis { void UpdateInstructionsAndPropagate( tensorflow::gtl::ArraySlice instructions); - // Sets the inputs of the given phi to given value(s). - void UpdatePhiInputs( - const HloInstruction* instruction, - tensorflow::gtl::ArraySlice inputs); + // Return the result of the SSA Phi function applied to the given inputs at + // the given instruction. If skip_top_level is true, then the top level of the + // value set of 'instruction' is not modified. + bool Phi(HloInstruction* instruction, + tensorflow::gtl::ArraySlice inputs); // Updates the positions of the HloValues in the output of the given // instruction. This should be called after the instruction value set of @@ -203,20 +173,6 @@ class HloDataflowAnalysis { HloInstruction* instruction, const InstructionValueSet& new_value_set, const InstructionValueSet* prev_value_set = nullptr); - // Returns true if the live range of the given value 'a' is strictly before - // the live range of value 'b' using the given HLO ordering. - bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b, - const HloOrdering& ordering) const; - - // Returns whether the value 'a' is defined before the value 'b' under the - // given ordering. - bool IsDefinedBefore(const HloValue& a, const HloValue& b, - const HloOrdering& ordering) const; - - // Returns whether the given use is before the given value definition. - bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value, - const HloOrdering& ordering) const; - // Verify various invariants of the dataflow analysis. Status Verify() const; @@ -226,19 +182,19 @@ class HloDataflowAnalysis { std::unique_ptr call_graph_; - // Array of all values in the module. This is allocated once at analysis - // construction time so HloValue references are stable. Updates to the - // analysis via UpdateAfterChangingOperand and UpdateAfterChangingRoot do not - // result in the creation or destruction of any HloValues. - std::vector values_; - - // Map hold the inputs to each phi value in the module. Used by ResolvePhi. - tensorflow::gtl::FlatMap> - phi_inputs_; + // The map of all HloValues in the module. We pass around pointers to the + // mapped HloValues, so the underlying container must keep them valid despite + // mutations touching other map entries. + std::unordered_map values_; // A map from instruction to InstructionValueSet. std::unordered_map value_sets_; + + // A vector containing all HloValues sorted by HloValue::Id. + std::vector values_vector_; + + // The Id to use for the next HloValue. + HloValue::Id next_value_id_ = 0; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 9f3dd539efe1ba67514f72cc13ff637a1212aeb2..ef0fa1d745ae38a7f899fe92ee2c5f77e270ec2f 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -44,8 +43,8 @@ class HloDataflowAnalysisTest : public HloTestBase, // Run dataflow analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. - HloDataflowAnalysis& RunAnalysis(bool ssa_form, - bool bitcast_defines_value = false) { + const HloDataflowAnalysis& RunAnalysis(bool ssa_form, + bool bitcast_defines_value = false) { analysis_ = HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value) .ConsumeValueOrDie(); @@ -71,8 +70,8 @@ class HloDataflowAnalysisTest : public HloTestBase, const HloInstruction* b) { EXPECT_FALSE(ShapeUtil::IsTuple(a->shape())); EXPECT_FALSE(ShapeUtil::IsTuple(b->shape())); - return analysis_->MayInterfere(analysis_->GetValueDefinedAt(a), - analysis_->GetValueDefinedAt(b), ordering); + return ordering.MayInterfere(analysis_->GetValueDefinedAt(a), + analysis_->GetValueDefinedAt(b)); } std::unique_ptr module_; @@ -499,37 +498,26 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module()); if (ssa_form) { - // While instruction should define phi values. The value at index {0} is a - // degenerate phi with a single input 'constant1'. - EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0})); - EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi()); - EXPECT_EQ(analysis.ResolvePhi(xla_while, /*index=*/{0}), - &analysis.GetValueDefinedAt(constant1)); - EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0})); - EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi()); - EXPECT_EQ(analysis.ResolvePhi(body_param, /*index=*/{0}), - &analysis.GetValueDefinedAt(constant1)); - EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0})); - EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi()); - EXPECT_EQ(analysis.ResolvePhi(cond_param, /*index=*/{0}), - &analysis.GetValueDefinedAt(constant1)); + // Element 0 of the tuple passed through the body so no phi value is + // defined. + EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0})); + // Element 1 of the tuple should be a phi value. EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1})); EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi()); - EXPECT_EQ(analysis.ResolvePhi(xla_while, /*index=*/{1}), nullptr); EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1})); EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi()); - EXPECT_EQ(analysis.ResolvePhi(body_param, /*index=*/{1}), nullptr); EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1})); EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi()); - EXPECT_EQ(analysis.ResolvePhi(cond_param, /*index=*/{1}), nullptr); - EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{xla_while, 0, {0}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}})); - EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}) - .live_out_of_module()); + // Constant1 passes through the body and out of the module. + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}) .live_out_of_module()); @@ -613,20 +601,15 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) { bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - if (ssa_form) { - EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while2).live_out_of_module()); - EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); - } else { - // Element 0 is passed through all the while instructions and out of the - // module. - EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}), - analysis.GetValueDefinedAt(constant1)); - EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}), - analysis.GetValueDefinedAt(constant1)); - EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}), - analysis.GetValueDefinedAt(constant1)); - EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); - } + // Element 0 is passed through all the while instructions and out of the + // module.. + EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}), + analysis.GetValueDefinedAt(constant1)); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); } TEST_P(HloDataflowAnalysisTest, NestedWhiles) { @@ -705,13 +688,18 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(negate))); if (ssa_form) { EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1})); EXPECT_TRUE( analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi()); - EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0})); - EXPECT_TRUE( - analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi()); + + // Element 0 of the nested while is %negate. + EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0})); + EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(negate))); + // Element 1 is a phi value (join of %add and %constant2). EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1})); EXPECT_TRUE( analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi()); @@ -724,8 +712,6 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { EXPECT_TRUE( analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi()); } else { - EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}), - UnorderedElementsAre(analysis.GetValueDefinedAt(negate))); EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}), UnorderedElementsAre(analysis.GetValueDefinedAt(add), analysis.GetValueDefinedAt(constant2))); @@ -1496,256 +1482,6 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) { EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log)); } -TEST_P(HloDataflowAnalysisTest, UpdateAnalysisForWhile) { - // Test updating dataflow after modifying a module with an array shaped while: - // - // body(F32[] %param): - // %negate = Negate(%param) - // - // condition(F32[] %param): - // return Constant(false) - // - // entry: - // %constant = Constant(1.0) - // %exp = Exp(%constant) - // return While(%exp, body, condition) - // - auto body_builder = HloComputation::Builder("body"); - auto body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); - auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( - scalar_shape_, HloOpcode::kNegate, body_param)); - HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); - - // Condition computation trivially returns a constant "false". - auto cond_builder = HloComputation::Builder("condition"); - auto cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); - cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); - HloComputation* condition = - module_->AddEmbeddedComputation(cond_builder.Build()); - - auto builder = HloComputation::Builder(TestName()); - auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant)); - auto xla_while = builder.AddInstruction( - HloInstruction::CreateWhile(scalar_shape_, condition, body, exp)); - module_->AddEntryComputation(builder.Build()); - - bool ssa_form = GetParam(); - HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - - // Sanity check the initial dataflow analysis before transforming the HLO - // graph. - if (ssa_form) { - EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param)); - EXPECT_TRUE(analysis.GetValueDefinedAt(body_param).is_phi()); - EXPECT_EQ(analysis.ResolvePhi(body_param), nullptr); - - EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param)); - EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param).is_phi()); - EXPECT_EQ(analysis.ResolvePhi(cond_param), nullptr); - - EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module()); - EXPECT_FALSE(analysis.GetValueDefinedAt(negate).live_out_of_module()); - } else { - EXPECT_THAT(HloValuesAt(body_param), - UnorderedElementsAre(analysis.GetValueDefinedAt(exp), - analysis.GetValueDefinedAt(negate))); - EXPECT_THAT(HloValuesAt(cond_param), - UnorderedElementsAre(analysis.GetValueDefinedAt(exp), - analysis.GetValueDefinedAt(negate))); - EXPECT_THAT(HloValuesAt(xla_while), - UnorderedElementsAre(analysis.GetValueDefinedAt(exp), - analysis.GetValueDefinedAt(negate))); - - EXPECT_TRUE(analysis.GetValueDefinedAt(negate).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(exp).live_out_of_module()); - } - - // Set the body root to the body_param. Previously it was Negate(body_param). - body->set_root_instruction(body_param); - - // Prior to updating, verify that the dataflow analysis is no longer valid. - Status verify_status = analysis.VerifyAgainstReference(); - EXPECT_FALSE(verify_status.ok()); - - analysis.UpdateAfterChangingRoot(/*old_root=*/negate, - /*new_root=*/body_param); - - // Analysis should be valid after the update. - TF_EXPECT_OK(analysis.VerifyAgainstReference()); - - if (ssa_form) { - // The phis should now be resolvable as 'exp' is passed through the body - // transparently. - EXPECT_EQ(analysis.ResolvePhi(body_param), - &analysis.GetValueDefinedAt(exp)); - EXPECT_EQ(analysis.ResolvePhi(cond_param), - &analysis.GetValueDefinedAt(exp)); - EXPECT_EQ(analysis.ResolvePhi(xla_while), &analysis.GetValueDefinedAt(exp)); - EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module()); - } else { - EXPECT_THAT(HloValuesAt(body_param), - UnorderedElementsAre(analysis.GetValueDefinedAt(exp))); - EXPECT_THAT(HloValuesAt(cond_param), - UnorderedElementsAre(analysis.GetValueDefinedAt(exp))); - EXPECT_THAT(HloValuesAt(xla_while), - UnorderedElementsAre(analysis.GetValueDefinedAt(exp))); - EXPECT_TRUE(analysis.GetValueDefinedAt(exp).live_out_of_module()); - } - EXPECT_FALSE(analysis.GetValueDefinedAt(negate).live_out_of_module()); - - // Now replace the operand of the while with %constant (was %exp). - TF_ASSERT_OK(exp->ReplaceUseWith(xla_while, constant)); - analysis.UpdateAfterChangingOperand(xla_while, /*old_operand=*/exp, - /*new_operand=*/constant); - - // Verify that the dataflow is correct. - TF_ASSERT_OK(analysis.VerifyAgainstReference()); - - if (ssa_form) { - // The phis now resolve to 'constant'. - EXPECT_EQ(analysis.ResolvePhi(body_param), - &analysis.GetValueDefinedAt(constant)); - EXPECT_EQ(analysis.ResolvePhi(cond_param), - &analysis.GetValueDefinedAt(constant)); - EXPECT_EQ(analysis.ResolvePhi(xla_while), - &analysis.GetValueDefinedAt(constant)); - } else { - EXPECT_THAT(HloValuesAt(body_param), - UnorderedElementsAre(analysis.GetValueDefinedAt(constant))); - EXPECT_THAT(HloValuesAt(cond_param), - UnorderedElementsAre(analysis.GetValueDefinedAt(constant))); - EXPECT_THAT(HloValuesAt(xla_while), - UnorderedElementsAre(analysis.GetValueDefinedAt(constant))); - EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module()); - } - - // And finally make the negate the root of the body again. - body->set_root_instruction(negate); - analysis.UpdateAfterChangingRoot(/*old_root=*/body_param, - /*new_root=*/negate); - - // Verify that the dataflow is correct. - TF_ASSERT_OK(analysis.VerifyAgainstReference()); - - if (ssa_form) { - // Phis should no longer be resolvable. - EXPECT_EQ(analysis.ResolvePhi(body_param), nullptr); - EXPECT_EQ(analysis.ResolvePhi(cond_param), nullptr); - EXPECT_EQ(analysis.ResolvePhi(xla_while), nullptr); - } else { - EXPECT_THAT(HloValuesAt(body_param), - UnorderedElementsAre(analysis.GetValueDefinedAt(constant), - analysis.GetValueDefinedAt(negate))); - EXPECT_THAT(HloValuesAt(cond_param), - UnorderedElementsAre(analysis.GetValueDefinedAt(constant), - analysis.GetValueDefinedAt(negate))); - EXPECT_THAT(HloValuesAt(xla_while), - UnorderedElementsAre(analysis.GetValueDefinedAt(constant), - analysis.GetValueDefinedAt(negate))); - - EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(negate).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module()); - } - - // After the updates, verify that the dataflow is correct. - TF_ASSERT_OK(analysis.VerifyAgainstReference()); -} - -TEST_P(HloDataflowAnalysisTest, UpdateOfATupleSelect) { - // Test changing the operands of kSelects of a tuple value and updating the - // dataflow. - auto builder = HloComputation::Builder(TestName()); - auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); - auto a = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto b = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(2.0))); - auto c = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(3.0))); - auto d = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(4.0))); - auto tuple_a = builder.AddInstruction(HloInstruction::CreateTuple({a})); - auto tuple_b = builder.AddInstruction(HloInstruction::CreateTuple({b})); - auto tuple_c = builder.AddInstruction(HloInstruction::CreateTuple({c})); - auto tuple_d = builder.AddInstruction(HloInstruction::CreateTuple({d})); - const Shape tuple_shape = tuple_a->shape(); - auto select_aa = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple_a, tuple_a)); - auto select_ab = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple_a, tuple_b)); - auto select_cd = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, tuple_c, tuple_d)); - auto select_abcd = builder.AddInstruction(HloInstruction::CreateTernary( - tuple_shape, HloOpcode::kSelect, pred, select_ab, select_cd)); - - module_->AddEntryComputation(builder.Build()); - - bool ssa_form = GetParam(); - HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - - // Sanity check dataflow before changing the graph and updating. - EXPECT_THAT(HloValuesAt(select_aa, /*index=*/{0}), - UnorderedElementsAre(analysis.GetValueDefinedAt(a))); - EXPECT_THAT(HloValuesAt(select_ab, /*index=*/{0}), - UnorderedElementsAre(analysis.GetValueDefinedAt(a), - analysis.GetValueDefinedAt(b))); - EXPECT_THAT(HloValuesAt(select_cd, /*index=*/{0}), - UnorderedElementsAre(analysis.GetValueDefinedAt(c), - analysis.GetValueDefinedAt(d))); - EXPECT_THAT(HloValuesAt(select_abcd, /*index=*/{0}), - UnorderedElementsAre(analysis.GetValueDefinedAt(a), - analysis.GetValueDefinedAt(b), - analysis.GetValueDefinedAt(c), - analysis.GetValueDefinedAt(d))); - EXPECT_TRUE(analysis.GetValueDefinedAt(a).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(b).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(c).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(d).live_out_of_module()); - - // Set the rhs of 'select_aa' to be 'd'. - TF_ASSERT_OK(select_aa->ReplaceOperandWith(2, tuple_d)); - analysis.UpdateAfterChangingOperand(select_aa, /*old_operand=*/tuple_a, - /*new_operand=*/tuple_d); - - // Verify that the dataflow is correct. - TF_ASSERT_OK(analysis.VerifyAgainstReference()); - - EXPECT_THAT(HloValuesAt(select_aa, /*index=*/{0}), - UnorderedElementsAre(analysis.GetValueDefinedAt(a), - analysis.GetValueDefinedAt(d))); - - // Set the lhs of 'select_cd' to be 'a'. - TF_ASSERT_OK(select_cd->ReplaceOperandWith(1, tuple_a)); - analysis.UpdateAfterChangingOperand(select_cd, /*old_operand=*/tuple_c, - /*new_operand=*/tuple_a); - - // Verify that the dataflow is correct. - TF_ASSERT_OK(analysis.VerifyAgainstReference()); - - EXPECT_THAT(HloValuesAt(select_cd, /*index=*/{0}), - UnorderedElementsAre(analysis.GetValueDefinedAt(a), - analysis.GetValueDefinedAt(d))); - EXPECT_THAT(HloValuesAt(select_abcd, /*index=*/{0}), - UnorderedElementsAre(analysis.GetValueDefinedAt(a), - analysis.GetValueDefinedAt(b), - analysis.GetValueDefinedAt(d))); - EXPECT_TRUE(analysis.GetValueDefinedAt(a).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(b).live_out_of_module()); - EXPECT_FALSE(analysis.GetValueDefinedAt(c).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(d).live_out_of_module()); - - // After the updates, verify that the dataflow is correct. - TF_ASSERT_OK(analysis.VerifyAgainstReference()); -} - INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 3cb8f10dd0cdcecb97244b474f65384db5d1c6ee..e09c9d3beb2d33008adade37e30ee2828df721df 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -900,6 +900,93 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; + Status HandleReduceWindow(HloInstruction* reduce_window, + HloInstruction* operand, const Window& window, + HloComputation* function) override { + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferReduceWindowShape( + /*operand_shape=*/reduce_window->operand(0)->shape(), + /*init_value=*/reduce_window->operand(1)->shape(), window, + /*to_apply_shape=*/function->ComputeProgramShape())); + TF_RET_CHECK( + ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) + << "return shape is set to: " + << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanStringWithLayout(inferred_return_shape); + + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(reduce_window->operand(0)); + VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString(); + const Literal& init_literal = + parent_->GetEvaluatedLiteralFor(reduce_window->operand(1)); + VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + auto result = Literal::CreateFromShape(reduce_window->shape()); + + // Creates a Shape object from window, for iteration below. + std::vector window_dimension_sizes; + for (const auto& window_dimension : window.dimensions()) { + window_dimension_sizes.push_back(window_dimension.size()); + } + const Shape window_shape = ShapeUtil::MakeShape( + operand->shape().element_type(), window_dimension_sizes); + + DimensionVector window_index(window.dimensions_size()); + DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); + + // For each resulting dimension, calculate and assign computed value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice output_index) { + ReturnT result_val = init_scalar; + + std::fill(window_index.begin(), window_index.end(), 0); + std::fill(operand_index.begin(), operand_index.end(), 0); + + do { + // Set curr_val to 0 if out of bound (padded). + ReturnT curr_val = static_cast(0); + bool out_of_bound = false; + for (int i = 0; i < operand_index.size(); ++i) { + operand_index[i] = + output_index[i] * window.dimensions(i).stride() + + window_index[i] - window.dimensions(i).padding_low(); + if (operand_index[i] < 0 || + operand_index[i] >= operand_literal.shape().dimensions(i)) { + out_of_bound = true; + break; + } + } + if (!out_of_bound) { + curr_val = operand_literal.Get(operand_index); + } + // Evaluate computation with specified literal operands. + const auto curr_val_literal = Literal::CreateR0(curr_val); + const auto result_val_literal = + Literal::CreateR0(result_val); + const std::vector args = {curr_val_literal.get(), + result_val_literal.get()}; + // We need a new visitor for each evaluation, so that the same + // computation can be visited more than once (with different + // inputs). + HloEvaluator embedded_evaluator; + std::unique_ptr computed_result = + embedded_evaluator.Evaluate(*function, args) + .ConsumeValueOrDie(); + + result_val = computed_result->Get({}); + } while (IndexUtil::BumpIndices(window_shape, &window_index)); + + return result_val; + })); + + parent_->evaluated_[reduce_window] = std::move(result); + return Status::OK(); + }; + Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override { const Shape& shape = slice->shape(); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, @@ -1070,7 +1157,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator* parent_; -}; +}; // namespace xla HloEvaluator::HloEvaluator() { typed_visitors_[PRED] = MakeUnique>(this); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 1bf483b3209e47a169caa19a332dde0cb8cbaad1..9205f5dc4e8bee56e53a99295ca916dcbd5789db 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -332,6 +332,53 @@ TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { LiteralTestUtil::ExpectEqual(*result, *output_literal); } +TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { + HloComputation::Builder b(TestName()); + + HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{-1, -2}, {100, 200}}))); + HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{-2, -3}, {-100, -200}}))); + + std::vector operands = {operand1, operand2}; + + Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); + b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(b.Build()); + + std::unique_ptr result = + evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + + auto expected = + Literal::CreateR2({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { + HloComputation::Builder b(TestName()); + + HloInstruction* operand1 = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({100, 200}))); + HloInstruction* operand2 = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({}))); + + std::vector operands = {operand1, operand2}; + + Shape shape = ShapeUtil::MakeShape(S64, {2}); + b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); + + HloModule module(TestName()); + auto computation = module.AddEntryComputation(b.Build()); + + std::unique_ptr result = + evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR1({100, 200}); + LiteralTestUtil::ExpectEqual(*expected, *result); +} + TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { HloComputation::Builder b(TestName()); @@ -1097,6 +1144,180 @@ TEST_F(HloEvaluatorTest, ReduceAdd) { LiteralTestUtil::ExpectEqual(*expected, *result); } +TEST_F(HloEvaluatorTest, ReduceWindowMax) { + HloComputation::Builder b(TestName()); + + // arg: + // f32[2,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // } + auto arg_array = MakeUnique>(2, 3); + arg_array->FillUnique(1.0f); + auto arg_literal = Literal::CreateR2FromArray2D(*arg_array); + + HloInstruction* arg_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); + + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + + HloComputation::Builder max_computation("max"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + max_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); + HloModule module(TestName()); + auto max_func = module.AddEmbeddedComputation(max_computation.Build()); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + Shape shape = ShapeUtil::MakeShape(F32, {1, 2}); + b.AddInstruction(HloInstruction::CreateReduceWindow( + shape, arg_instruction, init_value, window, max_func)); + + auto computation = module.AddEntryComputation(b.Build()); + std::unique_ptr result = + evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR2({{6, 7}}); + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, ReduceWindowAdd) { + HloComputation::Builder b(TestName()); + + // arg: + // f32[2,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // } + auto arg_array = MakeUnique>(2, 3); + arg_array->FillUnique(1.0f); + auto arg_literal = Literal::CreateR2FromArray2D(*arg_array); + + HloInstruction* arg_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); + + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + + HloComputation::Builder add_computation("add"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + add_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); + HloModule module(TestName()); + auto add_func = module.AddEmbeddedComputation(add_computation.Build()); + + Window window; + WindowDimension dim; + dim.set_size(1); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(1); + dim.set_padding_high(0); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + b.AddInstruction(HloInstruction::CreateReduceWindow( + shape, arg_instruction, init_value, window, add_func)); + + auto computation = module.AddEntryComputation(b.Build()); + std::unique_ptr result = + evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR2({{1, 3, 5}, {5, 11, 13}}); + LiteralTestUtil::ExpectEqual(*expected, *result); +} + +TEST_F(HloEvaluatorTest, ReduceWindowAdd6D) { + HloComputation::Builder b(TestName()); + + // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. + std::vector input_dims(6, 4); + std::unique_ptr arg_literal = + Literal::CreateFullWithMonotonicDim0MajorLayout(input_dims, 1.0f); + + HloInstruction* arg_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); + + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + + HloComputation::Builder add_computation("add"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + add_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); + HloModule module(TestName()); + auto add_func = module.AddEmbeddedComputation(add_computation.Build()); + + Window window; + + WindowDimension trivial_dim; + trivial_dim.set_size(1); + trivial_dim.set_stride(1); + trivial_dim.set_padding_low(0); + trivial_dim.set_padding_high(0); + trivial_dim.set_window_dilation(1); + trivial_dim.set_base_dilation(1); + + WindowDimension active_dim; + active_dim.set_size(2); + active_dim.set_stride(1); + active_dim.set_padding_low(0); + active_dim.set_padding_high(0); + active_dim.set_window_dilation(1); + active_dim.set_base_dilation(1); + + *window.add_dimensions() = trivial_dim; + *window.add_dimensions() = active_dim; + *window.add_dimensions() = active_dim; + *window.add_dimensions() = active_dim; + *window.add_dimensions() = trivial_dim; + *window.add_dimensions() = trivial_dim; + + Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 3, 3, 4, 4}); + b.AddInstruction(HloInstruction::CreateReduceWindow( + shape, arg_instruction, init_value, window, add_func)); + + auto computation = module.AddEntryComputation(b.Build()); + std::unique_ptr result = + evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + + std::vector output_dims = {4, 3, 3, 3, 4, 4}; + std::unique_ptr result_literal = + Literal::CreateFullWithMonotonicDim0MajorLayout(output_dims, 8.0f); + LiteralTestUtil::ExpectEqual(*result_literal, *result); +} + TEST_F(HloEvaluatorTest, StridedSlice) { HloComputation::Builder b(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 24a47f80af517db3b254e679bed4a25a0d60f1d7..07b3369d5c1276f0a62af4d3882fed70277f9a91 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -479,7 +479,7 @@ stylesheet=" // If this edge crosses a fusion cluster boundary, highlight it when the // cluster is hovered over. if (from_node->IsFused() && - from_node->fusion_instruction()->fused_expression_root() == from_node) { + from_node->parent()->root_instruction() == from_node) { int64 cluster_id = cluster_ids_.at(from_node->parent()); add_hover_css_rule("clust", cluster_id, kBlue); } @@ -561,13 +561,21 @@ tooltip = " "; } string comp_body = DumpComputation(subcomp); - string computation = - Printf(computation_fmt, id, style, subcomp_label, comp_body, id); - // Add an edge from the subcomputation to its parent node. If subcomp - // belongs to a fusion node, it's drawn in place of the fusion instruction, so - // there's no need to link those. - if (parent_instr->opcode() != HloOpcode::kFusion) { + if (parent_instr->opcode() == HloOpcode::kFusion) { + // Dump any nested fusion nodes. + for (const auto& subcomp_instr : subcomp->instructions()) { + if (subcomp_instr->opcode() == HloOpcode::kFusion) { + StrAppend( + &comp_body, + DumpSubcomputation(subcomp_instr->fused_instructions_computation(), + subcomp_instr.get())); + } + } + } else { + // Add an edge from the subcomputation to its parent node. If subcomp + // belongs to a fusion node, it's drawn in place of the fusion instruction, + // so there's no need to link those. edge_ids_.insert( {{subcomp->root_instruction(), parent_instr}, next_edge_id_++}); const char* edge_fmt = @@ -578,6 +586,9 @@ tooltip = " "; subcomp->name(), parent_instr->name())); } + string computation = + Printf(computation_fmt, id, style, subcomp_label, comp_body, id); + return computation; } @@ -657,7 +668,7 @@ string HloDotDumper::GetInstructionNodeInlinedConstants( // Special case: If instr is a parameter to a fusion node, check whether the // corresponding operand to the fusion node is a constant. if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { - const HloInstruction* fusion = instr->fusion_instruction(); + const HloInstruction* fusion = instr->parent()->FusionInstruction(); const HloInstruction* operand = fusion->operand(instr->parameter_number()); if (operand->opcode() != HloOpcode::kConstant) { return ""; @@ -898,7 +909,7 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { // expressions are handled specially -- we draw an edge from the corresponding // operand on the fusion node itself to the parameter. if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { - const HloInstruction* fusion = instr->fusion_instruction(); + const HloInstruction* fusion = instr->parent()->FusionInstruction(); add_edge(fusion->operand(instr->parameter_number()), instr, /*operand_num=*/0); } else { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 28ca9153105f75d1093deca54df405f0680106f7..ce9e0db77e1d89819d211eb047452f0d3ec0de7d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -512,7 +512,6 @@ HloInstruction::CreateSelectAndScatter( instruction->set_parent(fused_root->parent()); instruction->set_metadata(fused_root->metadata()); instruction->CloneAndFuseInternal(fused_root); - instruction->CheckFusionInstruction(); return instruction; } @@ -532,13 +531,13 @@ void HloInstruction::MergeFusionInstruction( HloInstruction* instruction_to_merge) { CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion); + CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != + operands().end()); // Clone the instruction from which to merge fused instructions. std::unique_ptr clone = instruction_to_merge->Clone(); // Replace uses of fused parameters with the corresponding operand of the - // fusion. - // Add all non-parameter fused instructions to 'unfused_instructions' to be - // merged into 'this'. - // This is done in reverse post order. + // fusion. Add all non-parameter fused instructions to 'unfused_instructions' + // to be merged into 'this'. This is done in reverse post order. std::vector unfused_instructions; auto fused_instructions = clone->fused_instructions_computation()->MakeInstructionPostOrder(); @@ -563,6 +562,8 @@ void HloInstruction::MergeFusionInstruction( } CHECK_EQ(0, clone->user_count()); clone->DetachFromOperands(); + TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation( + clone->fused_instructions_computation())); } void HloInstruction::MergeFusionInstructionIntoMultiOutput( @@ -634,7 +635,6 @@ HloInstruction* HloInstruction::FuseInstructionInternal( } HloInstruction* fused_instruction = CloneAndFuseInternal(instruction_to_fuse, add_output); - CheckFusionInstruction(); return fused_instruction; } @@ -642,23 +642,19 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( HloInstruction* instruction_to_fuse, bool add_output) { CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK(instruction_to_fuse->IsFusable()); - if (GetModule()) { - XLA_VLOG_LINES(3, GetModule()->ToString()); - } + VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); HloInstruction* clone = nullptr; if (called_computations_.empty()) { // New fusion instruction. It should not be a multioutput instruction. CHECK(!add_output); - auto builder = HloComputation::Builder("fused_computation", true); + auto builder = HloComputation::Builder("fused_computation", this); builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); called_computations_.push_back( CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); clone = fused_expression_root(); - clone->parent_fusion_instruction_ = this; } else { clone = fused_instructions_computation()->AddInstruction( instruction_to_fuse->Clone(/*suffix=*/"")); - clone->parent_fusion_instruction_ = this; // When add_output is false, instruction_to_fuse is necessarily an operand // of the fusion instruction. After fusion this will no longer be the case. // Remove the operand from the operand list and remove its corresponding @@ -727,12 +723,8 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // to avoid a double %%. string param_name = StrCat(operand->name().substr(1), ".param_", param_no); - std::unique_ptr param_instruction = - CreateParameter(param_no, operand->shape(), param_name); - - param_instruction->parent_fusion_instruction_ = this; fused_param = fused_instructions_computation()->AddParameter( - std::move(param_instruction)); + CreateParameter(param_no, operand->shape(), param_name)); AppendOperand(operand); } TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); @@ -762,7 +754,6 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( HloInstruction::CreateTuple(tuple_elements)); fused_instructions_computation()->set_root_instruction(new_root); shape_ = new_root->shape(); - new_root->parent_fusion_instruction_ = this; if (fused_root->opcode() == HloOpcode::kTuple) { TF_CHECK_OK( fused_instructions_computation()->RemoveInstruction(fused_root)); @@ -800,13 +791,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( } } - for (HloComputation* computation : - instruction_to_fuse->called_computations()) { - if (std::find(called_computations_.begin(), called_computations_.end(), - computation) == called_computations_.end()) { - called_computations_.push_back(computation); - } - } + VLOG(2) << "New clone:\n" << clone->ToString(); return clone; } @@ -835,82 +820,6 @@ bool HloInstruction::HasSideEffect() const { } } -void HloInstruction::CheckFusionInstruction() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - - const std::list>& fused_instructions_ = - fused_instructions_computation()->instructions(); - // All instructions owned by this fusion instruction must be fused, and the - // parent fusion instruction of the fused instructions must be 'this'. - for (auto& instruction : fused_instructions_) { - CHECK(instruction->IsFused()); - CHECK_EQ(this, instruction->fusion_instruction()); - CHECK_EQ(fused_instructions_computation(), instruction->parent()) - << instruction->ToString(); - } - - // Fused root instruction and fused parameters must all be owned by the fusion - // instruction. - bool root_owned = false; - const std::vector& fused_parameters_ = fused_parameters(); - const HloInstruction* fused_root_ = fused_expression_root(); - std::vector parameter_owned(fused_parameters_.size(), false); - for (auto& instruction : fused_instructions_) { - if (fused_root_ == instruction.get()) { - CHECK(!root_owned); - root_owned = true; - } - for (int i = 0; i < fused_parameters_.size(); ++i) { - if (fused_parameters_[i] == instruction.get()) { - CHECK(!parameter_owned[i]); - parameter_owned[i] = true; - } - } - } - CHECK(root_owned); - // Make sure all the parameter_owned entries are set - for (int i = 0; i < parameter_owned.size(); i++) { - CHECK(parameter_owned[i]); - } - - // Fused root must have no users. - CHECK_EQ(0, fused_root_->user_count()); - - // All uses of fused instructions must be in the fusion instruction, and every - // non-root instruction must have at least one use. - for (auto& instruction : fused_instructions_) { - if (instruction.get() != fused_root_) { - CHECK_GT(instruction->user_count(), 0); - for (auto& user : instruction->users()) { - CHECK(user->IsFused()); - CHECK_EQ(this, user->fusion_instruction()); - } - } - } - - // Fused parameter instructions must be numbered contiguously and match up - // (shapes compatible) with their respective operand. - CHECK_EQ(operands_.size(), fused_parameters_.size()); - std::vector parameter_numbers(fused_parameters_.size(), false); - for (auto fused_param : fused_parameters_) { - int64 param_no = fused_param->parameter_number(); - CHECK_GE(param_no, 0); - CHECK_LT(param_no, fused_parameters_.size()); - CHECK(!parameter_numbers[param_no]); - parameter_numbers[param_no] = true; - CHECK(ShapeUtil::Compatible(fused_param->shape(), - operands_[param_no]->shape())); - } - // Make sure all the parameter_numbers entries were seen - for (int i = 0; i < parameter_numbers.size(); i++) { - CHECK(parameter_numbers[i]); - } - - // Operands must be distinct. - std::set operand_set(operands_.begin(), operands_.end()); - CHECK_EQ(operand_set.size(), operands_.size()); -} - /* static */ std::unique_ptr HloInstruction::CreateCall( const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* computation) { @@ -948,6 +857,12 @@ void HloInstruction::CheckFusionInstruction() const { std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice new_operands) { + VLOG(3) << "CloneWithNewOperands:\n " << ToString(); + VLOG(3) << " new operands:"; + for (const HloInstruction* new_operand : new_operands) { + VLOG(3) << " " << new_operand->name(); + } + // Explicitly call the factory for the instruction type. This is more robust // in the face of code changes than copying fields explicitly. This also // properly sets the user fields of the operands. @@ -1166,15 +1081,10 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( std::list> new_fused_instructions; // Create the list of fused parameters by mapping through the cloned, // fused instructions. - std::vector new_fused_parameters; - const std::vector& fused_parameters_ = - fused_instructions_computation()->parameter_instructions(); - - for (HloInstruction* old_fused_parameter : fused_parameters_) { + for (HloInstruction* old_fused_parameter : + fused_instructions_computation()->parameter_instructions()) { new_fused_instructions.push_back(old_fused_parameter->Clone()); HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); - new_fusion_parameter->parent_fusion_instruction_ = new_instruction.get(); - new_fused_parameters.push_back(new_fusion_parameter); InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter); } for (auto old_fused_instruction : @@ -1195,12 +1105,12 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( old_fused_instruction->shape(), new_operands)); HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); new_fused_instruction->set_parent(parent()); - new_fused_instruction->parent_fusion_instruction_ = new_instruction.get(); InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); } new_instruction->fusion_kind_ = fusion_kind_; auto computation_builder = HloComputation::Builder( - fused_instructions_computation()->name() + ".clone", true); + fused_instructions_computation()->name() + ".clone", + new_instruction.get()); // We iterated the fusion instructions in reverse post order which means // that we must reverse our new list of fusion instructions. for (auto new_fused_instruction_iter = new_fused_instructions.rbegin(); @@ -1214,7 +1124,6 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( ->AddEmbeddedComputation( computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); new_instruction->set_parent(parent()); - new_instruction->CheckFusionInstruction(); return new_instruction; } @@ -1554,6 +1463,7 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { } void HloInstruction::DetachFromOperands() { + VLOG(3) << "DetachFromOperands:\n " << ToString(); CHECK_EQ(0, user_count()); // An instruction may be repeated as an operand. To avoid calling RemoveUser // twice on the same operand, keep a set of already detached operands. @@ -1681,6 +1591,21 @@ string HloInstruction::ExtendedOpcodeStr() const { string HloInstruction::ToString(bool compact_operands, bool include_metadata) const { + string result = + StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ", + ExtendedOpcodeStr(), "(", OperandsToString(compact_operands), ")"); + for (const string& extra : ExtraAttributesToString()) { + StrAppend(&result, ", ", extra); + } + if (include_metadata && + (!metadata_.op_type().empty() || !metadata_.op_name().empty() || + !metadata_.source_file().empty())) { + StrAppend(&result, " # metadata=", metadata_.ShortDebugString()); + } + return result; +} + +string HloInstruction::OperandsToString(bool compact) const { string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. @@ -1709,12 +1634,12 @@ string HloInstruction::ToString(bool compact_operands, } else { tensorflow::gtl::ArraySlice slice(operands_); const int64 kMaxOperandsToShowIfCompact = 4; - if (compact_operands && slice.size() > kMaxOperandsToShowIfCompact) { + if (compact && slice.size() > kMaxOperandsToShowIfCompact) { slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); } operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { *out += ShapeUtil::HumanStringWithLayout(operand->shape()); - if (!compact_operands) { + if (!compact) { StrAppend(out, " ", operand->name()); } }); @@ -1723,15 +1648,19 @@ string HloInstruction::ToString(bool compact_operands, StrAppend(&operands, ", ...(+", remaining, ")"); } } - string extra; + return operands; +} + +std::vector HloInstruction::ExtraAttributesToString() const { + std::vector extra; if (CanHaveDimensionsField()) { - StrAppend(&extra, ", dimensions={", Join(dimensions(), ","), "}"); + extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}")); } if (window_ != nullptr) { - StrAppend(&extra, ", ", window_util::ToString(*window_)); + extra.push_back(window_util::ToString(*window_)); } if (padding_config_ != nullptr) { - StrAppend(&extra, ", padding=", padding_config_->ShortDebugString()); + extra.push_back(StrCat("padding=", padding_config_->ShortDebugString())); } if (!slice_starts_.empty() && !slice_limits_.empty()) { std::vector bounds; @@ -1740,45 +1669,38 @@ string HloInstruction::ToString(bool compact_operands, bounds.push_back( StrCat("[", slice_starts_[i], ":", slice_limits_[i], "]")); } - StrAppend(&extra, ", slice={", Join(bounds, ", "), "}"); + extra.push_back(StrCat("slice={", Join(bounds, ", "), "}")); } if (convolution_dimension_numbers_ != nullptr) { - StrAppend(&extra, ", ", ConvolutionDimensionNumbersToString()); + extra.push_back(ConvolutionDimensionNumbersToString()); } if (opcode() == HloOpcode::kWhile) { - StrAppend(&extra, ", condition=", while_condition()->name()); - StrAppend(&extra, ", body=", while_body()->name()); + extra.push_back(StrCat("condition=", while_condition()->name())); + extra.push_back(StrCat("body=", while_body()->name())); } else if (opcode() == HloOpcode::kSelectAndScatter) { - StrAppend(&extra, ", select=", select()->name()); - StrAppend(&extra, ", scatter=", scatter()->name()); + extra.push_back(StrCat("select=", select()->name())); + extra.push_back(StrCat("scatter=", scatter()->name())); } else if (!called_computations().empty()) { - StrAppend(&extra, ", calls=", - Join(called_computations(), ", ", - [](string* out, const HloComputation* computation) { - StrAppend(out, computation->name()); - })); + extra.push_back(StrCat( + "calls=", Join(called_computations(), ", ", + [](string* out, const HloComputation* computation) { + StrAppend(out, computation->name()); + }))); } if (opcode() == HloOpcode::kGetTupleElement) { - StrAppend(&extra, ", index=", tuple_index()); + extra.push_back(StrCat("index=", tuple_index())); } if (!control_successors_.empty()) { - StrAppend( - &extra, ", control-successors=", + extra.push_back(StrCat( + "control-successors=", Join(control_successors_, ", ", [](string* out, HloInstruction* succ) { StrAppend(out, succ->name()); - })); - } - if (include_metadata && - (!metadata_.op_type().empty() || !metadata_.op_name().empty() || - !metadata_.source_file().empty())) { - StrAppend(&extra, " # metadata=", metadata_.ShortDebugString()); + }))); } - - return StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ", - ExtendedOpcodeStr(), "(", operands, ")", extra); + return extra; } string HloInstruction::ToShortString() const { @@ -1904,9 +1826,7 @@ string HloInstruction::TracingTag() const { return literal_->u8s_string(); } -bool HloInstruction::IsFused() const { - return parent_fusion_instruction_ != nullptr; -} +bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } bool HloInstruction::IsFusable() const { // Instructions which are traced should not be fused. @@ -1941,11 +1861,6 @@ HloComputation* HloInstruction::fused_instructions_computation() const { return fused_instructions_computation; } -HloInstruction* HloInstruction::fusion_instruction() const { - CHECK(IsFused()); - return parent_fusion_instruction_; -} - HloInstruction* HloInstruction::fused_expression_root() const { CHECK_EQ(opcode_, HloOpcode::kFusion); return fused_instructions_computation()->root_instruction(); @@ -2131,6 +2046,7 @@ using DFSStack = // cycle was detected, and true otherwise. inline bool PushDFSChild(DfsHloVisitor* visitor, DFSStack* dfs_stack, HloInstruction* child) { + CHECK(child != nullptr); const int id = child->unique_id(); CHECK_GE(id, 0) << "instruction may not have a parent computation"; switch (visitor->GetVisitState(id)) { @@ -2193,7 +2109,6 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor, visitor->SetVisitState(current_id, DfsHloVisitor::kVisiting); const size_t old_dfs_stack_size = dfs_stack.size(); - for (HloInstruction* child : current_node->operands()) { if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) { return FailedPrecondition( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a97066a7854b980a24d28ff618edc41192bd5486..bd8b8ac9bd8eea8d4fbda1b68a305f3707b99b10 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -548,6 +548,14 @@ class HloInstruction { string ToString(bool compact_operands = false, bool include_metadata = true) const; + // Components of the ToString() representation: + + // Returns a string representation of the operand list. + string OperandsToString(bool compact) const; + + // Returns string representation of op-specific attributes. + std::vector ExtraAttributesToString() const; + string ToStringNoMetadata() const { return ToString(false, false); } // As ToString, but returns a shorter string. @@ -603,26 +611,21 @@ class HloInstruction { // instruction. bool IsFused() const; + // Returns the computation for this fused instruction. + // + // Precondition: opcode() == HloOpcode::kFusion + HloComputation* fused_instructions_computation() const; + // Returns true if this instruction can be legally fused into a fusion // instruction. bool IsFusable() const; - // Returns the fusion instruction that contains this instruction. - // - // Note: only valid if this instruction is fused into a fusion instruction. - HloInstruction* fusion_instruction() const; - // Returns the root instruction of the fused expression contained within this // fusion instruction. // // Precondition: opcode() == HloOpcode::kFusion HloInstruction* fused_expression_root() const; - // Returns the computation for this fused instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - HloComputation* fused_instructions_computation() const; - // Returns the list of fused instructions inside this fusioninstruction. // // Note: although the list itself is const, the instructions contained in the @@ -802,8 +805,7 @@ class HloInstruction { const Shape& shape, tensorflow::gtl::ArraySlice operands); - // Returns the computations this instruction calls (if any). This includes - // computations called by fused instructions inside of a fusion instruction. + // Returns the computations this instruction directly calls (if any). const std::vector& called_computations() const { return called_computations_; } @@ -898,14 +900,6 @@ class HloInstruction { // instruction to make it a bitcast. bool CouldBeBitcast() const; - // Sets the parent fusion instruction for this instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - void SetParentFusion(HloInstruction* fusion_instruction) { - CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); - parent_fusion_instruction_ = fusion_instruction; - } - // Get/Set the number of partitions per outer dimension (in order, starting // with outer-most dimension first). Currently used by the parallel cpu // backend to partition HLOs into parallel tasks. @@ -976,9 +970,6 @@ class HloInstruction { const Shape& shape, tensorflow::gtl::ArraySlice operands); - // CHECKs various invariants of a fusion instruction. - void CheckFusionInstruction() const; - // Returns true if this instruction can legally have the dimensions field // set. Used for checking precondition of dimensions field accessors. bool CanHaveDimensionsField() const; @@ -1049,10 +1040,6 @@ class HloInstruction { // padding of this pad instruction. Only set for pad instructions. std::unique_ptr padding_config_; - // If this instruction is fused into a fusion instruction, this field points - // to the fusion instruction. - HloInstruction* parent_fusion_instruction_ = nullptr; - // The type of the fusion. Used by kFusion only. FusionKind fusion_kind_; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index ea5749581b57bb0ddf4c7c844bc1399c629594e5..2e1eeee36b58826045f2aeabf74497b019aa1764 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -758,16 +758,13 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { auto* fusion = computation->CreateFusionInstruction( {map_3_y}, HloInstruction::FusionKind::kLoop); auto* fused_computation = fusion->fused_instructions_computation(); - EXPECT_THAT(fusion->called_computations(), - ElementsAre(fused_computation, computation_y)); + EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation)); fusion->FuseInstruction(map_2_x); - EXPECT_THAT(fusion->called_computations(), - ElementsAre(fused_computation, computation_y, computation_x)); + EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation)); fusion->FuseInstruction(map_1_x); - EXPECT_THAT(fusion->called_computations(), - ElementsAre(fused_computation, computation_y, computation_x)); + EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation)); } TEST_F(HloInstructionTest, ComplexFusionOp) { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 4c3ff3bdafc0e5184b715b938b317c3ff85fbfa8..08f572bb2aba6c972ca0e8ee826c2ffac2e739c2 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -31,66 +32,6 @@ limitations under the License. namespace xla { -namespace { - -// Returns the nearest call graph ancestors of instructions 'a' and 'b' for -// which the ancestors are in the same computation. An instruction is an call -// graph ancestor of 'a' if the instruction calls the computation containing 'a' -// either directly or transitively. Degeneratively an instruction is an ancestor -// of itself. nullptr is returned if there is no common ancestor or if the -// caller chain of 'a' or 'b' diverges (has multiple callers) before the nearest -// common ancestor. -// -// Example: -// -// Entry computation: -// %x = Call(A, {Constant(42.0)}) -// %y = Call(B, {%x}) -// -// Computation A: -// %a = Negate(Param()) -// -// Computation B: -// %b = Exp(Param()); -// -// If called with %a and %b, this function would return (%x, %y). %x is an -// ancestor of %a, and %y is an ancestor of %b, and %x and %y are in the same -// computation. -std::pair -GetNearestCallGraphAncestorsInSameComputation(const HloInstruction* a, - const HloInstruction* b, - const CallGraph& call_graph) { - // Lambda which returns the next instruction in the callee->caller chain in - // the call graph. This is the unique instruction which calls the computation - // containing 'instruction'. If more than one instruction calls the - // computation containing 'instruction' or no instructions call the - // computation then nullptr is returned. - auto next_caller = - [&call_graph]( - const HloInstruction* instruction) -> const HloInstruction* { - const CallGraphNode& node = call_graph.GetNode(instruction->parent()); - if (node.caller_callsites().size() != 1) { - return nullptr; - } - return node.caller_callsites()[0].instruction(); - }; - - // Iterate through the callee->caller chains and find the earliest common - // element. - for (const HloInstruction* a_ancestor = a; a_ancestor != nullptr; - a_ancestor = next_caller(a_ancestor)) { - for (const HloInstruction* b_ancestor = b; b_ancestor != nullptr; - b_ancestor = next_caller(b_ancestor)) { - if (a_ancestor->parent() == b_ancestor->parent()) { - return {a_ancestor, b_ancestor}; - } - } - } - return {nullptr, nullptr}; -} - -} // namespace - bool HloOrdering::ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const { // 'a' and 'b' may be in different computations. In this case, find the @@ -100,7 +41,8 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, const HloInstruction* a_ancestor; const HloInstruction* b_ancestor; std::tie(a_ancestor, b_ancestor) = - GetNearestCallGraphAncestorsInSameComputation(a, b, *call_graph_); + call_graph_->NearestAncestorsInSameComputation( + const_cast(a), const_cast(b)); if (a_ancestor == nullptr) { // Ancestors in a common computation could not be found so consider the @@ -127,6 +69,155 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); } +bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { + // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b' + // is live into the module. + const HloModule* module = b.defining_instruction()->parent()->parent(); + if (b.defining_instruction()->parent() == module->entry_computation() && + b.defining_instruction()->opcode() == HloOpcode::kParameter) { + return false; + } + + // Phi values require special handling. Because XLA does not have a phi + // instruction, the definition instruction of the phis values are + // placeholders: either the subcomputation parameter (body or condition) or + // the while instruction. However, the program point where these values are + // logically defined does not necessarily coincide exactly with program point + // of these place-holder instructions. So we explicitly define the following + // order for phi values: + // + // body/condition parameter phi: + // Defined before all values defined in its computation excepting other + // phis. + // + // while phi: + // defined after all values defined in the condition or body. + // + auto is_body_or_condition_phi = [](const HloValue& v) { + return v.is_phi() && + v.defining_instruction()->opcode() == HloOpcode::kParameter; + }; + if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) && + call_graph_->InstructionIsNestedIn(b.defining_instruction(), + a.defining_instruction()->parent())) { + return true; + } + if (is_body_or_condition_phi(b) && + call_graph_->InstructionIsNestedIn(a.defining_instruction(), + b.defining_instruction()->parent())) { + return false; + } + + // If 'b' is a while phi and 'a' is in the body or condition, then 'a' + // executes before 'b'. + if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile && + (call_graph_->InstructionIsNestedIn( + a.defining_instruction(), b.defining_instruction()->while_body()) || + call_graph_->InstructionIsNestedIn( + a.defining_instruction(), + b.defining_instruction()->while_condition()))) { + return true; + } + + return ExecutesBefore(a.defining_instruction(), b.defining_instruction()); +} + +/* static */ +bool HloOrdering::UseIsBeforeValueDefinition(const HloUse& use, + const HloValue& value) const { + VLOG(4) << "UseIsBeforeValueDefinition(use=" << use + << ", value=" << value.ToShortString() << ")"; + if (ExecutesBefore(use.instruction, value.defining_instruction())) { + VLOG(4) << " use instruction executes before value-defining instruction"; + return true; + } + + // If the use is at the instruction where the value is defined, then the use + // is before the def if the instruction allows buffer sharing (in place + // computation). + if (use.instruction == value.defining_instruction() && + CanShareOperandBufferWithUser( + use.instruction->mutable_operand(use.operand_number), + use.operand_index, value.defining_instruction(), + value.defining_index())) { + VLOG(4) << " use is value def, and instruction can share use buffer"; + return true; + } + + // The use at a while is an input to a phi, and logically occurs before values + // are defined in the body or condition computations. + if (use.instruction->opcode() == HloOpcode::kWhile) { + const HloInstruction* xla_while = use.instruction; + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + xla_while->while_body()) || + call_graph_->InstructionIsNestedIn(value.defining_instruction(), + xla_while->while_condition())) { + VLOG(4) << " use is while " << use.instruction->name() + << " and def is in condition or body"; + return true; + } + } + + // Similarly if the value is defined at a while, it logically occurs after any + // uses in the body or condition computations. + if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { + CHECK(value.is_phi()); + const HloInstruction* xla_while = value.defining_instruction(); + if (call_graph_->InstructionIsNestedIn(use.instruction, + xla_while->while_body()) || + call_graph_->InstructionIsNestedIn(use.instruction, + xla_while->while_condition())) { + VLOG(4) << " value is while " << value.defining_instruction()->name() + << " and use is in condition or body"; + return true; + } + } + VLOG(4) << " use is not before while"; + return false; +} + +bool HloOrdering::LiveRangeStrictlyBefore(const HloValue& a, + const HloValue& b) const { + VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString() + << ", b = " << b.ToShortString() << ")"; + if (!IsDefinedBefore(a, b)) { + VLOG(4) << "a not defined before b"; + return false; + } + + // Live-out values from the module can never have ranges strictly before any + // other value. + if (a.live_out_of_module()) { + VLOG(4) << "a is live out of module"; + return false; + } + + // Live-out values of computations can never have ranges strictly before any + // other value in the computation (including values nested in + // subcomputations). + if (a.live_out_of_computation() && + call_graph_->InstructionIsNestedIn(b.defining_instruction(), + a.defining_instruction()->parent())) { + VLOG(4) << "a is live out of computation containing b"; + return false; + } + + // All uses of 'a' must be before 'b' is defined. + for (const HloUse& use : a.uses()) { + if (!UseIsBeforeValueDefinition(use, b)) { + VLOG(4) << "use of a (" << use << ") not before b is defined"; + return false; + } + } + + return true; +} + +bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b) const { + // Buffers without disjoint liveness may interfere. + return !LiveRangeStrictlyBefore(a, b) && !LiveRangeStrictlyBefore(b, a); +} + HloOrderingProto HloOrdering::ToProto() const { HloOrderingProto proto; for (const auto& computation : module_->computations()) { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index 130431f28070d52c3a76befa0d5272a3cc295711..efb5fca188a756b1fadda25f90defd94d8e3cb1c 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -41,11 +42,30 @@ class HloOrdering { // not reflexive, that is, an instruction does not execute before itself. bool ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const; + // Returns whether the value 'a' is defined before the value 'b' under the + // given ordering. + bool IsDefinedBefore(const HloValue& a, const HloValue& b) const; + + // Returns whether the given use is before the given value definition under + // the given ordering. + bool UseIsBeforeValueDefinition(const HloUse& use, + const HloValue& value) const; + // Returns whether the given values interfere. Two values interfere if they + // may both be simultaneously live. + bool MayInterfere(const HloValue& a, const HloValue& b) const; + + // Returns true if the live range of the given value 'a' is strictly before + // the live range of value 'b' using the given HLO ordering. + bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b) const; + // Returns the sequential instruction order for the given computation, or // nullptr if the computation does not have a sequential ordering. virtual const std::vector* SequentialOrder( const HloComputation& computation) const = 0; + // Return the call graph of the module used to compute ordering. + const CallGraph& call_graph() const { return *call_graph_; } + virtual string ToString() const = 0; // Returns the serialized representation of this ordering. @@ -81,6 +101,14 @@ class PredecessorHloOrdering : public HloOrdering { return nullptr; } + HloReachabilityMap& reachability_map(const HloComputation* computation) { + return *predecessors_.at(computation); + } + const HloReachabilityMap& reachability_map( + const HloComputation* computation) const { + return *predecessors_.at(computation); + } + protected: explicit PredecessorHloOrdering(const HloModule* module); string ToStringHelper(const string& name) const; diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index ad6070a9c1b45afd418c9210a2d1b3def3eaf4d5..c95e44bd5d9d2ed87992d630bed4c1fe5c161383 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" @@ -218,6 +219,94 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param)); } +TEST_F(HloOrderingTest, ValuesInWhileComputations) { + // Tests the ordering of values (defined by dataflow analysis) in the body and + // condition of a while instruction. HLO code: + // + // body(F32[]) %param): + // %negate = Negate(%param) + // + // condition(F32[] %param): + // %convert = Convert(%param) + // + // entry: + // %constant = Constant(1.0) + // %while = While(%constant, body, condition) + // %add = Add(%constant, %while) + // + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "body_param")); + auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape, HloOpcode::kNegate, body_param)); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "cond_param")); + auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(xla::PRED, {}), cond_param)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(scalar_shape, condition, body, constant)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, constant, xla_while)); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN( + auto dataflow, HloDataflowAnalysis::Run(module.get(), /*ssa_form=*/true)); + DependencyHloOrdering ordering(module.get()); + + // Init value is defined before the while, but live range is not before the + // while because of the use of the init value in the add. + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant), + dataflow->GetValueDefinedAt(xla_while))); + EXPECT_FALSE( + ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(constant), + dataflow->GetValueDefinedAt(xla_while))); + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant), + dataflow->GetValueDefinedAt(xla_while))); + + // Any value defined in the body or condition is defined before the while, and + // has a live range strictly before the while. + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(negate), + dataflow->GetValueDefinedAt(xla_while))); + EXPECT_TRUE( + ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(negate), + dataflow->GetValueDefinedAt(xla_while))); + EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate), + dataflow->GetValueDefinedAt(xla_while))); + + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert), + dataflow->GetValueDefinedAt(xla_while))); + EXPECT_TRUE( + ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(convert), + dataflow->GetValueDefinedAt(xla_while))); + EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(convert), + dataflow->GetValueDefinedAt(xla_while))); + + // The live range of the while should be before the add. + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(xla_while), + dataflow->GetValueDefinedAt(add))); + ASSERT_EQ(dataflow->GetValueDefinedAt(xla_while).uses().size(), 1); + + const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0]; + EXPECT_EQ(while_use.instruction, add); + EXPECT_TRUE(ordering.UseIsBeforeValueDefinition( + while_use, dataflow->GetValueDefinedAt(add))); + EXPECT_TRUE( + ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(xla_while), + dataflow->GetValueDefinedAt(add))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 4b824f8240074e7ae70b9d9fa82dfa0706d5b355..7ad33c8947c0bf7b013d26bb47b14e62688151e9 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; namespace xla { @@ -54,10 +55,18 @@ StatusOr HloPassPipeline::Run(HloModule* module) { << tensorflow::str_util::Join(disabled_passes, ", "); } - auto run_invariant_checkers = [this, module]() -> Status { + auto run_invariant_checkers = [this, + module](const string& message) -> Status { for (auto& invariant_checker : invariant_checkers_) { - TF_ASSIGN_OR_RETURN(bool changed, invariant_checker->Run(module)); - TF_RET_CHECK(!changed) << "invariant checkers must not change the graph"; + VLOG(1) << " Invariant checker " << invariant_checker->name(); + StatusOr changed_status = invariant_checker->Run(module); + if (!changed_status.ok()) { + return Status(changed_status.status().code(), + StrCat(changed_status.status().error_message(), + "\n\nFailed ", message)); + } + TF_RET_CHECK(!changed_status.ValueOrDie()) + << "invariant checkers must not change the graph"; } return Status::OK(); }; @@ -65,6 +74,8 @@ StatusOr HloPassPipeline::Run(HloModule* module) { string prefix = name().ToString() + ": pipeline start"; bool changed = false; string message; + TF_RETURN_IF_ERROR( + run_invariant_checkers(StrCat("before running pipeline: ", name()))); for (auto& pass : passes_) { if (disabled_passes.count(pass->name().ToString()) > 0) { VLOG(1) << " Skipping HLO pass " << pass->name() @@ -79,14 +90,14 @@ StatusOr HloPassPipeline::Run(HloModule* module) { StrAppend(&message, prefix, ", before ", pass->name()); DumpModule(*module, message); - TF_RETURN_IF_ERROR(run_invariant_checkers()); TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); + TF_RETURN_IF_ERROR( + run_invariant_checkers(StrCat("after running pass: ", pass->name()))); changed |= changed_this_pass; prefix.clear(); StrAppend(&prefix, name(), ": after ", pass->name()); } - TF_RETURN_IF_ERROR(run_invariant_checkers()); DumpModule(*module, prefix + ", pipeline end"); return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 278a1d7efadf6f524c490fdfe778648e1b820ef2..6e5d7bca75cba020af5ec7e0bec511d2fa693286 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1202,7 +1202,7 @@ StatusOr HloRematerialization::RematerializeComputation( StatusOr HloRematerialization::Run( HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit_bytes) { + int64 memory_limit_bytes, RematerializationSizes* sizes) { // The sequence is constructed entirely by this method. TF_RET_CHECK(sequence->empty()); @@ -1248,7 +1248,8 @@ StatusOr HloRematerialization::Run( sequence->at(node.computation()))); } return Status::OK(); - })); + }, + /*visit_unreachable_nodes=*/false)); // The peak memory usage of the module equals the peak memory use of the entry // computation plus the output size of the computation. This is because the @@ -1318,13 +1319,20 @@ StatusOr HloRematerialization::Run( << HumanReadableNumBytes(reduced_peak_memory) << " (" << reduced_peak_memory << " bytes)"; + if (sizes != nullptr) { + sizes->before_bytes = before_peak_memory; + sizes->after_bytes = current_peak_memory; + } + XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); if (current_peak_memory > memory_limit_bytes) { - LOG(WARNING) << "Can't reduce memory use below " - << HumanReadableNumBytes(memory_limit_bytes) - << " by rematerialization (only reduced to " - << HumanReadableNumBytes(current_peak_memory) << ")"; + LOG(WARNING) << tensorflow::strings::Printf( + "Can't reduce memory use below %s (%lld bytes) by rematerialization; " + "only reduced to %s (%lld bytes)", + HumanReadableNumBytes(memory_limit_bytes).c_str(), memory_limit_bytes, + HumanReadableNumBytes(current_peak_memory).c_str(), + current_peak_memory); } return changed; @@ -1333,9 +1341,10 @@ StatusOr HloRematerialization::Run( /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( const HloRematerialization::ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, - SequentialHloOrdering::HloModuleSequence* sequence) { + SequentialHloOrdering::HloModuleSequence* sequence, + RematerializationSizes* sizes) { HloRematerialization remat(size_function); - return remat.Run(hlo_module, sequence, memory_limit_bytes); + return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 42c279d440b78d90b9f19b92155c52787156e4b7..11f79a6d4158c6251c2faf63e9cac4e742440863 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -28,6 +28,13 @@ class HloRematerialization { public: using ShapeSizeFunction = std::function; + // Helper struct that communicates the before / after sizes for the + // rematerialization process. + struct RematerializationSizes { + int64 before_bytes; + int64 after_bytes; + }; + // Rematerialize HLO instructions in the given module to reduce peak memory // use below memory_limit_bytes where memory use is defined as the total size // of all live HLO instruction values. Parameters and constants are included @@ -46,6 +53,9 @@ class HloRematerialization { // rematerialization. This is the order in which HLO instructions should // be emitted to minimize memory use. // + // sizes: Optional outparam that indicates the peak memory usage of the HLO + // module before/after rematerialization. + // // Returns whether any instructions were rematerialized. If memory use is // already below the given limit then no instructions are rematerialized and // false is returned. @@ -55,8 +65,8 @@ class HloRematerialization { // code generation. static StatusOr RematerializeAndSchedule( const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, - SequentialHloOrdering::HloModuleSequence* sequence); + HloModule* hlo_module, SequentialHloOrdering::HloModuleSequence* sequence, + RematerializationSizes* sizes = nullptr); protected: HloRematerialization(const ShapeSizeFunction& size_function) @@ -69,7 +79,7 @@ class HloRematerialization { // contains the memory-minimizing order in which to emit the HLO instructions. StatusOr Run(HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit); + int64 memory_limit, RematerializationSizes* sizes); // Rematerializes instructions within the given computation. 'order' is the // order in which the computation's instructions will be emitted in the diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 3df760d159a5d2d5bdca4b12ce6e4b23c75f9ac0..25be448c8d186514e5d5d04382f4733fee3af68b 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -72,6 +72,13 @@ class ListScheduler { return scheduler.CreateSchedule(); } + // Returns whether the memory used by the given HLO should be ignored by the + // scheduling heuristic. + static bool IgnoreInstruction(const HloInstruction& instruction) { + return instruction.opcode() == HloOpcode::kParameter || + instruction.opcode() == HloOpcode::kConstant; + } + private: // The scheduling priority of an instruction is first the number of bytes // freed by scheduling the instruction, and second (tie-breaker) by the number @@ -127,9 +134,8 @@ class ListScheduler { // Returns whether the memory used by the given buffer should be ignored by // the scheduling heuristic. - bool IgnoreBuffer(const LogicalBuffer& buffer) { - return buffer.instruction()->opcode() == HloOpcode::kParameter || - buffer.instruction()->opcode() == HloOpcode::kConstant; + static bool IgnoreBuffer(const LogicalBuffer& buffer) { + return IgnoreInstruction(*buffer.instruction()); } // An entry in the worklist used by CreateSchedule. Corresponds to one @@ -306,6 +312,11 @@ StatusOr> RunDFSMemoryScheduler( tensorflow::gtl::FlatMap extra_users; tensorflow::gtl::FlatMap total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { + if (ListScheduler::IgnoreInstruction(*hlo)) { + extra_users[hlo] = 0; + total_sizes[hlo] = 0; + continue; + } extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; total_sizes[hlo] = SumLogicalBufferSizes( points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 76177462aa4959261483045296d2388acabe46a5..5a4c93b59a6810b962e3c8f54b2964dffa8ecd6d 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -91,10 +91,11 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction( string node_name; // If an instruction is fused, put it in the subgraph of the fusion; // otherwise, put it in the computation subgraph. - if (instruction->IsFused()) { - node_name = GetNodeNameForInstruction(instruction->fusion_instruction()); + const HloComputation* computation = instruction->parent(); + if (computation->IsFusionComputation()) { + node_name = GetNodeNameForInstruction(computation->FusionInstruction()); } else { - node_name = instruction->parent()->name(); + node_name = computation->name(); if (!instruction->metadata().op_name().empty()) { // Always make computations contain TF ops but not the other way around. StrAppend(&node_name, "/", instruction->metadata().op_name()); diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index f85d8ec50deae670dae632066d06890da126a09b..e6cf0d37b8a0f42dc04cfaad067a4741bc803705 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -159,12 +159,6 @@ void HloValue::AddPosition(HloInstruction* instruction, for (const HloPosition& position : positions_) { DCHECK_NE(position, new_position); } - // The shape of the new position must match existing positions. - if (!positions_.empty()) { - CHECK( - ShapeUtil::Compatible(positions_.front().shape(), new_position.shape())) - << "front: " << positions_.front() << " new: " << new_position; - } positions_.push_back(std::move(new_position)); diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index 63ecc25020b87cb3b650d20099dd5c6fddba9052..6872bc76a82253b916e826aa1afabc3d309c1d12 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -225,6 +225,9 @@ class HloValueSet { // already exist in the set. bool AddValue(const HloValue* value); + // Clear all values from the set. + void Clear() { values_.clear(); } + // Return the unique HLO value in the set. CHECKs if the set does not contain // exactly one value. const HloValue& GetUniqueValue() const { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 01fba49bc567900418f9e4622351373abe7b2e18..d40fceb0765f3808510b8651db15b3b6094aab2a 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -14,17 +14,421 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { +namespace { + +// Visitor which verifies that the output shape is correctly set. Verifies +// against the inferred shape for the instruction. +// TODO(b/26024837): Check output shape for all instruction types. +class ShapeVerifier : public DfsHloVisitor { + public: + explicit ShapeVerifier( + const std::function& shape_size_fn) + : shape_size_fn_(shape_size_fn) {} + + Status HandleElementwiseUnary(HloInstruction* hlo) override { + return CheckUnaryShape(hlo); + } + + Status HandleElementwiseBinary(HloInstruction* hlo) override { + return CheckBinaryShape(hlo); + } + + Status HandleClamp(HloInstruction* clamp, HloInstruction* min, + HloInstruction* arg, HloInstruction* max) override { + return CheckTernaryShape(clamp); + } + + Status HandleSelect(HloInstruction* select, HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) override { + return CheckTernaryShape(select); + } + + Status HandleConcatenate( + HloInstruction* concatenate, + tensorflow::gtl::ArraySlice operands) override { + return tensorflow::Status::OK(); + } + + Status HandleConvert(HloInstruction* convert) override { + return tensorflow::Status::OK(); + } + + Status HandleCopy(HloInstruction* copy) override { + return CheckUnaryShape(copy); + } + + Status HandleDot(HloInstruction* dot, HloInstruction* lhs, + HloInstruction* rhs) override { + return CheckBinaryShape(dot); + } + + Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, + HloInstruction* rhs, const Window& window) override { + return tensorflow::Status::OK(); + } + + Status HandleCrossReplicaSum(HloInstruction* crs) override { + return tensorflow::Status::OK(); + } + + Status HandleReducePrecision(HloInstruction* reduce_precision) override { + return tensorflow::Status::OK(); + } + + Status HandleInfeed(HloInstruction* infeed) override { + return tensorflow::Status::OK(); + } + + Status HandleOutfeed(HloInstruction* outfeed) override { + return tensorflow::Status::OK(); + } + + Status HandleRng(HloInstruction* random, + RandomDistribution distribution) override { + return tensorflow::Status::OK(); + } + + Status HandleReverse(HloInstruction* reverse, + HloInstruction* operand) override { + return tensorflow::Status::OK(); + } + + Status HandleSort(HloInstruction* sort, HloInstruction* operand) override { + return tensorflow::Status::OK(); + } + + Status HandleConstant(HloInstruction* constant, + const Literal& literal) override { + return tensorflow::Status::OK(); + } + + Status HandleGetTupleElement(HloInstruction* get_tuple_element, + HloInstruction* operand) override { + return tensorflow::Status::OK(); + } + + Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, + HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions, + HloComputation* function) override { + return tensorflow::Status::OK(); + } + + Status HandleBitcast(HloInstruction* bitcast) override { + // Bitcasts can be any shape, as long as the size matches the operand size. + TF_RET_CHECK(shape_size_fn_(bitcast->shape()) == + shape_size_fn_(bitcast->operand(0)->shape())); + return tensorflow::Status::OK(); + } + + Status HandleBroadcast(HloInstruction* broadcast) override { + TF_RET_CHECK(ShapeUtil::Rank(broadcast->operand(0)->shape()) == + broadcast->dimensions().size()); + return tensorflow::Status::OK(); + } + + Status HandleReshape(HloInstruction* reshape) override { + return tensorflow::Status::OK(); + } + + Status HandleTranspose(HloInstruction* transpose) override { + return tensorflow::Status::OK(); + } + + Status HandleParameter(HloInstruction* parameter) override { + return tensorflow::Status::OK(); + } + + Status HandleFusion(HloInstruction* fusion) override { + return tensorflow::Status::OK(); + } + + Status HandleCall(HloInstruction* call) override { + return tensorflow::Status::OK(); + } + + Status HandleCustomCall(HloInstruction* custom_call, + tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece custom_call_target) override { + return tensorflow::Status::OK(); + } + + Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override { + return tensorflow::Status::OK(); + } + + Status HandleDynamicSlice(HloInstruction* dynamic_slice, + HloInstruction* operand, + HloInstruction* start_indices) override { + return tensorflow::Status::OK(); + } + + Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, + HloInstruction* operand, + HloInstruction* update, + HloInstruction* start_indices) override { + return tensorflow::Status::OK(); + } + + Status HandleTuple( + HloInstruction* tuple, + tensorflow::gtl::ArraySlice operands) override { + return CheckVariadicShape(tuple); + } + + Status HandleMap( + HloInstruction* map, + tensorflow::gtl::ArraySlice operands, + HloComputation* function, + tensorflow::gtl::ArraySlice static_operands) override { + return tensorflow::Status::OK(); + } + + Status HandleReduceWindow(HloInstruction* reduce_window, + HloInstruction* operand, const Window& window, + HloComputation* function) override { + return tensorflow::Status::OK(); + } + + Status HandleSelectAndScatter(HloInstruction* instruction) override { + return tensorflow::Status::OK(); + } + + Status HandleWhile(HloInstruction* xla_while) override { + return tensorflow::Status::OK(); + } + + Status HandlePad(HloInstruction* pad) override { + return tensorflow::Status::OK(); + } + + Status HandleSend(HloInstruction* send) override { + return tensorflow::Status::OK(); + } + + Status HandleRecv(HloInstruction* recv) override { + return tensorflow::Status::OK(); + } + + Status HandleBatchNormTraining(HloInstruction* batchNormTraining) override { + return tensorflow::Status::OK(); + } + + Status HandleBatchNormInference(HloInstruction* batchNormInference) override { + return tensorflow::Status::OK(); + } + + Status HandleBatchNormGrad(HloInstruction* batchNormGrad) override { + return tensorflow::Status::OK(); + } + + Status FinishVisit(HloInstruction* root) override { + return tensorflow::Status::OK(); + } + + private: + // Check the instruction's shape against the given expected shape and return + // an appropriate error if there is a mismatch. + Status CheckShape(const HloInstruction* instruction, + const Shape& expected_shape) { + if (!ShapeUtil::Compatible(instruction->shape(), expected_shape)) { + return InvalidArgument( + "Expected instruction to have shape compatible with %s, actual " + "shape is %s:\n%s", + ShapeUtil::HumanString(expected_shape).c_str(), + ShapeUtil::HumanString(instruction->shape()).c_str(), + instruction->ToString().c_str()); + } + return tensorflow::Status::OK(); + } + + // Check a unary (binary, etc) instruction's shape against the inferred shape. + Status CheckUnaryShape(const HloInstruction* instruction) { + TF_ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferUnaryOpShape( + instruction->opcode(), instruction->operand(0))); + return CheckShape(instruction, expected); + } + Status CheckBinaryShape(const HloInstruction* instruction) { + TF_ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferBinaryOpShape( + instruction->opcode(), instruction->operand(0), + instruction->operand(1))); + return CheckShape(instruction, expected); + } + Status CheckTernaryShape(const HloInstruction* instruction) { + TF_ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferTernaryOpShape( + instruction->opcode(), instruction->operand(0), + instruction->operand(1), instruction->operand(2))); + return CheckShape(instruction, expected); + } + Status CheckVariadicShape(const HloInstruction* instruction) { + TF_ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferVariadicOpShape( + instruction->opcode(), instruction->operands())); + return CheckShape(instruction, expected); + } + + // Returns the size of a Shape in bytes. + const std::function shape_size_fn_; +}; + +string ComputationsToString( + tensorflow::gtl::ArraySlice computations) { + return tensorflow::str_util::Join( + computations, ",", [](string* s, const HloComputation* computation) { + s->append(computation->name()); + }); +} + +} // namespace + +Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { + // The parent fusion instruction of the fusion computation must be 'fusion'. + HloComputation* fused_computation = fusion->fused_instructions_computation(); + if (fusion != fused_computation->FusionInstruction()) { + return FailedPrecondition( + "Instruction of fused computation does not match expected instruction " + "%s.", + fusion->ToString().c_str()); + } + + // Fused root instruction and fused parameters must all be owned by the fusion + // computation. + bool root_owned = false; + const std::vector& fused_parameters = + fusion->fused_parameters(); + const HloInstruction* fused_root = fusion->fused_expression_root(); + std::vector parameter_owned(fused_parameters.size(), false); + for (auto& instruction : fused_computation->instructions()) { + if (fused_root == instruction.get()) { + if (root_owned) { + return FailedPrecondition("Root appears more than once in %s.", + fusion->ToString().c_str()); + } + root_owned = true; + } + for (int i = 0; i < fused_parameters.size(); ++i) { + if (fused_parameters[i] == instruction.get()) { + if (parameter_owned[i]) { + return FailedPrecondition("Parameter appears more than once in %s.", + fusion->ToString().c_str()); + } + parameter_owned[i] = true; + } + } + } + if (!root_owned) { + return FailedPrecondition("Root not found in computation of %s.", + fusion->ToString().c_str()); + } + // Make sure all the parameter_owned entries are set + for (int i = 0; i < parameter_owned.size(); i++) { + if (!parameter_owned[i]) { + return FailedPrecondition("Parameter %d not found in computation of %s.", + i, fusion->ToString().c_str()); + } + } + + // Fused root must have no users. + if (fused_root->user_count() != 0) { + return FailedPrecondition("Root of %s may not have users.", + fusion->ToString().c_str()); + } + + // All uses of fused instructions must be in the fusion computation, and every + // non-root instruction must have at least one use. + for (auto& instruction : + fusion->fused_instructions_computation()->instructions()) { + if (instruction.get() != fused_root) { + if (instruction->user_count() == 0) { + return FailedPrecondition( + "Non-root instruction %s in %s must have users.", + instruction->ToString().c_str(), fusion->ToString().c_str()); + } + for (auto& user : instruction->users()) { + if (fused_computation != user->parent()) { + return FailedPrecondition( + "Non-root instruction %s in %s may not have external users.", + instruction->ToString().c_str(), fusion->ToString().c_str()); + } + } + } + } + + // Fused parameter instructions must be numbered contiguously and match up + // (shapes compatible) with their respective operand. + CHECK_EQ(fusion->operands().size(), fused_parameters.size()); + std::vector parameter_numbers(fused_parameters.size(), false); + for (auto fused_param : fused_parameters) { + int64 param_no = fused_param->parameter_number(); + if (param_no < 0) { + return FailedPrecondition( + "Unexpected negative parameter number %lld in %s.", param_no, + fusion->ToString().c_str()); + } + if (param_no >= fused_parameters.size()) { + return FailedPrecondition( + "Unexpected parameter number %lld in %s: higher then number of " + "parameters %lu.", + param_no, fusion->ToString().c_str(), fused_parameters.size()); + } + if (parameter_numbers[param_no]) { + return FailedPrecondition( + "Did not expect parameter number %lld more than once in %s.", + param_no, fusion->ToString().c_str()); + } + parameter_numbers[param_no] = true; + if (!ShapeUtil::Compatible(fused_param->shape(), + fusion->operand(param_no)->shape())) { + return FailedPrecondition( + "Shape mismatch between parameter number %lld and its operand in %s.", + param_no, fusion->ToString().c_str()); + } + } + // Make sure all the parameter_numbers entries were seen + for (int i = 0; i < parameter_numbers.size(); i++) { + if (!parameter_numbers[i]) { + return FailedPrecondition("Did not see parameter number %d in %s.", i, + fusion->ToString().c_str()); + } + } + + // TODO(b/65423525): We'd like to check that all operands are distinct. + // This is currently disabled due to the invariant being violated by + // multi-output fusion. + return tensorflow::Status::OK(); +} + StatusOr HloVerifier::Run(HloModule* module) { tensorflow::gtl::FlatMap instructions; + ShapeVerifier shape_verifier(shape_size_fn_); for (auto& computation : module->computations()) { for (const auto& instruction : computation->instructions()) { TF_RET_CHECK(instruction->parent() == computation.get()); if (instruction->opcode() == HloOpcode::kFusion) { + TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction.get())); + TF_RET_CHECK( + ContainersEqual(instruction->called_computations(), + {instruction->fused_instructions_computation()})) + << "Fusion HLO calls computations other than the " + "fused_instructions_computation: " + << instruction->ToString() + << " instruction->fused_instructions_computation(): " + << instruction->fused_instructions_computation()->ToString() + << " instruction->called_computations(): " + << ComputationsToString(instruction->called_computations()); + for (const auto& fused : instruction->fused_instructions()) { TF_RET_CHECK(fused->parent() == instruction->fused_instructions_computation()) @@ -44,6 +448,8 @@ StatusOr HloVerifier::Run(HloModule* module) { << " in computation: " << previous->second->parent()->name(); instructions[instruction->name()] = instruction.get(); } + + TF_RETURN_IF_ERROR(computation->Accept(&shape_verifier)); } return false; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 5159420b3fbea5d3d01950fa379e8ba39437ab85..e35a7f3642ccf91df37f69a3a11bd8c8e428b846 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -24,12 +24,21 @@ namespace xla { // the module. class HloVerifier : public HloPassInterface { public: + explicit HloVerifier(const std::function& shape_size_fn) + : shape_size_fn_(shape_size_fn) {} ~HloVerifier() override = default; tensorflow::StringPiece name() const override { return "verifier"; } // Note: always returns false (no instructions are ever modified by this // pass). StatusOr Run(HloModule* module) override; + + private: + // CHECKs various invariants of a fusion instruction. + Status CheckFusionInstruction(HloInstruction* fusion) const; + + // Returns the size of a Shape in bytes. + const std::function shape_size_fn_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index edfcb0922d6c9e87eb02e5b2a95ec4d355ad1756..265be54116c406de433da5f07bdec724a1f9f580 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -111,19 +111,11 @@ namespace xla { return false; } -namespace { -// Returns true if fusing producer into consumer would cause producer to be -// duplicated. This is the case if producer has uses other than consumer. -bool FusionWouldDuplicate(const HloInstruction& producer, - const HloInstruction& consumer) { - return !(producer.users().size() == 1 && consumer.IsUserOf(&producer)); -} - // An "effectively unary" operation is one that has one "large" // input with the others being negligible in terms of memory usage. // We use "has a smaller true rank than the output" as a heuristic // for "negligible" memory usage. -bool EffectivelyUnary(HloInstruction* hlo) { +bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) { int64 output_rank = 0; ShapeUtil::ForEachSubshape( hlo->shape(), @@ -145,7 +137,6 @@ bool EffectivelyUnary(HloInstruction* hlo) { output_rank; }) <= 1; } -} // namespace bool InstructionFusion::CanFuseOnAllPaths( const HloReachabilityMap& reachability_map, HloInstruction* producer, @@ -212,7 +203,7 @@ bool InstructionFusion::CanFuseOnAllPaths( StatusOr InstructionFusion::Run(HloModule* module) { bool changed = false; - + module_ = module; std::vector computations; for (auto& computation : module->computations()) { if (computation->IsFusionComputation()) { @@ -243,7 +234,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { DoNotFuseSet do_not_fuse; auto reachability = computation->ComputeReachability(); - auto cheap_to_duplicate = [](HloInstruction* producer) { + auto cheap_to_duplicate = [this](HloInstruction* producer) { if (producer->opcode() == HloOpcode::kBroadcast) { return true; } @@ -395,7 +386,6 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, VLOG(2) << "Fusing " << producer->ToString() << " into " << consumer->ToString(); - auto kind = ChooseKind(producer, consumer); if (consumer->opcode() == HloOpcode::kFusion) { fusion_instruction = consumer; @@ -407,8 +397,8 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, HloInstruction::CreateFusion(consumer->shape(), kind, consumer)); TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction)); } - fusion_instruction->FuseInstruction(producer); + fusion_instruction->FuseInstruction(producer); return fusion_instruction; } @@ -423,13 +413,15 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer, if (consumer->opcode() == HloOpcode::kFusion && consumer->fusion_kind() != HloInstruction::FusionKind::kLoop && - consumer->fusion_kind() != HloInstruction::FusionKind::kInput) { + consumer->fusion_kind() != HloInstruction::FusionKind::kInput && + consumer->fusion_kind() != HloInstruction::FusionKind::kOutput) { return false; } - // Cost condition: not fuse (expensive producers) and (consumers who reuse - // operand elements). - if (consumer->ReusesOperandElements(operand_index) && + // Cost condition: not fuse (simple, expensive producers) and (consumers who + // reuse operand elements). + if (producer->opcode() != HloOpcode::kFusion && + consumer->ReusesOperandElements(operand_index) && is_expensive_(*producer)) { return false; } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index f6f37bb79b9fe1480db61b10b9810347960f9a72..0eb8d03489d29b701f09c2912bc906778f02a99b 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -66,12 +66,28 @@ class InstructionFusion : public HloPassInterface { virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer, const HloInstruction* consumer); + // Fuses producer into consumer. + virtual HloInstruction* Fuse(HloInstruction* producer, + HloInstruction* consumer); + + // An "effectively unary" operation is one that has one "large" + // input with the others being negligible in terms of memory usage. + // We use "has a smaller true rank than the output" as a heuristic + // for "negligible" memory usage. + bool EffectivelyUnary(HloInstruction* hlo); + + // Returns true if fusing producer into consumer would cause producer to be + // duplicated. This is the case if producer has uses other than consumer. + bool FusionWouldDuplicate(const HloInstruction& producer, + const HloInstruction& consumer) { + return !(producer.users().size() == 1 && consumer.IsUserOf(&producer)); + } + // Current HloComputation instance the loop fuser is traversing. HloComputation* computation_; + HloModule* module_; private: - HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); - // The set of producers whose consumers we cannot fuse into. using DoNotFuseSet = std::unordered_set; diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc index 353485c5a5cd214f2e1a6cff64f60d68ddb82cfe..8275531111ce10e05d81a77c739757a649f97a1c 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -21,52 +21,35 @@ limitations under the License. namespace xla { -// For now, ReducePrecision is only implemented for F32 arrays, so this -// ignores instructions that produce other data. In particular, this -// currently ignores instructions producing tuples, even if those tuples -// contain F32 arrays inside them. The assumption is that in most cases -// equivalent behavior can be obtained by adding ReducePrecision -// instructions after the instructions that pull the F32 arrays out of -// the tuples. -// -// TODO(b/64093391): Remove the IsScalar check once this won't cause -// failures on the GPU backend if the ReducePrecision instruction ends up -// inserted between a scalar constant and the init_value argument of a -// Reduce operation. -std::vector ReducePrecisionInsertion::instructions_to_suffix( +std::vector ReducePrecisionInsertion::instructions_to_modify( const HloComputation* computation) { - std::vector instructions_to_suffix; + std::vector instruction_list; - switch (pass_timing_) { - case HloReducePrecisionOptions::BEFORE_OP_FUSION: - case HloReducePrecisionOptions::AFTER_OP_FUSION: + switch (location_) { + case HloReducePrecisionOptions::OP_INPUTS: + case HloReducePrecisionOptions::OP_OUTPUTS: + case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS: for (auto& instruction : computation->instructions()) { VLOG(4) << "Visited instruction: " << instruction->ToString(); - - if (instruction->shape().element_type() == PrimitiveType::F32 && - !ShapeUtil::IsScalar(instruction->shape()) && - instruction_filter_function_(instruction.get())) { - instructions_to_suffix.push_back(instruction.get()); + if (instruction_filter_function_(instruction.get())) { + instruction_list.push_back(instruction.get()); } } break; - case HloReducePrecisionOptions::FUSION_BY_CONTENT: + case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT: + case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT: for (auto& instruction : computation->instructions()) { VLOG(4) << "Visited instruction: " << instruction->ToString(); - - if (instruction->opcode() != HloOpcode::kFusion || - instruction->shape().element_type() != PrimitiveType::F32 || - ShapeUtil::IsScalar(instruction->shape())) { + if (instruction->opcode() != HloOpcode::kFusion) { continue; } - for (auto& fused_instruction : instruction->fused_instructions_computation()->instructions()) { VLOG(4) << "Checking sub-instruction: " << fused_instruction->ToString(); if (instruction_filter_function_(fused_instruction.get())) { - instructions_to_suffix.push_back(instruction.get()); + instruction_list.push_back(instruction.get()); break; } } @@ -76,70 +59,170 @@ std::vector ReducePrecisionInsertion::instructions_to_suffix( default: break; } - VLOG(1) << "Adding " << instructions_to_suffix.size() - << " reduce-precision operations."; + VLOG(1) << "Found " << instruction_list.size() + << " candidate instruction(s) for reduce-precision insertion"; - return instructions_to_suffix; + return instruction_list; } -StatusOr ReducePrecisionInsertion::Run(HloModule* module) { - bool changed = false; - VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name(); +StatusOr ReducePrecisionInsertion::insert_after( + HloInstruction* instruction) { + // Check that this isn't already an equivalent operation. + if (is_redundant(instruction)) { + VLOG(2) << "Skipped: instruction is already an equivalent" + " reduce-precision instruction:" + << instruction->ToString(); + return false; + } - for (auto& computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; + // Check that we haven't already inserted an equivalant reduce-precision + // operation after this instruction. (The zero-user case occurs when this is + // the root instruction.) + if (instruction->user_count() > 0) { + bool redundant_followers = true; + for (HloInstruction* user : instruction->users()) { + if (!is_redundant(user)) { + redundant_followers = false; + break; + } + } + if (redundant_followers) { + VLOG(2) << "Skipped: instruction already followed by equivalent" + " reduce-precision instructions"; + return false; } + } - bool computation_changed = false; - for (auto& instruction : instructions_to_suffix(computation.get())) { - VLOG(2) << "Adding reduce-precision operation to output of instruction: " - << instruction->ToString(); - - // Check that we haven't already inserted an equivalant reduce-precision - // operation after this instruction. - if (instruction->user_count() == 1) { - HloInstruction* user = instruction->users()[0]; - - if (user->opcode() == HloOpcode::kReducePrecision && - user->exponent_bits() == exponent_bits_ && - user->mantissa_bits() == mantissa_bits_) { - VLOG(2) << "Skipped; instruction already followed by equivalent" - " reduce-precision instruction:" - << user->ToString(); - continue; - } - } + HloInstruction* reduced = instruction->parent()->AddInstruction( + HloInstruction::CreateReducePrecision(instruction->shape(), instruction, + exponent_bits_, mantissa_bits_)); + TF_RETURN_IF_ERROR( + instruction->parent()->ReplaceUsesOfInstruction(instruction, reduced)); + return true; +} + +StatusOr ReducePrecisionInsertion::insert_on_inputs( + const std::vector& instructions) { + bool computation_changed = false; + for (auto instruction : instructions) { + VLOG(2) << "Adding reduce-precision operation to inputs of instruction: " + << instruction->ToString(); + for (int64 i = 0; i < instruction->operand_count(); i++) { + HloInstruction* operand = instruction->mutable_operand(i); + VLOG(2) << "Adding to operand " << i << ": " << operand; - if (instruction->opcode() == HloOpcode::kFusion) { - // Insert the reduce-precision operation as the last operation inside - // the fusion computation. - instruction = instruction->fused_expression_root(); + if (!is_valid_shape(operand->shape())) { + VLOG(2) << "Skipped: value is not an F32 vector"; + continue; + } - VLOG(2) << "Inserting new operation after existing fusion root: " - << instruction->ToString(); + if (is_redundant(operand)) { + VLOG(2) << "Skipped: operand is already an equivalent reduce-precision" + " instruction"; + continue; + } - if (instruction->opcode() == HloOpcode::kReducePrecision && - instruction->exponent_bits() == exponent_bits_ && - instruction->mantissa_bits() == mantissa_bits_) { - VLOG(2) << "Skipped; fused computation already ends in equivalent" - " reduce-precision instruction:" - << instruction->ToString(); - continue; + if (instruction->opcode() == HloOpcode::kFusion && + (instruction->fusion_kind() == HloInstruction::FusionKind::kLoop || + instruction->fusion_kind() == HloInstruction::FusionKind::kInput)) { + // Insert the reduce-precision operation inside the fusion computation, + // after the corresponding parameter instruction. + TF_ASSIGN_OR_RETURN( + bool instruction_changed, + insert_after(instruction->fused_instructions_computation() + ->parameter_instruction(i))); + computation_changed |= instruction_changed; + } else { + // Look for an existing reduce-precision operation on the operand. (We + // need to be careful not to create a loop, though!) + HloInstruction* reduced = nullptr; + for (auto& user : operand->users()) { + if (user != instruction && + user->opcode() == HloOpcode::kReducePrecision && + user->exponent_bits() == exponent_bits_ && + user->mantissa_bits() == mantissa_bits_) { + reduced = user; + break; + } + } + // If there wasn't an existing reduce-precision operation, create one. + if (!reduced) { + reduced = instruction->parent()->AddInstruction( + HloInstruction::CreateReducePrecision( + operand->shape(), operand, exponent_bits_, mantissa_bits_)); } + // Insert the reduce-precision operation before the operand. + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(i, reduced)); + computation_changed = true; } + } + } + + return computation_changed; +} - HloInstruction* reduced = instruction->parent()->AddInstruction( - HloInstruction::CreateReducePrecision(instruction->shape(), - instruction, exponent_bits_, - mantissa_bits_)); +StatusOr ReducePrecisionInsertion::insert_on_outputs( + const std::vector& instructions) { + bool computation_changed = false; + for (const auto& instruction : instructions) { + VLOG(2) << "Adding reduce-precision operation to output of instruction: " + << instruction->ToString(); - TF_RETURN_IF_ERROR(instruction->parent()->ReplaceUsesOfInstruction( - instruction, reduced)); - computation_changed = true; + if (!is_valid_shape(instruction->shape())) { + VLOG(2) << "Skipped: value is not an F32 nonscalar array"; + continue; } - if (computation_changed) { + if (instruction->opcode() == HloOpcode::kFusion && + (instruction->fusion_kind() == HloInstruction::FusionKind::kLoop || + instruction->fusion_kind() == HloInstruction::FusionKind::kOutput)) { + // Insert the reduce-precision operation as the last operation inside + // the fusion computation. + HloInstruction* fusion_root = instruction->fused_expression_root(); + VLOG(2) << "Inserting new operation after existing fusion root: " + << fusion_root->ToString(); + + TF_ASSIGN_OR_RETURN(bool instruction_changed, insert_after(fusion_root)); + computation_changed |= instruction_changed; + } else { + // Insert the reduce-precision operation after the instruction. + TF_ASSIGN_OR_RETURN(bool instruction_changed, insert_after(instruction)); + computation_changed |= instruction_changed; + } + } + + return computation_changed; +} + +StatusOr ReducePrecisionInsertion::Run(HloModule* module) { + bool changed = false; + VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name(); + + for (auto& computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } + + StatusOr computation_changed; + switch (location_) { + case HloReducePrecisionOptions::OP_INPUTS: + case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT: + computation_changed = ReducePrecisionInsertion::insert_on_inputs( + instructions_to_modify(computation.get())); + break; + + case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT: + case HloReducePrecisionOptions::OP_OUTPUTS: + case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS: + computation_changed = ReducePrecisionInsertion::insert_on_outputs( + instructions_to_modify(computation.get())); + break; + default: + break; + } + TF_RETURN_IF_ERROR(computation_changed.status()); + + if (computation_changed.ValueOrDie()) { changed = true; VLOG(3) << "Computation after reduce-precision insertion:"; XLA_VLOG_LINES(3, computation->ToString()); @@ -186,12 +269,12 @@ ReducePrecisionInsertion::make_filter_function( } HloReducePrecisionOptions ReducePrecisionInsertion::make_options_proto( - const HloReducePrecisionOptions::PassTiming pass_timing, - const int exponent_bits, const int mantissa_bits, + const HloReducePrecisionOptions::Location location, const int exponent_bits, + const int mantissa_bits, const std::function& opcode_filter_function, const std::vector& opname_substring_list) { HloReducePrecisionOptions options; - options.set_pass_timing(pass_timing); + options.set_location(location); options.set_exponent_bits(exponent_bits); options.set_mantissa_bits(mantissa_bits); for (uint32_t opcode = 0; opcode < HloOpcodeCount(); opcode++) { @@ -205,13 +288,27 @@ HloReducePrecisionOptions ReducePrecisionInsertion::make_options_proto( return options; } -bool ReducePrecisionInsertion::AddPasses( - HloPassPipeline* pipeline, const DebugOptions& debug_options, - const HloReducePrecisionOptions::PassTiming pass_timing) { +bool ReducePrecisionInsertion::AddPasses(HloPassPipeline* pipeline, + const DebugOptions& debug_options, + const PassTiming pass_timing) { bool passes_added = false; for (const auto& pass_options : debug_options.hlo_reduce_precision_options()) { - if (pass_options.pass_timing() == pass_timing) { + bool add_pass; + switch (pass_options.location()) { + case HloReducePrecisionOptions::OP_INPUTS: + case HloReducePrecisionOptions::OP_OUTPUTS: + add_pass = pass_timing == PassTiming::BEFORE_OPTIMIZATION; + break; + case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS: + case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT: + case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT: + add_pass = pass_timing == PassTiming::AFTER_FUSION; + break; + default: + add_pass = false; + } + if (add_pass) { pipeline->AddPass(pass_options); passes_added = true; } diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index 6c5b8cdc0fd3cb45a26813a2de39417dcdae4571..afde3cf95c721b59a39b74b4e1ff3f15a335fe97 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -39,11 +39,11 @@ class ReducePrecisionInsertion : public HloPassInterface { // function returns true and the output type is F32. explicit ReducePrecisionInsertion( const int exponent_bits, const int mantissa_bits, - const HloReducePrecisionOptions::PassTiming pass_timing, + const HloReducePrecisionOptions::Location location, const InstructionFilterFunction& instruction_filter_function) : exponent_bits_(exponent_bits), mantissa_bits_(mantissa_bits), - pass_timing_(pass_timing), + location_(location), instruction_filter_function_(instruction_filter_function) {} // Version of the constructor that takes an HloReducePrecisionOptions proto @@ -53,7 +53,7 @@ class ReducePrecisionInsertion : public HloPassInterface { const HloReducePrecisionOptions& reduce_precision_options) : exponent_bits_(reduce_precision_options.exponent_bits()), mantissa_bits_(reduce_precision_options.mantissa_bits()), - pass_timing_(reduce_precision_options.pass_timing()), + location_(reduce_precision_options.location()), instruction_filter_function_( make_filter_function(reduce_precision_options)) {} @@ -72,33 +72,79 @@ class ReducePrecisionInsertion : public HloPassInterface { static InstructionFilterFunction make_filter_function( const HloReducePrecisionOptions& reduce_precision_options); static HloReducePrecisionOptions make_options_proto( - const HloReducePrecisionOptions::PassTiming pass_timing, + const HloReducePrecisionOptions::Location location, const int exponent_bits, const int mantissa_bits, const std::function& opcode_filter_function, const std::vector& opname_substring_list = {}); + // Enumeration to control which passes should be added. + enum class PassTiming { BEFORE_OPTIMIZATION, AFTER_FUSION }; + // Add ReducePrecisionInsertion passes to an HloPassPipeline based on the list // of HloReducePrecisionOptions in a DebugOptions proto. Returns true if any // passes were added. - static bool AddPasses( - HloPassPipeline* pipeline, const DebugOptions& debug_options, - const HloReducePrecisionOptions::PassTiming pass_timing); + static bool AddPasses(HloPassPipeline* pipeline, + const DebugOptions& debug_options, + const PassTiming pass_timing); private: - // Select the instructions that should be suffixed with reduce-precision - // operators. - std::vector instructions_to_suffix( + // Select the instructions that should have reduce-precision operations + // attached to them. + std::vector instructions_to_modify( const HloComputation* computation); + // Insert a reduce-precision operation into the graph on the output of the + // given instruction. + StatusOr insert_after(HloInstruction* instruction); + + // Insert reduce-precision operations into the graph on the inputs of the + // given instructions. (For fusion instructions, the operations will be + // inserted inside the fusion computation, on the outputs of the relevant + // input parameters.) + StatusOr insert_on_inputs( + const std::vector& instructions); + + // Insert reduce-precision operations into the graph on the outputs of the + // given instructions. (For fusion instructions, the operations will be + // inserted inside the fusion computation as a new root.) + StatusOr insert_on_outputs( + const std::vector& instructions); + + // Is this shape valid for inserting a reduce-precision operation? + bool is_valid_shape(const Shape& shape) { + // For now, ReducePrecision is only implemented for F32 arrays, so this + // ignores instructions that produce other data. In particular, this + // currently ignores instructions producing tuples, even if those tuples + // contain F32 arrays inside them. The assumption is that in most cases + // equivalent behavior can be obtained by adding ReducePrecision + // instructions after the instructions that pull the F32 arrays out of + // the tuples. + // + // TODO(b/64093391): Remove the IsScalar check once this won't cause + // failures on the GPU backend if the ReducePrecision instruction ends up + // inserted between a scalar constant and the init_value argument of a + // Reduce operation. + return shape.element_type() == PrimitiveType::F32 && + !ShapeUtil::IsScalar(shape); + } + + // Is this instruction one such that following or preceding it with a new + // reduce-precision operation will be redundant? + bool is_redundant(const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kReducePrecision && + instruction->exponent_bits() <= exponent_bits_ && + instruction->mantissa_bits() <= mantissa_bits_; + } + // Parameters for the precision reduction to be added. const int exponent_bits_; const int mantissa_bits_; // Pass "timing" parameter. This also controls aspects of how the pass // selects locations to insert instructions. - const HloReducePrecisionOptions::PassTiming pass_timing_; + const HloReducePrecisionOptions::Location location_; - // Function to determine (from the opcode) whether a given instruction should + // User-provided Function to determine whether a given instruction should // have a reduce-precision instruction inserted in its output stream. const InstructionFilterFunction instruction_filter_function_; }; diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc index d4640bec0e3732f8e44472ea4e3340bf897d44ca..064020896e7058a323a4dfd70996abd8a821adf4 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc @@ -35,16 +35,16 @@ using ::testing::UnorderedElementsAre; class ReducePrecisionInsertionTest : public HloTestBase { protected: bool InsertOps(HloModule* module, - const HloReducePrecisionOptions::PassTiming pass_timing, + const HloReducePrecisionOptions::Location location, const std::function& filter) { - ReducePrecisionInsertion op_insertion(5, 10, pass_timing, filter); + ReducePrecisionInsertion op_insertion(5, 10, location, filter); StatusOr result = op_insertion.Run(module); EXPECT_IS_OK(result.status()); return result.ValueOrDie(); } }; -TEST_F(ReducePrecisionInsertionTest, RootInstruction) { +TEST_F(ReducePrecisionInsertionTest, BeforeUnaryInstruction) { auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -59,19 +59,141 @@ TEST_F(ReducePrecisionInsertionTest, RootInstruction) { // Confirm expected state before adding ops. EXPECT_EQ(computation->root_instruction(), b); + EXPECT_EQ(b->operand(0), a); - EXPECT_TRUE(InsertOps(module.get(), - HloReducePrecisionOptions::BEFORE_OP_FUSION, + EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS, [](const HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kCos; })); // Confirm expected graph after adding ops. - EXPECT_THAT(computation->root_instruction(), op::ReducePrecision()); - EXPECT_EQ(computation->root_instruction()->operand(0), b); + EXPECT_EQ(computation->root_instruction(), b); + EXPECT_THAT(b->operand(0), op::ReducePrecision(a)); } -TEST_F(ReducePrecisionInsertionTest, NonRootInstruction) { +TEST_F(ReducePrecisionInsertionTest, BeforeBinaryInstruction) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + + // Create a simple graph with parameter feeding a binary add function. + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected state before adding ops. + EXPECT_EQ(computation->root_instruction(), c); + EXPECT_EQ(c->operand(0), a); + EXPECT_EQ(c->operand(1), b); + + EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS, + [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kAdd; + })); + + // Confirm expected graph after adding ops. + EXPECT_EQ(computation->root_instruction(), c); + EXPECT_THAT(c->operand(0), op::ReducePrecision(a)); + EXPECT_THAT(c->operand(1), op::ReducePrecision(b)); +} + +TEST_F(ReducePrecisionInsertionTest, BeforeZeroInputInstruction) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + + // Create a simple graph with a parameter feeding a unary cosine function. + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected state before adding ops. + EXPECT_EQ(computation->root_instruction(), b); + EXPECT_EQ(b->operand(0), a); + + EXPECT_FALSE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS, + [](const HloInstruction* instruction) { + return instruction->opcode() == + HloOpcode::kParameter; + })); + + // Confirm that graph has not changed. + EXPECT_EQ(computation->root_instruction(), b); + EXPECT_EQ(b->operand(0), a); +} + +TEST_F(ReducePrecisionInsertionTest, AvoidAddingDuplicateInstructions) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + + // Create a simple graph with parameter feeding a binary add function. + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kSin, a)); + HloInstruction* d = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, c)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected state before adding ops. + EXPECT_EQ(computation->root_instruction(), d); + EXPECT_EQ(b->operand(0), a); + EXPECT_EQ(c->operand(0), a); + + EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_INPUTS, + [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCos || + instruction->opcode() == HloOpcode::kSin; + })); + + // Confirm expected graph after adding ops. In particular, we want to confirm + // that the reduced-precision operation added for the input to b is re-used + // for the input to c. + EXPECT_THAT(b->operand(0), op::ReducePrecision(a)); + EXPECT_THAT(c->operand(0), op::ReducePrecision(a)); + EXPECT_EQ(b->operand(0), c->operand(0)); +} + +TEST_F(ReducePrecisionInsertionTest, AfterRootInstruction) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + + // Create a simple graph with a parameter feeding a unary cosine function. + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, a)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Confirm expected state before adding ops. + EXPECT_EQ(computation->root_instruction(), b); + + EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS, + [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCos; + })); + + // Confirm expected graph after adding ops. + EXPECT_THAT(computation->root_instruction(), op::ReducePrecision(b)); +} + +TEST_F(ReducePrecisionInsertionTest, AfterNonRootInstruction) { auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -100,8 +222,7 @@ TEST_F(ReducePrecisionInsertionTest, NonRootInstruction) { EXPECT_EQ(c->operand(0), a_cos); EXPECT_EQ(c->operand(1), b_cos); - EXPECT_TRUE(InsertOps(module.get(), - HloReducePrecisionOptions::BEFORE_OP_FUSION, + EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS, [](const HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kCos; })); @@ -131,7 +252,7 @@ TEST_F(ReducePrecisionInsertionTest, OutputIsNotFloat) { // Since none of the instructions produce F32 data, this should not change // the graph. EXPECT_FALSE( - InsertOps(module.get(), HloReducePrecisionOptions::BEFORE_OP_FUSION, + InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS, [](const HloInstruction* instruction) { return true; })); // Confirm that graph has not changed. @@ -157,7 +278,7 @@ TEST_F(ReducePrecisionInsertionTest, ShouldReduceOutputPrecisionIsFalse) { // Since none of the instructions match the should_reduce_output_precision // function, this should not change the graph. EXPECT_FALSE( - InsertOps(module.get(), HloReducePrecisionOptions::BEFORE_OP_FUSION, + InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS, [](const HloInstruction* instruction) { return false; })); // Confirm that graph has not changed. @@ -181,18 +302,18 @@ TEST_F(ReducePrecisionInsertionTest, InsertionIsNotRecursive) { // This should insert a new ReducePrecision after the existing one, but // should not then recurse by adding another after the just-inserted one. - EXPECT_TRUE( - InsertOps(module.get(), HloReducePrecisionOptions::BEFORE_OP_FUSION, - [](const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kReducePrecision; - })); + EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS, + [](const HloInstruction* instruction) { + return instruction->opcode() == + HloOpcode::kReducePrecision; + })); // Confirm expected graph after adding ops. EXPECT_THAT(computation->root_instruction(), op::ReducePrecision()); EXPECT_EQ(computation->root_instruction()->operand(0), b); } -TEST_F(ReducePrecisionInsertionTest, SkipRedundantReducePrecision) { +TEST_F(ReducePrecisionInsertionTest, SkipRedundantReducePrecisionAfter) { auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); HloInstruction* x = @@ -209,11 +330,11 @@ TEST_F(ReducePrecisionInsertionTest, SkipRedundantReducePrecision) { // Since the new reduce-precision operation would be redundant, this // should not change the graph. - EXPECT_FALSE( - InsertOps(module.get(), HloReducePrecisionOptions::BEFORE_OP_FUSION, - [](const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kParameter; - })); + EXPECT_FALSE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS, + [](const HloInstruction* instruction) { + return instruction->opcode() == + HloOpcode::kParameter; + })); // Confirm that graph has not changed. EXPECT_THAT(x->users(), UnorderedElementsAre(y)); @@ -237,8 +358,7 @@ TEST_F(ReducePrecisionInsertionTest, AddNonRedundantReducePrecision) { // Since the new reduce-precision operation is not the same as the existing // one, this should add a new one. - EXPECT_TRUE(InsertOps(module.get(), - HloReducePrecisionOptions::BEFORE_OP_FUSION, + EXPECT_TRUE(InsertOps(module.get(), HloReducePrecisionOptions::OP_OUTPUTS, [](const HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kParameter; })); @@ -273,7 +393,7 @@ TEST_F(ReducePrecisionInsertionTest, IgnoreOpsInsideFusionNode) { // The ReducePrecisionInsertion pass should not see inside the fusion // operation, so this should not change the graph. EXPECT_FALSE(InsertOps(module.get(), - HloReducePrecisionOptions::AFTER_OP_FUSION, + HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS, [](const HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kCos; })); @@ -284,6 +404,53 @@ TEST_F(ReducePrecisionInsertionTest, IgnoreOpsInsideFusionNode) { EXPECT_EQ(z->fused_expression_root(), y_fused); } +TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInHeadOfFusionNode) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + HloInstruction* y = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCos, x)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + // Manually fuse the kCos operation into a fusion operation. + HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion( + shape, HloInstruction::FusionKind::kLoop, y)); + EXPECT_IS_OK(computation->ReplaceUsesOfInstruction(y, z)); + EXPECT_IS_OK(computation->RemoveInstruction(y)); + + // Confirm expected graph before adding reduce-precision ops. + EXPECT_THAT(x->users(), UnorderedElementsAre(z)); + EXPECT_EQ(computation->root_instruction(), z); + HloInstruction* y_fused = z->fused_expression_root(); + EXPECT_EQ(y_fused->opcode(), HloOpcode::kCos); + + // This should see that the fusion computation contains a kCos operation, + // and insert a new reduce-precision node at its input. + EXPECT_TRUE(InsertOps(module.get(), + HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT, + [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCos; + })); + + // This should refuse to insert a second reduce-precision operation, as + // it would be redundant with the first. + EXPECT_FALSE(InsertOps(module.get(), + HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT, + [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCos; + })); + + // Confirm that the top-level computation still only contains the fusion + // instruction, but that the fused computation now has a reduce-precision + // instruction inserted after its parameter instruction. + EXPECT_THAT(x->users(), UnorderedElementsAre(z)); + EXPECT_EQ(computation->root_instruction(), z); + EXPECT_THAT(z->fused_expression_root(), y_fused); + EXPECT_THAT(y_fused->operand(0), op::ReducePrecision(op::Parameter())); +} + TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) { auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -309,7 +476,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) { // This should see that the fusion computation contains a kCos operation, // and insert a new reduce-precision node at its root. EXPECT_TRUE(InsertOps(module.get(), - HloReducePrecisionOptions::FUSION_BY_CONTENT, + HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT, [](const HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kCos; })); @@ -317,7 +484,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) { // This should refuse to insert a second reduce-precision operation, as // it would be redundant with the first. EXPECT_FALSE(InsertOps(module.get(), - HloReducePrecisionOptions::FUSION_BY_CONTENT, + HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT, [](const HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kCos; })); @@ -341,7 +508,7 @@ TEST_F(ReducePrecisionInsertionTest, MakeFilterFunctionNoSubstrings) { HloInstruction::CreateUnary(shape, HloOpcode::kSin, a)); auto options_proto = ReducePrecisionInsertion::make_options_proto( - HloReducePrecisionOptions::BEFORE_OP_FUSION, 5, 10, + HloReducePrecisionOptions::OP_OUTPUTS, 5, 10, [](const HloOpcode opcode) { return opcode == HloOpcode::kCos; }); auto filter_function = @@ -370,7 +537,7 @@ TEST_F(ReducePrecisionInsertionTest, MakeFilterFunctionWithSubstrings) { c->set_metadata(c_metadata); auto options_proto = ReducePrecisionInsertion::make_options_proto( - HloReducePrecisionOptions::BEFORE_OP_FUSION, 5, 10, + HloReducePrecisionOptions::OP_OUTPUTS, 5, 10, [](const HloOpcode opcode) { return opcode == HloOpcode::kCos; }, {"foo", "baz"}); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 8eeb1cd5d20d322780c3f833e828e4441e4ec1e3..1a24c6c4939af6119e17c400eaad4539e1e5cb1a 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -39,6 +39,112 @@ namespace xla { namespace { +// Return the UnaryOperation proto enum value associated with the given HLO +// opcode. +UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kAbs: + return UNOP_ABS; + case HloOpcode::kCeil: + return UNOP_CEIL; + case HloOpcode::kCos: + return UNOP_COS; + case HloOpcode::kExp: + return UNOP_EXP; + case HloOpcode::kFloor: + return UNOP_FLOOR; + case HloOpcode::kIsFinite: + return UNOP_IS_FINITE; + case HloOpcode::kLog: + return UNOP_LOG; + case HloOpcode::kLogicalNot: + return UNOP_LOGICAL_NOT; + case HloOpcode::kNegate: + return UNOP_NEGATE; + case HloOpcode::kSign: + return UNOP_SIGN; + case HloOpcode::kSin: + return UNOP_SIN; + case HloOpcode::kSort: + return UNOP_SORT; + case HloOpcode::kTanh: + return UNOP_TANH; + default: + LOG(FATAL) << "unhandled opcode " << opcode; + } +} + +// Return the BinaryOperation proto enum value associated with the given HLO +// opcode. +BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kDot: + return BINOP_DOT; + case HloOpcode::kMultiply: + return BINOP_MUL; + case HloOpcode::kAdd: + return BINOP_ADD; + case HloOpcode::kSubtract: + return BINOP_SUB; + case HloOpcode::kIndex: + return BINOP_INDEX; + case HloOpcode::kDivide: + return BINOP_DIV; + case HloOpcode::kEq: + return BINOP_EQ; + case HloOpcode::kGe: + return BINOP_GE; + case HloOpcode::kGt: + return BINOP_GT; + case HloOpcode::kLe: + return BINOP_LE; + case HloOpcode::kLt: + return BINOP_LT; + case HloOpcode::kNe: + return BINOP_NE; + case HloOpcode::kMaximum: + return BINOP_MAX; + case HloOpcode::kMinimum: + return BINOP_MIN; + case HloOpcode::kPower: + return BINOP_POW; + case HloOpcode::kRemainder: + return BINOP_REM; + case HloOpcode::kLogicalOr: + return BINOP_LOGICAL_OR; + case HloOpcode::kLogicalAnd: + return BINOP_LOGICAL_AND; + default: + LOG(FATAL) << "unhandled opcode " << opcode; + } +} + +// Return the TernaryOperation proto enum value associated with the given HLO +// opcode. +TernaryOperation OpcodeToTernaryOperation(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kClamp: + return TRIOP_CLAMP; + case HloOpcode::kSelect: + return TRIOP_SELECT; + case HloOpcode::kUpdate: + return TRIOP_UPDATE; + default: + LOG(FATAL) << "unhandled opcode " << opcode; + } +} + +// Return the VariadicOperation proto enum value associated with the given HLO +// opcode. +VariadicOperation OpcodeToVariadicOperation(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kTuple: + return VAROP_TUPLE; + default: + LOG(FATAL) << "unhandled opcode " << opcode; + } +} + // Returns true if no element is present in slice more than once. bool AllUnique(tensorflow::gtl::ArraySlice slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); @@ -176,11 +282,21 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } // namespace +/* static */ StatusOr ShapeInference::InferUnaryOpShape( + HloOpcode opcode, const HloInstruction* operand) { + // There is no copy operation at the proto level, so handle copy explicitly. + if (opcode == HloOpcode::kCopy) { + return operand->shape(); + } + + return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), operand->shape()); +} + /* static */ StatusOr ShapeInference::InferUnaryOpShape( UnaryOperation operation, const Shape& arg) { TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation")); - TF_DCHECK_OK(ShapeUtil::ValidateShape(arg)); + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(arg)); switch (operation) { case UNOP_FLOOR: case UNOP_CEIL: @@ -410,7 +526,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } Shape result = ShapeUtil::MakeShape(lhs.element_type(), dimensions); - TF_DCHECK_OK(ShapeUtil::ValidateShape(result)); + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result)); VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result); return result; } @@ -592,6 +708,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } } +/* static */ StatusOr ShapeInference::InferBinaryOpShape( + HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) { + return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs->shape(), + rhs->shape(), /*broadcast_dimensions=*/{}); +} + /* static */ StatusOr ShapeInference::InferBinaryOpShape( BinaryOperation operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { @@ -600,8 +722,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( BinaryOperation_Name(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), tensorflow::str_util::Join(broadcast_dimensions, ", ").c_str()); - TF_DCHECK_OK(ShapeUtil::ValidateShape(lhs)); - TF_DCHECK_OK(ShapeUtil::ValidateShape(rhs)); + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of binary operation")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of binary operation")); @@ -660,12 +782,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } } +/* static */ StatusOr ShapeInference::InferTernaryOpShape( + HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs, + const HloInstruction* ehs) { + return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs->shape(), + rhs->shape(), ehs->shape()); +} + /* static */ StatusOr ShapeInference::InferTernaryOpShape( TernaryOperation operation, const Shape& lhs, const Shape& rhs, const Shape& ehs) { - TF_DCHECK_OK(ShapeUtil::ValidateShape(lhs)); - TF_DCHECK_OK(ShapeUtil::ValidateShape(rhs)); - TF_DCHECK_OK(ShapeUtil::ValidateShape(ehs)); + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs)); switch (operation) { case TRIOP_CLAMP: return InferClampShape(lhs, rhs, ehs); @@ -686,9 +815,21 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } /* static */ StatusOr ShapeInference::InferVariadicOpShape( - VariadicOperation operation, std::vector operand_shapes) { + HloOpcode opcode, + tensorflow::gtl::ArraySlice operands) { + std::vector operand_shapes; + for (const HloInstruction* operand : operands) { + operand_shapes.push_back(&operand->shape()); + } + return InferVariadicOpShape(OpcodeToVariadicOperation(opcode), + operand_shapes); +} + +/* static */ StatusOr ShapeInference::InferVariadicOpShape( + VariadicOperation operation, + tensorflow::gtl::ArraySlice operand_shapes) { for (const Shape* shape : operand_shapes) { - TF_DCHECK_OK(ShapeUtil::ValidateShape(*shape)); + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape)); } switch (operation) { case VAROP_TUPLE: { @@ -792,11 +933,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( scale_shape, "scale input of batch norm training")); - TF_RET_CHECK(ShapeUtil::ValidateShape(operand_shape) == + TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == tensorflow::Status::OK()); - TF_RET_CHECK(ShapeUtil::ValidateShape(offset_shape) == + TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == tensorflow::Status::OK()); - TF_RET_CHECK(ShapeUtil::ValidateShape(scale_shape) == + TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == tensorflow::Status::OK()); if (feature_index >= ShapeUtil::Rank(operand_shape)) { @@ -896,15 +1037,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( scale_shape, "scale input of batch norm inference")); - TF_RET_CHECK(ShapeUtil::ValidateShape(operand_shape) == + TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == tensorflow::Status::OK()); - TF_RET_CHECK(ShapeUtil::ValidateShape(offset_shape) == + TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == tensorflow::Status::OK()); - TF_RET_CHECK(ShapeUtil::ValidateShape(scale_shape) == + TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == tensorflow::Status::OK()); - TF_RET_CHECK(ShapeUtil::ValidateShape(mean_shape) == + TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape) == tensorflow::Status::OK()); - TF_RET_CHECK(ShapeUtil::ValidateShape(variance_shape) == + TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) == tensorflow::Status::OK()); if (feature_index >= ShapeUtil::Rank(operand_shape)) { @@ -1044,11 +1185,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( output_grad_shape, "output_grad input of batch norm grad")); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(operand_shape)); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(mean_shape)); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(scale_shape)); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(var_shape)); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(output_grad_shape)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(var_shape)); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(output_grad_shape)); if (feature_index >= ShapeUtil::Rank(operand_shape)) { return InvalidArgument( @@ -1230,8 +1372,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "lhs: %s", num_dims, ShapeUtil::HumanString(lhs).c_str()); } - TF_DCHECK_OK(ShapeUtil::ValidateShape(lhs)); - TF_DCHECK_OK(ShapeUtil::ValidateShape(rhs)); + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); // Verifies that the input and window dimensions are a permutation of // the dimension numbers. diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 5d55df92a91cb725f33786152ea9ceb629e09230..96e3b46c7dece6c945ae9b2a2a0a4eac8a0eb350 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -21,6 +21,8 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -31,32 +33,48 @@ limitations under the License. namespace xla { // For a given operation and input shapes, infers what the resulting shape is -// for the operation. With this functionality, the user does not need to -// specify the expected result type for computations that are built up via the -// API -- the shape that results from an operation is inferred. +// for the operation. With this functionality, the user does not need to specify +// the expected result type for computations that are built up via the API -- +// the shape that results from an operation is inferred. Some methods have +// overloads for inferring shape at the HLO level. +// TODO(b/166374537): Complete HLO level inference overloads and use to +// automatically infer shape in HloInstruction::Create* methods. class ShapeInference { public: // Infers the shape produced by applying the given unary operation to the // given input shape. static StatusOr InferUnaryOpShape(UnaryOperation operation, const Shape& arg); + static StatusOr InferUnaryOpShape(HloOpcode opcode, + const HloInstruction* operand); // Infers the shape produced by applying the given binary operation to the // given input shapes. static StatusOr InferBinaryOpShape( BinaryOperation operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); + static StatusOr InferBinaryOpShape(HloOpcode opcode, + const HloInstruction* lhs, + const HloInstruction* rhs); // Infers the shape produced by applying the given ternary operation to the // given input shapes. static StatusOr InferTernaryOpShape(TernaryOperation operation, const Shape& lhs, const Shape& rhs, const Shape& ehs); + static StatusOr InferTernaryOpShape(HloOpcode opcode, + const HloInstruction* lhs, + const HloInstruction* rhs, + const HloInstruction* ehs); // Infers the shape produced by applying the given variadic operation to the // given input operand shapes. static StatusOr InferVariadicOpShape( - VariadicOperation operation, std::vector operand_shapes); + VariadicOperation operation, + tensorflow::gtl::ArraySlice operand_shapes); + static StatusOr InferVariadicOpShape( + HloOpcode opcode, + tensorflow::gtl::ArraySlice operands); // Infers the shape produced by applying the given mapping computation shape // to the given operand shapes. diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index cfa5c98f593a44eabc8bc442a6c28301a46c7d50..858db8fa0e1dfcc1239b89996db050f786b3b965 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -2471,14 +2471,14 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( operand->shape().element_type(), AsInt64Slice(output_shape.dimensions())); // Do explicit broadcast for scalar. if (ShapeUtil::IsScalar(operand->shape())) { - return hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, operand, AsInt64Slice(broadcast_shape.dimensions()))); + return hlo_builder_.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, operand, {})); } // Do explicit broadcast for degenerate broadcast. std::vector broadcast_dimensions; std::vector reshaped_dimensions; for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) { - if (operand->shape().dimensions(i) > 1) { + if (operand->shape().dimensions(i) == output_shape.dimensions(i)) { broadcast_dimensions.push_back(i); reshaped_dimensions.push_back(operand->shape().dimensions(i)); } diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index 709749592949f0dc4f608af32ae5f2e89666f3b5..6b0d6b9e11cd638b8f8a2d6f6be7e5a96b351382 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -197,6 +197,65 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { operands[1]->opcode() == HloOpcode::kBroadcast); } +TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) { + auto debug_options = DebugOptions(); + debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); + + // Build a binary computation with degenerate broadcast. + // + // %a = Param({1, 2, 3}); + // %b = Param({1, 2, 1}); + // %add = Add(%a, %b, {}); + ComputationHandle handle; + handle.set_handle(123); + UserComputation computation("TheComputation", handle); + + ParameterRequest a_request; + *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 3}); + a_request.set_name("a"); + a_request.set_parameter(0); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, + computation.AddParameterInstruction(a_request)); + + ParameterRequest b_request; + *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 1}); + b_request.set_name("b"); + b_request.set_parameter(1); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, + computation.AddParameterInstruction(b_request)); + + BinaryOpRequest add; + add.set_binop(BINOP_ADD); + *add.mutable_lhs() = a_handle; + *add.mutable_rhs() = b_handle; + TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); + + auto hlo_resolver = [](const VersionedComputationHandle& handle) { + return nullptr; + }; + VersionedComputationHandle latest_version = computation.GetVersionedHandle(); + + // Build the HLO computation. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr hlo_computation, + computation.BuildHloComputation(latest_version.version, hlo_resolver, + debug_options)); + + // b a + // | | + // reshape | + // | | + // broadcast | + // \ / + // add + EXPECT_EQ(5, hlo_computation->instruction_count()); + EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); + const auto& operands = hlo_computation->root_instruction()->operands(); + ASSERT_EQ(2, operands.size()); + EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kParameter && + operands[1]->opcode() == HloOpcode::kBroadcast); +} + TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { auto debug_options = DebugOptions(); debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 0a2d337752b91238be93ebb49101d1c9b6aabae3..9f7ae4ae8730fa4caf79e7992f89e4b71d27d8a9 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -215,6 +215,7 @@ cc_library( ], deps = [ "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -899,6 +900,7 @@ xla_test_library( "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc index 407b5f4ada517d54af2a44742376348a1625b9b9..b61544466a163fb59c132bcd3880b32dee94f6a7 100644 --- a/tensorflow/compiler/xla/tests/filecheck.cc +++ b/tensorflow/compiler/xla/tests/filecheck.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" diff --git a/tensorflow/compiler/xla/tests/filecheck.h b/tensorflow/compiler/xla/tests/filecheck.h index 599bf57ad327fe0ef3b4972395eb4e0c883f763b..493ff7414bde31b18a39a5098925d9c991529b00 100644 --- a/tensorflow/compiler/xla/tests/filecheck.h +++ b/tensorflow/compiler/xla/tests/filecheck.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 4356307660260decd9f717aca65deb0d8797d0e8..2be409561ab3e23d9ea2e49aac381a90395380d0 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -153,8 +153,8 @@ float FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, } template <> -uint8 FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, - ArraySlice xs) { +bool FusionTest::ComputeElementwiseAnswer(HloOpcode opcode, + ArraySlice xs) { switch (opcode) { case HloOpcode::kEq: return xs[0] == xs[1]; @@ -569,12 +569,12 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(S32, {1}), HloOpcode::kNegate, reduce2)); + ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, reduce2)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-15}), + LiteralTestUtil::ExpectEqual(*Literal::CreateR0(-15), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -690,26 +690,24 @@ XLA_TEST_F(FusionTest, Maximum2D) { TestElementwise2D(HloOpcode::kMaximum); } -XLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D(HloOpcode::kEq); } +XLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D(HloOpcode::kEq); } XLA_TEST_F(FusionTest, Inequal2D) { - TestElementwise2D(HloOpcode::kNe); + TestElementwise2D(HloOpcode::kNe); } XLA_TEST_F(FusionTest, Greater2D) { - TestElementwise2D(HloOpcode::kGt); + TestElementwise2D(HloOpcode::kGt); } -XLA_TEST_F(FusionTest, Lesser2D) { - TestElementwise2D(HloOpcode::kLt); -} +XLA_TEST_F(FusionTest, Lesser2D) { TestElementwise2D(HloOpcode::kLt); } XLA_TEST_F(FusionTest, GreaterOrEqual2D) { - TestElementwise2D(HloOpcode::kGe); + TestElementwise2D(HloOpcode::kGe); } XLA_TEST_F(FusionTest, LesserOrEqual2D) { - TestElementwise2D(HloOpcode::kLe); + TestElementwise2D(HloOpcode::kLe); } XLA_TEST_F(FusionTest, Clamp2D) { diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index bba5b5aa044447937d2e6a55c5f9bde799d90ea3..22d2b917a1d55f4f453e21c2d8fea38e32ff796b 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -67,7 +67,7 @@ class MultiOutputFusionTest : public HloTestBase { elem_shape0, HloOpcode::kAdd, param0, const0)); HloInstruction* broadcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(elem_shape2, add1, {0, 1})); + HloInstruction::CreateBroadcast(elem_shape2, add1, {})); auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, elem_shape2, "1")); @@ -134,7 +134,7 @@ class MultiOutputFusionTest : public HloTestBase { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {size, 1}), add)); HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {}), HloOpcode::kDot, sub, reshape)); + ShapeUtil::MakeShape(F32, {1}), HloOpcode::kDot, sub, reshape)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -160,7 +160,7 @@ class MultiOutputFusionTest : public HloTestBase { auto p0 = TransferToDevice(input0); auto p1 = TransferToDevice(input1); - Literal expect = *Literal::CreateR0(size * 1.5f * 3.5f); + Literal expect = *Literal::CreateR1({size * 1.5f * 3.5f}); auto actual = ExecuteAndTransfer(std::move(hlo_module), {p0, p1}); LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index c2a708e87212c9149eb3d4a72aa35af072c58eb5..4756ba096896806ece8fe35d18c4eaef041b8830 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -270,7 +270,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionBeforeFusion) { auto reduce_precision_pass = execution_options_.mutable_debug_options() ->add_hlo_reduce_precision_options(); *reduce_precision_pass = ReducePrecisionInsertion::make_options_proto( - HloReducePrecisionOptions::BEFORE_OP_FUSION, 5, 10, + HloReducePrecisionOptions::OP_OUTPUTS, 5, 10, [](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; }); ComputeAndCompareR1(&builder, {0.0f}, {a_data.get()}); @@ -294,7 +294,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionSkippedAfterFusion) { auto reduce_precision_pass = execution_options_.mutable_debug_options() ->add_hlo_reduce_precision_options(); *reduce_precision_pass = ReducePrecisionInsertion::make_options_proto( - HloReducePrecisionOptions::AFTER_OP_FUSION, 5, 10, + HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS, 5, 10, [](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; }); ComputeAndCompareR1(&builder, {-1.00001f}, {a_data.get()}); @@ -316,7 +316,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionAddedAfterFusion) { auto reduce_precision_pass = execution_options_.mutable_debug_options() ->add_hlo_reduce_precision_options(); *reduce_precision_pass = ReducePrecisionInsertion::make_options_proto( - HloReducePrecisionOptions::AFTER_OP_FUSION, 5, 10, + HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS, 5, 10, [](const HloOpcode opcode) { return opcode == HloOpcode::kFusion; }); ComputeAndCompareR1(&builder, {-1.0f}, {a_data.get()}); @@ -339,7 +339,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionSkippedFusionContains) { auto reduce_precision_pass = execution_options_.mutable_debug_options() ->add_hlo_reduce_precision_options(); *reduce_precision_pass = ReducePrecisionInsertion::make_options_proto( - HloReducePrecisionOptions::FUSION_BY_CONTENT, 5, 10, + HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT, 5, 10, [](const HloOpcode opcode) { return opcode == HloOpcode::kCos; }); ComputeAndCompareR1(&builder, {-1.00001f}, {a_data.get()}); @@ -362,7 +362,7 @@ XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionAddedFusionContains) { auto reduce_precision_pass = execution_options_.mutable_debug_options() ->add_hlo_reduce_precision_options(); *reduce_precision_pass = ReducePrecisionInsertion::make_options_proto( - HloReducePrecisionOptions::FUSION_BY_CONTENT, 5, 10, + HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT, 5, 10, [](const HloOpcode opcode) { return opcode == HloOpcode::kAbs; }); ComputeAndCompareR1(&builder, {-1.0f}, {a_data.get()}); diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 60d6d19ce18704439fd8ef92924a17a39a9b3238..6ef5c4a8c8b0f103cdfdbf4b12344d34b964cac2 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -319,6 +320,68 @@ TEST_F(ReduceWindowTest, R4UnitWindow) { ErrorSpec(1e-3, 1e-3)); } +XLA_TEST_F(HloTestBase, R6Add) { + auto b = HloComputation::Builder(TestName()); + + std::vector input_dims(6, 8); + std::unique_ptr arg_literal = + Literal::CreateFullWithMonotonicDim0MajorLayout(input_dims, 1.0f); + auto input = + b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); + + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.f))); + + HloComputation::Builder add_computation("add"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = add_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + add_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); + + auto module = CreateNewModule(); + auto add_func = module->AddEmbeddedComputation(add_computation.Build()); + + WindowDimension trivial_dim; + trivial_dim.set_size(1); + trivial_dim.set_stride(1); + trivial_dim.set_padding_low(0); + trivial_dim.set_padding_high(0); + trivial_dim.set_window_dilation(1); + trivial_dim.set_base_dilation(1); + + WindowDimension active_dim; + active_dim.set_size(3); + active_dim.set_stride(1); + active_dim.set_padding_low(0); + active_dim.set_padding_high(0); + active_dim.set_window_dilation(1); + active_dim.set_base_dilation(1); + + Window window; + *window.add_dimensions() = trivial_dim; + *window.add_dimensions() = trivial_dim; + *window.add_dimensions() = active_dim; + *window.add_dimensions() = active_dim; + *window.add_dimensions() = trivial_dim; + *window.add_dimensions() = trivial_dim; + + Shape shape = ShapeUtil::MakeShape(F32, {8, 8, 6, 6, 8, 8}); + b.AddInstruction(HloInstruction::CreateReduceWindow(shape, input, init_value, + window, add_func)); + + std::vector output_dims = {8, 8, 6, 6, 8, 8}; + std::unique_ptr expected = + Literal::CreateFullWithMonotonicDim0MajorLayout(output_dims, 9.0f); + + module->AddEntryComputation(b.Build()); + auto actual = ExecuteAndTransfer(std::move(module), {}); + + LiteralTestUtil::ExpectNear(*actual, *expected, ErrorSpec(1e-3, 1e-3)); +} + XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) { Array4D input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index a946d335ca6c8be583528caf9bcc97baf6245ae8..da39ba3ffc3b136b0d2c6f28a02b900893366a70 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -111,6 +111,11 @@ cc_binary( deps = [ ":replay_computation_library", "//tensorflow/compiler/plugin/executor:plugin_lib", + # TODO: This dependency is a workaround for linking error with clang. + # Without it, linker complains about missing symbols from + # 'xla_device_launch_op'. This dependency should be propagated from + # plugin_lib instead, but no targets other than this break without it. + "//tensorflow/compiler/jit", ], ) diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index bd93e114b73aeb38c04c2c6a5169b9dc82d27e51..89b26b8916b67eeb38852c9e91314187fc8a7d48 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -144,7 +144,7 @@ int RealMain(tensorflow::gtl::ArraySlice args, int main(int argc, char** argv) { // Flags - string fake_infeed_shape; + xla::string fake_infeed_shape; bool use_fake_data = false; const std::vector flag_list = { tensorflow::Flag("use_fake_data", &use_fake_data, diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 6f62d77208adbdf7e46dbcc182cede5c104fb517..4840ddb8817a37c7dabcfb27e24a2a5472f4b6a2 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -22,19 +22,28 @@ package xla; // Options for the HLO insert-reduce-precision-operations pass. message HloReducePrecisionOptions { - // When and how the pass will be run. - enum PassTiming { - // Pass runs before operation fusion or other optimization occurs, - // selecting operations to suffix according to their type. - BEFORE_OP_FUSION = 0; - // Pass runs after the operation-fusion pass, selecting operations to - // suffix according to their type. - AFTER_OP_FUSION = 1; - // Pass runs after the operation-fusion pass, and selects fusion operations - // to suffix according to their contents. - FUSION_BY_CONTENT = 2; + // Where and when the reduce-precision operations will be added. + enum Location { + // Add reduce-precision operations to the inputs of selected instructions. + // This is done before any optimization occurs. + OP_INPUTS = 0; + // Add reduce-precision operations to the outputs of selected instructions. + // This is done before any optimization occurs. + OP_OUTPUTS = 1; + // After operation-fusion occurs, add reduce-precision operations to the + // outputs of any selected instructions that have not been fused into + // fusion instructions. + UNFUSED_OP_OUTPUTS = 2; + // After operation-fusion occurs, add reduce-precision operations to the + // outputs of any fusion instructions that contain operations matching the + // selection criteria. + FUSION_INPUTS_BY_CONTENT = 3; + // After operation-fusion occurs, add reduce-precision operations to the + // outputs of any fusion instructions that contain operations matching the + // selection criteria. + FUSION_OUTPUTS_BY_CONTENT = 4; } - PassTiming pass_timing = 1; + Location location = 1; // Exponent and mantissa bit counts for the reduced precision. uint32 exponent_bits = 2; @@ -138,6 +147,9 @@ message DebugOptions { // the generated IR. bool xla_llvm_enable_invariant_load_metadata = 72; + // If true, a set of expensive LLVM optimization passes will not be run. + bool xla_llvm_disable_expensive_passes = 73; + // Options for inserting reduce-precision operations for numerical // experimentation. This is a repeated field, as we may want to have // multiple passes with different parameters. diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 47a0f54a023860413fb58f16fab6b3f13af5a573..11e4ea888c7015468fbe7bfe026ba7f6faafedcd 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -24,10 +24,12 @@ py_library( "//tensorflow/contrib/deprecated:deprecated_py", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/eager/python:tfe", + "//tensorflow/contrib/estimator:estimator_py", "//tensorflow/contrib/factorization:factorization_py", "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/fused_conv:fused_conv_py", + "//tensorflow/contrib/gan", "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/contrib/grid_rnn:grid_rnn_py", "//tensorflow/contrib/hooks", @@ -72,6 +74,7 @@ py_library( "//tensorflow/contrib/staging", "//tensorflow/contrib/stat_summarizer:stat_summarizer_py", "//tensorflow/contrib/stateless", + "//tensorflow/contrib/summary:summary_ops", "//tensorflow/contrib/tensor_forest:init_py", "//tensorflow/contrib/tensorboard", "//tensorflow/contrib/testing:testing_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 315ea943cf3f4929ed0252273e74987308975651..5b3f0b3f6eee6c49a85ff6e3654e390da64ab762 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -29,8 +29,10 @@ from tensorflow.contrib import cudnn_rnn from tensorflow.contrib import data from tensorflow.contrib import deprecated from tensorflow.contrib import distributions +from tensorflow.contrib import estimator from tensorflow.contrib import factorization from tensorflow.contrib import framework +from tensorflow.contrib import gan from tensorflow.contrib import graph_editor from tensorflow.contrib import grid_rnn from tensorflow.contrib import image diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt index f61e9560ef36b32388940ef8bda8a73bb6c6e7bb..11e2128d72718e5bdbf4e86dca99784af85d54ea 100644 --- a/tensorflow/contrib/android/cmake/CMakeLists.txt +++ b/tensorflow/contrib/android/cmake/CMakeLists.txt @@ -36,7 +36,7 @@ set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION # Change to compile flags should be replicated into bazel build file set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \ -std=c++11 -fno-rtti -fno-exceptions \ - -O2 -Wno-narrowing \ + -O2 -Wno-narrowing -fomit-frame-pointer \ -mfpu=neon -mfloat-abi=softfp -fPIE \ -ftemplate-depth=900 \ -DGOOGLE_PROTOBUF_NO_RTTI \ diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index 9b7f394258d73775b5b46d3858b575d92bee1712..6389ef1f5dae9b39b9531c2ba76adf67f3dfcd57 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -16,8 +16,8 @@ limitations under the License. package org.tensorflow.contrib.android; import android.content.res.AssetManager; -import android.os.Trace; import android.os.Build.VERSION; +import android.os.Trace; import android.text.TextUtils; import android.util.Log; import java.io.FileInputStream; @@ -86,7 +86,7 @@ public class TensorFlowInferenceInterface { throw new RuntimeException("Failed to load model from '" + model + "'", e); } } - + /* * Load a TensorFlow model from provided InputStream. * Note: The InputStream will not be closed after loading model, users need to @@ -96,7 +96,7 @@ public class TensorFlowInferenceInterface { */ public TensorFlowInferenceInterface(InputStream is) { prepareNativeRuntime(); - + // modelName is redundant for model loading from input stream, here is for // avoiding error in initialization as modelName is marked final. this.modelName = ""; @@ -191,8 +191,9 @@ public class TensorFlowInferenceInterface { } /** - * Cleans up the state associated with this Object. initializeTensorFlow() can then be called - * again to initialize a new session. + * Cleans up the state associated with this Object. + * + *

The TenosrFlowInferenceInterface object is no longer usable after this method returns. */ public void close() { closeFeeds(); @@ -266,6 +267,25 @@ public class TensorFlowInferenceInterface { addFeed(inputName, Tensor.create(DataType.UINT8, dims, ByteBuffer.wrap(src))); } + /** + * Copy a byte sequence into the input Tensor with name {@link inputName} as a string-valued + * scalar tensor. In the TensorFlow type system, a "string" is an arbitrary sequence of + * bytes, not a Java {@code String} (which is a sequence of characters). + */ + public void feedString(String inputName, byte[] src) { + addFeed(inputName, Tensor.create(src)); + } + + /** + * Copy an array of byte sequences into the input Tensor with name {@link inputName} as a + * string-valued one-dimensional tensor (vector). In the TensorFlow type system, a "string" + * is an arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of + * characters). + */ + public void feedString(String inputName, byte[][] src) { + addFeed(inputName, Tensor.create(src)); + } + // Methods for taking a native Tensor and filling it with src from Java native IO buffers. /** @@ -417,7 +437,7 @@ public class TensorFlowInferenceInterface { public void fetch(String outputName, ByteBuffer dst) { getTensor(outputName).writeTo(dst); } - + private void prepareNativeRuntime() { Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded"); try { @@ -442,7 +462,7 @@ public class TensorFlowInferenceInterface { final long startMs = System.currentTimeMillis(); if (VERSION.SDK_INT >= 18) { - Trace.beginSection("initializeTensorFlow"); + Trace.beginSection("loadGraph"); Trace.beginSection("readGraphDef"); } @@ -470,7 +490,7 @@ public class TensorFlowInferenceInterface { if (VERSION.SDK_INT >= 18) { Trace.endSection(); // importGraphDef. - Trace.endSection(); // initializeTensorFlow. + Trace.endSection(); // loadGraph. } final long endMs = System.currentTimeMillis(); @@ -541,7 +561,7 @@ public class TensorFlowInferenceInterface { fetchNames.clear(); } - // State immutable between initializeTensorFlow calls. + // Immutable state. private final String modelName; private final Graph g; private final Session sess; diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index c377c50e9fea69aa147a3910e6899e138804c7a3..a8b60460c8fd8dce56bd9d5aa11d8e72ae5e1f60 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -18,6 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import os + from tensorflow.contrib.boosted_trees.proto import tree_config_pb2 from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch from tensorflow.contrib.decision_trees.proto import generic_tree_model_extensions_pb2 @@ -26,18 +29,21 @@ from tensorflow.contrib.learn.python.learn import export_strategy from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils from tensorflow.python.client import session as tf_session from tensorflow.python.framework import ops +from tensorflow.python.platform import gfile from tensorflow.python.saved_model import loader as saved_model_loader from tensorflow.python.saved_model import tag_constants -def make_custom_export_strategy(name, convert_fn, feature_columns, +def make_custom_export_strategy(name, + convert_fn, + feature_columns, export_input_fn): """Makes custom exporter of GTFlow tree format. Args: name: A string, for the name of the export strategy. convert_fn: A function that converts the tree proto to desired format and - saves it to the desired location. + saves it to the desired location. Can be None to skip conversion. feature_columns: A list of feature columns. export_input_fn: A function that takes no arguments and returns an `InputFnOps`. @@ -68,9 +74,22 @@ def make_custom_export_strategy(name, convert_fn, feature_columns, dtec = tree_config_pb2.DecisionTreeEnsembleConfig() dtec.ParseFromString(dfec_str) # Export the result in the same folder as the saved model. - convert_fn(dtec, sorted_feature_names, len(dense_floats), - len(sparse_float_indices), len(sparse_int_indices), - result_dir, eval_result) + if convert_fn: + convert_fn(dtec, sorted_feature_names, + len(dense_floats), + len(sparse_float_indices), + len(sparse_int_indices), result_dir, eval_result) + feature_importances = _get_feature_importances( + dtec, sorted_feature_names, + len(dense_floats), + len(sparse_float_indices), len(sparse_int_indices)) + sorted_by_importance = sorted( + feature_importances.items(), key=lambda x: -x[1]) + assets_dir = os.path.join(result_dir, "assets.extra") + gfile.MakeDirs(assets_dir) + with gfile.GFile(os.path.join(assets_dir, "feature_importances"), + "w") as f: + f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance)) return result_dir return export_strategy.ExportStrategy(name, export_fn) @@ -157,3 +176,41 @@ def convert_to_universal_format(dtec, sorted_feature_names, node.left_child_id.value = split.left_id node.right_child_id.value = split.right_id return model_and_features + + +def _get_feature_importances(dtec, feature_names, num_dense_floats, + num_sparse_float, num_sparse_int): + """Export the feature importance per feature column.""" + del num_sparse_int # Unused. + sums = collections.defaultdict(lambda: 0) + for tree_idx in range(len(dtec.trees)): + tree = dtec.trees[tree_idx] + for tree_node in tree.nodes: + node_type = tree_node.WhichOneof("node") + if node_type == "dense_float_binary_split": + split = tree_node.dense_float_binary_split + split_column = feature_names[split.feature_column] + elif node_type == "sparse_float_binary_split_default_left": + split = tree_node.sparse_float_binary_split_default_left.split + split_column = feature_names[split.feature_column + num_dense_floats] + elif node_type == "sparse_float_binary_split_default_right": + split = tree_node.sparse_float_binary_split_default_right.split + split_column = feature_names[split.feature_column + num_dense_floats] + elif node_type == "categorical_id_binary_split": + split = tree_node.categorical_id_binary_split + split_column = feature_names[split.feature_column + num_dense_floats + + num_sparse_float] + elif node_type == "categorical_id_set_membership_binary_split": + split = tree_node.categorical_id_set_membership_binary_split + split_column = feature_names[split.feature_column + num_dense_floats + + num_sparse_float] + elif node_type == "leaf": + assert tree_node.node_metadata.gain == 0 + continue + else: + raise ValueError("Unexpected split type %s", node_type) + # Apply shrinkage factor. It is important since it is not always uniform + # across different trees. + sums[split_column] += ( + tree_node.node_metadata.gain * dtec.tree_weights[tree_idx]) + return dict(sums) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py index 8d801fa1f382fb1b3f53ac0a1214269837c7c0cc..4ed18b2d34c5af47826ab1c058f5d13797593bd4 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py @@ -27,7 +27,7 @@ from tensorflow.python.platform import googletest class ConvertModelTest(test_util.TensorFlowTestCase): - def testConvertModel(self): + def _make_trees(self): dtec_str = """ trees { nodes { @@ -108,8 +108,12 @@ class ConvertModelTest(test_util.TensorFlowTestCase): """ dtec = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge(dtec_str, dtec) - # The feature columns in the order they were added. feature_columns = ["feature_b", "feature_a", "feature_d"] + return dtec, feature_columns + + def testConvertModel(self): + dtec, feature_columns = self._make_trees() + # The feature columns in the order they were added. out = custom_export_strategy.convert_to_universal_format( dtec, feature_columns, 1, 1, 1) @@ -273,6 +277,16 @@ class ConvertModelTest(test_util.TensorFlowTestCase): }""" self.assertProtoEquals(expected_tree, out) + def testFeatureImportance(self): + dtec, feature_columns = self._make_trees() + feature_importances = custom_export_strategy._get_feature_importances( + dtec, feature_columns, 1, 1, 1) + self.assertItemsEqual(["feature_b", "feature_a", "feature_d"], + feature_importances.keys()) + self.assertAlmostEqual(50.0, feature_importances["feature_b"], places=4) + self.assertAlmostEqual(50.0, feature_importances["feature_a"], places=4) + self.assertAlmostEqual(50.0, feature_importances["feature_d"], places=4) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index e28adad53ec917de8d60b1eb442602b8e5abbb38..f8028acbdb0be44b7fd81b96b04b6e24d9060aa6 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -61,11 +61,19 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): logits_modifier_function: A modifier function for the logits. center_bias: Whether a separate tree should be created for first fitting the bias. + + Raises: + ValueError: If learner_config is not valid. """ head = head_lib.multi_class_head( n_classes=n_classes, weight_column_name=weight_column_name, enable_centered_bias=False) + if learner_config.num_classes == 0: + learner_config.num_classes = n_classes + elif learner_config.num_classes != n_classes: + raise ValueError("n_classes (%d) doesn't match learner_config (%d)." % + (learner_config.num_classes, n_classes)) super(GradientBoostedDecisionTreeClassifier, self).__init__( model_fn=model.model_builder, params={ @@ -129,6 +137,10 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): label_dimension=label_dimension, weight_column_name=weight_column_name, enable_centered_bias=False) + if label_dimension == 1: + learner_config.num_classes = 2 + else: + learner_config.num_classes = label_dimension super(GradientBoostedDecisionTreeRegressor, self).__init__( model_fn=model.model_builder, params={ diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 2d517f781115f55de99a208419ff300c28470c04..8cda5c8f2b14f2ec3cfe3702e38b81803dd075f7 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -92,6 +92,7 @@ def model_builder(features, labels, mode, params, config): examples_per_layer=examples_per_layer, learner_config=learner_config, feature_columns=feature_columns, + logits_dimension=head.logits_dimension, features=features) with ops.name_scope("gbdt", "gbdt_optimizer"): predictions_dict = gbdt_model.predict(mode) diff --git a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc index 42112c586a5f5e940d31e0810ae9589d79239641..f4ad99f779e0d7fcf207934d77776548214371c1 100644 --- a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc @@ -74,7 +74,7 @@ class TreeEnsembleStampTokenOp : public OpKernel { decision_tree_ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &decision_tree_ensemble_resource)); - mutex_lock l(*decision_tree_ensemble_resource->get_mutex()); + tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex()); core::ScopedUnref unref_me(decision_tree_ensemble_resource); Tensor* output_stamp_token_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), @@ -95,7 +95,7 @@ class TreeEnsembleSerializeOp : public OpKernel { decision_tree_ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &decision_tree_ensemble_resource)); - mutex_lock l(*decision_tree_ensemble_resource->get_mutex()); + tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex()); core::ScopedUnref unref_me(decision_tree_ensemble_resource); Tensor* output_stamp_token_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index daca0495481fdadebd239933c14e8b6ff08f4558..8ffd7f120b49b09a49fde2ac7319f56a3f03459a 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -143,7 +143,7 @@ class GradientTreesPredictionOp : public OpKernel { // Release the reference to the resource once we're done using it. core::ScopedUnref unref_me(decision_tree_ensemble_resource); if (use_locking_) { - mutex_lock l(*decision_tree_ensemble_resource->get_mutex()); + tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex()); DoCompute(context, decision_tree_ensemble_resource); } else { DoCompute(context, decision_tree_ensemble_resource); @@ -334,7 +334,7 @@ class GradientTreesPartitionExamplesOp : public OpKernel { // Release the reference to the resource once we're done using it. core::ScopedUnref unref_me(decision_tree_ensemble_resource); if (use_locking_) { - mutex_lock l(*decision_tree_ensemble_resource->get_mutex()); + tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex()); DoCompute(context, decision_tree_ensemble_resource); } else { DoCompute(context, decision_tree_ensemble_resource); diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index bbcc308677ec9dd5238807d72ee31d9c82186450..3ccc36dff891d101733e66aadbe3e5744fd352f9 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -901,7 +901,7 @@ class BucketizeWithInputBoundariesOp : public OpKernel { Tensor* output_tensor = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); - auto output = output_tensor->template flat(); + auto output = output_tensor->template flat(); for (size_t i = 0; i < input.size(); i++) { output(i) = CalculateBucketIndex(input(i)); @@ -909,10 +909,14 @@ class BucketizeWithInputBoundariesOp : public OpKernel { } private: - int64 CalculateBucketIndex(const T value) { + int32 CalculateBucketIndex(const T value) { auto first_bigger_it = std::upper_bound(boundaries_.begin(), boundaries_.end(), value); - return first_bigger_it - boundaries_.begin(); + int32 index = first_bigger_it - boundaries_.begin(); + CHECK(index >= 0 && index <= boundaries_.size()) + << "Invalid bucket index: " << index + << " boundaries_.size(): " << boundaries_.size(); + return index; } std::vector boundaries_; }; diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 9e9ef1738cdf941e821f06ea3310b22f8c564134..d528757cf99c9c6dc0b3d75c765e47f5cbcff19c 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -656,7 +656,8 @@ class GrowTreeEnsembleOp : public OpKernel { CHECK(split->split_info.split_node().node_case() != TreeNode::NODE_NOT_SET); CHECK(tree_config->nodes(node_id).node_case() == TreeNode::kLeaf) << "Unexpected node type to split " - << tree_config->nodes(node_id).node_case(); + << tree_config->nodes(node_id).node_case() << " for node_id " << node_id + << ". Tree config: " << tree_config->DebugString(); // Add left leaf. int32 left_id = tree_config->nodes_size(); @@ -767,7 +768,7 @@ class TreeEnsembleStatsOp : public OpKernel { OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &decision_tree_ensemble_resource)); core::ScopedUnref unref_me(decision_tree_ensemble_resource); - mutex_lock l(*decision_tree_ensemble_resource->get_mutex()); + tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex()); // Get the stamp token. const Tensor* stamp_token_t; diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler_test.cc index 82664aed72d99aa3e84d5f3f38bff8ec5e4ca099..f4c7df7fabda1a38d7e6cca4c5c8bc81cb7551b1 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler_test.cc @@ -42,6 +42,7 @@ class BiasFeatureColumnHandlerTest : public ::testing::Test { example_partitions_({0, 0, 1, 3}) { // Set L2 regularization. learner_config_.mutable_regularization()->set_l2(2.0f); + learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); // Create handler. handler_.reset(new BiasFeatureColumnHandler(kClassId, kSlotId, kBatchSize)); diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler_test.cc index abd72384648dc3ac5d7f00e3b6d89fea3eb09afb..ea82b3f086d24dc1f9ceb4783abd68be35b34b00 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler_test.cc @@ -51,7 +51,7 @@ class CategoricalFeatureColumnHandlerTest : public ::testing::Test { values_(test::AsTensor({1, 2, 2, 0}, {4})) { // Set L2 regularization. learner_config_.mutable_regularization()->set_l2(2.0f); - + learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); // Create handler. handler_.reset(new CategoricalFeatureColumnHandler( kClassId, kSlotId, kBatchSize, kFeatureColumn, indices_.matrix(), diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc index 396f48e5321f1012571bcfb2f3f013cf94ffd987..1bc9d733ad3090f1cfc9547644061f54d7d2c8c6 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc @@ -51,7 +51,7 @@ class DenseQuantizedFeatureColumnHandlerTest : public ::testing::Test { dense_quantized_values_(test::AsTensor({1, 1, 0, 1}, {4})) { // Set L2 regularization. learner_config_.mutable_regularization()->set_l2(2.0f); - + learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); // Create handler. handler_.reset(new DenseQuantizedFeatureColumnHandler( kClassId, kSlotId, kBatchSize, kFeatureColumn, diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc index db8c64a617f88ecd5ce9696317c12b632de6f78d..643d936ad23850e601bc5518d69c8637011f53c0 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc @@ -53,7 +53,7 @@ class SparseQuantizedFeatureColumnHandlerTest : public ::testing::Test { sparse_quantized_values_(test::AsTensor({1, 0, 1}, {3})) { // Set L2 regularization. learner_config_.mutable_regularization()->set_l2(2.0f); - + learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); // Create handler. handler_.reset(new SparseQuantizedFeatureColumnHandler( kClassId, kSlotId, kBatchSize, kFeatureColumn, diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats_test.cc index f99b6826a7819e627d274e23700d0c8c9c53d2af..ecb7a04efb96248210d9af770c8377b7f6906598 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats_test.cc @@ -30,6 +30,7 @@ const double kDelta = 1e-5; TEST(NodeStatsTest, AlmostZero) { LearnerConfig learner_config; + learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); NodeStats node_stats(learner_config, GradientStats(1e-8f, 1e-8f)); EXPECT_EQ(0, node_stats.weight_contribution[0]); EXPECT_EQ(0, node_stats.gain); @@ -37,6 +38,7 @@ TEST(NodeStatsTest, AlmostZero) { TEST(NodeStatsTest, LessThanMinWeightConstraint) { LearnerConfig learner_config; + learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); learner_config.mutable_constraints()->set_min_node_weight(3.2f); NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f)); EXPECT_EQ(0, node_stats.weight_contribution[0]); @@ -45,6 +47,7 @@ TEST(NodeStatsTest, LessThanMinWeightConstraint) { TEST(NodeStatsTest, L1RegSquashed) { LearnerConfig learner_config; + learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); learner_config.mutable_regularization()->set_l1(10.0f); NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f)); EXPECT_EQ(0, node_stats.weight_contribution[0]); @@ -53,6 +56,7 @@ TEST(NodeStatsTest, L1RegSquashed) { TEST(NodeStatsTest, L1RegPos) { LearnerConfig learner_config; + learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); learner_config.mutable_regularization()->set_l1(5.0f); NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f)); const float expected_clipped_grad = 7.32f - 5.0f; @@ -66,6 +70,7 @@ TEST(NodeStatsTest, L1RegPos) { TEST(NodeStatsTest, L1RegNeg) { LearnerConfig learner_config; + learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); learner_config.mutable_regularization()->set_l1(5.0f); NodeStats node_stats(learner_config, GradientStats(-7.32f, 1.63f)); const float expected_clipped_grad = -7.32f + 5.0f; @@ -79,6 +84,7 @@ TEST(NodeStatsTest, L1RegNeg) { TEST(NodeStatsTest, L2Reg) { LearnerConfig learner_config; + learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); learner_config.mutable_regularization()->set_l2(8.0f); NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f)); const float expected_denom = 1.63f + 8.0f; @@ -91,6 +97,7 @@ TEST(NodeStatsTest, L2Reg) { TEST(NodeStatsTest, L1L2Reg) { LearnerConfig learner_config; + learner_config.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); learner_config.mutable_regularization()->set_l1(5.0f); learner_config.mutable_regularization()->set_l2(8.0f); NodeStats node_stats(learner_config, GradientStats(7.32f, 1.63f)); diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h index 314c44fddc5d293d410fc5a69f2dadfe4f35a46a..dad3b4e10deff7b8fb3a2a393e27a5d7099984a1 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h @@ -15,6 +15,7 @@ #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ +#include #include #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h" @@ -34,10 +35,27 @@ class WeightedQuantilesSummary { struct SummaryEntry { SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min, - const WeightType& max) - : value(v), weight(w), min_rank(min), max_rank(max) {} + const WeightType& max) { + // Explicitely initialize all of memory (including padding from memory + // alignment) to allow the struct to be msan-resistant "plain old data". + // + // POD = http://en.cppreference.com/w/cpp/concept/PODType + memset(this, 0, sizeof(*this)); + + value = v; + weight = w; + min_rank = min; + max_rank = max; + } + + SummaryEntry() { + memset(this, 0, sizeof(*this)); - SummaryEntry() : value(0), weight(0), min_rank(0), max_rank(0) {} + value = 0; + weight = 0; + min_rank = 0; + max_rank = 0; + } bool operator==(const SummaryEntry& other) const { return value == other.value && weight == other.weight && diff --git a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc index ac8a5201a081c3af41613c008f8bc325f2201902..0336008e86108bb9813a64ebd254dedf34975b63 100644 --- a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc @@ -289,7 +289,7 @@ the sparse feature tensors. REGISTER_OP("BucketizeWithInputBoundaries") .Input("input: T") .Input("boundaries: float") - .Output("output: int64") + .Output("output: int32") .Attr("T: {int32, int64, float, double}") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( diff --git a/tensorflow/contrib/boosted_trees/proto/learner.proto b/tensorflow/contrib/boosted_trees/proto/learner.proto index 06ee223467b63c0e23f63ac60425c43d53b15ee2..919e7cd81427c27cf892bc77998f52406d2bcf15 100644 --- a/tensorflow/contrib/boosted_trees/proto/learner.proto +++ b/tensorflow/contrib/boosted_trees/proto/learner.proto @@ -17,7 +17,7 @@ message TreeRegularizationConfig { // Tree constraints config. message TreeConstraintsConfig { - // Maximum depth of the trees. + // Maximum depth of the trees. The default value is 6 if not specified. uint32 max_tree_depth = 1; // Min hessian weight per node. @@ -86,20 +86,22 @@ message LearningRateDropoutDrivenConfig { message LearnerConfig { enum PruningMode { - PRE_PRUNE = 0; - POST_PRUNE = 1; + PRUNING_MODE_UNSPECIFIED = 0; + PRE_PRUNE = 1; + POST_PRUNE = 2; } enum GrowingMode { - WHOLE_TREE = 0; - // Layer by layer is only supported by the batch learner. - LAYER_BY_LAYER = 1; + GROWING_MODE_UNSPECIFIED = 0; + WHOLE_TREE = 1; + LAYER_BY_LAYER = 2; } enum MultiClassStrategy { - TREE_PER_CLASS = 0; - FULL_HESSIAN = 1; - DIAGONAL_HESSIAN = 2; + MULTI_CLASS_STRATEGY_UNSPECIFIED = 0; + TREE_PER_CLASS = 1; + FULL_HESSIAN = 2; + DIAGONAL_HESSIAN = 3; } // Number of classes. @@ -118,16 +120,18 @@ message LearnerConfig { // Constraints. TreeConstraintsConfig constraints = 5; - // Pruning. + // Pruning. POST_PRUNE is the default pruning mode. PruningMode pruning_mode = 8; - // Growing Mode. + // Growing Mode. LAYER_BY_LAYER is the default growing mode. GrowingMode growing_mode = 9; - // Learning rate. + // Learning rate. By default we use fixed learning rate of 0.1. LearningRateConfig learning_rate_tuner = 6; - // Multi-class strategy. + // Multi-class strategy. By default we use TREE_PER_CLASS for binary + // classification and linear regression. For other cases, we use + // DIAGONAL_HESSIAN as the default. MultiClassStrategy multi_class_strategy = 10; // If you want to average the ensembles (for regularization), provide the diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py index 51e084b79c6bd0d7562e1d2ef6d908de48c5c8a2..37595f1c75deab4db810d6ae49b57f56f417c52f 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py @@ -344,6 +344,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # Prepare learner config. learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE result, result_no_dropout, dropout_info = ( prediction_ops.gradient_trees_prediction( diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index b6a07eafd594bbca658d47b3e6c7de036483f40a..2d28e0a9f160373b4565d83e9b57de401a052bd6 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -261,6 +261,7 @@ class GradientBoostedDecisionTreeModel(object): examples_per_layer, learner_config, features, + logits_dimension, feature_columns=None): """Construct a new GradientBoostedDecisionTreeModel function. @@ -273,8 +274,8 @@ class GradientBoostedDecisionTreeModel(object): a tree layer. It can also be a function that computes the number of examples based on the depth of the layer that's being built. learner_config: A learner config. - print split, sorted_feature_names[split.feature_column] features: `dict` of `Tensor` objects. + logits_dimension: An int, the dimension of logits. feature_columns: A list of feature columns. Raises: @@ -289,11 +290,39 @@ class GradientBoostedDecisionTreeModel(object): if learner_config.num_classes < 2: raise ValueError("Number of classes must be >=2") + self._logits_dimension = logits_dimension self._is_chief = is_chief self._num_ps_replicas = num_ps_replicas self._ensemble_handle = ensemble_handle self._center_bias = center_bias self._examples_per_layer = examples_per_layer + + # Fill in the defaults. + if (learner_config.multi_class_strategy == + learner_pb2.LearnerConfig.MULTI_CLASS_STRATEGY_UNSPECIFIED): + if logits_dimension == 1: + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.TREE_PER_CLASS) + else: + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) + + if (learner_config.growing_mode == + learner_pb2.LearnerConfig.GROWING_MODE_UNSPECIFIED): + learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER + + if (learner_config.pruning_mode == + learner_pb2.LearnerConfig.PRUNING_MODE_UNSPECIFIED): + learner_config.pruning_mode = learner_pb2.LearnerConfig.POST_PRUNE + + if learner_config.constraints.max_tree_depth == 0: + # Use 6 as the default maximum depth. + learner_config.constraints.max_tree_depth = 6 + + tuner = learner_config.learning_rate_tuner.WhichOneof("tuner") + if not tuner: + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + self._learner_config = learner_config self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() @@ -378,75 +407,81 @@ class GradientBoostedDecisionTreeModel(object): local_stamp), _refresh_local_ensemble_fn, lambda: (control_flow_ops.no_op(), ensemble_stamp)) - # Once updated, Use the the local model for prediction. + # Once updated, use the local model for prediction. with ops.control_dependencies([refresh_local_ensemble]): ensemble_stats = training_ops.tree_ensemble_stats( local_ensemble_handle, ensemble_stamp) - apply_dropout, seed = _dropout_params(mode, ensemble_stats) # We don't need dropout info - we can always restore it based on the # seed. - predictions, predictions_no_dropout, _ = ( - prediction_ops.gradient_trees_prediction( - local_ensemble_handle, - seed, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - learner_config=self._learner_config_serialized, - apply_dropout=apply_dropout, - apply_averaging=apply_averaging, - use_locking=False, - center_bias=self._center_bias, - reduce_dim=self._reduce_dim)) - partition_ids = prediction_ops.gradient_trees_partition_examples( - local_ensemble_handle, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - use_locking=False) + apply_dropout, seed = _dropout_params(mode, ensemble_stats) + # Make sure ensemble stats run. This will check that the ensemble has + # the right stamp. + with ops.control_dependencies(ensemble_stats): + predictions, predictions_no_dropout, _ = ( + prediction_ops.gradient_trees_prediction( + local_ensemble_handle, + seed, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + learner_config=self._learner_config_serialized, + apply_dropout=apply_dropout, + apply_averaging=apply_averaging, + use_locking=True, + center_bias=self._center_bias, + reduce_dim=self._reduce_dim)) + partition_ids = prediction_ops.gradient_trees_partition_examples( + local_ensemble_handle, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + use_locking=True) else: with ops.device(self._ensemble_handle.device): ensemble_stats = training_ops.tree_ensemble_stats( self._ensemble_handle, ensemble_stamp) - apply_dropout, seed = _dropout_params(mode, ensemble_stats) # We don't need dropout info - we can always restore it based on the # seed. - predictions, predictions_no_dropout, _ = ( - prediction_ops.gradient_trees_prediction( - self._ensemble_handle, - seed, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - learner_config=self._learner_config_serialized, - apply_dropout=apply_dropout, - apply_averaging=apply_averaging, - use_locking=False, - center_bias=self._center_bias, - reduce_dim=self._reduce_dim)) - partition_ids = prediction_ops.gradient_trees_partition_examples( - self._ensemble_handle, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - use_locking=False) + apply_dropout, seed = _dropout_params(mode, ensemble_stats) + # Make sure ensemble stats run. This will check that the ensemble has + # the right stamp. + with ops.control_dependencies(ensemble_stats): + predictions, predictions_no_dropout, _ = ( + prediction_ops.gradient_trees_prediction( + self._ensemble_handle, + seed, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + learner_config=self._learner_config_serialized, + apply_dropout=apply_dropout, + apply_averaging=apply_averaging, + use_locking=True, + center_bias=self._center_bias, + reduce_dim=self._reduce_dim)) + partition_ids = prediction_ops.gradient_trees_partition_examples( + self._ensemble_handle, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + use_locking=True) return _make_predictions_dict(ensemble_stamp, predictions, predictions_no_dropout, partition_ids, @@ -467,6 +502,11 @@ class GradientBoostedDecisionTreeModel(object): Raises: ValueError: if inputs are not valid. """ + # Get the worker device from input dependencies. + input_deps = (self._dense_floats + self._sparse_float_indices + + self._sparse_int_indices) + worker_device = input_deps[0].device + # Get tensors relevant for training and form the loss. predictions = predictions_dict[PREDICTIONS] partition_ids = predictions_dict[PARTITION_IDS] @@ -478,7 +518,6 @@ class GradientBoostedDecisionTreeModel(object): colocate_gradients_with_ops=False, gate_gradients=0, aggregation_method=None)[0] - strategy = self._learner_config.multi_class_strategy num_classes = self._learner_config.num_classes @@ -541,7 +580,7 @@ class GradientBoostedDecisionTreeModel(object): fc_name_idx = 0 handlers = [] init_stamp_token = constant_op.constant(0, dtype=dtypes.int64) - with ops.device(self._get_replica_device_setter()): + with ops.device(self._get_replica_device_setter(worker_device)): # Create handlers for dense float columns for dense_float_column_idx in range(len(self._dense_floats)): fc_name = self._fc_names[fc_name_idx] @@ -666,10 +705,6 @@ class GradientBoostedDecisionTreeModel(object): # Update handler stats. handler_reads = {} - - input_deps = (self._dense_floats + self._sparse_float_indices + - self._sparse_int_indices) - worker_device = input_deps[0].device for handler in handlers: handler_reads[handler] = handler.scheduled_reads() @@ -841,7 +876,7 @@ class GradientBoostedDecisionTreeModel(object): return diag_hessian_list - def _get_replica_device_setter(self): + def _get_replica_device_setter(self, worker_device): """Creates a replica device setter.""" ps_tasks = self._num_ps_replicas ps_ops = [ @@ -854,6 +889,7 @@ class GradientBoostedDecisionTreeModel(object): ] ps_strategy = _OpRoundRobinStrategy(ps_ops, ps_tasks) return device_setter.replica_device_setter( + worker_device=worker_device, ps_tasks=ps_tasks, merge_devices=True, ps_ops=ps_ops, diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 9ce434edf8bbd4a28dcc22557314263e804c5c61..16e24d97ddee0751e0b808b89080074c1b4baba7 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -164,7 +164,7 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - features=features) + logits_dimension=1, features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -268,7 +268,7 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=num_examples_fn, learner_config=learner_config, - features=features) + logits_dimension=1, features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -371,7 +371,7 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - features=features) + logits_dimension=1, features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -442,7 +442,7 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - features=features) + logits_dimension=1, features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -505,7 +505,7 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - features=features) + logits_dimension=1, features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -588,7 +588,7 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - features=features) + logits_dimension=1, features=features) # Create predict op. mode = model_fn.ModeKeys.EVAL @@ -627,7 +627,7 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - features=features) + logits_dimension=5, features=features) predictions = array_ops.constant( [[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0], @@ -730,7 +730,7 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - features=features) + logits_dimension=5, features=features) predictions = array_ops.constant( [[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0], @@ -833,7 +833,7 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - features=features) + logits_dimension=5, features=features) batch_size = 3 predictions = array_ops.constant( diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index f6a47d26c805824c0ad31ac42f082d2916bd2e51..c249a2855622581534534a94af9991d12b73f5e9 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -33,6 +33,7 @@ option(tensorflow_BUILD_MORE_PYTHON_TESTS "Build more python unit tests for cont option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF) option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON) option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions") +option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON) if (NOT WIN32) # Threads: defines CMAKE_THREAD_LIBS_INIT and adds -pthread compile option @@ -125,6 +126,7 @@ include(nsync) include(protobuf) include(re2) include(cub) +include(sqlite) if (tensorflow_BUILD_CC_TESTS) include(googletest) endif() @@ -142,6 +144,7 @@ set(tensorflow_EXTERNAL_LIBRARIES ${nsync_STATIC_LIBRARIES} ${protobuf_STATIC_LIBRARIES} ${re2_STATIC_LIBRARIES} + ${sqlite_STATIC_LIBRARIES} ) set(tensorflow_EXTERNAL_DEPENDENCIES zlib_copy_headers_to_destination @@ -159,6 +162,7 @@ set(tensorflow_EXTERNAL_DEPENDENCIES cub fft2d re2 + sqlite_copy_headers_to_destination ) include_directories( @@ -180,6 +184,7 @@ include_directories( ${nsync_INCLUDE_DIR} ${PROTOBUF_INCLUDE_DIRS} ${re2_INCLUDE_DIR} + ${sqlite_INCLUDE_DIR} ) if(tensorflow_ENABLE_SSL_SUPPORT) @@ -200,6 +205,12 @@ if(tensorflow_ENABLE_JEMALLOC_SUPPORT) list(APPEND tensorflow_EXTERNAL_DEPENDENCIES jemalloc) include_directories(${jemalloc_INCLUDE_DIRS}) endif() +if(tensorflow_ENABLE_SNAPPY_SUPPORT) + include(snappy) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${snappy_STATIC_LIBRARIES}) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES snappy) + include_directories(${snappy_INCLUDE_DIR}) +endif() if(WIN32) list(APPEND tensorflow_EXTERNAL_LIBRARIES wsock32 ws2_32 shlwapi) endif() diff --git a/tensorflow/contrib/cmake/external/boringssl.cmake b/tensorflow/contrib/cmake/external/boringssl.cmake index 04a9664701c072d30f19f4435bcf137109eefb37..dc27eadaca14361ffeffa6eadf6d4d97524de310 100644 --- a/tensorflow/contrib/cmake/external/boringssl.cmake +++ b/tensorflow/contrib/cmake/external/boringssl.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(boringssl_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src/boringssl/include) #set(boringssl_EXTRA_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src) set(boringssl_URL https://boringssl.googlesource.com/boringssl) -set(boringssl_TAG 17cf2cb1d226b0ba2401304242df7ddd3b6f1ff2) +set(boringssl_TAG ee7aa02) set(boringssl_BUILD ${CMAKE_BINARY_DIR}/boringssl/src/boringssl-build) #set(boringssl_LIBRARIES ${boringssl_BUILD}/obj/so/libboringssl.so) set(boringssl_STATIC_LIBRARIES diff --git a/tensorflow/contrib/cmake/external/cub.cmake b/tensorflow/contrib/cmake/external/cub.cmake index a5ce0059df244cddf2796800664678000c71e8db..d98579d2077f0a3bc58e6466ee830e53f44f40cb 100644 --- a/tensorflow/contrib/cmake/external/cub.cmake +++ b/tensorflow/contrib/cmake/external/cub.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(cub_URL http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.6.4.zip) -set(cub_HASH SHA256=966d0c4f41e2bdc81aebf9ccfbf0baffaac5a74f00b826b06f4dee79b2bb8cee) +set(cub_URL http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.3.zip) +set(cub_HASH SHA256=b7ead9e291d34ffa8074243541c1380d63be63f88de23de8ee548db573b72ebe) set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_ARCHIVE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/cub_archive) diff --git a/tensorflow/contrib/cmake/external/snappy.cmake b/tensorflow/contrib/cmake/external/snappy.cmake new file mode 100644 index 0000000000000000000000000000000000000000..a35d8654fb6fa5f5b5d230ffbc061d050e5aeb5e --- /dev/null +++ b/tensorflow/contrib/cmake/external/snappy.cmake @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +include (ExternalProject) + +set(snappy_URL https://github.com/google/snappy.git) +set(snappy_TAG "55924d11095df25ab25c405fadfe93d0a46f82eb") +set(snappy_BUILD ${CMAKE_CURRENT_BINARY_DIR}/snappy/src/snappy) +set(snappy_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/snappy/src/snappy) + +if(WIN32) + set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/$(Configuration)/snappy.lib) +else() + set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/libsnappy.a) +endif() + +set(snappy_HEADERS + "${snappy_INCLUDE_DIR}/snappy.h" +) + +ExternalProject_Add(snappy + PREFIX snappy + GIT_REPOSITORY ${snappy_URL} + GIT_TAG ${snappy_TAG} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + BUILD_IN_SOURCE 1 + INSTALL_COMMAND "" + LOG_DOWNLOAD ON + LOG_CONFIGURE ON + LOG_BUILD ON + CMAKE_CACHE_ARGS + -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + -DSNAPPY_BUILD_TESTS:BOOL=OFF + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON +) + +# actually enables snappy in the source code +add_definitions(-DSNAPPY) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/sqlite.cmake b/tensorflow/contrib/cmake/external/sqlite.cmake new file mode 100644 index 0000000000000000000000000000000000000000..a8809d4a4a510cb5d0773c48f59b6481ca23d62f --- /dev/null +++ b/tensorflow/contrib/cmake/external/sqlite.cmake @@ -0,0 +1,72 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +include (ExternalProject) + +set(sqlite_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/sqlite) +set(sqlite_URL http://www.sqlite.org/2017/sqlite-amalgamation-3200000.zip) +set(sqlite_HASH SHA256=208780b3616f9de0aeb50822b7a8f5482f6515193859e91ed61637be6ad74fd4) +set(sqlite_BUILD ${CMAKE_CURRENT_BINARY_DIR}/sqlite/src/sqlite) +set(sqlite_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/sqlite/install) + +if(WIN32) + set(sqlite_STATIC_LIBRARIES ${sqlite_INSTALL}/lib/sqlite.lib) +else() + set(sqlite_STATIC_LIBRARIES ${sqlite_INSTALL}/lib/sqlite.a) +endif() + +set(sqlite_HEADERS + "${sqlite_BUILD}/sqlite3.h" +) + +if (WIN32) + ExternalProject_Add(sqlite + PREFIX sqlite + URL ${sqlite_URL} + URL_HASH ${sqlite_HASH} + PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/sqlite/CMakeLists.txt ${sqlite_BUILD} + INSTALL_DIR ${sqlite_INSTALL} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + CMAKE_CACHE_ARGS + -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + -DCMAKE_INSTALL_PREFIX:STRING=${sqlite_INSTALL} + ) + +else() + ExternalProject_Add(sqlite + PREFIX sqlite + URL ${sqlite_URL} + URL_HASH ${sqlite_HASH} + INSTALL_DIR ${sqlite_INSTALL} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + BUILD_COMMAND $(MAKE) + INSTALL_COMMAND $(MAKE) install + CFLAGS=-fPIC + ) + +endif() + +# put sqlite includes in the directory where they are expected +add_custom_target(sqlite_create_destination_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${sqlite_INCLUDE_DIR} + DEPENDS sqlite) + +add_custom_target(sqlite_copy_headers_to_destination + DEPENDS sqlite_create_destination_dir) + +foreach(header_file ${sqlite_HEADERS}) + add_custom_command(TARGET sqlite_copy_headers_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${sqlite_INCLUDE_DIR}) +endforeach() \ No newline at end of file diff --git a/tensorflow/contrib/cmake/patches/sqlite/CMakeLists.txt b/tensorflow/contrib/cmake/patches/sqlite/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e67792afd2602da0c18201b098d023020f07a8f9 --- /dev/null +++ b/tensorflow/contrib/cmake/patches/sqlite/CMakeLists.txt @@ -0,0 +1,28 @@ +cmake_minimum_required(VERSION 2.8.3) + +project(sqlite) + +set(SQLITE_SRCS + "sqlite3.c" +) + +set(SQLITE_INCLUDES + "sqlite3.h" +) + +include_directories("${CMAKE_CURRENT_SOURCE_DIR}") + +add_library(sqlite ${SQLITE_SRCS}) + +# C++11 +target_compile_features(sqlite PRIVATE + cxx_rvalue_references +) + +install(TARGETS sqlite + LIBRARY DESTINATION lib COMPONENT RuntimeLibraries + ARCHIVE DESTINATION lib COMPONENT Development) + +foreach(SQLITE_INCLUDE ${SQLITE_INCLUDES}) + install(FILES ${SQLITE_INCLUDE} DESTINATION include COMPONENT Development) +endforeach() diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index 24291a94b3ef3fa25f3d80b57e9fdd5b77504a51..c5a101812710f0e6eb0aa8816acd2b395e7f7472 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -18,6 +18,7 @@ set(tf_c_srcs "${tensorflow_source_dir}/tensorflow/c/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/c_api.h" + "${tensorflow_source_dir}/tensorflow/c/c_api_function.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h" "${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc" @@ -29,10 +30,19 @@ set(tf_c_srcs ) add_library(tf_c OBJECT ${tf_c_srcs}) -add_dependencies(tf_c tf_cc_framework tf_core_lib tf_protos_cc) +add_dependencies( + tf_c + tf_cc_framework + tf_cc_while_loop + tf_core_lib + tf_protos_cc) add_library(tf_c_python_api OBJECT "${tensorflow_source_dir}/tensorflow/c/python_api.cc" "${tensorflow_source_dir}/tensorflow/c/python_api.h" ) -add_dependencies(tf_c_python_api tf_c tf_cc_framework tf_core_lib tf_protos_cc) +add_dependencies( + tf_c_python_api + tf_c + tf_core_lib + tf_protos_cc) diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index b53f428461d70685b0226f4d0cd0a3f63d8f47d8..6632433087c608a65d9425e5a1efdfccc95af339 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -105,6 +105,16 @@ add_library(tf_cc_ops OBJECT "${tensorflow_source_dir}/tensorflow/cc/ops/standard_ops.h" ) +######################################################## +# tf_cc_while_loop library +######################################################## +add_library(tf_cc_while_loop OBJECT + "${tensorflow_source_dir}/tensorflow/cc/ops/while_loop.h" + "${tensorflow_source_dir}/tensorflow/cc/ops/while_loop.cc" +) + +add_dependencies(tf_cc_while_loop tf_core_framework tf_cc_ops) + ######################################################## # tf_cc library ######################################################## diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 31d4fe5e7f1e2ade2c15fdb03b1dc38b73614a18..ce94f718a109deff9d8214562eccab32766b0a98 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -218,6 +218,48 @@ add_python_module("tensorflow/python/estimator/inputs/queues") add_python_module("tensorflow/python/feature_column") add_python_module("tensorflow/python/framework") add_python_module("tensorflow/python/grappler") +add_python_module("tensorflow/python/keras") +add_python_module("tensorflow/python/keras/activations") +add_python_module("tensorflow/python/keras/applications") +add_python_module("tensorflow/python/keras/applications/inception_v3") +add_python_module("tensorflow/python/keras/applications/mobilenet") +add_python_module("tensorflow/python/keras/applications/resnet50") +add_python_module("tensorflow/python/keras/applications/vgg16") +add_python_module("tensorflow/python/keras/applications/vgg19") +add_python_module("tensorflow/python/keras/applications/xception") +add_python_module("tensorflow/python/keras/backend") +add_python_module("tensorflow/python/keras/callbacks") +add_python_module("tensorflow/python/keras/constraints") +add_python_module("tensorflow/python/keras/datasets") +add_python_module("tensorflow/python/keras/datasets/boston_housing") +add_python_module("tensorflow/python/keras/datasets/cifar10") +add_python_module("tensorflow/python/keras/datasets/cifar100") +add_python_module("tensorflow/python/keras/datasets/imdb") +add_python_module("tensorflow/python/keras/datasets/mnist") +add_python_module("tensorflow/python/keras/datasets/reuters") +add_python_module("tensorflow/python/keras/initializers") +add_python_module("tensorflow/python/keras/layers") +add_python_module("tensorflow/python/keras/losses") +add_python_module("tensorflow/python/keras/metrics") +add_python_module("tensorflow/python/keras/models") +add_python_module("tensorflow/python/keras/optimizers") +add_python_module("tensorflow/python/keras/preprocessing") +add_python_module("tensorflow/python/keras/preprocessing/image") +add_python_module("tensorflow/python/keras/preprocessing/sequence") +add_python_module("tensorflow/python/keras/preprocessing/text") +add_python_module("tensorflow/python/keras/regularizers") +add_python_module("tensorflow/python/keras/utils") +add_python_module("tensorflow/python/keras/wrappers") +add_python_module("tensorflow/python/keras/wrappers/scikit_learn") +add_python_module("tensorflow/python/keras/_impl") +add_python_module("tensorflow/python/keras/_impl/keras") +add_python_module("tensorflow/python/keras/_impl/keras/applications") +add_python_module("tensorflow/python/keras/_impl/keras/datasets") +add_python_module("tensorflow/python/keras/_impl/keras/engine") +add_python_module("tensorflow/python/keras/_impl/keras/layers") +add_python_module("tensorflow/python/keras/_impl/keras/preprocessing") +add_python_module("tensorflow/python/keras/_impl/keras/utils") +add_python_module("tensorflow/python/keras/_impl/keras/wrappers") add_python_module("tensorflow/python/kernel_tests") add_python_module("tensorflow/python/kernel_tests/distributions") add_python_module("tensorflow/python/layers") @@ -299,6 +341,9 @@ add_python_module("tensorflow/contrib/distributions/python") add_python_module("tensorflow/contrib/distributions/python/kernel_tests") add_python_module("tensorflow/contrib/distributions/python/ops") add_python_module("tensorflow/contrib/distributions/python/ops/bijectors") +add_python_module("tensorflow/contrib/estimator") +add_python_module("tensorflow/contrib/estimator/python") +add_python_module("tensorflow/contrib/estimator/python/estimator") add_python_module("tensorflow/contrib/factorization") add_python_module("tensorflow/contrib/factorization/examples") add_python_module("tensorflow/contrib/factorization/kernels") @@ -315,6 +360,7 @@ add_python_module("tensorflow/contrib/framework/ops") add_python_module("tensorflow/contrib/framework/python") add_python_module("tensorflow/contrib/framework/python/framework") add_python_module("tensorflow/contrib/framework/python/ops") +add_python_module("tensorflow/contrib/gan") add_python_module("tensorflow/contrib/graph_editor") add_python_module("tensorflow/contrib/graph_editor/examples") add_python_module("tensorflow/contrib/graph_editor/tests") @@ -819,6 +865,7 @@ if(WIN32) $ $ $ + $ $ $ $ @@ -868,6 +915,7 @@ add_library(pywrap_tensorflow_internal SHARED $ $ $ + $ $ $ $ diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 76531add3120bfc9687ba336ceb7d8cb4076945a..9dff8881559afa8b287c16f06b1787e548d7e08b 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -142,6 +142,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/debug/cli/*_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/*_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/wrappers/*_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/*_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py" "${tensorflow_source_dir}/tensorflow/python/meta_graph_transform/*_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/*_test.py" @@ -240,10 +241,13 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/training/quantize_training_test.py" # Needs quantization ops to be included in windows. "${tensorflow_source_dir}/tensorflow/python/training/supervisor_test.py" # Flaky I/O error on rename. "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" # Needs portpicker. + "${tensorflow_source_dir}/tensorflow/python/training/server_lib_test.py" # Test occasionally deadlocks. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops # Broken tensorboard test due to cmake issues. "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561 # tensor_forest tests (also note that we exclude the hybrid tests for now) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py" # Results in wrong order. @@ -291,6 +295,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Failing with TF 1.3 (TODO) "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py" "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py" + # Test should only be run manually + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/reduction_ops_test_big.py" ) endif() list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py index 470661a9b1194404db5fca9c56099ed44d787462..87ba834770d8f707c5364ed7bb8db4aaaa21f286 100644 --- a/tensorflow/contrib/cudnn_rnn/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/__init__.py @@ -14,6 +14,8 @@ # ============================================================================== """Ops for fused Cudnn RNN models. +@@CudnnCompatibleGRUCell +@@CudnnCompatibleLSTMCell @@CudnnGRU @@CudnnLSTM @@CudnnRNNRelu @@ -28,6 +30,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleGRUCell +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRU from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTM @@ -36,9 +40,12 @@ from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNRelu from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable + from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ + "CudnnCompatibleGRUCell", + "CudnnCompatibleLSTMCell", "CudnnGRU", "CudnnLSTM", "CudnnRNNRelu", diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 7794c371e1ccbdfdbc4d73c2c578440045ea7760..f6eeb016755b66a8ac2a4b4e711543ebdf468269 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -63,9 +63,10 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): this cell seamlessly. """ - def __init__(self, num_units): + def __init__(self, num_units, reuse=None): super(CudnnCompatibleLSTMCell, self).__init__( - num_units, forget_bias=0, clip_cell=False, use_peephole=False) + num_units, forget_bias=0, clip_cell=False, use_peephole=False, + reuse=reuse) self._names.update({"scope": "cudnn_compatible_lstm_cell"}) @@ -692,6 +693,7 @@ _cudnn_rnn_common_doc_string = """ canonical format. This is a typical use case: + * The user creates a CudnnRNN model. * The user query that parameter buffer size. * The user creates a variable of that size that serves as the parameter @@ -715,6 +717,482 @@ _cudnn_rnn_common_doc_string = """ """ +def _check_direction(direction): + if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION): + raise ValueError("Invalid direction: %s, expect %s or %s" % + (direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION)) + + +def _check_rnn_mode(rnn_mode): + if rnn_mode not in (CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, CUDNN_RNN_RELU): + raise ValueError("Invalid rnn_mode: %s, expect one of (%s, %s, %s, %s)" % + (rnn_mode, CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, + CUDNN_RNN_RELU)) + + +def _get_seed(seed): + seed, seed2 = random_seed.get_seed(seed) + if seed is None and seed2 is None: + seed, seed2 = 0, 0 + return seed, seed2 + + +def _get_num_params(rnn_mode, num_layers, direction): + """Return num params for given Cudnn config.""" + if rnn_mode == CUDNN_LSTM: + num_params_per_layer = 8 + elif rnn_mode == CUDNN_GRU: + num_params_per_layer = 6 + elif rnn_mode in (CUDNN_RNN_RELU, CUDNN_RNN_TANH): + num_params_per_layer = 2 + else: + raise ValueError("Invalid \'rnn_mode\': %s", rnn_mode) + num_params = num_layers * num_params_per_layer + if direction != CUDNN_RNN_UNIDIRECTION: + num_params *= 2 + return num_params + + +def _cudnn_rnn(inputs, + input_h, + input_c, + params, + is_training, + rnn_mode, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0., + seed=0, + name=None): + """Cudnn RNN. + + Args: + inputs: the input sequence to the RNN model. A Tensor of shape [?, + batch_size, input_size]. + input_h: the initial hidden state for h. A Tensor of shape [num_layers, + batch_size, num_units]. + input_c: the initial hidden state for c. This is only relevant for LSTM. + A Tensor of the same shape as input_h. + params: the parameter buffer created for this model. + is_training: whether this operation will be used in training or inference + rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh'). + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. + 'linear_input' (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior). + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + dropout: whether to enable dropout. With it is 0, dropout is disabled. + seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + for behavior. + name: name of the operation. + Returns: + outputs, output_h, output_c + """ + _check_rnn_mode(rnn_mode) + _check_direction(direction) + seed, seed2 = random_seed.get_seed(seed) + outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn( + input=inputs, + input_h=input_h, + input_c=input_c, + params=params, + is_training=is_training, + rnn_mode=rnn_mode, + input_mode=input_mode, + direction=direction, + dropout=dropout, + seed=seed, + seed2=seed2, + name=name) + return (outputs, output_h, output_c) + + +def cudnn_lstm(inputs, + input_h, + input_c, + params, + is_training, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0., + seed=0, + name=None): + """Cudnn LSTM. + + Args: + inputs: the input sequence to the RNN model. A Tensor of shape [?, + batch_size, input_size]. + input_h: the initial hidden state for h. A Tensor of shape [num_layers, + batch_size, num_units]. + input_c: the initial hidden state for c. This is only relevant for LSTM. + A Tensor of the same shape as input_h. + params: the parameter buffer created for this model. + is_training: whether this operation will be used in training or inference + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. + 'linear_input' (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior). + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + dropout: whether to enable dropout. With it is 0, dropout is disabled. + seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + for behavior. + name: name of the operation. + Returns: + outputs, output_h, output_c + """ + return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM, + input_mode, direction, dropout, seed, name) + + +def _cudnn_rnn_no_input_c(inputs, + input_h, + params, + is_training, + rnn_mode, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0., + seed=0, + name=None): + """Cudnn RNN w/o input_c. + + Args: + inputs: the input sequence to the RNN model. A Tensor of shape [?, + batch_size, input_size]. + input_h: the initial hidden state for h. A Tensor of shape [num_layers, + batch_size, num_units]. + params: the parameter buffer created for this model. + is_training: whether this operation will be used in training or inference + rnn_mode: one of ('lstm', 'gru', 'rnn_relu', 'rnn_tanh'). + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. + 'linear_input' (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior). + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + dropout: whether to enable dropout. With it is 0, dropout is disabled. + seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + for behavior. + name: name of the operation. + Returns: + outputs, output_h + """ + input_c = array_ops.constant([], dtype=input_h.dtype) + outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params, + is_training, rnn_mode, input_mode, + direction, dropout, seed, name) + return outputs, output_h + + +def cudnn_gru(inputs, + input_h, + params, + is_training, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0., + seed=0, + name=None): + """Cudnn GRU. + + Args: + inputs: the input sequence to the RNN model. A Tensor of shape [?, + batch_size, input_size]. + input_h: the initial hidden state for h. A Tensor of shape [num_layers, + batch_size, num_units]. + params: the parameter buffer created for this model. + is_training: whether this operation will be used in training or inference + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. + 'linear_input' (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior). + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + dropout: whether to enable dropout. With it is 0, dropout is disabled. + seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + for behavior. + name: name of the operation. + Returns: + outputs, output_h + """ + return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, CUDNN_GRU, + input_mode, direction, dropout, seed, name) + + +def cudnn_rnn_relu(inputs, + input_h, + params, + is_training, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0., + seed=0, + name=None): + """Cudnn RNN Relu. + + Args: + inputs: the input sequence to the RNN model. A Tensor of shape [?, + batch_size, input_size]. + input_h: the initial hidden state for h. A Tensor of shape [num_layers, + batch_size, num_units]. + params: the parameter buffer created for this model. + is_training: whether this operation will be used in training or inference + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. + 'linear_input' (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior). + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + dropout: whether to enable dropout. With it is 0, dropout is disabled. + seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + for behavior. + name: name of the operation. + Returns: + outputs, output_h + """ + return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, + CUDNN_RNN_RELU, input_mode, direction, dropout, + seed, name) + + +def cudnn_rnn_tanh(inputs, + input_h, + params, + is_training, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0., + seed=0, + name=None): + """Cudnn RNN Tanh. + + Args: + inputs: the input sequence to the RNN model. A Tensor of shape [?, + batch_size, input_size]. + input_h: the initial hidden state for h. A Tensor of shape [num_layers, + batch_size, num_units]. + params: the parameter buffer created for this model. + is_training: whether this operation will be used in training or inference + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. + 'linear_input' (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior). + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + dropout: whether to enable dropout. With it is 0, dropout is disabled. + seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + for behavior. + name: name of the operation. + Returns: + outputs, output_h + """ + return _cudnn_rnn_no_input_c(inputs, input_h, params, is_training, + CUDNN_RNN_TANH, input_mode, direction, dropout, + seed, name) + + +def cudnn_rnn_params_to_canonical(rnn_mode, + num_layers, + num_units, + input_size, + params, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0, + seed=0, + name=None): + """Convert cudnn opaque params to canonical. + + Args: + rnn_mode: a string specifies the mode, under which this RNN model runs. + Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. + num_layers: the number of layers for the RNN model. + num_units: the number of units within the RNN model. + input_size: the size of the input, it could be different from the + num_units. + params: opaque cudnn params var. + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. + 'linear_input' (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior). + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + dropout: whether to enable dropout. With it is 0, dropout is disabled. + seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + for behavior. + name: name of the operation. + Returns: + weights list and bias list + Raises: + ValueError: if rnn_mode or direction is invalid. + """ + + _check_rnn_mode(rnn_mode) + _check_direction(direction) + num_params = _get_num_params(rnn_mode, num_layers, direction) + seed, seed2 = random_seed.get_seed(seed) + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( + rnn_mode=rnn_mode, + num_layers=num_layers, + num_units=num_units, + input_size=input_size, + params=params, + input_mode=input_mode, + direction=direction, + dropout=dropout, + seed=seed, + seed2=seed2, + num_params=num_params, + name=name) + return weights, biases + + +def cudnn_rnn_canonical_to_params(rnn_mode, + num_layers, + num_units, + input_size, + weights, + biases, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0, + seed=0, + name=None): + """Converts params from the canonical format to a specific format of cuDNN. + + Args: + rnn_mode: a string specifies the mode, under which this RNN model runs. + Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. + num_layers: the number of layers for the RNN model. + num_units: the number of units within the RNN model. + input_size: the size of the input, it could be different from the + num_units. + weights: a Tensor for weight parameters. + biases: a Tensor for bias parameters. + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. + 'linear_input' (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior). + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + dropout: whether to enable dropout. With it is 0, dropout is disabled. + seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + for behavior. + name: name of the operation. + Returns: + an opaque Cudnn param. + Raises: + ValueError: if rnn_mode or direction is invalid. + """ + _check_rnn_mode(rnn_mode) + _check_direction(direction) + seed, seed2 = random_seed.get_seed(seed) + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( + rnn_mode=rnn_mode, + num_layers=num_layers, + num_units=num_units, + input_size=input_size, + weights=weights, + biases=biases, + input_mode=input_mode, + direction=direction, + dropout=dropout, + seed=seed, + seed2=seed2, + name=name) + + +def cudnn_opaque_params_size(rnn_mode, + num_layers, + num_units, + input_size, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dtype=dtypes.float32, + dropout=0, + seed=0, + name=None): + """Returns opaque params size for specific Cudnn config. + + Args: + rnn_mode: a string specifies the mode, under which this RNN model runs. + Could be either 'lstm', 'gru', 'rnn_tanh' or 'rnn_relu'. + num_layers: the number of layers for the RNN model. + num_units: the number of units within the RNN model. + input_size: the size of the input, it could be different from the + num_units. + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It could be + 'linear_input', 'skip_input' or 'auto_select'. + 'linear_input' (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior). + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Could be either + 'unidirectional' or 'bidirectional' + dtype: one of tf.float32 or tf.float64. + dropout: whether to enable dropout. With it is 0, dropout is disabled. + seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + for behavior. + name: name of the operation. + Returns: + a int, size of Cudnn opaque params. + Raises: + ValueError: if rnn_mode or direction is invalid. + """ + _check_rnn_mode(rnn_mode) + _check_direction(direction) + seed, seed2 = random_seed.get_seed(seed) + return gen_cudnn_rnn_ops.cudnn_rnn_params_size( + rnn_mode=rnn_mode, + num_layers=num_layers, + num_units=num_units, + input_size=input_size, + T=dtype, + S=dtypes.int32, + dropout=dropout, + seed=seed, + seed2=seed2, + input_mode=input_mode, + direction=direction, + name=name)[0] + + class _CudnnRNN(object): """Creates an RNN model using the underlying Cudnn implementation. @@ -760,9 +1238,6 @@ class _CudnnRNN(object): Raises: ValueError: if direction is invalid. """ - if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION): - raise ValueError("Invalid direction: %s, expect %s or %s", - direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION) self._num_layers = num_layers self._num_units = num_units self._input_size = input_size @@ -771,10 +1246,7 @@ class _CudnnRNN(object): self._direction = direction self._dtype = dtype self._dropout = dropout - # get graph and op seed. - self._seed, self._seed2 = random_seed.get_seed(seed) - if self._seed is None and self._seed2 is None: - self._seed, self._seed2 = 0, 0 + self._seed = seed @property def input_mode(self): @@ -806,18 +1278,16 @@ class _CudnnRNN(object): Returns: The calculated parameter buffer size. """ - return gen_cudnn_rnn_ops.cudnn_rnn_params_size( + return cudnn_opaque_params_size( + rnn_mode=self._rnn_mode, num_layers=self._num_layers, num_units=self._num_units, input_size=self._input_size, - T=self._dtype, - S=dtypes.int32, + dtype=self._dtype, dropout=self._dropout, seed=self._seed, - seed2=self._seed2, - rnn_mode=self._rnn_mode, input_mode=self._input_mode, - direction=self._direction)[0] + direction=self._direction) def __call__(self, input_data, input_h, input_c, params, is_training=True): """Runs the forward step for the RNN model. @@ -836,22 +1306,17 @@ class _CudnnRNN(object): output_h: the final state for h. output_c: the final state for c. This is only relevant for LSTM. """ - if self._rnn_mode != CUDNN_LSTM: - # For model that doesn't take input_c, replace with a dummy tensor. - input_c = array_ops.constant([], dtype=self._dtype) - output, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn( - input=input_data, - input_h=input_h, - input_c=input_c, - params=params, - rnn_mode=self._rnn_mode, + return _cudnn_rnn( + input_data, + input_h, + input_c, + params, + is_training, + self._rnn_mode, input_mode=self._input_mode, direction=self._direction, dropout=self._dropout, - seed=self._seed, - seed2=self._seed2, - is_training=is_training) - return (output, output_h, output_c) + seed=self._seed) def params_to_canonical(self, params): """Converts params from a specific format of cuDNN to the canonical format. @@ -862,22 +1327,16 @@ class _CudnnRNN(object): Returns: A function for the specific-to-canonical conversion. """ - num_params = self._num_layers * self._NUM_PARAMS_PER_LAYER - if self._direction != CUDNN_RNN_UNIDIRECTION: - num_params *= 2 - weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( + return cudnn_rnn_params_to_canonical( + rnn_mode=self._rnn_mode, num_layers=self._num_layers, num_units=self._num_units, input_size=self._input_size, params=params, - dropout=self._dropout, - seed=self._seed, - seed2=self._seed2, - num_params=num_params, - rnn_mode=self._rnn_mode, input_mode=self._input_mode, - direction=self._direction) - return weights, biases + direction=self._direction, + dropout=self._dropout, + seed=self._seed) def canonical_to_params(self, weights, biases): """Converts params from the canonical format to a specific format of cuDNN. @@ -889,18 +1348,17 @@ class _CudnnRNN(object): Returns: A function for the canonical-to-params-to-specific conversion.. """ - return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( + return cudnn_rnn_canonical_to_params( + rnn_mode=self._rnn_mode, num_layers=self._num_layers, num_units=self._num_units, input_size=self._input_size, weights=weights, biases=biases, - dropout=self._dropout, - seed=self._seed, - seed2=self._seed2, - rnn_mode=self._rnn_mode, input_mode=self._input_mode, - direction=self._direction) + direction=self._direction, + dropout=self._dropout, + seed=self._seed) class CudnnLSTM(_CudnnRNN): @@ -1035,9 +1493,16 @@ class _CudnnRNNNoInputC(_CudnnRNN): output: the output sequuence. output_h: the final state for h. """ - output, output_h, _ = super(_CudnnRNNNoInputC, self).__call__( - input_data, input_h, None, params, is_training=is_training) - return (output, output_h) + return _cudnn_rnn_no_input_c( + input_data, + input_h, + params, + is_training, + self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction, + dropout=self._dropout, + seed=self._seed) class CudnnGRU(_CudnnRNNNoInputC): diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 7b916d82c1c02ffab1841e630d12a66359bd2a97..c417650a96f1bf3a642c7a6eb0736ea60d0bf0bf 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -10,6 +10,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:sloppy_ops", "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 5308ab64ace297d89dacddcf17921f428cd72b1d..c74e1369d5d2962090f8e8a762698e3ed136f160 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -22,6 +22,9 @@ @@read_batch_features @@rejection_resample +@@group_by_window +@@sloppy_interleave +@@sloppy_map """ from __future__ import absolute_import @@ -31,11 +34,13 @@ from __future__ import print_function # pylint: disable=unused-import from tensorflow.contrib.data.python.ops.dataset_ops import Dataset from tensorflow.contrib.data.python.ops.dataset_ops import FixedLengthRecordDataset +from tensorflow.contrib.data.python.ops.dataset_ops import group_by_window from tensorflow.contrib.data.python.ops.dataset_ops import Iterator from tensorflow.contrib.data.python.ops.dataset_ops import read_batch_features from tensorflow.contrib.data.python.ops.dataset_ops import rejection_resample from tensorflow.contrib.data.python.ops.dataset_ops import TextLineDataset from tensorflow.contrib.data.python.ops.dataset_ops import TFRecordDataset +from tensorflow.contrib.data.python.ops.sloppy_ops import sloppy_interleave # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index ae87a60c78a4530681fb8ae4706a15614ee2b7d2..2f93c3450274dad32fdddaa0c5631eb14f152f2f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -146,6 +146,25 @@ py_test( ], ) +py_test( + name = "sloppy_transformation_dataset_op_test", + size = "small", + srcs = ["sloppy_transformation_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:sloppy_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + py_test( name = "list_files_dataset_op_test", size = "small", @@ -222,6 +241,21 @@ py_test( ], ) +py_test( + name = "sql_dataset_op_test", + size = "small", + srcs = ["sql_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + py_test( name = "resample_test", size = "medium", diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 71df1ee0a501f16571017dd61e1635a8ae866d07..0111aae1035cdad19300e35a452f255cab76a3fa 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -37,7 +37,9 @@ class GroupByWindowTest(test.TestCase): components = np.random.randint(100, size=(200,)).astype(np.int64) iterator = dataset_ops.Iterator.from_dataset( dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x) - .group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4)) + .apply( + dataset_ops.group_by_window, + args=(lambda x: x % 2, lambda _, xs: xs.batch(4), 4))) init_op = iterator.initializer get_next = iterator.get_next() @@ -61,8 +63,9 @@ class GroupByWindowTest(test.TestCase): components = np.array( [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) iterator = dataset_ops.Iterator.from_dataset( - dataset_ops.Dataset.from_tensor_slices(components).repeat(-1) - .group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)) + dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( + dataset_ops.group_by_window, + args=(lambda x: x % 3, lambda _, xs: xs.batch(4), 4))) init_op = iterator.initializer get_next = iterator.get_next() @@ -81,8 +84,9 @@ class GroupByWindowTest(test.TestCase): def testSmallGroups(self): components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64) iterator = dataset_ops.Iterator.from_dataset( - dataset_ops.Dataset.from_tensor_slices(components) - .group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 4)) + dataset_ops.Dataset.from_tensor_slices(components).apply( + dataset_ops.group_by_window, + args=(lambda x: x % 2, lambda _, xs: xs.batch(4), 4))) init_op = iterator.initializer get_next = iterator.get_next() @@ -108,8 +112,9 @@ class GroupByWindowTest(test.TestCase): iterator = dataset_ops.Iterator.from_dataset( dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: (x, ops.convert_to_tensor([x * x]))) - .group_by_window(lambda x, _: x % 2, reduce_func, 32)) + .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply( + dataset_ops.group_by_window, + args=(lambda x, _: x % 2, reduce_func, 32))) init_op = iterator.initializer get_next = iterator.get_next() @@ -124,17 +129,20 @@ class GroupByWindowTest(test.TestCase): def reduce_func(key, window): # Apply two different kinds of padding to the input: tight # padding, and quantized (to a multiple of 10) padding. - return dataset_ops.Dataset.zip((window.padded_batch( - 4, - padded_shapes=tensor_shape.TensorShape([None])), window.padded_batch( + return dataset_ops.Dataset.zip(( + window.padded_batch( + 4, padded_shapes=tensor_shape.TensorShape([None])), + window.padded_batch( 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),)) iterator = dataset_ops.Iterator.from_dataset( dataset_ops.Dataset.from_tensor_slices(components) .map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x)) - .group_by_window( - lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64), - reduce_func, 4)) + .apply( + dataset_ops.group_by_window, + args= + (lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64), + reduce_func, 4))) init_op = iterator.initializer get_next = iterator.get_next() @@ -151,10 +159,9 @@ class GroupByWindowTest(test.TestCase): self.assertEqual(len(components), sum(counts)) -# NOTE(mrry): These tests are based on the tests in -# bucket_ops_test.py. Currently, different batch sizes for each key -# are not supported, although this would be possible to add to -# `Dataset.group_by_window()`. +# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py. +# Currently, they use a constant batch size, though should be made to use a +# different batch size per key. class BucketTest(test.TestCase): def _dynamicPad(self, bucket, window, window_size): @@ -168,6 +175,7 @@ class BucketTest(test.TestCase): tensor_shape.TensorShape([3]))))) def testSingleBucket(self): + def _map_fn(v): return (v, array_ops.fill([v], v), array_ops.fill([3], string_ops.as_string(v))) @@ -175,9 +183,10 @@ class BucketTest(test.TestCase): input_dataset = ( dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn)) - bucketed_dataset = input_dataset.group_by_window( - lambda x, y, z: 0, lambda k, bucket: self._dynamicPad(k, bucket, 32), - 32) + bucketed_dataset = input_dataset.apply( + dataset_ops.group_by_window, + args=(lambda x, y, z: 0, + lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)) iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset) init_op = iterator.initializer @@ -201,6 +210,7 @@ class BucketTest(test.TestCase): self.assertAllEqual(expected_vec3_str, bucketed_values[2]) def testEvenOddBuckets(self): + def _map_fn(v): return (v, array_ops.fill([v], v), array_ops.fill([3], string_ops.as_string(v))) @@ -208,9 +218,10 @@ class BucketTest(test.TestCase): input_dataset = ( dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn)) - bucketed_dataset = input_dataset.group_by_window( - lambda x, y, z: math_ops.cast(x % 2, dtypes.int64), - lambda k, bucket: self._dynamicPad(k, bucket, 32), 32) + bucketed_dataset = input_dataset.apply( + dataset_ops.group_by_window, + args=(lambda x, y, z: math_ops.cast(x % 2, dtypes.int64), + lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)) iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset) init_op = iterator.initializer @@ -256,25 +267,31 @@ class BucketTest(test.TestCase): self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2]) def testEvenOddBucketsFilterOutAllOdd(self): + def _map_fn(v): - return {"x": v, - "y": array_ops.fill([v], v), - "z": array_ops.fill([3], string_ops.as_string(v))} + return { + "x": v, + "y": array_ops.fill([v], v), + "z": array_ops.fill([3], string_ops.as_string(v)) + } def _dynamic_pad_fn(bucket, window, _): return dataset_ops.Dataset.zip( (dataset_ops.Dataset.from_tensors(bucket), window.padded_batch( - 32, {"x": tensor_shape.TensorShape([]), - "y": tensor_shape.TensorShape([None]), - "z": tensor_shape.TensorShape([3])}))) + 32, { + "x": tensor_shape.TensorShape([]), + "y": tensor_shape.TensorShape([None]), + "z": tensor_shape.TensorShape([3]) + }))) input_dataset = ( dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn) .filter(lambda d: math_ops.equal(d["x"] % 2, 0))) - bucketed_dataset = input_dataset.group_by_window( - lambda d: math_ops.cast(d["x"] % 2, dtypes.int64), - lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32) + bucketed_dataset = input_dataset.apply( + dataset_ops.group_by_window, + args=(lambda d: math_ops.cast(d["x"] % 2, dtypes.int64), + lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32)) iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset) init_op = iterator.initializer @@ -295,6 +312,40 @@ class BucketTest(test.TestCase): self.assertAllEqual( np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"]) + def testDynamicWindowSize(self): + components = np.arange(100).astype(np.int64) + + # Key fn: even/odd + # Reduce fn: batches of 5 + # Window size fn: even=5, odd=10 + + def window_size_func(key): + window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64) + return window_sizes[key] + + dataset = dataset_ops.Dataset.from_tensor_slices(components).apply( + dataset_ops.group_by_window, + args=(lambda x: x % 2, lambda _, xs: xs.batch(20), None, + window_size_func)) + iterator = dataset_ops.Iterator.from_dataset(dataset) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + with self.assertRaises(errors.OutOfRangeError): + batches = 0 + while True: + result = sess.run(get_next) + is_even = all(x % 2 == 0 for x in result) + is_odd = all(x % 2 == 1 for x in result) + self.assertTrue(is_even or is_odd) + expected_batch_size = 5 if is_even else 10 + self.assertEqual(expected_batch_size, result.shape[0]) + batches += 1 + + self.assertEqual(batches, 15) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index b3af84a9536bc7d0b28881f62db9ddd84e131010..4c1496ccf98f7306653a70ded78db501d751ba76 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -16,9 +16,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import namedtuple import os import threading +from collections import namedtuple import numpy as np @@ -383,6 +385,32 @@ class MapDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testCaptureSameResourceMultipleTimes(self): + elements = np.random.randint(100, size=[200]) + queue = data_flow_ops.FIFOQueue( + 200, dtypes.int64, shapes=[], shared_name="shared_queue") + queue_2 = data_flow_ops.FIFOQueue( + 200, dtypes.int64, shapes=[], shared_name="shared_queue") + + enqueue_op = queue.enqueue_many(elements) + close_op = queue.close() + + iterator = (dataset_ops.Dataset.from_tensors(0).repeat(-1) + .map(lambda _: (queue.dequeue(), queue_2.dequeue())) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(enqueue_op) + sess.run(close_op) + sess.run(init_op) + for i in range(100): + self.assertEqual(sorted([elements[i * 2], elements[i * 2 + 1]]), + sorted(sess.run(get_next))) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testCaptureVariable(self): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) @@ -455,6 +483,40 @@ class MapDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testMapNamedtuple(self, count=10): + # construct dataset of tuples + labels = dataset_ops.Dataset.range(count) + images = labels.map(lambda l: -l) + dataset_tuple = dataset_ops.Dataset.zip((labels, images)) + + # convert dataset of tuples to dataset of namedtuples + example = namedtuple("Example", ["label", "image"]) + dataset_namedtuple = dataset_tuple.map(example) + + def preprocess_tuple(label, image): + image = 2 * image + return label, image + + def preprocess_namedtuple(example): + return example._replace(image=2 * example.image) + + # preprocess both datasets + dataset_tuple = dataset_tuple.map(preprocess_tuple) + dataset_namedtuple = dataset_namedtuple.map(preprocess_namedtuple) + + next_tuple = dataset_tuple.make_one_shot_iterator().get_next() + next_namedtuple = dataset_namedtuple.make_one_shot_iterator().get_next() + + # make sure both datasets contain the same data + with self.test_session() as sess: + for i in range(count): + tuple_, namedtuple_ = sess.run([next_tuple, next_namedtuple]) + self.assertEqual(tuple_, namedtuple_) + self.assertEqual(tuple_, (i, -2 * i)) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_namedtuple) + def testUseStepContainerInMap(self): row = np.arange(6) iterator = ( diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index a8edbbd20c87fb5a6a49ee2b6345713c0f27c174..87bab43ccf508241702807897b4326c105ed8651 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -17,17 +17,29 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile from tensorflow.python.platform import test class RangeDatasetTest(test.TestCase): + def tearDown(self): + # Remove all checkpoint files. + prefix = self._iterator_checkpoint_prefix() + pattern = prefix + "*" + files = gfile.Glob(pattern) + map(gfile.Remove, files) + def testStop(self): stop = array_ops.placeholder(dtypes.int64, shape=[]) iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator() @@ -175,6 +187,196 @@ class RangeDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def _iterator_checkpoint_prefix(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def testSaveRestore(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + path = self._iterator_checkpoint_prefix() + save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) + restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, + path) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + break_point = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Saving and restoring in same session. + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testMultipleSaves(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + path = self._iterator_checkpoint_prefix() + save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) + restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, + path) + return init_op, get_next, save_op, restore_op + + start = 2 + stop = 10 + break_point1 = 5 + break_point2 = 7 + + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point1): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for i in range(break_point1, break_point2): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + break_point2 = 7 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for i in range(break_point2, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSaveRestoreWithRepeat(self): + + def _build_graph(start, stop, num_epochs): + iterator = dataset_ops.Dataset.range( + start, stop).repeat(num_epochs).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + path = self._iterator_checkpoint_prefix() + save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) + restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, + path) + return init_op, get_next, save_op, restore_op + + start = 2 + stop = 10 + num_epochs = 5 + break_range = 5 + break_epoch = 3 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph( + start, stop, num_epochs) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for _ in range(break_epoch - 1): + for i in range(start, stop): + self.assertEqual(i, sess.run(get_next)) + for i in range(start, break_range): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for i in range(break_range, stop): + self.assertEqual(i, sess.run(get_next)) + for _ in range(break_epoch, num_epochs): + for i in range(start, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSaveRestoreExhaustedIterator(self): + + def _build_graph(start, stop, num_epochs): + iterator = dataset_ops.Dataset.range( + start, stop).repeat(num_epochs).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + path = self._iterator_checkpoint_prefix() + save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) + restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, + path) + return init_op, get_next, save_op, restore_op + + start = 2 + stop = 10 + num_epochs = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph( + start, stop, num_epochs) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for _ in range(num_epochs): + for i in range(start, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 72a3ff17898cbc6aeabb0c32099a4615bdb13028..d631fbc76e3b24e0b121d12451e25a294e765324 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -255,7 +256,6 @@ class FixedLengthRecordReaderTest(test.TestCase): def testFixedLengthRecordDatasetBuffering(self): test_filenames = self._createFiles() - dataset = dataset_ops.FixedLengthRecordDataset( test_filenames, self._record_bytes, @@ -271,6 +271,124 @@ class FixedLengthRecordReaderTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) + def _build_iterator_graph(self, num_epochs): + filenames = self._createFiles() + path = os.path.join(self.get_temp_dir(), "iterator") + dataset = (dataset_ops.FixedLengthRecordDataset( + filenames, self._record_bytes, self._header_bytes, self._footer_bytes) + .repeat(num_epochs)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next_op = iterator.get_next() + save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) + restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, + path) + return init_op, get_next_op, save_op, restore_op + + def testSaveRestore(self): + num_epochs = 10 + epoch_break = 5 + file_break = self._num_files // 2 + record_break = self._num_records // 2 + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch == epoch_break and f == file_break and + r == record_break): + sess.run(save_op) + break + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + else: + continue + break + else: + continue + break + else: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch < epoch_break or + (epoch == epoch_break and f < file_break) or + (epoch == epoch_break and f == file_break and + r < record_break)): + continue + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testRestoreUnusedIterator(self): + num_epochs = 10 + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + # Save unused iterator. + sess.run(save_op) + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for _ in range(num_epochs * self._num_files * self._num_records): + sess.run(get_next_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testRestoreExhaustedIterator(self): + num_epochs = 10 + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for _ in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + class TFRecordDatasetTest(test.TestCase): @@ -558,8 +676,8 @@ class ReadBatchFeaturesTest(test.TestCase): def testRead(self): for batch_size in [1, 2]: for num_epochs in [1, 10]: - with ops.Graph().as_default(): - with self.test_session(graph=ops.get_default_graph()) as sess: + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: # Basic test: read from file 0. self.outputs = self._read_batch_features( filenames=self.test_filenames[0], @@ -569,8 +687,8 @@ class ReadBatchFeaturesTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) - with ops.Graph().as_default(): - with self.test_session(graph=ops.get_default_graph()) as sess: + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: # Basic test: read from file 1. self.outputs = self._read_batch_features( filenames=self.test_filenames[1], @@ -580,8 +698,8 @@ class ReadBatchFeaturesTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) - with ops.Graph().as_default(): - with self.test_session(graph=ops.get_default_graph()) as sess: + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: # Basic test: read from both files. self.outputs = self._read_batch_features( filenames=self.test_filenames, diff --git a/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f9198bacfbdd08317db47e4f0a5bce60d0691cb2 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py @@ -0,0 +1,475 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import math +import threading +import time + +from six.moves import zip_longest + +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import sloppy_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import script_ops +from tensorflow.python.platform import test + + +class SloppyInterleaveDatasetTest(test.TestCase): + + def setUp(self): + self.input_values = array_ops.placeholder(dtypes.int64, shape=[None]) + self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) + self.block_length = array_ops.placeholder(dtypes.int64, shape=[]) + + self.repeat_count = 2 + + # Set up threading events used to sequence when items are produced that + # are subsequently interleaved. These events allow us to deterministically + # simulate slowdowns and force sloppiness. + self.read_coordination_events = {} + self.write_coordination_events = {} + # input values [4, 5, 6] are the common case for the tests; set defaults + for i in range(4, 7): + self.read_coordination_events[i] = threading.Semaphore(0) + self.write_coordination_events[i] = threading.Event() + + def map_py_fn(x): + self.write_coordination_events[x].wait() + self.write_coordination_events[x].clear() + self.read_coordination_events[x].release() + return x * x + + def map_fn(x): + return script_ops.py_func(map_py_fn, [x], x.dtype) + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + dataset = dataset.repeat(x) + return dataset.map(map_fn) + + self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values) + .repeat(self.repeat_count).apply( + sloppy_ops.sloppy_interleave, + args=(interleave_fn, self.cycle_length, + self.block_length))) + self.iterator = self.dataset.make_initializable_iterator() + self.init_op = self.iterator.initializer + self.next_element = self.iterator.get_next() + + def _interleave(self, lists, cycle_length, block_length): + """Python implementation of interleave used for testing.""" + num_open = 0 + + # `all_iterators` acts as a queue of iterators over each element of `lists`. + all_iterators = [iter(l) for l in lists] + + # `open_iterators` are the iterators whose elements are currently being + # interleaved. + open_iterators = [] + for i in range(cycle_length): + if all_iterators: + open_iterators.append(all_iterators.pop(0)) + num_open += 1 + else: + open_iterators.append(None) + + while num_open or all_iterators: + for i in range(cycle_length): + if open_iterators[i] is None: + if all_iterators: + open_iterators[i] = all_iterators.pop(0) + num_open += 1 + else: + continue + for _ in range(block_length): + try: + yield next(open_iterators[i]) + except StopIteration: + open_iterators[i] = None + num_open -= 1 + break + + def testPythonImplementation(self): + input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], + [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] + + # Cycle length 1 acts like `Dataset.flat_map()`. + expected_elements = itertools.chain(*input_lists) + for expected, produced in zip(expected_elements, + self._interleave(input_lists, 1, 1)): + self.assertEqual(expected, produced) + + # Cycle length > 1. + expected_elements = [ + 4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, + 6, 5, 6, 5, 6, 6 + ] + for index, (expected, produced) in enumerate( + zip_longest(expected_elements, self._interleave(input_lists, 2, 1))): + self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % + (index, expected, produced)) + + def testPythonImplementationBlockLength(self): + input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2 + expected_elements = [ + 4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6, 4, 4, 6, 6, 5, 5, 6, 6, 5, + 5, 6, 6, 5, 6, 6 + ] + for index, (expected, produced) in enumerate( + zip_longest(expected_elements, self._interleave(input_lists, 2, 2))): + self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % + (index, expected, produced)) + + def testPythonImplementationEmptyLists(self): + input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [], + [6, 6, 6, 6, 6, 6]] + + expected_elements = [ + 4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6 + ] + for index, (expected, produced) in enumerate( + zip_longest(expected_elements, self._interleave(input_lists, 2, 1))): + self.assertEqual(expected, produced, "Values differ at %s. %s != %s" % + (index, expected, produced)) + + def _clear_coordination_events(self): + for i in range(4, 7): + self.read_coordination_events[i] = threading.Semaphore(0) + self.write_coordination_events[i].clear() + + def _allow_all_map_threads(self): + for i in range(4, 7): + self.write_coordination_events[i].set() + + def testSingleThreaded(self): + # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and + # `Dataset.flat_map()` and is single-threaded. No synchronization required. + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 1, + self.block_length: 1 + }) + + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 1): + self.write_coordination_events[expected_element].set() + self.assertEqual(expected_element * expected_element, + sess.run(self.next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContention(self): + # num_threads > 1. + # Explicit coordination should result in `Dataset.interleave()` behavior + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1 + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 1)): + self.write_coordination_events[expected_element].set() + if done_first_event: # First event starts the worker threads. + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event: + self.read_coordination_events[expected_element].acquire() + done_first_event = True + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContentionWithRaces(self): + """Tests where all the workers race in producing elements. + + Note: this is in contrast with the prevous test which carefully sequences + the execution of the map functions. + """ + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1 + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 1)): + if done_first_event: # First event starts the worker threads. + self._allow_all_map_threads() + self.read_coordination_events[expected_element].acquire() + else: + self.write_coordination_events[expected_element].set() + time.sleep(0.01) # Sleep to consistently "avoid" the race condition. + actual_element = sess.run(self.next_element) + if not done_first_event: + done_first_event = True + self.assertTrue( + self.read_coordination_events[expected_element].acquire(False)) + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContentionBlockLength(self): + # num_threads > 1. + # Explicit coordination should result in `Dataset.interleave()` behavior + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 2 + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 2)): + self.write_coordination_events[expected_element].set() + if done_first_event: # First event starts the worker threads. + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event: + done_first_event = True + self.read_coordination_events[expected_element].acquire() + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testTwoThreadsNoContentionWithRacesAndBlocking(self): + """Tests where all the workers race in producing elements. + + Note: this is in contrast with the prevous test which carefully sequences + the execution of the map functions. + """ + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 2 + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 2)): + if done_first_event: # First event starts the worker threads. + self._allow_all_map_threads() + self.read_coordination_events[expected_element].acquire() + else: + self.write_coordination_events[expected_element].set() + time.sleep(0.01) # Sleep to consistently "avoid" the race condition. + actual_element = sess.run(self.next_element) + if not done_first_event: + done_first_event = True + self.assertTrue( + self.read_coordination_events[expected_element].acquire(False)) + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testEmptyInput(self): + with self.test_session() as sess: + # Empty input. + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [], + self.cycle_length: 2, + self.block_length: 3 + }) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testNonEmptyInputIntoEmptyOutputs(self): + # Non-empty input leading to empty output. + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [0, 0, 0], + self.cycle_length: 2, + self.block_length: 3 + }) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testPartiallyEmptyOutputs(self): + # Mixture of non-empty and empty interleaved datasets. + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 0, 6], + self.cycle_length: 2, + self.block_length: 1 + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)): + self.write_coordination_events[expected_element].set() + if done_first_event: # First event starts the worker threads + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event: + done_first_event = True + self.read_coordination_events[expected_element].acquire() + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testDelayedOutput(self): + # Explicitly control the sequence of events to ensure we correctly avoid + # head-of-line blocking. + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1 + }) + + mis_ordering = [ + 4, 4, 5, 4, 5, 5, 4, 5, 6, 6, 6, 5, 4, 4, 6, 6, 4, 4, 6, 5, 6, 6, 6, + 6, 5, 5, 5, 5, 6, 6 + ] + for element in mis_ordering: + self.write_coordination_events[element].set() + self.assertEqual(element * element, sess.run(self.next_element)) + self.assertTrue(self.read_coordination_events[element].acquire(False)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testBlockLengthWithContention(self): + with self.test_session() as sess: + self._clear_coordination_events() + done_first_event = False + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 3 + }) + # Test against a generating sequence that differs from the uncontended + # case, in order to prove sloppy correctness. + for i, expected_element in enumerate( + self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, + cycle_length=2, + block_length=2)): + self.write_coordination_events[expected_element].set() + if done_first_event: # First event starts the worker threads. + self.read_coordination_events[expected_element].acquire() + actual_element = sess.run(self.next_element) + if not done_first_event: + self.read_coordination_events[expected_element].acquire() + done_first_event = True + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testEarlyExit(self): + # Exiting without consuming all input should not block + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 3, + self.block_length: 2 + }) + for i in range(4, 7): + self.write_coordination_events[i].set() + elem = sess.run(self.next_element) # Start all workers + # Allow the one successful worker to progress beyond the py_func again. + elem = int(math.sqrt(elem)) + self.write_coordination_events[elem].set() + self.read_coordination_events[elem].acquire() + # Allow the prefetch to succeed + for i in range(4, 7): + self.read_coordination_events[i].acquire() + self.write_coordination_events[i].set() + + def testTooManyReaders(self): + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + dataset = dataset.repeat(math_ops.cast(x, dtype=dtypes.int64)) + return dataset + + dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6]) + dataset = dataset.repeat(self.repeat_count) + dataset = dataset.apply( + sloppy_ops.sloppy_interleave, + args=(interleave_fn,), + kwargs={"cycle_length": 16, + "block_length": 2}) + iterator = dataset.make_one_shot_iterator() + + with self.test_session() as sess: + output_values = [] + for _ in range(30): + output_values.append(sess.run(iterator.get_next())) + + expected_values = self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2) + self.assertItemsEqual(output_values, expected_values) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e520bc05d643e929f8f13945487a4fbe20e4ee12 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -0,0 +1,188 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for experimental sql input op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sqlite3 + +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class SqlDatasetTest(test.TestCase): + + def setUp(self): + self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + self.driver_name = array_ops.placeholder(dtypes.string, shape=[]) + self.query = array_ops.placeholder(dtypes.string, shape=[]) + self.output_types = (dtypes.string, dtypes.string, dtypes.string) + + conn = sqlite3.connect(self.data_source_name) + c = conn.cursor() + c.execute("DROP TABLE IF EXISTS students") + c.execute("DROP TABLE IF EXISTS people") + c.execute( + "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY," + " first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100))") + c.execute( + "INSERT INTO students (first_name, last_name, motto) VALUES ('John', " + "'Doe', 'Hi!'), ('Apple', 'Orange', 'Hi again!')") + c.execute( + "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, " + "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))") + c.execute( + "INSERT INTO people (first_name, last_name, state) VALUES ('Benjamin'," + " 'Franklin', 'Pennsylvania'), ('John', 'Doe', 'California')") + conn.commit() + conn.close() + + dataset = dataset_ops.SqlDataset(self.driver_name, self.data_source_name, + self.query, self.output_types).repeat(2) + iterator = dataset.make_initializable_iterator() + self.init_op = iterator.initializer + self.get_next = iterator.get_next() + + # Test that SqlDataset can read from a database table. + def testReadResultSet(self): + with self.test_session() as sess: + for _ in range(2): # Run twice to verify statelessness of db operations. + sess.run( + self.init_op, + feed_dict={ + self.driver_name: "sqlite", + self.query: "SELECT first_name, last_name, motto FROM students " + "ORDER BY first_name DESC" + }) + for _ in range(2): # Dataset is repeated. See setUp. + self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(self.get_next)) + self.assertEqual((b"Apple", b"Orange", b"Hi again!"), + sess.run(self.get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.get_next) + + # Test that SqlDataset works on a join query. + def testReadResultSetJoinQuery(self): + with self.test_session() as sess: + sess.run( + self.init_op, + feed_dict={ + self.driver_name: "sqlite", + self.query: + "SELECT students.first_name, state, motto FROM students " + "INNER JOIN people " + "ON students.first_name = people.first_name " + "AND students.last_name = people.last_name" + }) + for _ in range(2): + self.assertEqual((b"John", b"California", b"Hi!"), + sess.run(self.get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.get_next) + + # Test that an `OutOfRangeError` is raised on the first call to `get_next` + # if result set is empty. + def testReadEmptyResultSet(self): + with self.test_session() as sess: + sess.run( + self.init_op, + feed_dict={ + self.driver_name: "sqlite", + self.query: "SELECT first_name, last_name, motto FROM students " + "WHERE first_name = 'Nonexistent'" + }) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.get_next) + + # Test that an error is raised when `driver_name` is invalid. + def testReadResultSetWithInvalidDriverName(self): + with self.test_session() as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run( + self.init_op, + feed_dict={ + self.driver_name: "sqlfake", + self.query: "SELECT first_name, last_name, motto FROM students " + "ORDER BY first_name DESC" + }) + + # Test that an error is raised when a column name in `query` is nonexistent + def testReadResultSetWithInvalidColumnName(self): + with self.test_session() as sess: + sess.run( + self.init_op, + feed_dict={ + self.driver_name: "sqlite", + self.query: + "SELECT first_name, last_name, fake_column FROM students " + "ORDER BY first_name DESC" + }) + with self.assertRaises(errors.UnknownError): + sess.run(self.get_next) + + # Test that an error is raised when there is a syntax error in `query`. + def testReadResultSetOfQueryWithSyntaxError(self): + with self.test_session() as sess: + sess.run( + self.init_op, + feed_dict={ + self.driver_name: "sqlite", + self.query: + "SELEmispellECT first_name, last_name, motto FROM students " + "ORDER BY first_name DESC" + }) + with self.assertRaises(errors.UnknownError): + sess.run(self.get_next) + + # Test that an error is raised when the number of columns in `query` + # does not match the length of `output_types`. + def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self): + with self.test_session() as sess: + sess.run( + self.init_op, + feed_dict={ + self.driver_name: "sqlite", + self.query: "SELECT first_name, last_name FROM students " + "ORDER BY first_name DESC" + }) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(self.get_next) + + # Test that no results are returned when `query` is an insert query rather + # than a select query. In particular, the error refers to the number of + # output types passed to the op not matching the number of columns in the + # result set of the query (namely, 0 for an insert statement.) + def testReadResultSetOfInsertQuery(self): + with self.test_session() as sess: + sess.run( + self.init_op, + feed_dict={ + self.driver_name: "sqlite", + self.query: + "INSERT INTO students (first_name, last_name, motto) " + "VALUES ('Foo', 'Bar', 'Baz'), ('Fizz', 'Buzz', 'Fizzbuzz')" + }) + with self.assertRaises(errors.InvalidArgumentError): + sess.run(self.get_next) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 8afd122d82d6845c88b1de9ec38cd360d2f6a11f..94969c1c704bc3aaf984a4f397624b5b45b5a473 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -32,6 +32,21 @@ py_library( ], ) +py_library( + name = "sloppy_ops", + srcs = ["sloppy_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_ops", + "//tensorflow/contrib/data/python/framework:function", + "//tensorflow/contrib/data/python/util:nest", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index aadcf77f3c393a23f4e2ceda077655ad902ca4bc..0ee9acfc97f839642ac84e9060dee678aa4f400e 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -595,6 +595,23 @@ class Dataset(object): The elements generated by `generator` must be compatible with the given `output_types` and (optional) `output_shapes` arguments. + For example: + + ```python + import itertools + + def gen(): + for i in itertools.count(1): + yield (i, [1] * i) + + ds = Dataset.from_generator( + gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None]))) + value = ds.make_one_shot_iterator().get_next() + + sess.run(value) # (1, array([1])) + sess.run(value) # (2, array([1, 1])) + ``` + Args: generator: A callable object that takes no arguments and returns an object that supports the `iter()` protocol. @@ -1017,7 +1034,7 @@ class Dataset(object): d = d.shard(FLAGS.num_workers, FLAGS.worker_index) d = d.repeat(FLAGS.num_epochs) d = d.shuffle(FLAGS.shuffle_buffer_size) - d = d.map(parser_fn, num_threads=FLAGS.num_map_threads) + d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) ``` Important caveats: @@ -1037,7 +1054,7 @@ class Dataset(object): d = d.repeat() d = d.interleave(tf.contrib.data.TFRecordDataset, cycle_length=FLAGS.num_readers, block_length=1) - d = d.map(parser_fn, num_threads=FLAGS.num_map_threads) + d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) ``` Args: @@ -1182,51 +1199,39 @@ class Dataset(object): return DenseToSparseBatchDataset(self, batch_size, row_shape) def group_by_window(self, key_func, reduce_func, window_size): - """Performs a windowed "group-by" operation on this dataset. - - This method maps each consecutive element in this dataset to a key - using `key_func` and groups the elements by key. It then applies - `reduce_func` to at most `window_size` elements matching the same - key. All execpt the final window for each key will contain - `window_size` elements; the final window may be smaller. - - Args: - key_func: A function mapping a nested structure of tensors - (having shapes and types defined by `self.output_shapes` and - `self.output_types`) to a scalar `tf.int64` tensor. - reduce_func: A function mapping a key and a dataset of up to `batch_size` - consecutive elements matching that key to another dataset. - window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - consecutive elements matching the same key to combine in a single - batch, which will be passed to `reduce_func`. - - Returns: - A `Dataset`. - """ - return GroupByWindowDataset(self, key_func, reduce_func, window_size) - - def map(self, map_func, num_threads=None, output_buffer_size=None): + """See group_by_window().""" + return self.apply( + group_by_window, args=(key_func, reduce_func, window_size)) + + def map(self, + map_func, + num_threads=None, + output_buffer_size=None, + num_parallel_calls=None): """Maps `map_func` across this datset. Args: map_func: A function mapping a nested structure of tensors (having shapes and types defined by `self.output_shapes` and `self.output_types`) to another nested structure of tensors. - num_threads: (Optional.) A `tf.int32` scalar `tf.Tensor`, representing - the number of threads to use for processing elements in parallel. If - not specified, elements will be processed sequentially without - buffering. + num_threads: (Optional.) Deprecated, use `num_parallel_calls` instead. output_buffer_size: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the maximum number of processed elements that will be - buffered when processing in parallel. + buffered. + num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, + representing the number elements to process in parallel. If not + specified, elements will be processed sequentially. Returns: A `Dataset`. """ - if num_threads is None: + if num_threads is None and num_parallel_calls is None: ret = MapDataset(self, map_func) else: - ret = ParallelMapDataset(self, map_func, num_threads) + if num_threads is None: + ret = ParallelMapDataset(self, map_func, num_parallel_calls) + else: + ret = ParallelMapDataset(self, map_func, num_threads) if output_buffer_size is not None: ret = ret.prefetch(output_buffer_size) return ret @@ -1255,8 +1260,8 @@ class Dataset(object): # each file. filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ..."] dataset = (Dataset.from_tensor_slices(filenames) - .interleave( - lambda x: TextLineDataset(x).map(parse_fn, num_threads=1), + .interleave(lambda x: + TextLineDataset(x).map(parse_fn, num_parallel_calls=1), cycle_length=4, block_length=16)) ``` @@ -1346,6 +1351,43 @@ class Dataset(object): """ return FilterDataset(self, predicate) + def apply(self, fn, args=(), kwargs={}): # pylint: disable=dangerous-default-value + """Apply a function to this dataset. + + `apply` enables chaining of custom `Dataset` transformations. + + For example: + + ``` + dataset.map( + lambda x: x**2 + ).apply( + group_by_window, args=(key_func, reduce_func, window_size) + ).map( + lambda x: x**3 + ) + ``` + + Args: + fn: A function that takes a `Dataset`, `args`, and `kwargs`, and + returns a `Dataset`. + args: A `tuple` or `list` of arguments to be passed to `fn`. + kwargs: A `dict` of keyword arguments to be passed to `fn`. + + Returns: + The `Dataset` returned by `fn`. + """ + if not (isinstance(args, tuple) or isinstance(args, list)): + raise TypeError("args must be a tuple or list.") + if not isinstance(kwargs, dict): + raise TypeError("kwargs must be a dict.") + + dataset = fn(self, *args, **kwargs) + + if not isinstance(dataset, Dataset): + raise TypeError("fn must return a Dataset.") + return dataset + class TensorDataset(Dataset): """A `Dataset` with a single element, viz. a nested structure of tensors.""" @@ -1879,7 +1921,7 @@ class DenseToSparseBatchDataset(Dataset): def _should_unpack_args(args): """Returns `True` if `args` should be `*args` when passed to a callable.""" - return nest.is_sequence(args) and not isinstance(args, dict) + return type(args) is tuple # pylint: disable=unidiomatic-typecheck class _ResourceDataset(Dataset): @@ -1903,71 +1945,6 @@ class _ResourceDataset(Dataset): return self._output_types -class GroupByWindowDataset(Dataset): - """A `Dataset` that groups its input and performs a windowed reduction.""" - - def __init__(self, input_dataset, key_func, reduce_func, window_size): - """See `Dataset.group_by_window()` for details.""" - super(GroupByWindowDataset, self).__init__() - self._input_dataset = input_dataset - self._window_size = window_size - - @function.Defun(*nest.flatten(input_dataset.output_types)) - def tf_key_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): - arg.set_shape(shape) - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - if _should_unpack_args(nested_args): - ret = key_func(*nested_args) - else: - ret = key_func(nested_args) - ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) - if ret.dtype != dtypes.int64: - raise ValueError("`key_func` must return a single tf.int64 tensor.") - return ret - - self._key_func = tf_key_func - self._key_func.add_to_graph(ops.get_default_graph()) - - @function.Defun(dtypes.int64, dtypes.resource) - def tf_reduce_func(key, window_dataset_resource): - """A wrapper for Defun that facilitates shape inference.""" - key.set_shape([]) - window_dataset = _ResourceDataset(window_dataset_resource, - input_dataset.output_types, - input_dataset.output_shapes) - output_dataset = reduce_func(key, window_dataset) - if not isinstance(output_dataset, Dataset): - raise TypeError("`reduce_func` must return a `Dataset` object.") - self._output_types = output_dataset.output_types - self._output_shapes = output_dataset.output_shapes - return output_dataset.make_dataset_resource() - - self._reduce_func = tf_reduce_func - self._reduce_func.add_to_graph(ops.get_default_graph()) - - def make_dataset_resource(self): - return gen_dataset_ops.group_by_window_dataset( - self._input_dataset.make_dataset_resource(), - self._key_func.captured_inputs, - self._reduce_func.captured_inputs, - self._window_size, - key_func=self._key_func, - reduce_func=self._reduce_func, - output_types=nest.flatten(self.output_types), - output_shapes=nest.flatten(self.output_shapes)) - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - - class MapDataset(Dataset): """A `Dataset` that maps a function over elements in its input.""" @@ -2110,7 +2087,8 @@ class FlatMapDataset(Dataset): class InterleaveDataset(Dataset): - """A `Dataset` that maps a function over its input and flattens the result.""" + """A `Dataset` that maps a function over its input and interleaves the result. + """ def __init__(self, input_dataset, map_func, cycle_length, block_length): """See `Dataset.interleave()` for details.""" @@ -2126,7 +2104,7 @@ class InterleaveDataset(Dataset): nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - if nest.is_sequence(nested_args): + if _should_unpack_args(nested_args): dataset = map_func(*nested_args) else: dataset = map_func(nested_args) @@ -2292,6 +2270,46 @@ class TextLineDataset(Dataset): return dtypes.string +class SqlDataset(Dataset): + """A `Dataset` consisting of the results from a SQL query.""" + + def __init__(self, driver_name, data_source_name, query, output_types): + """Creates a `SqlDataset`. + + Args: + driver_name: A 0-D `tf.string` tensor containing the database type. + Currently, the only supported value is 'sqlite'. + data_source_name: A 0-D `tf.string` tensor containing a connection string + to connect to the database. + query: A 0-D `tf.string` tensor containing the SQL query to execute. + output_types: A tuple of `tf.DType` objects representing the types of the + columns returned by `query`. + """ + super(SqlDataset, self).__init__() + self._driver_name = ops.convert_to_tensor( + driver_name, dtype=dtypes.string, name="driver_name") + self._data_source_name = ops.convert_to_tensor( + data_source_name, dtype=dtypes.string, name="data_source_name") + self._query = ops.convert_to_tensor( + query, dtype=dtypes.string, name="query") + self._output_types = output_types + + def make_dataset_resource(self): + return gen_dataset_ops.sql_dataset(self._driver_name, + self._data_source_name, self._query, + nest.flatten(self.output_types), + nest.flatten(self.output_shapes)) + + @property + def output_shapes(self): + return nest.map_structure(lambda _: tensor_shape.TensorShape([]), + self._output_types) + + @property + def output_types(self): + return self._output_types + + class TFRecordDataset(Dataset): """A `Dataset` comprising records from one or more TFRecord files.""" @@ -2395,7 +2413,7 @@ def rejection_resample(dataset, shapes and types defined by `dataset.output_shapes` and `dataset.output_types`) to a scalar `tf.int32` tensor. Values should be in `[0, num_classes)`. - target_dist: A floating point type tensor, shaped `[num_classes]. + target_dist: A floating point type tensor, shaped `[num_classes]`. initial_dist: (Optional.) A floating point type tensor, shaped `[num_classes]`. If not provided, the true class distribution is estimated live in a streaming fashion. @@ -2595,3 +2613,149 @@ def _get_file_names(file_pattern, randomize_input): if not randomize_input: file_names = sorted(file_names) return file_names + + +class GroupByWindowDataset(Dataset): + """A `Dataset` that groups its input and performs a windowed reduction.""" + + def __init__(self, input_dataset, key_func, reduce_func, window_size_func): + """See `group_by_window()` for details.""" + super(GroupByWindowDataset, self).__init__() + + self._input_dataset = input_dataset + + self._make_key_func(key_func, input_dataset) + self._make_reduce_func(reduce_func, input_dataset) + self._make_window_size_func(window_size_func) + + def _make_window_size_func(self, window_size_func): + """Make wrapping Defun for window_size_func.""" + + @function.Defun(dtypes.int64) + def tf_window_size_func(key): + key.set_shape([]) + window_size = ops.convert_to_tensor( + window_size_func(key), dtype=dtypes.int64) + if window_size.dtype != dtypes.int64: + raise ValueError( + "`window_size_func` must return a single tf.int64 tensor.") + return window_size + + self._window_size_func = tf_window_size_func + self._window_size_func.add_to_graph(ops.get_default_graph()) + + def _make_key_func(self, key_func, input_dataset): + """Make wrapping Defun for key_func.""" + + @function.Defun(*nest.flatten(input_dataset.output_types)) + def tf_key_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + # Pass in shape information from the input_dataset. + for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): + arg.set_shape(shape) + nested_args = nest.pack_sequence_as(input_dataset.output_types, args) + if _should_unpack_args(nested_args): + ret = key_func(*nested_args) + else: + ret = key_func(nested_args) + ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) + if ret.dtype != dtypes.int64: + raise ValueError("`key_func` must return a single tf.int64 tensor.") + return ret + + self._key_func = tf_key_func + self._key_func.add_to_graph(ops.get_default_graph()) + + def _make_reduce_func(self, reduce_func, input_dataset): + """Make wrapping Defun for reduce_func.""" + + @function.Defun(dtypes.int64, dtypes.resource) + def tf_reduce_func(key, window_dataset_resource): + """A wrapper for Defun that facilitates shape inference.""" + key.set_shape([]) + window_dataset = _ResourceDataset(window_dataset_resource, + input_dataset.output_types, + input_dataset.output_shapes) + output_dataset = reduce_func(key, window_dataset) + if not isinstance(output_dataset, Dataset): + raise TypeError("`reduce_func` must return a `Dataset` object.") + self._output_types = output_dataset.output_types + self._output_shapes = output_dataset.output_shapes + return output_dataset.make_dataset_resource() + + self._reduce_func = tf_reduce_func + self._reduce_func.add_to_graph(ops.get_default_graph()) + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + def make_dataset_resource(self): + return gen_dataset_ops.group_by_window_dataset( + self._input_dataset.make_dataset_resource(), + self._key_func.captured_inputs, + self._reduce_func.captured_inputs, + self._window_size_func.captured_inputs, + key_func=self._key_func, + reduce_func=self._reduce_func, + window_size_func=self._window_size_func, + output_types=nest.flatten(self.output_types), + output_shapes=nest.flatten(self.output_shapes)) + + +def group_by_window(dataset, + key_func, + reduce_func, + window_size=None, + window_size_func=None): + """Performs a windowed "group-by" operation on this dataset. + + This method maps each consecutive element in this dataset to a key + using `key_func` and groups the elements by key. It then applies + `reduce_func` to at most `window_size_func(key)` elements matching the same + key. All execpt the final window for each key will contain + `window_size_func(key)` elements; the final window may be smaller. + + You may provide either a constant `window_size` or a window size determined by + the key through `window_size_func`. + + Args: + dataset: A `Dataset`. + key_func: A function mapping a nested structure of tensors + (having shapes and types defined by `self.output_shapes` and + `self.output_types`) to a scalar `tf.int64` tensor. + reduce_func: A function mapping a key and a dataset of up to `batch_size` + consecutive elements matching that key to another dataset. + window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements matching the same key to combine in a single + batch, which will be passed to `reduce_func`. Mutually exclusive with + `window_size_func`. + window_size_func: A function mapping a key to a `tf.int64` scalar + `tf.Tensor`, representing the number of consecutive elements matching + the same key to combine in a single batch, which will be passed to + `reduce_func`. Mutually exclusive with `window_size`. + + Returns: + A `Dataset`. + + Raises: + ValueError: if neither or both of {`window_size`, `window_size_func`} are + passed. + """ + if (window_size is not None and window_size_func or + not (window_size is not None or window_size_func)): + raise ValueError("Must pass either window_size or window_size_func.") + + if window_size is not None: + + def constant_window_func(unused_key): + return ops.convert_to_tensor(window_size, dtype=dtypes.int64) + + window_size_func = constant_window_func + + assert window_size_func is not None + return GroupByWindowDataset(dataset, key_func, reduce_func, window_size_func) diff --git a/tensorflow/contrib/data/python/ops/sloppy_ops.py b/tensorflow/contrib/data/python/ops/sloppy_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..010bd31161fe07964d4b92854aefb12d7b1cb54d --- /dev/null +++ b/tensorflow/contrib/data/python/ops/sloppy_ops.py @@ -0,0 +1,120 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Non-deterministic dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.framework import function +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +class SloppyInterleaveDataset(dataset_ops.Dataset): + """A `Dataset` that maps a function over its input and flattens the result.""" + + def __init__(self, input_dataset, map_func, cycle_length, block_length): + """See `tf.contrib.data.sloppy_interleave()` for details.""" + super(SloppyInterleaveDataset, self).__init__() + self._input_dataset = input_dataset + + @function.Defun(*nest.flatten(input_dataset.output_types)) + def tf_map_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + # Pass in shape information from the input_dataset. + for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): + arg.set_shape(shape) + + nested_args = nest.pack_sequence_as(input_dataset.output_types, args) + + if nest.is_sequence(nested_args): + dataset = map_func(*nested_args) + else: + dataset = map_func(nested_args) + + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`map_func` must return a `Dataset` object.") + + self._output_types = dataset.output_types + self._output_shapes = dataset.output_shapes + + return dataset.make_dataset_resource() + + self._map_func = tf_map_func + self._map_func.add_to_graph(ops.get_default_graph()) + + self._cycle_length = ops.convert_to_tensor( + cycle_length, dtype=dtypes.int64, name="cycle_length") + self._block_length = ops.convert_to_tensor( + block_length, dtype=dtypes.int64, name="block_length") + + def make_dataset_resource(self): + return gen_dataset_ops.sloppy_interleave_dataset( + self._input_dataset.make_dataset_resource(), + self._map_func.captured_inputs, + self._cycle_length, + self._block_length, + f=self._map_func, + output_types=nest.flatten(self.output_types), + output_shapes=nest.flatten(self.output_shapes)) + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + +def sloppy_interleave(dataset, map_func, cycle_length, block_length): + """Maps `map_func` across `dataset`, and interleaves the results. + + The resulting dataset is almost identical to `interleave`. The key + difference being that if retrieving a value from a given output iterator would + cause `get_next` to block, that iterator will be skipped, and consumed + when next available. If consuming from all iterators would cause the + `get_next` call to block, the `get_next` call blocks until the first value is + available. + + If the underlying datasets produce elements as fast as they are consumed, the + `sloppy_interleave` dataset behaves identically to the `interleave` dataset. + However, if an underlying dataset would block the consumer, the + `sloppy_interleave` dataset can violate to the round-robin order (respected by + the `interleave` dataset), producing an element from a different underlying + dataset instead. + + WARNING: The order of elements in the resulting dataset is not + deterministic. Use `Dataset.interleave()` if you want the elements to have a + deterministic order. + + Args: + dataset: A `Dataset` that produces elements to feed to `map_func`. + map_func: A function mapping a nested structure of tensors (having shapes + and types defined by `self.output_shapes` and `self.output_types`) to a + `Dataset`. + cycle_length: The number of threads to interleave from in parallel. + block_length: The number of consecutive elements to pull from a thread + before advancing to the next thread. Note: sloppy_interleave will + skip the remainder of elements in the block_length in order to avoid + blocking. + + Returns: + A `Dataset`. + """ + return SloppyInterleaveDataset(dataset, map_func, cycle_length, block_length) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index c78b064b4fd0707f4c39de2e70f88416f284e988..c2b99d67c7fe7cc7af298c4910c5032833118229 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -341,7 +341,7 @@ cuda_py_test( cuda_py_test( name = "sample_stats_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/sample_stats_test.py"], additional_deps = [ ":distributions_py", diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index da1cd72a6f13f7c585a60d0be122c212671fe5e8..699cf45a73883a49d116fa70c81a4f9ecb36e598 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -150,7 +150,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each class. Only one of `logits` or `probs` should be passed in. - dtype: The type of the event samples (default: int32). + dtype: The type of the event samples (default: float32). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -388,7 +388,7 @@ class RelaxedOneHotCategorical( dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each class. Only one of `logits` or `probs` should be passed in. - dtype: The type of the event samples (default: int32). + dtype: The type of the event samples (default: float32). validate_args: Unused in this distribution. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index e29314099d725a722b0b4eeec9f228cedd3d52c3..1b831f8afba5e402de873719928e0e0436952a74 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -2,11 +2,14 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + py_library( name = "tfe", srcs = ["tfe.py"], srcs_version = "PY2AND3", deps = [ + ":saver", "//tensorflow/python:framework_ops", "//tensorflow/python:util", "//tensorflow/python/eager:backprop", @@ -18,6 +21,28 @@ py_library( ], ) +py_library( + name = "saver", + srcs = ["saver.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:training", + ], +) + +cuda_py_test( + name = "saver_test", + srcs = ["saver_test.py"], + additional_deps = [ + ":saver", + "//tensorflow/python:array_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py new file mode 100644 index 0000000000000000000000000000000000000000..12c902a4b668eb1e8bf460a4a75609bbd31501f5 --- /dev/null +++ b/tensorflow/contrib/eager/python/saver.py @@ -0,0 +1,122 @@ +"""Saver for eager mode TensorFlow.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib + +from tensorflow.python.framework import errors +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import saver as _saver + + +def _init_from_checkpoint(self, *args, **kwargs): + """Overrides default init by loading value from checkpoint.""" + self.old_init(*args, **kwargs) + # pylint: disable=protected-access + if self._shared_name not in self.ckpt_var_cache: + raise errors.NotFoundError(None, None, + "%s not found in checkpoint" % self._shared_name) + + val = self.ckpt_var_cache[self._shared_name] + if val is not None: + self.assign(self.ckpt_var_cache[self._shared_name]) + # Avoid assigning for the second time. + self.ckpt_var_cache[self._shared_name] = None + # pylint: enable=protected-access + + +class Saver(object): + """A simple tf.train.Saver adapter for eager mode. + + save and restore API are similar to the tf.train.Saver, except that + session is not needed. + + restore_on_create is eager mode's way to reload checkpoint value during + the execution. (unlike graph mode's reload before run). + + Args: + var_list: See tf.train.Saver. Works the same for save/restore. Ignored + by restore_on_create. + """ + + def __init__(self, var_list=None): + self._saver = _saver.Saver(var_list=var_list) + + def save(self, save_path, global_step=None): + """Saves variables. + + Args: + save_path: See save method in tf.train.Saver. + global_step: See save method in tf.train.Saver. + + Returns: + See save method in tf.train.Saver. + """ + return self._saver.save(None, save_path, global_step=global_step) + + def restore(self, save_path): + """Restores previously saved variables. + + Args: + save_path: See restore method in tf.train.Saver. + """ + self._saver.restore(None, save_path) + + @contextlib.contextmanager + def maybe_restore_on_create(self, save_path): + """ContextManager that restores variables on creation. + + When save_path is None (e.g. No checkpoint), does nothing. + Otherwise, it preloads all values from checkpoint. When the + corresponding variable is first created, it assigns the checkpoint + value to the variable. + + Args: + save_path: Same as save_path of retore. If None, do not restore. + + Yields: + Nothing. + + Raises: + NotFoundError: If the variable is not found in checkpoint. + """ + if save_path: + ckpt_var_cache = dict() + reader = checkpoint_utils.load_checkpoint(save_path) + for k, _ in checkpoint_utils.list_variables(save_path): + ckpt_var_cache[k] = reader.get_tensor(k) + + old_init = getattr( + resource_variable_ops.ResourceVariable, "_init_from_args", None) + assert old_init, "ResourceVariable misses _init_from_args method." + setattr(resource_variable_ops.ResourceVariable, "_init_from_args", + _init_from_checkpoint) + setattr(resource_variable_ops.ResourceVariable, "old_init", old_init) + setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", + ckpt_var_cache) + try: + yield + except Exception as e: + raise e + finally: + if save_path: + setattr(resource_variable_ops.ResourceVariable, "_init_from_args", + old_init) + setattr(resource_variable_ops.ResourceVariable, "old_init", None) + setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", None) diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b8ff566ec2e0c51f6c22791b06369e2278d06c5b --- /dev/null +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -0,0 +1,88 @@ +"""Tests for eager mode Saver.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.eager.python import saver as _saver +from tensorflow.python.eager import context +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test + + +class SaverTest(test.TestCase): + + def testBasics(self): + with context.eager_mode(): + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + def model(): + return array_ops.constant(2.0) * v1 + + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + + _ = model() + saver = _saver.Saver() + saver.save(ckpt_prefix) + v1.assign(2.0) + self.assertEqual(v1.read_value().numpy(), 2.0) + + saver.restore(ckpt_prefix) + self.assertEqual(v1.read_value().numpy(), 1.0) + + def testRestoreOnCreate(self): + with context.eager_mode(): + def model(init_val): + v1 = resource_variable_ops.ResourceVariable(init_val, name='v1') + return array_ops.constant(1.0) * v1 + + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + _ = model(1.0) + saver = _saver.Saver() + saver.save(ckpt_prefix) + + with ops.Graph().as_default(): + saver = _saver.Saver() + with saver.maybe_restore_on_create(ckpt_prefix): + # Value is from checkpoint, but not from argument. + ret = model(2.0) + self.assertEqual(ret.numpy(), 1.0) + # Create it a second time won't re-assign the checkpoint value. + v1_2 = resource_variable_ops.ResourceVariable(3.0, name='v1') + self.assertEqual(v1_2.read_value().numpy(), 3.0) + + def testRestoreNotFound(self): + with context.eager_mode(): + def model(v): + return array_ops.constant(1.0) * v + + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + _ = model(resource_variable_ops.ResourceVariable(1.0, name='v1')) + saver = _saver.Saver() + saver.save(ckpt_prefix) + + with self.assertRaisesRegexp(errors.NotFoundError, + 'v2 not found in checkpoint'): + with saver.maybe_restore_on_create(ckpt_prefix): + _ = model(resource_variable_ops.ResourceVariable(1.0, name='v2')) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index aa0276dfd91fe4603b117627d54cc7048d2a7c74..2c7494a0a86db3f021916135e3e22f2f5b2b3d00 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -42,6 +42,8 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@inf_nan_callback @@nan_callback @@seterr + +@@Saver """ from __future__ import absolute_import @@ -51,6 +53,7 @@ from __future__ import print_function # pylint:disable=g-bad-import-order,g-import-not-at-top,unused-import # +from tensorflow.contrib.eager.python.saver import Saver from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.eager import backprop from tensorflow.python.eager.custom_gradient import custom_gradient diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..46cdf086ddca883fda0b20943389b84b4d2bc014 --- /dev/null +++ b/tensorflow/contrib/estimator/BUILD @@ -0,0 +1,61 @@ +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "estimator_py", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":extenders", + ], +) + +py_library( + name = "extenders", + srcs = [ + "python/estimator/extenders.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:util", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:util", + ], +) + +py_test( + name = "extenders_test", + size = "small", + srcs = ["python/estimator/extenders_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":extenders", + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:metrics", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:linear", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9180a3acc366dd86c60f71f24c13764716e747b5 --- /dev/null +++ b/tensorflow/contrib/estimator/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental utilities re:tf.estimator.*.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.estimator.python.estimator.extenders import * + +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = ['add_metrics'] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py new file mode 100644 index 0000000000000000000000000000000000000000..451cd6373d6ac31c7cec5f4495f1f73787ad49a7 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -0,0 +1,124 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Extenders of tf.estimator.Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator import util as estimator_util +from tensorflow.python.util import tf_inspect + +_VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) + + +def add_metrics(estimator, metric_fn): + """Creates new ${tf.estimator.Estimator} which has given metrics. + + Example: + + ```python + def my_auc(labels, predictions): + return {'auc': tf.metrics.auc(labels, predictions['logistic'])} + + estimator = tf.estimator.DNNClassifier(...) + estimator = tf.contrib.estimator.add_metrics(estimator, my_auc) + estimator.train(...) + estimator.evaluate(...) + ``` + Example usage of custom metric which uses features: + + ```python + def my_auc(features, labels, predictions): + return {'auc': tf.metrics.auc( + labels, predictions['logistic'], weights=features['weight'])} + + estimator = tf.estimator.DNNClassifier(...) + estimator = tf.contrib.estimator.add_metrics(estimator, my_auc) + estimator.train(...) + estimator.evaluate(...) + ``` + + Args: + estimator: A ${tf.estimator.Estimator} object. + metric_fn: A function which should obey the following signature: + - Args: can only have following four arguments in any order: + * predictions: Predictions `Tensor` or dict of `Tensor` created by given + `estimator`. + * features: Input `dict` of `Tensor` objects created by `input_fn` which + is given to `estimator.evaluate` as an argument. + * labels: Labels `Tensor` or dict of `Tensor` created by `input_fn` + which is given to `estimator.evaluate` as an argument. + * config: config attribute of the `estimator`. + - Returns: + Dict of metric results keyed by name. Final metrics are a union of this + and `estimator's` existing metrics. If there is a name conflict between + this and `estimator`s existing metrics, this will override the existing + one. The values of the dict are the results of calling a metric + function, namely a `(metric_tensor, update_op)` tuple. + + Returns: + A new ${tf.estimator.Estimator} which has a union of original metrics with + given ones. + """ + _verify_metric_fn_args(metric_fn) + + def new_model_fn(features, labels, mode): + spec = _get_model_fn(estimator)(features, labels, mode) + if mode != model_fn_lib.ModeKeys.EVAL: + return spec + new_metrics = _call_metric_fn(metric_fn, features, labels, spec.predictions, + estimator.config) + all_metrics = spec.eval_metric_ops or {} + all_metrics.update(new_metrics) + return spec._replace(eval_metric_ops=all_metrics) + + return estimator_lib.Estimator( + model_fn=new_model_fn, + model_dir=estimator.model_dir, + config=estimator.config) + + +# TODO(ispir): Move this to tf.estimator.Estimator. +def _get_model_fn(estimator): + return estimator._call_model_fn # pylint: disable=protected-access + + +def _verify_metric_fn_args(metric_fn): + args = set(estimator_util.fn_args(metric_fn)) + if tf_inspect.ismethod(metric_fn): + if 'self' in args: + args.remove('self') + invalid_args = list(args - _VALID_METRIC_FN_ARGS) + if invalid_args: + raise ValueError('metric_fn (%s) has following not expected args: %s' % + (metric_fn, invalid_args)) + + +def _call_metric_fn(metric_fn, features, labels, predictions, config): + """Calls metric fn with proper arguments.""" + metric_fn_args = estimator_util.fn_args(metric_fn) + kwargs = {} + if 'features' in metric_fn_args: + kwargs['features'] = features + if 'labels' in metric_fn_args: + kwargs['labels'] = labels + if 'predictions' in metric_fn_args: + kwargs['predictions'] = predictions + if 'config' in metric_fn_args: + kwargs['config'] = config + return metric_fn(**kwargs) diff --git a/tensorflow/contrib/estimator/python/estimator/extenders_test.py b/tensorflow/contrib/estimator/python/estimator/extenders_test.py new file mode 100644 index 0000000000000000000000000000000000000000..422c16d24e9de21ac8c9236844effadcbe452667 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/extenders_test.py @@ -0,0 +1,135 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""extenders tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.estimator.python.estimator import extenders +from tensorflow.python.estimator import run_config +from tensorflow.python.estimator.canned import linear +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.platform import test + + +def get_input_fn(x, y): + + def input_fn(): + dataset = dataset_ops.Dataset.from_tensor_slices({'x': x, 'y': y}) + iterator = dataset.make_one_shot_iterator() + features = iterator.get_next() + labels = features.pop('y') + return features, labels + + return input_fn + + +class AddMetricsTest(test.TestCase): + + def test_should_add_metrics(self): + input_fn = get_input_fn( + x=np.arange(4)[:, None, None], y=np.ones(4)[:, None]) + estimator = linear.LinearClassifier([fc.numeric_column('x')]) + + def metric_fn(features): + return {'mean_x': metrics_lib.mean(features['x'])} + + estimator = extenders.add_metrics(estimator, metric_fn) + + estimator.train(input_fn=input_fn) + metrics = estimator.evaluate(input_fn=input_fn) + self.assertIn('mean_x', metrics) + self.assertEqual(1.5, metrics['mean_x']) + # assert that it keeps original estimators metrics + self.assertIn('auc', metrics) + + def test_should_error_out_for_not_recognized_args(self): + estimator = linear.LinearClassifier([fc.numeric_column('x')]) + + def metric_fn(features, not_recognized): + _, _ = features, not_recognized + return {} + + with self.assertRaisesRegexp(ValueError, 'not_recognized'): + estimator = extenders.add_metrics(estimator, metric_fn) + + def test_all_supported_args(self): + input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) + estimator = linear.LinearClassifier([fc.numeric_column('x')]) + + def metric_fn(features, predictions, labels, config): + self.assertIn('x', features) + self.assertIsNotNone(labels) + self.assertIn('logistic', predictions) + self.assertTrue(isinstance(config, run_config.RunConfig)) + return {} + + estimator = extenders.add_metrics(estimator, metric_fn) + + estimator.train(input_fn=input_fn) + estimator.evaluate(input_fn=input_fn) + + def test_all_supported_args_in_different_order(self): + input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) + estimator = linear.LinearClassifier([fc.numeric_column('x')]) + + def metric_fn(labels, config, features, predictions): + self.assertIn('x', features) + self.assertIsNotNone(labels) + self.assertIn('logistic', predictions) + self.assertTrue(isinstance(config, run_config.RunConfig)) + return {} + + estimator = extenders.add_metrics(estimator, metric_fn) + + estimator.train(input_fn=input_fn) + estimator.evaluate(input_fn=input_fn) + + def test_all_args_are_optional(self): + input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) + estimator = linear.LinearClassifier([fc.numeric_column('x')]) + + def metric_fn(): + return {'two': metrics_lib.mean(constant_op.constant([2.]))} + + estimator = extenders.add_metrics(estimator, metric_fn) + + estimator.train(input_fn=input_fn) + metrics = estimator.evaluate(input_fn=input_fn) + self.assertEqual(2., metrics['two']) + + def test_overrides_existing_metrics(self): + input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]]) + estimator = linear.LinearClassifier([fc.numeric_column('x')]) + estimator.train(input_fn=input_fn) + metrics = estimator.evaluate(input_fn=input_fn) + self.assertNotEqual(2., metrics['auc']) + + def metric_fn(): + return {'auc': metrics_lib.mean(constant_op.constant([2.]))} + + estimator = extenders.add_metrics(estimator, metric_fn) + metrics = estimator.evaluate(input_fn=input_fn) + self.assertEqual(2., metrics['auc']) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/framework/python/ops/checkpoint_ops.py b/tensorflow/contrib/framework/python/ops/checkpoint_ops.py index 848e26ab966efeb9968cfbb0288959b81b373235..26146790b653a3b13c4a06d2113f14e1296ccbd7 100644 --- a/tensorflow/contrib/framework/python/ops/checkpoint_ops.py +++ b/tensorflow/contrib/framework/python/ops/checkpoint_ops.py @@ -17,440 +17,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_checkpoint_ops from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops - -ops.NotDifferentiable("GenerateVocabRemapping") -ops.NotDifferentiable("LoadAndRemapMatrix") - - -def _load_and_remap_matrix(ckpt_path, - old_tensor_name, - new_row_vocab_offset, - num_rows_to_load, - new_col_vocab_size, - initializer, - old_row_vocab_file=None, - new_row_vocab_file=None, - old_col_vocab_file=None, - new_col_vocab_file=None, - num_row_oov_buckets=0, - num_col_oov_buckets=0, - max_rows_in_memory=-1): - """Loads a 2-D (matrix) `Tensor` from checkpoint. - - Generates 1D-remappings for rows and columns using the - `GenerateVocabRemapping` op, and initializes any anticipated values with the - provided initializer. Then, uses the `LoadAndRemapMatrix` op to create a - matrix that loads existing values from the checkpoint, while filling out - "missing" values with the newly initialized values. See - contrib/framework/ops/checkpoint_ops.cc for more information on the wrapped - functionality (LoadAndRemapMatrix). This wrapper can be used to perform only - row remapping or only col remapping. If only row remapping is desired, - {new,old}_col_vocab_file should be `None`, and vice versa for column - remapping. - - NOTE: This only supports div-partitioning the vocabulary on the 1st dimension - (row axis) via `new_row_vocab_offset`. - - Args: - ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`) - from which the old matrix `Tensor` will be loaded. - old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint. - new_row_vocab_offset: A 0-indexed integer representing what line to - start reading at in the new row vocabulary. Used for partitioned - variables. - num_rows_to_load: Number of rows to load for the new vocabulary (note: to - support variable partitioning and partial loading, this does not need to - be the same as the number of entries in `new_row_vocab_file`). - new_col_vocab_size: Number of columns to load - should be the same as the - number of entries in `new_col_vocab_file`, since we don't support - partitioning along the column axis. - initializer: Callable initializer function that accepts a 1-D tensor as the - arg to specify the shape of the returned tensor. Used to initialize - missing values. - old_row_vocab_file: A scalar `Tensor` of type `string` containing the - path to the old row vocabulary file. Can be None, which represents no - remapping on the row axis. - new_row_vocab_file: A scalar `Tensor` of type `string` containing the path - to the new row vocabulary file. Can be None, which represents no remapping - on the row axis - in which case, `new_row_vocab_offset` and - `num_rows_to_load` work under the assumption that the new row vocab is the - same as the old row vocab. - old_col_vocab_file: A scalar `Tensor` of type `string` containing the - path to the old column vocabulary file. Can be None, which represents no - remapping on the column axis. - new_col_vocab_file: A scalar `Tensor` of type `string` containing the path - to the new column vocabulary file. Can be None, which represents no - remapping on the column axis - in which case, `new_col_vocab_size` works - under the assumption that the new col vocab is the same as the old col - vocab. - num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows - to append. Must be >= 0. - num_col_oov_buckets: `int` specifying the number of out-of-vocabulary - columns to append. Must be >= 0. - max_rows_in_memory: `int` specifying the maximum number of rows to load from - the checkpoint at once. If less than or equal to 0, the entire matrix will - be loaded into memory. Setting this arg trades increased disk reads for - lower memory usage. - - Returns: - A Tensor of shape `[num_rows_to_load + num_row_oov_buckets, - new_col_vocab_size + num_col_oov_buckets]`, with values loaded from the - specified tensor in the checkpoint, and any missing or OOV values - initialized with the given `initializer`. - - Raises: - ValueError: If `num_row_oov_buckets` or `num_col_oov_buckets` < 0. - ValueError: If either `old_row_vocab_file` or `new_row_vocab_file` is - provided, while the other is not. Same for `old_col_vocab_file` and - `new_col_vocab_file`. - ValueError: If neither row vocabs or col vocabs are provided. - """ - if num_row_oov_buckets < 0: - raise ValueError("num_row_oov_buckets must be >= 0, but received %d" % - num_row_oov_buckets) - if num_col_oov_buckets < 0: - raise ValueError("num_col_oov_buckets must be >= 0, but received %d" % - num_col_oov_buckets) - - if bool(old_row_vocab_file) != bool(new_row_vocab_file): - raise ValueError( - "old_row_vocab_file and new_row_vocab_file must both be specified or " - "left unspecified. old_row_vocab_file='{}', new_row_vocab_file='{}'". - format(old_row_vocab_file, new_row_vocab_file)) - if bool(old_col_vocab_file) != bool(new_col_vocab_file): - raise ValueError( - "old_col_vocab_file and new_col_vocab_file must both be specified or " - "left unspecified. old_col_vocab_file='{}', new_col_vocab_file='{}'". - format(old_col_vocab_file, new_col_vocab_file)) - - remap_rows = new_row_vocab_file and old_row_vocab_file - remap_cols = new_col_vocab_file and old_col_vocab_file - if not (remap_rows or remap_cols): - raise ValueError( - "Must provide either row or column vocab files. If no remapping is " - "necessary, consider using `tf.contrib.framework.init_from_checkpoint` " - "instead.") - - num_rows_present = num_rows_to_load - if remap_rows: - row_remapping, num_rows_present = ( - gen_checkpoint_ops._generate_vocab_remapping( # pylint: disable=protected-access - new_vocab_file=new_row_vocab_file, - old_vocab_file=old_row_vocab_file, - new_vocab_offset=new_row_vocab_offset, - num_new_vocab=num_rows_to_load)) - else: - # Even when the rows are not being reordered, we still need to generate a - # remapping to account for initializing partitioned Variables (when - # new_row_vocab_offset is non-zero). - row_remapping = math_ops.range( - new_row_vocab_offset, - new_row_vocab_offset + num_rows_to_load, - dtype=dtypes.int64) +from tensorflow.python.training import checkpoint_ops - col_remapping = [] - num_cols_present = new_col_vocab_size - if remap_cols: - col_remapping, num_cols_present = ( - gen_checkpoint_ops._generate_vocab_remapping( # pylint: disable=protected-access - new_vocab_file=new_col_vocab_file, - old_vocab_file=old_col_vocab_file, - new_vocab_offset=0, # Offset is unused for cols (no partitioning). - num_new_vocab=new_col_vocab_size)) - init_vals = initializer([ - num_rows_to_load * new_col_vocab_size - - num_rows_present * num_cols_present, 1 - ]) - return_tensor = gen_checkpoint_ops._load_and_remap_matrix( # pylint: disable=protected-access - ckpt_path=ckpt_path, - old_tensor_name=old_tensor_name, - row_remapping=row_remapping, - col_remapping=col_remapping, - initializing_values=init_vals, - num_rows=num_rows_to_load, - num_cols=new_col_vocab_size, - max_rows_in_memory=max_rows_in_memory) - - # Add OOV row(s) and column(s). - if num_row_oov_buckets > 0: - init_row_oov_val = initializer([num_row_oov_buckets, new_col_vocab_size]) - init_row_oov_val = ops.convert_to_tensor(init_row_oov_val) - return_tensor = array_ops.concat([return_tensor, init_row_oov_val], 0) - if num_col_oov_buckets > 0: - # We need to add any row OOV to the new column shape. - init_col_oov_val = initializer( - [num_rows_to_load + num_row_oov_buckets, num_col_oov_buckets]) - init_col_oov_val = ops.convert_to_tensor(init_col_oov_val) - return_tensor = array_ops.concat([return_tensor, init_col_oov_val], 1) - - return return_tensor - - -def load_and_remap_matrix_initializer(ckpt_path, - old_tensor_name, - new_row_vocab_size, - new_col_vocab_size, - old_row_vocab_file=None, - new_row_vocab_file=None, - old_col_vocab_file=None, - new_col_vocab_file=None, - num_row_oov_buckets=0, - num_col_oov_buckets=0, - initializer=None, - max_rows_in_memory=-1): - r"""Returns a var initializer for loading and remapping a 2-D (matrix) tensor. - - The returned initializer loads a 2-D (matrix) `Tensor` with name - `old_tensor_name` from the checkpoint at `ckpt_path`. It will reorder the - rows/columns according to the specified vocab files and append additional - out-of-vocabulary rows/columns according to the number of OOV buckets. - - The format of the file at the `{old,new}_{row,col}_vocab_file` path should be - a text file, with each line containing a single entity within the vocabulary. - Let the function `line_of(f, "x")` return the 0-indexed line number of the - entity "x" in file f, and the function `entity_at(f, i)` return the entity at - line i of file f. Then, row i of the new output matrix will be taken from row - `line_of(old_row_vocab_file, entity_at(new_row_vocab_file, i))` of the old - matrix. If any entity in `new_row_vocab_file` is not found in - `old_row_vocab_file`, that row is considered a "missing" row, and its values - will be initialized using the `initializer` arg. The same logic also applies - for the columns. - - For example, assuming that: - - * `old_row_vocab_file` contains "mercury\nvenus\nmars" - * `new_row_vocab_file` contains "venus\njupiter\nmercury" - * `old_col_vocab_file` contains "good\nbetter\nbest" - * `new_col_vocab_file` contains "good\nbest\nfantastic" - * `initializer` returns the natural numbers `[1, 2, 3, 4, ...]` - * `w(i, j)` represents the value from row i, column j of the old matrix - - Then the new output matrix will look like: - - `[[w(1, 0), w(1, 2), 1], - [2, 3, 4], - [w(0, 0), w(0, 2), 5]]` - - If we further specify that: - - * `num_row_oov_buckets` == 2 - * `num_col_oov_buckets` == 1 - - Then the new output matrix will look like: - - `[[w(1, 0), w(1, 2), 1, 12], - [2, 3, 4, 13], - [w(0, 0), w(0, 2), 5, 14], - [6, 7, 8, 15], - [9, 10, 11, 16]]` - - If `{old,new}_row_vocab_file` are None, we assume that the old and new row - vocab files are the same, and no row remapping is done. If - `{old,new}_col_vocab_file` are None, we assume that the old and new column - vocab files are the same, and no column remapping is done. - - The returned initializer only supports div-partitioning along the row axis. It - does not support partitioning along the column axis or mod-partitioning. - - NOTE: When this is used to warm-start variables, client code should use - `tf.lookup.index_table_from_tensor()` like - contrib/layers/python/layers/feature_column.py does, as opposed to - `tf.feature_to_id()` - in order to ensure the underlying lookup tables are the - same. - - Args: - ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`) - from which the old matrix `Tensor` will be loaded. - old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint. - new_row_vocab_size: `int` specifying the number of entries in - `new_row_vocab_file`. If no row remapping is needed (no row vocab - provided), this should be equal to the number of rows to load from the old - matrix (which can theoretically be smaller than the number of rows in the - old matrix). - new_col_vocab_size: `int` specifying the number of entries in - `new_col_vocab_file`. If no column remapping is needed (no column vocab - provided), this should be equal to the number of columns in the old - matrix. - old_row_vocab_file: A scalar `Tensor` of type `string` containing the - path to the old row vocabulary file. Can be None, which represents no - remapping on the row axis. - new_row_vocab_file: A scalar `Tensor` of type `string` containing the path - to the new row vocabulary file. Can be None, which represents no remapping - on the row axis. - old_col_vocab_file: A scalar `Tensor` of type `string` containing the - path to the old column vocabulary file. Can be None, which represents no - remapping on the column axis. - new_col_vocab_file: A scalar `Tensor` of type `string` containing the path - to the new column vocabulary file. Can be None, which represents no - remapping on the column axis. - num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows - to append. Must be >= 0. - num_col_oov_buckets: `int` specifying the number of out-of-vocabulary - columns to append. Must be >= 0. - initializer: Initializer function to initialize missing values. Accepts a - 1-D tensor as the arg to specify the shape of the returned tensor. If - `None`, defaults to using `zeros_initializer()`. - max_rows_in_memory: `int` specifying the maximum number of rows to load from - the checkpoint at once. If less than or equal to 0, the entire matrix will - be loaded into memory. Setting this arg trades increased disk reads for - lower memory usage. - - Returns: - A variable initializer function that should be used to initialize a - (potentially partitioned) `Variable` whose complete shape is - `[new_row_vocab_size + num_row_oov_buckets, new_col_vocab_size + - num_col_oov_buckets]`. - - Raises: - TypeError: If `initializer` is specified but not callable. - """ - if initializer is None: - # TODO(b/25671353): Consider using sqrt(6/(fan_in + fan_out)) instead, from - # Glorot and Bengio, 2010. - initializer = init_ops.zeros_initializer() - - if not callable(initializer): - raise TypeError( - "initializer must be callable, instead of being {} of type {}.".format( - initializer, type(initializer))) - - def _initializer(shape, dtype=dtypes.float32, partition_info=None): - """Variable initializer. - - Args: - shape: Shape of `Tensor` to return. Should include OOV on both axes. - dtype: Must be float32. - partition_info: variable_scope._PartitionInfo. - - Returns: - `Tensor` of shape `shape`. - - Raises: - TypeError: If `dtype` is anything other than float32. - ValueError: For shape mismatch upon invocation. - """ - # Sanity checks. - if dtype != dtypes.float32: - raise TypeError( - "Currently, only float32 is supported. Received dtype: {}".format( - dtype)) - if len(shape) != 2: - raise ValueError("Expected 2-dim shape, but received: {}".format(shape)) - if shape[0] <= 0: - raise ValueError( - "Expected 1st dim of shape to be > 0, but received shape: {}".format( - shape)) - if shape[1] != (new_col_vocab_size + num_col_oov_buckets): - raise ValueError( - "Expected 2nd dim of shape to be new_col_vocab_size ({}) + " - "num_col_oov_buckets ({}) = {}, but received shape: {}".format( - new_col_vocab_size, num_col_oov_buckets, - new_col_vocab_size + num_col_oov_buckets, shape)) - - offset = 0 - if partition_info is not None: - offset = partition_info.single_offset(shape) - - if offset + shape[0] > new_row_vocab_size + num_row_oov_buckets: - raise ValueError( - "Trying to initialize {} additional rows after {} rows have already " - "been initialized, which would exceed expected total row count of " - "new_row_vocab_size ({}) + num_row_oov_buckets ({}) = {}.".format( - shape[0], offset, new_row_vocab_size, num_row_oov_buckets, - new_row_vocab_size + num_row_oov_buckets)) - - row_oov_buckets_to_use = min(shape[0], - max(0, offset + shape[0] - new_row_vocab_size)) - num_rows_to_load = shape[0] - row_oov_buckets_to_use - - return _load_and_remap_matrix( - ckpt_path=ckpt_path, - old_tensor_name=old_tensor_name, - new_row_vocab_offset=offset, - num_rows_to_load=num_rows_to_load, - new_col_vocab_size=new_col_vocab_size, - initializer=initializer, - old_row_vocab_file=old_row_vocab_file, - new_row_vocab_file=new_row_vocab_file, - old_col_vocab_file=old_col_vocab_file, - new_col_vocab_file=new_col_vocab_file, - num_row_oov_buckets=row_oov_buckets_to_use, - num_col_oov_buckets=num_col_oov_buckets, - max_rows_in_memory=max_rows_in_memory) - - return _initializer - - -def load_embedding_initializer(ckpt_path, - embedding_tensor_name, - new_vocab_size, - embedding_dim, - old_vocab_file, - new_vocab_file, - num_oov_buckets=0, - initializer=None, - max_rows_in_memory=-1): - """Returns a variable initializer for loading pre-trained embeddings. - - Wrapper around `load_and_remap_matrix_initializer()` specialized for loading - embedding weights and remapping according to the provided vocab files. See - docs for `load_and_remap_matrix_initializer()` for more details. - - NOTE: Only for use with div-partitioned variables / vocabularies. - - Args: - ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`) - from which the old matrix `Tensor` will be loaded. - embedding_tensor_name: Name of the 2-D `Tensor` to load from checkpoint. - new_vocab_size: Number of entries in the new vocab. - embedding_dim: `int` specifying the dimension of the embedding vectors from - the checkpoint. Must match the number of columns in the old embedding - matrix. - old_vocab_file: A scalar `Tensor` of type `string` containing the - path to the old vocabulary file. - new_vocab_file: A scalar `Tensor` of type `string` containing the - path to the new vocabulary file. - num_oov_buckets: `int` specifying the number of out-of-vocabulary - buckets to use. Must be >= 0. - initializer: Initializer function that accepts a 1-D tensor as the arg to - specify the shape of the returned tensor. If `None`, defaults to using - `truncated_normal_initializer()`. - max_rows_in_memory: `int` specifying the maximum number of rows to load from - the checkpoint at once. If less than or equal to 0, the entire matrix will - be loaded into memory. Setting this arg trades increased disk reads for - lower memory usage. - - Returns: - A variable initializer function. - """ - if initializer is None: - # TODO(b/25671353): This should be kept in sync with the stddev used by - # feature_column.py's _EmbeddingColumn. - initializer = init_ops.truncated_normal_initializer( - stddev=1.0 / math.sqrt(embedding_dim)) - - return load_and_remap_matrix_initializer( - ckpt_path=ckpt_path, - old_tensor_name=embedding_tensor_name, - new_row_vocab_size=new_vocab_size, - new_col_vocab_size=embedding_dim, - old_row_vocab_file=old_vocab_file, - new_row_vocab_file=new_vocab_file, - old_col_vocab_file=None, - new_col_vocab_file=None, - num_row_oov_buckets=num_oov_buckets, - num_col_oov_buckets=0, - initializer=initializer, - max_rows_in_memory=max_rows_in_memory) +# pylint: disable=protected-access,line-too-long +load_and_remap_matrix_initializer = checkpoint_ops._load_and_remap_matrix_initializer +# pylint: enable=line-too-long +load_embedding_initializer = checkpoint_ops._load_embedding_initializer +# pylint: enable=protected-access def load_linear_multiclass_bias_initializer(ckpt_path, diff --git a/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py b/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py index a11d373244d870a09d8d4b58e849d2975a64e60b..b7b9f5c59e12ec0ac44455f00d8285c196a7ac39 100644 --- a/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py +++ b/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py @@ -21,7 +21,6 @@ import os import numpy as np from tensorflow.contrib import framework as contrib_framework -from tensorflow.contrib.framework.python.ops import checkpoint_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -38,250 +37,6 @@ FLAGS = flags.FLAGS _TESTDATA_PATH = 'contrib/framework/testdata' -class LoadAndRemapWrappersTest(test.TestCase): - """Tests for the functionality of the Python wrappers.""" - - def setUp(self): - self.bundle_file = os.path.join( - test.test_src_dir_path(_TESTDATA_PATH), 'bundle_checkpoint') - self.new_feature_vocab_file = os.path.join( - test.test_src_dir_path(_TESTDATA_PATH), 'bundle_checkpoint_vocab.txt') - self.old_feature_vocab_file = os.path.join( - test.test_src_dir_path(_TESTDATA_PATH), - 'bundle_checkpoint_vocab_with_oov.txt') - self.new_class_vocab_file = os.path.join( - test.test_src_dir_path(_TESTDATA_PATH), 'keyword_new.txt') - self.old_class_vocab_file = os.path.join( - test.test_src_dir_path(_TESTDATA_PATH), 'keyword.txt') - self.init_val = 42 - - def _init_val_initializer(shape, dtype=None, partition_info=None): - del dtype, partition_info # Unused by this unit-testing initializer. - return array_ops.tile( - constant_op.constant([[self.init_val]], dtype=dtypes.float32), shape) - - self.initializer = _init_val_initializer - - def test_load_and_remap_matrix(self): - """Tests the end-to-end loading / remapping of weights.""" - # _load_and_remap_matrix() is the generalized wrapper that takes in row and - # column vocabulary files, calls the relevant remappings, and returns the - # weight matrix. Take this example to be linear multi-class by providing - # both row and column vocabularies. - remapped_matrix = checkpoint_ops._load_and_remap_matrix( - new_row_vocab_file=self.new_feature_vocab_file, - old_row_vocab_file=self.old_feature_vocab_file, - num_rows_to_load=4, - new_col_vocab_file=self.new_class_vocab_file, - old_col_vocab_file=self.old_class_vocab_file, - new_col_vocab_size=4, - old_tensor_name='some_scope/embeddings', - ckpt_path=[self.bundle_file], - new_row_vocab_offset=1, - initializer=self.initializer, - num_row_oov_buckets=1, - num_col_oov_buckets=1) - - # [4 in vocab + 1 oov features, 4 in vocab + 1 oov classes]. The offset - # means we read - expected_remapped_matrix = np.concatenate( - [ - np.reshape([18, 34, 50, self.init_val, self.init_val], [5, 1]), - np.reshape([16, 32, 48, self.init_val, self.init_val], [5, 1]), - np.reshape([self.init_val] * 5, [5, 1]), - np.reshape([17, 33, 49, self.init_val, self.init_val], [5, 1]), - np.reshape([self.init_val] * 5, [5, 1]) - ], - axis=1) - - with self.test_session(): - self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval()) - - def test_load_and_remap_output_layer_weight_initializer_linear(self): - """Tests for the output layer initializer in the linear multi-class case.""" - loading_initializer = (contrib_framework.load_and_remap_matrix_initializer( - new_row_vocab_size=5, - new_col_vocab_file=self.new_class_vocab_file, - old_col_vocab_file=self.old_class_vocab_file, - new_col_vocab_size=4, - old_tensor_name='some_scope/embeddings', - ckpt_path=[self.bundle_file], - new_row_vocab_file=self.new_feature_vocab_file, - old_row_vocab_file=self.old_feature_vocab_file, - num_row_oov_buckets=1, - num_col_oov_buckets=1, - initializer=self.initializer)) - - expected_remapped_matrix = np.concatenate( - [ - np.reshape([2, 18, 34, 50, self.init_val, self.init_val], [6, 1]), - np.reshape([0, 16, 32, 48, self.init_val, self.init_val], [6, 1]), - np.reshape([self.init_val] * 6, [6, 1]), - np.reshape([1, 17, 33, 49, self.init_val, self.init_val], [6, 1]), - np.reshape([self.init_val] * 6, [6, 1]) - ], - axis=1) - - # The new weight matrix is of size - # [5 feature vocab + 1 feature OOV, 4 class vocab + 1 class OOV]. Use a - # partitioned variable to confirm that the offset logic works. - remapped_matrix = variable_scope.get_variable( - name='linear/obtained_weight_matrix', - shape=[6, 5], - initializer=loading_initializer, - partitioner=partitioned_variables.fixed_size_partitioner(2)) - - with self.test_session(): - variables.global_variables_initializer().run() - self.assertAllClose(expected_remapped_matrix, - remapped_matrix.as_tensor().eval()) - - def test_load_and_remap_output_layer_weight_initializer_dnn_output(self): - """Tests for the output layer initializer in the DNN output case.""" - loading_initializer = (contrib_framework.load_and_remap_matrix_initializer( - new_row_vocab_size=5, - new_col_vocab_file=self.new_class_vocab_file, - old_col_vocab_file=self.old_class_vocab_file, - new_col_vocab_size=4, - old_tensor_name='some_scope/embeddings', - ckpt_path=[self.bundle_file], - num_col_oov_buckets=1, - initializer=self.initializer)) - - expected_remapped_matrix = np.concatenate( - [ - np.reshape([2, 18, 34, 50, 66], [5, 1]), - np.reshape([0, 16, 32, 48, 64], [5, 1]), - np.reshape([self.init_val] * 5, [5, 1]), - np.reshape([1, 17, 33, 49, 65], [5, 1]), - np.reshape([self.init_val] * 5, [5, 1]) - ], - axis=1) - - # The new weight matrix is of size - # [5-sized input layer, 4 class vocab + 1 class OOV]. - remapped_matrix = variable_scope.get_variable( - name='dnn_output/obtained_weight_matrix', - shape=[5, 5], - initializer=loading_initializer, - partitioner=partitioned_variables.fixed_size_partitioner(2)) - - with self.test_session(): - variables.global_variables_initializer().run() - self.assertAllClose(expected_remapped_matrix, - remapped_matrix.as_tensor().eval()) - - def test_initializer_with_oov_only_partition(self): - """Tests for the output layer initializer where one partition is all OOV.""" - loading_initializer = (contrib_framework.load_and_remap_matrix_initializer( - new_row_vocab_size=5, - new_col_vocab_file=self.new_class_vocab_file, - old_col_vocab_file=self.old_class_vocab_file, - new_col_vocab_size=4, - old_tensor_name='some_scope/embeddings', - ckpt_path=[self.bundle_file], - new_row_vocab_file=self.new_feature_vocab_file, - old_row_vocab_file=self.old_feature_vocab_file, - num_row_oov_buckets=5, - num_col_oov_buckets=1, - initializer=self.initializer)) - - expected_remapped_matrix = np.concatenate( - [ - np.reshape([2, 18, 34, 50] + [self.init_val] * 6, [10, 1]), - np.reshape([0, 16, 32, 48] + [self.init_val] * 6, [10, 1]), - np.reshape([self.init_val] * 10, [10, 1]), - np.reshape([1, 17, 33, 49] + [self.init_val] * 6, [10, 1]), - np.reshape([self.init_val] * 10, [10, 1]), - ], - axis=1) - - # The new weight matrix is of size - # [5 feature vocab + 5 feature OOV, 4 class vocab + 1 class OOV]. The - # second partition has only OOV. - remapped_matrix = variable_scope.get_variable( - name='linear_all_oov/obtained_weight_matrix', - shape=[10, 5], - initializer=loading_initializer, - partitioner=partitioned_variables.fixed_size_partitioner(2)) - - with self.test_session(): - variables.global_variables_initializer().run() - self.assertAllClose(expected_remapped_matrix, - remapped_matrix.as_tensor().eval()) - - def test_load_and_remap_linear_multiclass_initializer_default_init(self): - """Tests where the zeros_initializer default is used for linear.""" - loading_initializer = (contrib_framework.load_and_remap_matrix_initializer( - new_row_vocab_size=5, - new_col_vocab_file=self.new_class_vocab_file, - old_col_vocab_file=self.old_class_vocab_file, - new_col_vocab_size=4, - old_tensor_name='some_scope/embeddings', - ckpt_path=[self.bundle_file], - new_row_vocab_file=self.new_feature_vocab_file, - old_row_vocab_file=self.old_feature_vocab_file, - num_row_oov_buckets=1, - num_col_oov_buckets=1)) - - expected_remapped_matrix = np.concatenate( - [ - np.reshape([2, 18, 34, 50, 0, 0], [6, 1]), - np.reshape([0, 16, 32, 48, 0, 0], [6, 1]), - np.reshape([0] * 6, [6, 1]), - np.reshape([1, 17, 33, 49, 0, 0], [6, 1]), - np.reshape([0] * 6, [6, 1]) - ], - axis=1) - - remapped_matrix = variable_scope.get_variable( - name='linear_init_fallback/obtained_weight_matrix', - shape=[6, 5], - initializer=loading_initializer, - partitioner=partitioned_variables.fixed_size_partitioner(2)) - - with self.test_session(): - variables.global_variables_initializer().run() - self.assertAllClose(expected_remapped_matrix, - remapped_matrix.as_tensor().eval()) - - def test_load_embedding_initializer(self): - """Tests for the load_embedding_initializer wrapper.""" - embedding_loading_initializer = ( - contrib_framework.load_embedding_initializer( - new_vocab_file=self.new_feature_vocab_file, - old_vocab_file=self.old_feature_vocab_file, - new_vocab_size=5, - embedding_dim=16, - embedding_tensor_name='some_scope/embeddings', - ckpt_path=[self.bundle_file], - num_oov_buckets=1, - initializer=self.initializer)) - - expected_remapped_embeddings = np.concatenate( - [ - np.reshape(range(64), [4, 16]), - np.reshape([self.init_val] * 32, [2, 16]), - ], - axis=0) - - # The new weight matrix is of size - # [5 feature vocab + 1 feature OOV, 16 (embedding dimension)], where the - # last vocab row (2nd last row) is newly initialized (wasn't found in - # previous vocab) and the actual last row is OOV and also newly initialized. - # Use a partitioned variable to confirm that the offset logic works. - remapped_embeddings = variable_scope.get_variable( - name='embedding/obtained_embedding_matrix', - shape=[6, 16], - initializer=embedding_loading_initializer, - partitioner=partitioned_variables.fixed_size_partitioner(2)) - - with self.test_session(): - variables.global_variables_initializer().run() - self.assertAllClose(expected_remapped_embeddings, - remapped_embeddings.as_tensor().eval()) - - class LoadMulticlassBiasTest(test.TestCase): """Tests for the load_linear_multiclass_bias_initializer functionality.""" diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index f5d21278db89b96ca21963a2e37bba546cc78096..9b34cf1bdb048fe4653594b4ca3a6971d2275909 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -60,12 +60,14 @@ tf_kernel_library( srcs = [ "kernels/fused_conv2d_bias_activation_op.cc", "kernels/fused_conv2d_bias_activation_op.h", + "kernels/fused_conv_ops_gpu.h", ], prefix = "fused_conv2d_bias_activation_op", deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:stream_executor", "//tensorflow/core/kernels:bounds_check_lib", "//tensorflow/core/kernels:conv_2d_hdrs", "//tensorflow/core/kernels:conv_ops_gpu_hdrs", @@ -81,6 +83,7 @@ tf_custom_op_library( srcs = [ "kernels/fused_conv2d_bias_activation_op.cc", "kernels/fused_conv2d_bias_activation_op.h", + "kernels/fused_conv_ops_gpu.h", "ops/fused_conv2d_bias_activation_op.cc", ], deps = [ @@ -94,12 +97,8 @@ tf_custom_op_library( ) tf_gen_op_libs( - op_lib_names = [ - "fused_conv2d_bias_activation_op", - ], - deps = [ - "//tensorflow/core:lib_proto_parsing", - ], + op_lib_names = ["fused_conv2d_bias_activation_op"], + deps = ["//tensorflow/core:lib_proto_parsing"], ) tf_gen_op_wrapper_py( @@ -109,7 +108,7 @@ tf_gen_op_wrapper_py( cuda_py_test( name = "fused_conv2d_bias_activation_op_test", - size = "small", + size = "large", srcs = ["python/ops/fused_conv2d_bias_activation_op_test.py"], additional_deps = [ ":fused_conv_py", diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index dc0701b234fc50aae2478015228a3cd00b60efc7..675ff2be3888b2bdd354217ae16f30c865ccee3a 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#define EIGEN_USE_THREADS - #if GOOGLE_CUDA #define EIGEN_USE_GPU #endif // GOOGLE_CUDA @@ -31,8 +29,8 @@ limitations under the License. #include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/padding.h" -#include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/use_cudnn.h" #if GOOGLE_CUDA @@ -40,38 +38,84 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/util/activation_mode.h" #endif // GOOGLE_CUDA + namespace tensorflow { -typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -template -struct LaunchConvOp; +template +struct RawType { + using type = T; +}; + +template <> +struct RawType { + using type = int8; +}; + +// Template struct to convert int8x4 to int32. +// (for NCHW_VECT_C with element type int8, we can consider it to be +// an NCHW layout with element type int32 for operations like padding). +template +struct Int8x4ToInt32 { + // By default, do not change T. + using type = T; +}; + +template <> +struct Int8x4ToInt32 { + using type = int32; +}; -template +// T is the element type of the conv_input, filter and side_input tensors. +// BiasType is the element type of the bias tensor, which can be different. +// ScaleType is the type used for conv_input_scale, side_input_scale. +template class FusedConv2DBiasActivationOp : public OpKernel { public: explicit FusedConv2DBiasActivationOp(OpKernelConstruction* context) : OpKernel(context) { - string data_format; - OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); - OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + string data_format_str, filter_format_str; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(context, + context->GetAttr("filter_format", &filter_format_str)); OP_REQUIRES(context, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), - errors::InvalidArgument("Current implementation only supports " - "NHWC and NCHW data formats.")); - OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); - OP_REQUIRES(context, strides_.size() == 4, + FilterFormatFromString(filter_format_str, &filter_format_), + errors::InvalidArgument("Invalid filter format")); + + std::vector strides; + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides)); + OP_REQUIRES(context, strides.size() == 4, errors::InvalidArgument("Sliding window strides field must " "specify 4 dimensions")); + + stride_rows_ = GetTensorDim(strides, data_format_, 'H'); + stride_cols_ = GetTensorDim(strides, data_format_, 'W'); OP_REQUIRES( context, - (GetTensorDim(strides_, data_format_, 'N') == 1 && - GetTensorDim(strides_, data_format_, 'C') == 1), - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + (GetTensorDim(strides, data_format_, 'N') == 1 && + GetTensorDim(strides, data_format_, 'C') == 1), + errors::InvalidArgument("Convolutional strides are not supported in " + "the batch or depth dimensions.")); + + // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here. + constexpr bool is_int8x4 = std::is_same::value; + + // Note: Only NCHW_VECT_C format is supported for int8. + // This is because it is expected to be the fastest, and our previous tests + // found cudnn 6 does not fully support the other formats for int8 mode. + OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)), + errors::InvalidArgument( + "qint8 should be used with data_format NCHW_VECT_C.")); + + OP_REQUIRES(context, (is_int8x4 == (filter_format_ == FORMAT_OIHW_VECT_I)), + errors::InvalidArgument( + "qint8 should be used with filter_format OIHW_VECT_I.")); + + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_type_)); + eigen_padding_type_ = BrainPadding2EigenPadding(padding_type_); string activation_mode_str; OP_REQUIRES_OK(context, context->GetAttr("activation_mode", &activation_mode_str)); @@ -79,130 +123,111 @@ class FusedConv2DBiasActivationOp : public OpKernel { &activation_mode_)); OP_REQUIRES(context, activation_mode_ == ActivationMode::RELU, errors::InvalidArgument("Current implementation only supports " - "relu as the activation mode.")); + "RELU as the activation function.")); cudnn_use_autotune_ = CudnnUseAutotune(); + float conv_input_scale_flt, side_input_scale_flt; + OP_REQUIRES_OK(context, + context->GetAttr("conv_input_scale", &conv_input_scale_flt)); + OP_REQUIRES_OK(context, + context->GetAttr("side_input_scale", &side_input_scale_flt)); + conv_input_scale_ = conv_input_scale_flt; + side_input_scale_ = side_input_scale_flt; + } + + Status CheckShape(const Tensor& tensor, const string& tensor_name) { + const int num_dims = tensor.dims(); + for (int i = 0; i < num_dims; i++) { + if (!FastBoundsCheck(tensor.dim_size(i), + std::numeric_limits::max())) { + return errors::InvalidArgument(tensor_name, " dimension ", i, + " too large"); + } + } + // If there is a 5th dimension it is the VECT_C or VECT_I dimension. + if (num_dims == 5 && tensor.dim_size(4) != 4) { + return errors::InvalidArgument("The last dimension of ", tensor_name, + " must be of size 4 for qint8."); + } + return Status::OK(); } void Compute(OpKernelContext* context) override { - // Input tensor is one of the following shapes: - // [ batch, in_rows, in_cols, in_depth ] (for NHWC data format) - // [ batch, in_depth, in_rows, in_cols ] (for NCHW data format) - const Tensor& input = context->input(0); + // The conv_input tensor is one of the following formats: + // NHWC, NCHW, NCHW_VECT_C. + const Tensor& conv_input = context->input(0); + OP_REQUIRES_OK(context, CheckShape(conv_input, "conv_input")); - // Input filter is of the following dimensions: - // [ filter_rows, filter_cols, in_depth, out_depth ] + // The filter tensor is one of the following formats: + // HWIO, OIHW, OIHW_VECT_I. const Tensor& filter = context->input(1); + OP_REQUIRES_OK(context, CheckShape(filter, "filter")); - // Input bias is a 1-D tensor the size of the last - // dimension of Output tensor + // Input bias is a 1-D tensor, with size matching output depth. const Tensor& bias = context->input(2); + OP_REQUIRES_OK(context, CheckShape(bias, "conv_input")); - // For 2D convolution, there should be 4 dimensions. - OP_REQUIRES(context, input.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", - input.shape().DebugString())); - OP_REQUIRES(context, filter.dims() == 4, - errors::InvalidArgument("filter must be 4-dimensional: ", - filter.shape().DebugString())); - - // Bias should be a 1-D tensor. - OP_REQUIRES(context, bias.dims() == 1, - errors::InvalidArgument("bias must be 1-dimensional: ", - bias.shape().DebugString())); - - for (int i = 0; i < 4; i++) { - OP_REQUIRES(context, - FastBoundsCheck(filter.dim_size(i), - std::numeric_limits::max()), - errors::InvalidArgument("filter dimension too large")); - OP_REQUIRES( - context, - FastBoundsCheck(input.dim_size(i), std::numeric_limits::max()), - errors::InvalidArgument("input dimension too large")); + // If side_input_scale != 0, then side_input is not ignored and + // has the same type and dimensions as the output. + const Tensor& side_input = context->input(3); + if (side_input_scale_ != 0) { + OP_REQUIRES_OK(context, CheckShape(side_input, "side_input")); } - // The last dimension for input is in_depth. It must be the same as the - // filter's in_depth. - const int64 in_depth = GetTensorDim(input, data_format_, 'C'); - OP_REQUIRES(context, in_depth == filter.dim_size(2), - errors::InvalidArgument( - "input and filter must have the same depth: ", in_depth, - " vs ", filter.dim_size(2))); - - // The last dimension for filter is out_depth. - const int32 out_depth = static_cast(filter.dim_size(3)); - - // The second dimension for input is rows/height. - // The first dimension for filter is rows/height. - const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H'); - const int32 input_rows = static_cast(input_rows_raw); - const int32 filter_rows = static_cast(filter.dim_size(0)); - - // The third dimension for input is columns/width. - // The second dimension for filter is columns/width. - const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W'); - const int32 input_cols = static_cast(input_cols_raw); - const int32 filter_cols = static_cast(filter.dim_size(1)); - - // The first dimension for input is batch. - const int64 batch_raw = GetTensorDim(input, data_format_, 'N'); - const int32 batch = static_cast(batch_raw); - - // For now we take the stride from the second and third dimensions only (we - // do not support striding on the batch or depth dimension). - const int32 stride_rows = - static_cast(GetTensorDim(strides_, data_format_, 'H')); - const int32 stride_cols = - static_cast(GetTensorDim(strides_, data_format_, 'W')); - const int32 bias_size = static_cast(bias.dim_size(0)); - - int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; - OP_REQUIRES_OK(context, - GetWindowedOutputSize(input_rows, filter_rows, stride_rows, - padding_, &out_rows, &pad_rows)); - OP_REQUIRES_OK(context, - GetWindowedOutputSize(input_cols, filter_cols, stride_cols, - padding_, &out_cols, &pad_cols)); - // Output tensor is of the following dimensions: - // [ in_batch, out_rows, out_cols, out_depth ] - TensorShape out_shape = - ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth); + // TODO(pauldonnelly): Switch to a more efficient mechanism to access + // dimension indexes and per-dimension attributes. + const int32 filter_rows = GetFilterDim(filter, filter_format_, 'H'); + const int32 filter_cols = GetFilterDim(filter, filter_format_, 'W'); + const int32 output_depth = GetFilterDim(filter, filter_format_, 'O'); + + const int32 batch_size = GetTensorDim(conv_input, data_format_, 'N'); + const int32 conv_input_rows = GetTensorDim(conv_input, data_format_, 'H'); + const int32 conv_input_cols = GetTensorDim(conv_input, data_format_, 'W'); + + int64 output_rows = 0, output_cols = 0, pad_rows = 0, pad_cols = 0; + OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_rows, filter_rows, + stride_rows_, padding_type_, + &output_rows, &pad_rows)); + OP_REQUIRES_OK(context, GetWindowedOutputSize(conv_input_cols, filter_cols, + stride_cols_, padding_type_, + &output_cols, &pad_cols)); + // Initialize the output tensor shape according to data_format_ + TensorShape output_shape = ShapeFromFormat( + data_format_, batch_size, output_rows, output_cols, output_depth); Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); - - // Bias size should be the same as the size of the channel dimension of - // output. - OP_REQUIRES(context, bias_size == out_depth, - errors::InvalidArgument( - "bias size should equal the channel " - "dimension size of output. bias shape: ", - bias.shape().DebugString() + - ", output shape: " + output->shape().DebugString())); + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); - VLOG(2) << "FusedConv2DBiasActivation: in_depth = " << in_depth - << ", input_cols = " << input_cols + VLOG(2) << "FusedConv2DBiasActivation: conv_input_cols = " + << conv_input_cols << ", conv_input_rows = " << conv_input_rows << ", filter_cols = " << filter_cols - << ", input_rows = " << input_rows << ", filter_rows = " << filter_rows - << ", stride_rows = " << stride_rows - << ", stride_cols = " << stride_cols - << ", bias_size = " << bias_size << ", out_depth = " << out_depth; + << ", stride_cols = " << stride_cols_ + << ", stride_rows = " << stride_rows_ + << ", output_depth = " << output_depth + << ", output_cols = " << output_cols + << ", output_rows = " << output_rows + << ", output_shape.num_elements = " << output_shape.num_elements(); // If there is nothing to compute, return. - if (out_shape.num_elements() == 0) { + if (output_shape.num_elements() == 0) { return; } - launcher_.launch(context, cudnn_use_autotune_, input, filter, stride_rows, - stride_cols, bias, activation_mode_, - BrainPadding2EigenPadding(padding_), data_format_, output); + + launcher_.launch(context, cudnn_use_autotune_, conv_input, + conv_input_scale_, filter, stride_rows_, stride_cols_, + eigen_padding_type_, side_input, side_input_scale_, bias, + activation_mode_, data_format_, filter_format_, output); } private: - std::vector strides_; - Padding padding_; + int32 stride_rows_, stride_cols_; + Padding padding_type_; + Eigen::PaddingType eigen_padding_type_; ActivationMode activation_mode_; TensorFormat data_format_; - LaunchFusedConv2DBiasActivationOp launcher_; + FilterTensorFormat filter_format_; + ScaleType conv_input_scale_; + ScaleType side_input_scale_; + LaunchFusedConv2DBiasActivationOp launcher_; bool cudnn_use_autotune_; TF_DISALLOW_COPY_AND_ASSIGN(FusedConv2DBiasActivationOp); @@ -211,67 +236,72 @@ class FusedConv2DBiasActivationOp : public OpKernel { #if GOOGLE_CUDA namespace dnn = ::perftools::gputools::dnn; -dnn::ActivationMode BrainActivationMode2CudnnActivationMode( - ActivationMode activation_mode) { - switch (activation_mode) { - case ActivationMode::SIGMOID: - return dnn::ActivationMode::kSigmoid; - case ActivationMode::RELU: - return dnn::ActivationMode::kRelu; - case ActivationMode::RELUX: - return dnn::ActivationMode::kReluX; - case ActivationMode::RELU6: - return dnn::ActivationMode::kRelu6; - case ActivationMode::TANH: - return dnn::ActivationMode::kTanh; - case ActivationMode::BANDPASS: - return dnn::ActivationMode::kBandPass; - } - // Prevent compiler warning about missing return - return dnn::ActivationMode::kRelu; -} - // A dummy type to group forward convolution autotune results together. struct ConvBiasActivationAutoTuneGroup { static string name() { return "ConvBiasActivation"; } }; -typedef AutoTuneSingleton +typedef AutoTuneSingleton AutoTuneConvBiasActivation; -template -void LaunchFusedConv2DBiasActivationOp::launch( - OpKernelContext* ctx, bool cudnn_use_autotune, const Tensor& input_param, - const Tensor& filter, int32 row_stride, int32 col_stride, - const Tensor& bias, const ActivationMode& activation_mode, - const Eigen::PaddingType& padding, TensorFormat data_format, - Tensor* output) { - using perftools::gputools::dnn::AlgorithmConfig; - using perftools::gputools::dnn::AlgorithmType; - using perftools::gputools::dnn::ProfileResult; - using perftools::gputools::dnn::kDefaultAlgorithm; +// Allocates 'transformed_tensor' and transforms 'nhwc_tensor' into it +// using the specified 'batch_size', 'rows', 'cols', and 'depth' dimensions. +template +Status TransformNHWCToNCHW(OpKernelContext* ctx, const Tensor& nhwc_tensor, + int batch_size, int rows, int cols, int depth, + Tensor* transformed_tensor, const Tensor** result) { + TensorShape nchw_shape = + ShapeFromFormat(FORMAT_NCHW, batch_size, rows, cols, depth); + if (depth > 1) { + TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum::value, nchw_shape, + transformed_tensor)); + functor::NHWCToNCHW()( + ctx->eigen_device(), nhwc_tensor.tensor(), + transformed_tensor->tensor()); + } else { + // If depth <= 1, then just reshape. + CHECK(transformed_tensor->CopyFrom(nhwc_tensor, nchw_shape)); + } + *result = transformed_tensor; + return Status::OK(); +} + +template +void LaunchFusedConv2DBiasActivationOp:: + launch(OpKernelContext* ctx, bool cudnn_use_autotune, + const Tensor& conv_input_param, ScaleType conv_input_scale, + const Tensor& filter_param, int32 row_stride, int32 col_stride, + const Eigen::PaddingType& padding, const Tensor& side_input_param, + ScaleType side_input_scale, const Tensor& bias, + ActivationMode activation_mode, TensorFormat data_format, + FilterTensorFormat filter_format, Tensor* output_param) { auto* stream = ctx->op_device_context()->stream(); OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); - Tensor input = input_param; - - perftools::gputools::dnn::ActivationMode cudnn_activation_mode = - BrainActivationMode2CudnnActivationMode(activation_mode); - // TODO(yangzihao): refactor all the complicated/duplicated code in regular // conv ops to a shared conv utility. - int32 padding_rows = 0; - int32 padding_cols = 0; - const int64 in_batch = GetTensorDim(input, data_format, 'N'); - int64 in_rows = GetTensorDim(input, data_format, 'H'); - int64 in_cols = GetTensorDim(input, data_format, 'W'); - const int64 in_depths = GetTensorDim(input, data_format, 'C'); - const int64 out_batch = GetTensorDim(*output, data_format, 'N'); - const int64 out_rows = GetTensorDim(*output, data_format, 'H'); - const int64 out_cols = GetTensorDim(*output, data_format, 'W'); - const int64 out_depths = GetTensorDim(*output, data_format, 'C'); - const int64 patch_rows = filter.dim_size(0); - const int64 patch_cols = filter.dim_size(1); + + // Assuming qint8 <--> NCHW_VECT_C, OIHW_VECT_I (int8x4) here. + constexpr bool is_int8x4 = std::is_same::value; + constexpr int rank = is_int8x4 ? 5 : 4; + constexpr int vect = is_int8x4 ? 4 : 1; + + const int batch_size = GetTensorDim(conv_input_param, data_format, 'N'); + int conv_input_rows = GetTensorDim(conv_input_param, data_format, 'H'); + int conv_input_cols = GetTensorDim(conv_input_param, data_format, 'W'); + + const int conv_input_depth = + GetTensorDim(conv_input_param, data_format, 'C') * vect; + const int output_rows = GetTensorDim(*output_param, data_format, 'H'); + const int output_cols = GetTensorDim(*output_param, data_format, 'W'); + const int output_depth = GetFilterDim(filter_param, filter_format, 'O'); + const int filter_rows = GetFilterDim(filter_param, filter_format, 'H'); + const int filter_cols = GetFilterDim(filter_param, filter_format, 'W'); + int padding_rows = 0; + int padding_cols = 0; + const Tensor* conv_input = &conv_input_param; + + Tensor maybe_padded_conv_input; if (padding == Eigen::PADDING_SAME) { // Total padding on rows and cols is // Pr = (R' - 1) * S + Kr - R @@ -281,114 +311,152 @@ void LaunchFusedConv2DBiasActivationOp::launch( // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means // we pad more on the right and bottom than on the top and left. - padding_rows = - std::max(0, (out_rows - 1) * row_stride + patch_rows - in_rows); - padding_cols = - std::max(0, (out_cols - 1) * col_stride + patch_cols - in_cols); - const int rows_parity = padding_rows & 1; - const int cols_parity = padding_cols & 1; - if ((rows_parity | cols_parity) != 0) { + padding_rows = std::max( + 0, (output_rows - 1) * row_stride + filter_rows - conv_input_rows); + padding_cols = std::max( + 0, (output_cols - 1) * col_stride + filter_cols - conv_input_cols); + const int padding_rows_parity = padding_rows & 1; + const int padding_cols_parity = padding_cols & 1; + if ((padding_rows_parity | padding_cols_parity) != 0) { Tensor transformed_input; - int64 new_in_rows = in_rows + rows_parity; - int64 new_in_cols = in_cols + cols_parity; + const int new_conv_input_rows = conv_input_rows + padding_rows_parity; + const int new_conv_input_cols = conv_input_cols + padding_cols_parity; + + using VectT = typename Int8x4ToInt32::type>::type; + auto pad_data_format = is_int8x4 ? FORMAT_NCHW : data_format; + OP_REQUIRES_OK( - ctx, - ctx->allocate_temp(DataTypeToEnum::value, - ShapeFromFormat(data_format, in_batch, new_in_rows, - new_in_cols, in_depths), - &transformed_input)); - - functor::PadInput()( - ctx->eigen_device(), To32Bit(input_param.tensor()), - {{0, 0}}, {{rows_parity, cols_parity}}, - To32Bit(transformed_input.tensor()), data_format); - - input = transformed_input; - in_rows = new_in_rows; - in_cols = new_in_cols; + ctx, ctx->allocate_temp( + DataTypeToEnum::value, + ShapeFromFormat(data_format, batch_size, new_conv_input_rows, + new_conv_input_cols, conv_input_depth), + &maybe_padded_conv_input)); + + auto conv_input_eigen_tensor = + To32Bit(conv_input_param.reinterpret_last_dimension()); + auto padded_conv_input_eigen_tensor = To32Bit( + maybe_padded_conv_input.reinterpret_last_dimension()); + + functor::PadInput()( + ctx->eigen_device(), conv_input_eigen_tensor, {{0, 0}}, + {{padding_rows_parity, padding_cols_parity}}, + padded_conv_input_eigen_tensor, pad_data_format); + + conv_input = &maybe_padded_conv_input; + conv_input_rows = new_conv_input_rows; + conv_input_cols = new_conv_input_cols; } } - if (data_format == FORMAT_NHWC) { - // Convert the input tensor from NHWC to NCHW. - TensorShape nchw_shape = - ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths); - if (in_depths > 1) { - Tensor transformed_input; - OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, - nchw_shape, &transformed_input)); - functor::NHWCToNCHW()( - ctx->eigen_device(), - const_cast(input).tensor(), - transformed_input.tensor()); - input = transformed_input; - } else { - // If depth <= 1, then just reshape. - CHECK(input.CopyFrom(input, nchw_shape)); + Tensor maybe_transformed_conv_input, maybe_transformed_side_input; + Tensor maybe_transformed_output; + const Tensor* side_input = &side_input_param; + Tensor* output = output_param; + + // NOTE: Here and elsewhere, checking 'is_int8x4' may look unnecessary + // and inefficient, but it is actually both a time and code size optimization, + // since 'is_int8x4' is a constexpr determined by the template parameter. + if (!is_int8x4 && data_format == FORMAT_NHWC) { + OP_REQUIRES_OK(ctx, (TransformNHWCToNCHW( + ctx, *conv_input, batch_size, conv_input_rows, + conv_input_cols, conv_input_depth, + &maybe_transformed_conv_input, &conv_input))); + if (side_input_scale != 0) { + OP_REQUIRES_OK( + ctx, (TransformNHWCToNCHW( + ctx, side_input_param, batch_size, output_rows, output_cols, + output_depth, &maybe_transformed_side_input, &side_input))); + } + if (output_depth > 1) { + // Allocate a tensor for the NCHW output of the kernel and point output + // to it. Afterwards, we will transform it to NHWC while copying back to + // 'output_param'. + TensorShape nchw_shape = ShapeFromFormat( + FORMAT_NCHW, batch_size, output_rows, output_cols, output_depth); + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum::value, nchw_shape, + &maybe_transformed_output)); + output = &maybe_transformed_output; } } - CHECK(padding_rows >= 0 && padding_cols >= 0) - << "Negative row or col paddings: (" << padding_rows << ", " - << padding_cols << ")"; - perftools::gputools::dnn::BatchDescriptor input_desc; - input_desc.set_count(in_batch) - .set_feature_map_count(in_depths) - .set_height(in_rows) - .set_width(in_cols) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::BatchDescriptor output_desc; - output_desc.set_count(out_batch) - .set_height(out_rows) - .set_width(out_cols) - .set_feature_map_count(out_depths) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::FilterDescriptor filter_desc; - filter_desc.set_input_filter_height(filter.dim_size(0)) - .set_input_filter_width(filter.dim_size(1)) - .set_input_feature_map_count(filter.dim_size(2)) - .set_output_feature_map_count(filter.dim_size(3)); - perftools::gputools::dnn::ConvolutionDescriptor conv_desc; + constexpr auto data_layout = is_int8x4 ? dnn::DataLayout::kBatchDepthYX4 + : dnn::DataLayout::kBatchDepthYX; + constexpr auto filter_layout = is_int8x4 ? dnn::FilterLayout::kOutputInputYX4 + : dnn::FilterLayout::kOutputInputYX; + + dnn::BatchDescriptor conv_input_desc; + conv_input_desc.set_count(batch_size) + .set_feature_map_count(conv_input_depth) + .set_height(conv_input_rows) + .set_width(conv_input_cols) + .set_layout(data_layout); + dnn::FilterDescriptor filter_desc; + filter_desc.set_input_filter_height(filter_rows) + .set_input_filter_width(filter_cols) + .set_input_feature_map_count(conv_input_depth) + .set_output_feature_map_count(output_depth) + .set_layout(filter_layout); + dnn::BatchDescriptor side_input_desc; + side_input_desc.set_count(batch_size) + .set_height(output_rows) + .set_width(output_cols) + .set_feature_map_count(output_depth) + .set_layout(data_layout); + dnn::BatchDescriptor bias_desc; + bias_desc.set_count(1) + .set_height(1) + .set_width(1) + .set_feature_map_count(output_depth) + .set_layout(dnn::DataLayout::kBatchDepthYX); + dnn::BatchDescriptor output_desc; + output_desc.set_count(batch_size) + .set_height(output_rows) + .set_width(output_cols) + .set_feature_map_count(output_depth) + .set_layout(data_layout); + dnn::ConvolutionDescriptor conv_desc; conv_desc.set_vertical_filter_stride(row_stride) .set_horizontal_filter_stride(col_stride) .set_zero_padding_height(padding_rows / 2) .set_zero_padding_width(padding_cols / 2); - // Shuffles a filter tensor from: - // [, in, out] - // to: - // [out, in, ] - // TODO(yangzihao): Support a data layout tag for the filter weights, and only - // do the transform if the weights are not already in the correct layout. - Tensor transformed_filter; - OP_REQUIRES_OK(ctx, ctx->allocate_temp( - DataTypeToEnum::value, - TensorShape({filter.dim_size(3), filter.dim_size(2), - filter.dim_size(0), filter.dim_size(1)}), - &transformed_filter)); - - functor::TransformFilter()( - ctx->eigen_device(), To32Bit(filter.tensor()), - To32Bit(transformed_filter.tensor())); - - Tensor transformed_output; - OP_REQUIRES_OK( - ctx, ctx->allocate_temp(DataTypeToEnum::value, - ShapeFromFormat(FORMAT_NCHW, out_batch, out_rows, - out_cols, out_depths), - &transformed_output)); - - auto input_ptr = AsDeviceMemory(input.template flat().data(), - input.template flat().size()); + Tensor maybe_transformed_filter; + const Tensor* filter; + if (is_int8x4) { + // We have already checked filter is OIHW_VECT_I in the constructor. + filter = &filter_param; + } else if (filter_format == FORMAT_HWIO) { + // Shuffle filter tensor from HWIO to OIHW: + OP_REQUIRES_OK(ctx, ctx->allocate_temp( + DataTypeToEnum::value, + ShapeFromFilterFormat( + FORMAT_OIHW, filter_param.shape(), FORMAT_HWIO), + &maybe_transformed_filter)); + functor::TransformFilter()( + ctx->eigen_device(), To32Bit(filter_param.tensor()), + To32Bit(maybe_transformed_filter.tensor())); + filter = &maybe_transformed_filter; + } + + auto conv_input_ptr = + AsDeviceMemory(reinterpret_cast::type*>( + conv_input->template flat().data()), + conv_input->template flat().size()); auto filter_ptr = - AsDeviceMemory(transformed_filter.template flat().data(), - transformed_filter.template flat().size()); + AsDeviceMemory(reinterpret_cast::type*>( + filter->template flat().data()), + filter->template flat().size()); + auto side_input_ptr = + AsDeviceMemory(reinterpret_cast::type*>( + side_input->template flat().data()), + side_input->template flat().size()); auto output_ptr = - AsDeviceMemory(transformed_output.template flat().data(), - transformed_output.template flat().size()); - - auto bias_ptr = AsDeviceMemory(bias.template flat().data(), - bias.template flat().size()); + AsDeviceMemory(reinterpret_cast::type*>( + output->template flat().data()), + output->template flat().size()); + auto bias_ptr = AsDeviceMemory(bias.template flat().data(), + bias.template flat().size()); static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit( // default value is in bytes despite the name of the environment variable @@ -396,38 +464,42 @@ void LaunchFusedConv2DBiasActivationOp::launch( ); int device_id = stream->parent()->device_ordinal(); - DataType dtype = input.dtype(); - ConvParameters conv_parameters = { - in_batch, - in_depths, - {{in_rows, in_cols}}, - out_depths, - {{patch_rows, patch_cols}}, + FusedConvParameters fused_conv_parameters = { + batch_size, + conv_input_depth, + {{conv_input_rows, conv_input_cols}}, + output_depth, + {{filter_rows, filter_cols}}, {{row_stride, col_stride}}, {{padding_rows, padding_cols}}, - dtype, + conv_input->dtype(), device_id, + (side_input_scale != 0), + activation_mode, }; - AlgorithmConfig algorithm_config; + dnn::AlgorithmConfig algorithm_config; if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find( - conv_parameters, &algorithm_config)) { - std::vector algorithms; + fused_conv_parameters, &algorithm_config)) { + std::vector algorithms; CHECK(stream->parent()->GetConvolveAlgorithms( - conv_parameters.ShouldIncludeWinogradNonfusedAlgo(), &algorithms)); - ProfileResult best_result; - ProfileResult best_result_no_scratch; + fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo(), + &algorithms)); + dnn::ProfileResult best_result; + dnn::ProfileResult best_result_no_scratch; for (auto profile_algorithm : algorithms) { // TODO(zhengxq): profile each algorithm multiple times to better // accuracy. CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); - ProfileResult profile_result; + dnn::ProfileResult profile_result; bool cudnn_launch_status = stream - ->ThenConvolveWithAlgorithm( - input_desc, input_ptr, filter_desc, filter_ptr, conv_desc, - bias_ptr, cudnn_activation_mode, output_desc, &output_ptr, - &scratch_allocator, AlgorithmConfig(profile_algorithm), + ->ThenFusedConvolveWithAlgorithm( + conv_input_desc, conv_input_ptr, conv_input_scale, + filter_desc, filter_ptr, conv_desc, side_input_ptr, + side_input_scale, bias_desc, bias_ptr, + dnn::ActivationMode::kRelu, output_desc, &output_ptr, + &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm), &profile_result) .ok(); if (cudnn_launch_status) { @@ -454,42 +526,68 @@ void LaunchFusedConv2DBiasActivationOp::launch( algorithm_config.set_algorithm_no_scratch( best_result_no_scratch.algorithm()); } - AutoTuneConvBiasActivation::GetInstance()->Insert(conv_parameters, + AutoTuneConvBiasActivation::GetInstance()->Insert(fused_conv_parameters, algorithm_config); } CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); bool cudnn_launch_status = stream - ->ThenConvolveWithAlgorithm( - input_desc, input_ptr, filter_desc, filter_ptr, conv_desc, - bias_ptr, cudnn_activation_mode, output_desc, &output_ptr, - &scratch_allocator, algorithm_config, + ->ThenFusedConvolveWithAlgorithm( + conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc, + filter_ptr, conv_desc, side_input_ptr, side_input_scale, + bias_desc, bias_ptr, dnn::ActivationMode::kRelu, output_desc, + &output_ptr, &scratch_allocator, algorithm_config, /*output_profile_result=*/nullptr) .ok(); if (!cudnn_launch_status) { - ctx->SetStatus(errors::Internal( - "cuDNN launch failure : input shape(", input.shape().DebugString(), - ") filter shape(", filter.shape().DebugString(), ")")); + ctx->SetStatus(errors::Internal("cuDNN launch failure : conv_input shape(", + conv_input->shape().DebugString(), + ") filter shape(", + filter->shape().DebugString(), ")")); } - // Convert the output tensor back from NCHW to NHWC. - if (data_format == FORMAT_NHWC) { + // Convert the output tensor back from NCHW to NHWC if necessary. + if (!is_int8x4 && (data_format == FORMAT_NHWC) && (output_depth > 1)) { functor::NCHWToNHWC()( ctx->eigen_device(), - const_cast(transformed_output).tensor(), - output->tensor()); - } else { - *output = transformed_output; + const_cast(output)->tensor(), + output_param->tensor()); } } +// Forward declarations of the functor specializations for GPU used above. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void PadInput::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + const std::array& padding_left, \ + const std::array& padding_right, \ + typename TTypes::Tensor out, TensorFormat data_format); \ + extern template struct PadInput; + +DECLARE_GPU_SPEC(float); +DECLARE_GPU_SPEC(int32); +#undef DECLARE_GPU_SPEC +} // namespace functor + // Registration of the GPU implementations. -REGISTER_KERNEL_BUILDER(Name("FusedConv2DBiasActivation") - .Device(DEVICE_GPU) - .TypeConstraint("T"), - FusedConv2DBiasActivationOp); + +REGISTER_KERNEL_BUILDER( + Name("FusedConv2DBiasActivation") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .TypeConstraint("Tbias"), + FusedConv2DBiasActivationOp); + +REGISTER_KERNEL_BUILDER( + Name("FusedConv2DBiasActivation") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .TypeConstraint("Tbias"), + FusedConv2DBiasActivationOp); #endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h index d71b26cf1db4bd79f238d66417c437288bf50ad8..7534f5797c4f3eee3b031b2693e212749af85c6e 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.h @@ -24,7 +24,7 @@ limitations under the License. #if GOOGLE_CUDA #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA @@ -33,27 +33,30 @@ namespace tensorflow { // Forward declaration. class OpKernelContext; -template +template class LaunchFusedConv2DBiasActivationOp { public: void launch(OpKernelContext* ctx, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int row_stride, - int col_stride, const Tensor& bias, - const ActivationMode& activation_mode, - const Eigen::PaddingType& padding, TensorFormat data_format, - Tensor* output); + const Tensor& conv_input, ScaleType conv_input_scale, + const Tensor& filter, int32 row_stride, int32 col_stride, + const Eigen::PaddingType& padding, const Tensor& side_input, + ScaleType side_input_scale, const Tensor& bias, + ActivationMode activation_mode, TensorFormat data_format, + FilterTensorFormat filter_format, Tensor* output); }; #ifdef GOOGLE_CUDA -template -class LaunchFusedConv2DBiasActivationOp { +template +class LaunchFusedConv2DBiasActivationOp { public: void launch(OpKernelContext* ctx, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int32 row_stride, - int32 col_stride, const Tensor& bias, - const ActivationMode& activation_mode, - const Eigen::PaddingType& padding, TensorFormat data_format, - Tensor* output); + const Tensor& conv_input, ScaleType conv_input_scale, + const Tensor& filter, int32 row_stride, int32 col_stride, + const Eigen::PaddingType& padding, const Tensor& side_input, + ScaleType side_input_scale, const Tensor& bias, + ActivationMode activation_mode, TensorFormat data_format, + FilterTensorFormat filter_format, Tensor* output); }; #endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h new file mode 100644 index 0000000000000000000000000000000000000000..dc43af11580ce5fda74ee25da6c151a5b89c7aee --- /dev/null +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ + +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/util/activation_mode.h" + +// TODO(pauldonnelly): Merge this file into core/kernels/conv_ops_gpu.h. + +namespace tensorflow { + +// Add additional parameters specific to fused convolutions. +class FusedConvParameters : public ConvParameters { + public: + FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in, + int64 out_depths, const SpatialArray& filter, + const SpatialArray& stride, const SpatialArray& padding, + DataType dtype, int device_id, bool has_side_input, + ActivationMode activation_mode) + : ConvParameters(batch, in_depths, in, out_depths, filter, stride, + padding, dtype, device_id), + activation_mode_(activation_mode), + has_side_input_(has_side_input) { + hash_code_ = Hash64Combine(hash_code_, has_side_input); + hash_code_ = Hash64Combine(hash_code_, activation_mode); + } + + bool operator==(const FusedConvParameters& other) const { + return this->get_data_as_tuple() == other.get_data_as_tuple(); + } + + bool operator!=(const FusedConvParameters& other) const { + return !(*this == other); + } + + string ToString() const { + return strings::StrCat(ConvParameters::ToString(), ", ", has_side_input_, + ", ", activation_mode_, ", "); + } + + private: + using ParameterDataType = + std::tuple; + + ParameterDataType get_data_as_tuple() const { + return std::make_tuple(ConvParameters::get_data_as_tuple(), has_side_input_, + activation_mode_); + } + + ActivationMode activation_mode_; + bool has_side_input_; +}; + +} // namespace tensorflow + +#endif // GOOGLE_CUDA + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc index 6134c5c699dad7e0464495feb49d6519a333e576..48f058b4c535fd814e0ee8c4757fd3b4706c269c 100644 --- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc @@ -33,40 +33,73 @@ string GetAllActivationModeAttrString() { return "activation_mode: {'Relu'}"; } } // namespace // -------------------------------------------------------------------------- + +// TODO(pauldonnelly): Add support for double inputs and scales to this Op, +// (currently Attr does not support double). + REGISTER_OP("FusedConv2DBiasActivation") - .Input("input: T") + .Input("conv_input: T") .Input("filter: T") - .Input("bias: T") + .Input("bias: Tbias") + .Input("side_input: T") .Output("output: T") - .Attr("T: {float}") + .Attr("T: {float, half, qint8}") + .Attr("Tbias: {float, half}") + .Attr("conv_input_scale: float = 1.0") + .Attr("side_input_scale: float = 0.0") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) - .Attr(GetConvnetDataFormatAttrString()) - .Attr(GetAllActivationModeAttrString()) + .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") + .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'") + .Attr("activation_mode: {'Relu'} = 'Relu'") .SetShapeFn(shape_inference::FusedConvBiasActivationShape) .Doc(R"doc( - Computes a fused 2-D convolution, adds bias, and applies an activation function - on the output given 4-D `input`, 4-D `filter`, 1-D `bias` tensors and an activation mode. + Computes a fused kernel which implements: 2-D convolution, adds side input, + with separate scaling on convolution and side inputs, then adds bias and + applies the RELU activation function to the result. Supports both float and + qint8 data formats. In the case of qint8, the output is clipped to [0..127]. - input: A 4-D tensor. The dimension order is interpreted according to the value - of `data_format`, see below for details. - filter: A 4-D tensor of shape - `[filter_height, filter_width, in_channels, out_channels]` - bias: 1-D with size of the `out_channels` dimension in filter. - output: A 4-D tensor. The dimension order is determined by the value of - `data_format`, see below for details. - T: The data type for the elements of input, filter, bias, and output Tensors. + conv_input: A tensor with format as specified by `data_format` (see below). + filter: A tensor with format depending on `data_format` as follows: + "NHWC", "NCHW": + `float [ filter_height, filter_width, in_channels, out_channels ]` + "NCHW_VECT_C": + `qint8 [ out_channels, in_channels, filter_height, filter_width ]` + bias: 1-D float tensor with size matching the `out_channels` dimension of + `filter`. + Note: this tensor is still float, even if other inputs are qint8. + side_input: A tensor with format as specified by `data_format` (see below). + This tensor will be ignored and can be [] if side_input_scale == 0. + Otherwise, the size of each dimension must match the `output` tensor. + output: A tensor with format as specified by `data_format` (see below). + The dimension sizes are determined automatically based on other inputs + and attributes. + T: The element data type of `conv_input`, `side_input` and `output` tensors. + Note: must match with the `data_format`. + Tbias: The element data type of `bias`. + conv_input_scale: scalar float value to be multiplied by `conv_input`. + (conceptually.. in reality it is applied after convolution). + side_input_scale: scalar float value to be multiplied by `side_input`. strides: 1-D tensor of length 4. The stride of the sliding window for each dimension of `input`. The dimension order is determined by the value of `data_format`, see below for details. + Note: the stride for batch and channel dimensions must be 1. padding: The type of padding algorithm to use. - data_format: Specify the data format of the input and output data. With the - default format "NHWC", the data is stored in the order of: - [batch, height, width, channels]. - Alternatively, the format could be "NCHW", the data storage order of: - [batch, channels, height, width]. - activation_mode: Specify the activation function to apply to the output tensor - of bias add. Currently only supports "Relu". + data_format: A string specifying the data format of `conv_input`, + `side_input` and `output` tensors with the following options: + "NHWC": `float [ batch, height, width, channels ]` + "NCHW": `float [ batch, channels, height, width ]` + "NCHW_VECT_C": + `qint8 [ batch, channels / 4, height, width, channels % 4 ]` + Note: for "NCHW_VECT_C", `channels` must be a multiple of 4. + filter_format: A string specifying the data format of `filter`, + "HWIO": `float [ kernel_height, kernel_width, input_channels, + output_channels ]` + "OIHW_VECT_I": + `qint8 [ output_channels, input_channels / 4, + kernel_height, kernel_width, input_channels % 4 ]` + activation_mode: The activation applied to the output. + Currently must be "Relu". )doc"); } // namespace tensorflow diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py index 41f986dd07c7e394604f0532a708809e5f01c598..8f3f31bad0d89a01050de866ee773c28889c0fbe 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op.py @@ -26,62 +26,83 @@ _fused_conv2d_bias_activation_op_so = loader.load_op_library( resource_loader.get_path_to_datafile("_fused_conv2d_bias_activation_op.so")) -def fused_conv2d_bias_activation(input_tensor, - filter_tensor, +# pylint: disable=redefined-builtin +def fused_conv2d_bias_activation(conv_input, + filter, bias, - strides, - padding, - activation_mode, + strides=None, + padding=None, + conv_input_scale=1.0, + side_input_scale=0.0, + side_input=None, + activation_mode="Relu", data_format=None, + filter_format=None, name=None): - """Computes a fused 2-D convolution, adds bias, and applies relu. + """Fused 2D conv, bias and activation with optional side input. - input_tensor: A 4-D tensor. The dimension order is interpreted - according to the value of `data_format`, see below for details. - filter_tensor: A 4-D tensor of shape - `[filter_height, filter_width, in_channels, out_channels]` - bias: 1-D with size of the `out_channels` dimension in filter. - output: A 4-D tensor. The dimension order is determined by the value of - `data_format`, see below for details. - T: The data type for the elements of input, filter, bias, and output - Tensors. - strides: 1-D tensor of length 4. The stride of the sliding window for - each - dimension of `input`. The dimension order is determined by the value - of - `data_format`, see below for details. - padding: The type of padding algorithm to use. - data_format: Specify the data format of the input and output data. With - the - default format "NHWC", the data is stored in the order of: - [batch, height, width, channels]. - Alternatively, the format could be "NCHW", the data storage order of: - [batch, channels, height, width]. - activation_mode: Specify the activation function to apply to the output - tensor - of bias add. Currently only supports "Relu". + Computes a fused 2-D convolution scaled by conv_input_scale, + adds an optional side input scaled by side_input_scale, adds biases, + and applies ReLU. As an equation: + output = ReLU(conv_input_scale * Conv(conv_input, filter) + + side_input_scale * side_input + bias) + Note: In int8 mode, The ReLU will clip the output to the range [0..127]. Args: - input_tensor: A `Tensor`. Must be one of the following types: `float32`. - filter_tensor: A `Tensor`. Must have the same type as `input`. - bias: A `Tensor`. Must have the same type as `input`. - strides: A list of `ints`. + conv_input: A `Tensor` of the format specified by `data_format`. + filter: A `Tensor` whose format depends on `data_format`: + if `data_format` is "NCHW_VECT_C", filter should be "OIHW_VECT_I" + otherwise, it should be "HWIO" format. + bias: A 1-D `Tensor` of type `float32`, and dimensions equal to the + number of output channels. + strides: A list of 4 `ints` specifying convolution strides. + if `data_format` is "NCHW" or "NCHW_VECT_C", the order should be NCHW. + if `data_format` is "NHWC", the order should be NHWC. padding: A `string` from: `"SAME", "VALID"`. - activation_mode: A `string` from: `"Sigmoid", "Relu", "Relu6", "ReluX", - "Tanh", "BandPass"`. - data_format: An optional `string` from: `"NHWC", "NCHW"`. Defaults to - `"NHWC"`. + conv_input_scale: A scalar `float32` that will be multiplied by conv_input. + This is optional and defaults to 1. However it should be set to + specify the quantization scale when `data_format` is "NCHW_VECT_C". + side_input_scale: A scalar `float32` that will be multiplied by side_input. + This is optional and defaults to 0. + side_input: A `Tensor` of the format specified by `data_format`. + This is useful for imlementing ResNet blocks. + activation_mode: (optional) currently must be the default "Relu". + Note that in qint8 mode, it also clips to 127, so acts like ReluX. + data_format: Specifies the data format. + Possible values are: + "NHWC" float [batch, height, width, channels] + "NCHW" float [batch, channels, height, width] + "NCHW_VECT_C" qint8 [batch, channels / 4, height, width, channels % 4] + Defaults to `"NHWC"`. + Performance is worst for `"NHWC"` and best for `"NCHW_VECT_C"`. + filter_format: Specifies the filter format. + Possible values are: + "HWIO" float [kernel_height, kernel_width, input_channels, + output_channels ] + "OIHW" float [output_channels, input_channels, kernel_height, + kernel_width ] + "OIHW_VECT_I" qint8 [ output_channels, input_channels / 4, + kernel_height, kernel_width, input_channels % 4 ] + Defaults to `"HWIO"`. name: A name for the operation (optional). Returns: - A `Tensor`. Has the same type as `input`. + A `Tensor` of the format specified by `data_format`. """ + if strides is None: + strides = [1, 1, 1, 1] + if side_input is None: + side_input = [] return gen_fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( - input=input_tensor, - filter=filter_tensor, - bias=bias, - strides=strides, + conv_input, + filter, + bias, padding=padding, + strides=strides, + conv_input_scale=conv_input_scale, + side_input_scale=side_input_scale, + side_input=side_input, activation_mode=activation_mode, data_format=data_format, + filter_format=filter_format, name=name) diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index 5d6a2fa3b83cc36b507947586c24fd2770ffb96a..3b8f7d6ed760647c4c61ce5ea60be1d7d17ddfa0 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -19,13 +19,16 @@ from __future__ import division from __future__ import print_function import numpy as np + from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -484,7 +487,8 @@ class FusedConv2DBiasActivationTest(test.TestCase): with self.test_session() as sess: # Illegal strides. with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "strides in the batch and depth"): + "Convolutional strides are not supported in " + "the batch or depth dimensions."): sess.run( fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( array_ops.placeholder(dtypes.float32), @@ -494,7 +498,8 @@ class FusedConv2DBiasActivationTest(test.TestCase): padding="SAME", activation_mode="Relu")) with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, - "strides in the batch and depth"): + "Convolutional strides are not supported in " + "the batch or depth dimensions."): sess.run( fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( array_ops.placeholder(dtypes.float32), @@ -552,6 +557,286 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding, return Test +def CalculateCovolvedOutputDim(input_dim, filter_dim, stride, padding_type): + """Calculates the size of an output dimension of a strided convolution. + + Given the sizes of the corresponding dimension of the input and filter shapes, + and the stride and padding_types, calculates the size of the output dimension. + This function can be called separately for each input dimension. + + Args: + input_dim: An `int` specifying the size of the input dimension. + filter_dim: An `int` specifying the size of the filter dimension. + stride: An `int` specifying the step size of the convolution along the + input dimension. + padding_type: either 'VALID' or 'SAME'. + + Returns: + The size of the output dimension. + """ + if padding_type == "VALID": + return (input_dim - filter_dim + stride) // stride + else: # padding_type == 'SAME' + return (input_dim + stride - 1) // stride + + +def NchwVectCToNchw(in_tensor): + # [N, C / 4, H, W, 4] => [N, C / 4, 4, H, W] == [N, C, H, W] + t = array_ops.transpose(in_tensor, [0, 1, 4, 2, 3]) + n = in_tensor.shape.dims[0].value + c = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value + h = in_tensor.shape.dims[2].value + w = in_tensor.shape.dims[3].value + return array_ops.reshape(t, [n, c, h, w]) + + +def OihwVectIToHwio(in_tensor): + # [O, I / 4, H, W, 4] => [O, I / 4, 4, H, W] == [O, I, H, W] + t = array_ops.transpose(in_tensor, [2, 3, 1, 4, 0]) + o = in_tensor.shape.dims[0].value + i = in_tensor.shape.dims[1].value * in_tensor.shape.dims[4].value + h = in_tensor.shape.dims[2].value + w = in_tensor.shape.dims[3].value + return array_ops.reshape(t, [h, w, i, o]) + + +def NchwToNchwVectC(in_tensor): + n, c, h, w = in_tensor.shape.as_list() + assert c % 4 == 0 + t = array_ops.reshape(in_tensor, [n, c // 4, 4, h, w]) + return array_ops.transpose(t, [0, 1, 3, 4, 2]) + + +def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel, + padding, strides, side_input_scale, + side_input, biases): + """Simulates the int8 fused 2-D convolution op using separate float ops. + + The arguments and return values have the same format, meanings and + restrictions as the actual op. + Args: + conv_input_scale: A scalar 'float'. + conv_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout. + kernel: A `Tensor` of type `qint8` in OIHW_VECT_I layout. + padding: A `string` from: `"SAME", "VALID"`. + strides: A list of `ints`. + side_input_scale: A scalar 'float'. + side_input: A `Tensor` of type `qint8` in NCHW_VECT_C layout. + biases: A `Tensor` of type `float32` in NCHW layout. + Returns: + A `Tensor` of type `qint8` in NCHW_VECT_C layout. + """ + conv_result = nn_ops.conv2d( + NchwVectCToNchw(gen_array_ops.dequantize(conv_input, -128, 127)), + OihwVectIToHwio(gen_array_ops.dequantize(kernel, -128, 127)), + strides=strides, + padding=padding, + data_format="NCHW") * conv_input_scale + + conv_and_side_inputs = conv_result + side_input_scale * NchwVectCToNchw( + gen_array_ops.dequantize(side_input, -128, 127)) + + logit = nn_ops.bias_add(conv_and_side_inputs, biases, data_format="NCHW") + + result, _, _ = gen_array_ops.quantize_v2( + NchwToNchwVectC(nn_ops.relu(logit)), -128, 127, dtypes.qint8) + return result + + +class FusedConvInt8Tests(test.TestCase): + _test_params = [ + { + "batch_size": 2, + "input_channels": 8, + "output_channels": 16, + "input_height": 8, + "input_width": 8, + "filter_height": 3, + "filter_width": 3, + "vertical_stride": 2, + "horizontal_stride": 2, + "conv_input_scale": 0.002, + "side_input_scale": 0.0, + "bias_scale": 1, + "padding_type": "VALID" + }, + { + "batch_size": 2, + "input_channels": 8, + "output_channels": 16, + "input_height": 8, + "input_width": 8, + "filter_height": 3, + "filter_width": 3, + "vertical_stride": 2, + "horizontal_stride": 2, + "conv_input_scale": 0.002, + "side_input_scale": 0.0, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 2, + "input_channels": 8, + "output_channels": 16, + "input_height": 8, + "input_width": 8, + "filter_height": 3, + "filter_width": 3, + "vertical_stride": 2, + "horizontal_stride": 2, + "conv_input_scale": 0.002, + "side_input_scale": 0.5, + "bias_scale": 1, + "padding_type": "VALID" + }, + { + "batch_size": 2, + "input_channels": 16, + "output_channels": 16, + "input_height": 9, + "input_width": 9, + "filter_height": 3, + "filter_width": 3, + "vertical_stride": 1, + "horizontal_stride": 1, + "conv_input_scale": 0.001, + "side_input_scale": 0.5, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 3, + "input_channels": 8, + "output_channels": 8, + "input_height": 9, + "input_width": 9, + "filter_height": 5, + "filter_width": 5, + "vertical_stride": 1, + "horizontal_stride": 1, + "conv_input_scale": 0.001, + "side_input_scale": 0.5, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 3, + "input_channels": 8, + "output_channels": 8, + "input_height": 9, + "input_width": 9, + "filter_height": 7, + "filter_width": 1, + "vertical_stride": 2, + "horizontal_stride": 1, + "conv_input_scale": 0.002, + "side_input_scale": 0.5, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 3, + "input_channels": 8, + "output_channels": 8, + "input_height": 9, + "input_width": 9, + "filter_height": 1, + "filter_width": 7, + "vertical_stride": 1, + "horizontal_stride": 1, + "conv_input_scale": 0.002, + "side_input_scale": 0.5, + "bias_scale": 1, + "padding_type": "SAME" + }, + ] + + def runTest(self, test_param): + batch_size = test_param["batch_size"] + input_channels = test_param["input_channels"] + output_channels = test_param["output_channels"] + input_height = test_param["input_height"] + input_width = test_param["input_width"] + filter_height = test_param["filter_height"] + filter_width = test_param["filter_width"] + vertical_stride = test_param["vertical_stride"] + horizontal_stride = test_param["horizontal_stride"] + conv_input_scale = test_param["conv_input_scale"] + side_input_scale = test_param["side_input_scale"] + bias_scale = test_param["bias_scale"] + padding_type = test_param["padding_type"] + + conv_input, _, _ = gen_array_ops.quantize_v2( + random_ops.random_uniform( + [batch_size, input_channels // 4, input_height, input_width, 4], + minval=-0.0, + maxval=1.0, + dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) + + kernel, _, _ = gen_array_ops.quantize_v2( + random_ops.random_uniform( + [ + output_channels, input_channels // 4, filter_height, + filter_width, 4 + ], + minval=-1.0, + maxval=1.0, + dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) + + output_height = CalculateCovolvedOutputDim(input_height, filter_height, + vertical_stride, padding_type) + output_width = CalculateCovolvedOutputDim(input_width, filter_width, + horizontal_stride, padding_type) + print("output_height=", output_height, ", output_width=", output_width) + + side_input, _, _ = gen_array_ops.quantize_v2( + random_ops.random_uniform( + [batch_size, output_channels // 4, output_height, output_width, 4], + minval=0.0, + maxval=1.0, + dtype=dtypes.float32), -1.0, 1.0, dtypes.qint8) + + biases = random_ops.random_uniform( + [output_channels], + minval=-10 * bias_scale, + maxval=20 * bias_scale, + dtype=dtypes.float32) + + strides = [1, 1, vertical_stride, horizontal_stride] + + actual = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + conv_input, + kernel, + biases, + strides=strides, + padding=padding_type, + conv_input_scale=conv_input_scale, + side_input_scale=side_input_scale, + side_input=side_input, + data_format="NCHW_VECT_C", + filter_format="OIHW_VECT_I") + + expected = SimulateFusedConv2dBiasActivationInt8( + conv_input_scale, conv_input, kernel, padding_type, strides, + side_input_scale, side_input, biases) + + with self.test_session(use_gpu=True) as sess: + actual_y, expected_y = sess.run([actual, expected]) + print("actual_y = ", actual_y) + print("expected_y = ", expected_y) + self.assertTrue(np.array_equal(actual_y, expected_y)) + + def testFusedConvInt8(self): + if not test.is_gpu_available( + cuda_only=True, min_cuda_compute_capability=(6, 1)): + tf_logging.info("int8 test skipped because not run with --config=cuda or " + "no GPUs with compute capability >= 6.1 are available.") + return + for test_param in self._test_params: + self.runTest(test_param) + + if __name__ == "__main__": for index, (input_size_, filter_size_, output_size_, stride_, padding_) in enumerate(GetShrunkInceptionShapes()): diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..b2de2823563c6b964c8b5b66c32141f56c0dc5b2 --- /dev/null +++ b/tensorflow/contrib/gan/BUILD @@ -0,0 +1,27 @@ +package(default_visibility = ["//tensorflow:__subpackages__"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "gan", + srcs = [ + "__init__.py", + ], + srcs_version = "PY2AND3", + deps = [ + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md new file mode 100644 index 0000000000000000000000000000000000000000..586e5ac331c42006d66b9ff1dfd9e0490e8a1004 --- /dev/null +++ b/tensorflow/contrib/gan/README.md @@ -0,0 +1,4 @@ +This directory contains the TFGAN project. + +This file will have more details as code is added. + diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a46b0e8d5de09a56585d6c5f5c103720b127031a --- /dev/null +++ b/tensorflow/contrib/gan/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TFGAN grouped API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index 645e364d19120983e4c3b1d11d796c1b2c30b40e..bebcf079ba444946bf0377106cbafcbaa7e94e74 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -62,6 +62,7 @@ tf_cuda_library( }), deps = [ ":gdr_proto_cc", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:gpu_runtime", "//tensorflow/core:lib", diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index e631c243c3cf9f0295f6972f0c20551b81f53e3d..a27bec48010746976ab93f2df9d123e3f03e0441 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -121,12 +121,9 @@ tf_gen_op_wrapper_py( cc_library( name = "image_ops_cc", - srcs = [ - "ops/image_ops.cc", - ], + srcs = ["ops/image_ops.cc"], deps = [ ":image_ops_kernels", - "//tensorflow/core", "//tensorflow/core:framework", ], alwayslink = 1, diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD index 7a562f727edfd8b4f458618b273509f9f5ef919e..7e0019ce4ad6c96e09ac9e222e2f4e2840273983 100644 --- a/tensorflow/contrib/keras/BUILD +++ b/tensorflow/contrib/keras/BUILD @@ -1,5 +1,6 @@ # Description: # Contains the Keras API (internal TensorFlow version). +# Note that tf.contrib.keras has been deprecated in favor of tf.keras. licenses(["notice"]) # Apache 2.0 @@ -7,9 +8,6 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_test") -load("//tensorflow:tensorflow.bzl", "py_test") - py_library( name = "keras", srcs = [ @@ -48,640 +46,10 @@ py_library( "api/keras/utils/__init__.py", "api/keras/wrappers/__init__.py", "api/keras/wrappers/scikit_learn/__init__.py", - "python/keras/__init__.py", - "python/keras/activations.py", - "python/keras/applications/__init__.py", - "python/keras/applications/imagenet_utils.py", - "python/keras/applications/inception_v3.py", - "python/keras/applications/mobilenet.py", - "python/keras/applications/resnet50.py", - "python/keras/applications/vgg16.py", - "python/keras/applications/vgg19.py", - "python/keras/applications/xception.py", - "python/keras/backend.py", - "python/keras/callbacks.py", - "python/keras/constraints.py", - "python/keras/datasets/__init__.py", - "python/keras/datasets/boston_housing.py", - "python/keras/datasets/cifar.py", - "python/keras/datasets/cifar10.py", - "python/keras/datasets/cifar100.py", - "python/keras/datasets/imdb.py", - "python/keras/datasets/mnist.py", - "python/keras/datasets/reuters.py", - "python/keras/engine/__init__.py", - "python/keras/engine/topology.py", - "python/keras/engine/training.py", - "python/keras/initializers.py", - "python/keras/layers/__init__.py", - "python/keras/layers/advanced_activations.py", - "python/keras/layers/convolutional.py", - "python/keras/layers/convolutional_recurrent.py", - "python/keras/layers/core.py", - "python/keras/layers/embeddings.py", - "python/keras/layers/local.py", - "python/keras/layers/merge.py", - "python/keras/layers/noise.py", - "python/keras/layers/normalization.py", - "python/keras/layers/pooling.py", - "python/keras/layers/recurrent.py", - "python/keras/layers/serialization.py", - "python/keras/layers/wrappers.py", - "python/keras/losses.py", - "python/keras/metrics.py", - "python/keras/models.py", - "python/keras/optimizers.py", - "python/keras/preprocessing/__init__.py", - "python/keras/preprocessing/image.py", - "python/keras/preprocessing/sequence.py", - "python/keras/preprocessing/text.py", - "python/keras/regularizers.py", - "python/keras/testing_utils.py", - "python/keras/utils/__init__.py", - "python/keras/utils/conv_utils.py", - "python/keras/utils/data_utils.py", - "python/keras/utils/generic_utils.py", - "python/keras/utils/io_utils.py", - "python/keras/utils/layer_utils.py", - "python/keras/utils/np_utils.py", - "python/keras/utils/vis_utils.py", - "python/keras/wrappers/__init__.py", - "python/keras/wrappers/scikit_learn.py", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/tensorboard:projector", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client", - "//tensorflow/python:clip_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:ctc_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework", - "//tensorflow/python:framework_ops", - "//tensorflow/python:functional_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:image_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:layers_base", - "//tensorflow/python:logging_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn", - "//tensorflow/python:platform", - "//tensorflow/python:random_ops", - "//tensorflow/python:sparse_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary", - "//tensorflow/python:tensor_array_grad", - "//tensorflow/python:tensor_array_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "@six_archive//:six", - ], -) - -py_test( - name = "integration_test", - size = "medium", - srcs = ["python/keras/integration_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//tensorflow/python:layers", - "//tensorflow/python:nn", - "//third_party/py/numpy", - ], -) - -py_test( - name = "activations_test", - size = "small", - srcs = ["python/keras/activations_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "constraints_test", - size = "small", - srcs = ["python/keras/constraints_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "initializers_test", - size = "small", - srcs = ["python/keras/initializers_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//tensorflow/python:init_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "regularizers_test", - size = "small", - srcs = ["python/keras/regularizers_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "optimizers_test", - size = "medium", - srcs = ["python/keras/optimizers_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//tensorflow/python:training", - "//third_party/py/numpy", - ], -) - -py_test( - name = "losses_test", - size = "small", - srcs = ["python/keras/losses_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "metrics_test", - size = "small", - srcs = ["python/keras/metrics_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "inception_v3_test", - size = "medium", - srcs = ["python/keras/applications/inception_v3_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "mobilenet_test", - size = "small", - srcs = ["python/keras/applications/mobilenet_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "resnet50_test", - size = "small", - srcs = ["python/keras/applications/resnet50_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "vgg16_test", - size = "small", - srcs = ["python/keras/applications/vgg16_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "vgg19_test", - size = "small", - srcs = ["python/keras/applications/vgg19_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "xception_test", - size = "medium", - srcs = ["python/keras/applications/xception_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "advanced_activations_test", - size = "small", - srcs = ["python/keras/layers/advanced_activations_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "convolutional_recurrent_test", - size = "medium", - srcs = ["python/keras/layers/convolutional_recurrent_test.py"], - shard_count = 2, - srcs_version = "PY2AND3", - tags = ["noasan"], # times out b/63678675 - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "convolutional_test", - size = "medium", - srcs = ["python/keras/layers/convolutional_test.py"], - srcs_version = "PY2AND3", - tags = [ - "manual", - "noasan", # times out b/63678675 - "notsan", - ], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "pooling_test", - size = "small", - srcs = ["python/keras/layers/pooling_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "core_test", - size = "small", - srcs = ["python/keras/layers/core_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "embeddings_test", - size = "small", - srcs = ["python/keras/layers/embeddings_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "local_test", - size = "medium", - srcs = ["python/keras/layers/local_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "merge_test", - size = "small", - srcs = ["python/keras/layers/merge_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "noise_test", - size = "small", - srcs = ["python/keras/layers/noise_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "normalization_test", - size = "small", - srcs = ["python/keras/layers/normalization_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "simplernn_test", - size = "medium", - srcs = ["python/keras/layers/simplernn_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "gru_test", - size = "medium", - srcs = ["python/keras/layers/gru_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], # http://b/62136390 - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "lstm_test", - size = "medium", - srcs = ["python/keras/layers/lstm_test.py"], - srcs_version = "PY2AND3", - tags = [ - "noasan", # times out b/63678675 - "notsan", # http://b/62189182 - ], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "serialization_test", - size = "small", - srcs = ["python/keras/layers/serialization_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "wrappers_test", - size = "small", - srcs = ["python/keras/layers/wrappers_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "scikit_learn_test", - size = "small", - srcs = ["python/keras/wrappers/scikit_learn_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "data_utils_test", - size = "small", - srcs = ["python/keras/utils/data_utils_test.py"], - srcs_version = "PY2AND3", - tags = [ - "noasan", # times out - "notsan", - ], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "generic_utils_test", - size = "small", - srcs = ["python/keras/utils/generic_utils_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - ], -) - -py_test( - name = "io_utils_test", - size = "small", - srcs = ["python/keras/utils/io_utils_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "imagenet_utils_test", - size = "small", - srcs = ["python/keras/applications/imagenet_utils_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "image_test", - size = "medium", - srcs = ["python/keras/preprocessing/image_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "sequence_test", - size = "small", - srcs = ["python/keras/preprocessing/sequence_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "text_test", - size = "small", - srcs = ["python/keras/preprocessing/text_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "callbacks_test", - size = "small", - srcs = ["python/keras/callbacks_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "training_test", - size = "medium", - srcs = ["python/keras/engine/training_test.py"], - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//third_party/py/numpy", - ], -) - -py_test( - name = "topology_test", - size = "small", - srcs = ["python/keras/engine/topology_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//third_party/py/numpy", - ], -) - -py_test( - name = "models_test", - size = "small", - srcs = ["python/keras/models_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//tensorflow/python:training", - "//third_party/py/numpy", - ], -) - -py_test( - name = "backend_test", - size = "small", - srcs = ["python/keras/backend_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":keras", - "//tensorflow/python:client_testlib", - "//tensorflow/python:util", - "//third_party/py/numpy", - ], -) - -py_library( - name = "testing_utils", - srcs = [ - "python/keras/testing_utils.py", ], srcs_version = "PY2AND3", deps = [ - ":keras", - "//tensorflow/python:util", - "//third_party/py/numpy", + "//tensorflow/python/keras", ], ) diff --git a/tensorflow/contrib/keras/README.md b/tensorflow/contrib/keras/README.md index db2556fe422c179737178f622a53d69d57282b8e..de4c81268d57d5ec7e88d4344c364e9cd99e5204 100644 --- a/tensorflow/contrib/keras/README.md +++ b/tensorflow/contrib/keras/README.md @@ -1,3 +1,6 @@ +NOTE: THE `tensorflow.contrib.keras` MODULE HAS BEEN DEPRECATED. +USE INSTEAD `tensorflow.keras`, PART OF CORE TENSORFLOW. + Keras is an object-oriented API for defining and training neural networks. This module contains a pure-TensorFlow implementation of the Keras API, diff --git a/tensorflow/contrib/keras/api/keras/activations/__init__.py b/tensorflow/contrib/keras/api/keras/activations/__init__.py index af6f249e71c9b6c5c23d0f3c9aef91e52b37e8a5..d04838c218d6643a703723a1d163c88547c14da7 100644 --- a/tensorflow/contrib/keras/api/keras/activations/__init__.py +++ b/tensorflow/contrib/keras/api/keras/activations/__init__.py @@ -19,22 +19,22 @@ from __future__ import division from __future__ import print_function # Activation functions. -from tensorflow.contrib.keras.python.keras.activations import elu -from tensorflow.contrib.keras.python.keras.activations import hard_sigmoid -from tensorflow.contrib.keras.python.keras.activations import linear -from tensorflow.contrib.keras.python.keras.activations import relu -from tensorflow.contrib.keras.python.keras.activations import selu -from tensorflow.contrib.keras.python.keras.activations import sigmoid -from tensorflow.contrib.keras.python.keras.activations import softmax -from tensorflow.contrib.keras.python.keras.activations import softplus -from tensorflow.contrib.keras.python.keras.activations import softsign -from tensorflow.contrib.keras.python.keras.activations import tanh +from tensorflow.python.keras._impl.keras.activations import elu +from tensorflow.python.keras._impl.keras.activations import hard_sigmoid +from tensorflow.python.keras._impl.keras.activations import linear +from tensorflow.python.keras._impl.keras.activations import relu +from tensorflow.python.keras._impl.keras.activations import selu +from tensorflow.python.keras._impl.keras.activations import sigmoid +from tensorflow.python.keras._impl.keras.activations import softmax +from tensorflow.python.keras._impl.keras.activations import softplus +from tensorflow.python.keras._impl.keras.activations import softsign +from tensorflow.python.keras._impl.keras.activations import tanh # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.activations import deserialize -from tensorflow.contrib.keras.python.keras.activations import serialize -from tensorflow.contrib.keras.python.keras.activations import get +from tensorflow.python.keras._impl.keras.activations import deserialize +from tensorflow.python.keras._impl.keras.activations import serialize +from tensorflow.python.keras._impl.keras.activations import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py index d8ca73fb97f8c7e1d817fcfcc3470e3f5317532a..abf8393ae45d71dc0cb746706abb72f77b82d199 100644 --- a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.applications.inception_v3 import decode_predictions -from tensorflow.contrib.keras.python.keras.applications.inception_v3 import InceptionV3 -from tensorflow.contrib.keras.python.keras.applications.inception_v3 import preprocess_input +from tensorflow.python.keras._impl.keras.applications.inception_v3 import decode_predictions +from tensorflow.python.keras._impl.keras.applications.inception_v3 import InceptionV3 +from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py index 594861fb51c7138dec7f4a8d9badf34cc3870594..b809e91193b459a46906443796344c092e1d2a6b 100644 --- a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.applications.mobilenet import decode_predictions -from tensorflow.contrib.keras.python.keras.applications.mobilenet import MobileNet -from tensorflow.contrib.keras.python.keras.applications.mobilenet import preprocess_input +from tensorflow.python.keras._impl.keras.applications.mobilenet import decode_predictions +from tensorflow.python.keras._impl.keras.applications.mobilenet import MobileNet +from tensorflow.python.keras._impl.keras.applications.mobilenet import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py index e9b25b66d5a065cad9aa5359407fe765c6c5ca6f..530805d150bfe32c5b81d7d7d3f92e203b83b602 100644 --- a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.applications.resnet50 import decode_predictions -from tensorflow.contrib.keras.python.keras.applications.resnet50 import preprocess_input -from tensorflow.contrib.keras.python.keras.applications.resnet50 import ResNet50 +from tensorflow.python.keras._impl.keras.applications.resnet50 import decode_predictions +from tensorflow.python.keras._impl.keras.applications.resnet50 import preprocess_input +from tensorflow.python.keras._impl.keras.applications.resnet50 import ResNet50 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py index 2a1f789cc51594ecbe4ee6ba0c91fedb6bb66516..118361604bbc7e0a88ed34243c0d5ea98856a301 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.applications.vgg16 import decode_predictions -from tensorflow.contrib.keras.python.keras.applications.vgg16 import preprocess_input -from tensorflow.contrib.keras.python.keras.applications.vgg16 import VGG16 +from tensorflow.python.keras._impl.keras.applications.vgg16 import decode_predictions +from tensorflow.python.keras._impl.keras.applications.vgg16 import preprocess_input +from tensorflow.python.keras._impl.keras.applications.vgg16 import VGG16 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py index 22b5e7c8e495c6f0e0b4fdcb657337a9b93b30b7..cda52628f3c10d65fdbe70b2f86cc12c771870a9 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.applications.vgg19 import decode_predictions -from tensorflow.contrib.keras.python.keras.applications.vgg19 import preprocess_input -from tensorflow.contrib.keras.python.keras.applications.vgg19 import VGG19 +from tensorflow.python.keras._impl.keras.applications.vgg19 import decode_predictions +from tensorflow.python.keras._impl.keras.applications.vgg19 import preprocess_input +from tensorflow.python.keras._impl.keras.applications.vgg19 import VGG19 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py index 23d1b6a0b371b49a19284819d7731abbfa2cf917..ae9cd9cd18c5ccc5ec37c8cd1bf36f8aabd9929c 100644 --- a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.applications.xception import decode_predictions -from tensorflow.contrib.keras.python.keras.applications.xception import preprocess_input -from tensorflow.contrib.keras.python.keras.applications.xception import Xception +from tensorflow.python.keras._impl.keras.applications.xception import decode_predictions +from tensorflow.python.keras._impl.keras.applications.xception import preprocess_input +from tensorflow.python.keras._impl.keras.applications.xception import Xception del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/backend/__init__.py b/tensorflow/contrib/keras/api/keras/backend/__init__.py index f3721a8dcb1cae66a1ab594985fd7f3a5ca4fa44..10ef5a75852deb6595bced2703d7c5f29b0efac3 100644 --- a/tensorflow/contrib/keras/api/keras/backend/__init__.py +++ b/tensorflow/contrib/keras/api/keras/backend/__init__.py @@ -19,144 +19,144 @@ from __future__ import division from __future__ import print_function # pylint: disable=redefined-builtin -from tensorflow.contrib.keras.python.keras.backend import abs -from tensorflow.contrib.keras.python.keras.backend import all -from tensorflow.contrib.keras.python.keras.backend import any -from tensorflow.contrib.keras.python.keras.backend import arange -from tensorflow.contrib.keras.python.keras.backend import argmax -from tensorflow.contrib.keras.python.keras.backend import argmin -from tensorflow.contrib.keras.python.keras.backend import backend -from tensorflow.contrib.keras.python.keras.backend import batch_dot -from tensorflow.contrib.keras.python.keras.backend import batch_flatten -from tensorflow.contrib.keras.python.keras.backend import batch_get_value -from tensorflow.contrib.keras.python.keras.backend import batch_normalization -from tensorflow.contrib.keras.python.keras.backend import batch_set_value -from tensorflow.contrib.keras.python.keras.backend import bias_add -from tensorflow.contrib.keras.python.keras.backend import binary_crossentropy -from tensorflow.contrib.keras.python.keras.backend import cast -from tensorflow.contrib.keras.python.keras.backend import cast_to_floatx -from tensorflow.contrib.keras.python.keras.backend import categorical_crossentropy -from tensorflow.contrib.keras.python.keras.backend import clear_session -from tensorflow.contrib.keras.python.keras.backend import clip -from tensorflow.contrib.keras.python.keras.backend import concatenate -from tensorflow.contrib.keras.python.keras.backend import constant -from tensorflow.contrib.keras.python.keras.backend import conv1d -from tensorflow.contrib.keras.python.keras.backend import conv2d -from tensorflow.contrib.keras.python.keras.backend import conv2d_transpose -from tensorflow.contrib.keras.python.keras.backend import conv3d -from tensorflow.contrib.keras.python.keras.backend import cos -from tensorflow.contrib.keras.python.keras.backend import count_params -from tensorflow.contrib.keras.python.keras.backend import ctc_batch_cost -from tensorflow.contrib.keras.python.keras.backend import ctc_decode -from tensorflow.contrib.keras.python.keras.backend import ctc_label_dense_to_sparse -from tensorflow.contrib.keras.python.keras.backend import dot -from tensorflow.contrib.keras.python.keras.backend import dropout -from tensorflow.contrib.keras.python.keras.backend import dtype -from tensorflow.contrib.keras.python.keras.backend import elu -from tensorflow.contrib.keras.python.keras.backend import epsilon -from tensorflow.contrib.keras.python.keras.backend import equal -from tensorflow.contrib.keras.python.keras.backend import eval -from tensorflow.contrib.keras.python.keras.backend import exp -from tensorflow.contrib.keras.python.keras.backend import expand_dims -from tensorflow.contrib.keras.python.keras.backend import eye -from tensorflow.contrib.keras.python.keras.backend import flatten -from tensorflow.contrib.keras.python.keras.backend import floatx -from tensorflow.contrib.keras.python.keras.backend import foldl -from tensorflow.contrib.keras.python.keras.backend import foldr -from tensorflow.contrib.keras.python.keras.backend import function -from tensorflow.contrib.keras.python.keras.backend import gather -from tensorflow.contrib.keras.python.keras.backend import get_session -from tensorflow.contrib.keras.python.keras.backend import get_uid -from tensorflow.contrib.keras.python.keras.backend import get_value -from tensorflow.contrib.keras.python.keras.backend import gradients -from tensorflow.contrib.keras.python.keras.backend import greater -from tensorflow.contrib.keras.python.keras.backend import greater_equal -from tensorflow.contrib.keras.python.keras.backend import hard_sigmoid -from tensorflow.contrib.keras.python.keras.backend import image_data_format -from tensorflow.contrib.keras.python.keras.backend import in_test_phase -from tensorflow.contrib.keras.python.keras.backend import in_top_k -from tensorflow.contrib.keras.python.keras.backend import in_train_phase -from tensorflow.contrib.keras.python.keras.backend import int_shape -from tensorflow.contrib.keras.python.keras.backend import is_sparse -from tensorflow.contrib.keras.python.keras.backend import l2_normalize -from tensorflow.contrib.keras.python.keras.backend import learning_phase -from tensorflow.contrib.keras.python.keras.backend import less -from tensorflow.contrib.keras.python.keras.backend import less_equal -from tensorflow.contrib.keras.python.keras.backend import log -from tensorflow.contrib.keras.python.keras.backend import manual_variable_initialization -from tensorflow.contrib.keras.python.keras.backend import map_fn -from tensorflow.contrib.keras.python.keras.backend import max -from tensorflow.contrib.keras.python.keras.backend import maximum -from tensorflow.contrib.keras.python.keras.backend import mean -from tensorflow.contrib.keras.python.keras.backend import min -from tensorflow.contrib.keras.python.keras.backend import minimum -from tensorflow.contrib.keras.python.keras.backend import moving_average_update -from tensorflow.contrib.keras.python.keras.backend import name_scope -from tensorflow.contrib.keras.python.keras.backend import ndim -from tensorflow.contrib.keras.python.keras.backend import normalize_batch_in_training -from tensorflow.contrib.keras.python.keras.backend import not_equal -from tensorflow.contrib.keras.python.keras.backend import one_hot -from tensorflow.contrib.keras.python.keras.backend import ones -from tensorflow.contrib.keras.python.keras.backend import ones_like -from tensorflow.contrib.keras.python.keras.backend import permute_dimensions -from tensorflow.contrib.keras.python.keras.backend import placeholder -from tensorflow.contrib.keras.python.keras.backend import pool2d -from tensorflow.contrib.keras.python.keras.backend import pool3d -from tensorflow.contrib.keras.python.keras.backend import pow -from tensorflow.contrib.keras.python.keras.backend import print_tensor -from tensorflow.contrib.keras.python.keras.backend import prod -from tensorflow.contrib.keras.python.keras.backend import random_binomial -from tensorflow.contrib.keras.python.keras.backend import random_normal -from tensorflow.contrib.keras.python.keras.backend import random_normal_variable -from tensorflow.contrib.keras.python.keras.backend import random_uniform -from tensorflow.contrib.keras.python.keras.backend import random_uniform_variable -from tensorflow.contrib.keras.python.keras.backend import relu -from tensorflow.contrib.keras.python.keras.backend import repeat -from tensorflow.contrib.keras.python.keras.backend import repeat_elements -from tensorflow.contrib.keras.python.keras.backend import reset_uids -from tensorflow.contrib.keras.python.keras.backend import reshape -from tensorflow.contrib.keras.python.keras.backend import resize_images -from tensorflow.contrib.keras.python.keras.backend import resize_volumes -from tensorflow.contrib.keras.python.keras.backend import reverse -from tensorflow.contrib.keras.python.keras.backend import rnn -from tensorflow.contrib.keras.python.keras.backend import round -from tensorflow.contrib.keras.python.keras.backend import separable_conv2d -from tensorflow.contrib.keras.python.keras.backend import set_epsilon -from tensorflow.contrib.keras.python.keras.backend import set_floatx -from tensorflow.contrib.keras.python.keras.backend import set_image_data_format -from tensorflow.contrib.keras.python.keras.backend import set_learning_phase -from tensorflow.contrib.keras.python.keras.backend import set_session -from tensorflow.contrib.keras.python.keras.backend import set_value -from tensorflow.contrib.keras.python.keras.backend import shape -from tensorflow.contrib.keras.python.keras.backend import sigmoid -from tensorflow.contrib.keras.python.keras.backend import sign -from tensorflow.contrib.keras.python.keras.backend import sin -from tensorflow.contrib.keras.python.keras.backend import softmax -from tensorflow.contrib.keras.python.keras.backend import softplus -from tensorflow.contrib.keras.python.keras.backend import softsign -from tensorflow.contrib.keras.python.keras.backend import sparse_categorical_crossentropy -from tensorflow.contrib.keras.python.keras.backend import spatial_2d_padding -from tensorflow.contrib.keras.python.keras.backend import spatial_3d_padding -from tensorflow.contrib.keras.python.keras.backend import sqrt -from tensorflow.contrib.keras.python.keras.backend import square -from tensorflow.contrib.keras.python.keras.backend import squeeze -from tensorflow.contrib.keras.python.keras.backend import stack -from tensorflow.contrib.keras.python.keras.backend import std -from tensorflow.contrib.keras.python.keras.backend import stop_gradient -from tensorflow.contrib.keras.python.keras.backend import sum -from tensorflow.contrib.keras.python.keras.backend import switch -from tensorflow.contrib.keras.python.keras.backend import tanh -from tensorflow.contrib.keras.python.keras.backend import temporal_padding -from tensorflow.contrib.keras.python.keras.backend import to_dense -from tensorflow.contrib.keras.python.keras.backend import transpose -from tensorflow.contrib.keras.python.keras.backend import truncated_normal -from tensorflow.contrib.keras.python.keras.backend import update -from tensorflow.contrib.keras.python.keras.backend import update_add -from tensorflow.contrib.keras.python.keras.backend import update_sub -from tensorflow.contrib.keras.python.keras.backend import var -from tensorflow.contrib.keras.python.keras.backend import variable -from tensorflow.contrib.keras.python.keras.backend import zeros -from tensorflow.contrib.keras.python.keras.backend import zeros_like +from tensorflow.python.keras._impl.keras.backend import abs +from tensorflow.python.keras._impl.keras.backend import all +from tensorflow.python.keras._impl.keras.backend import any +from tensorflow.python.keras._impl.keras.backend import arange +from tensorflow.python.keras._impl.keras.backend import argmax +from tensorflow.python.keras._impl.keras.backend import argmin +from tensorflow.python.keras._impl.keras.backend import backend +from tensorflow.python.keras._impl.keras.backend import batch_dot +from tensorflow.python.keras._impl.keras.backend import batch_flatten +from tensorflow.python.keras._impl.keras.backend import batch_get_value +from tensorflow.python.keras._impl.keras.backend import batch_normalization +from tensorflow.python.keras._impl.keras.backend import batch_set_value +from tensorflow.python.keras._impl.keras.backend import bias_add +from tensorflow.python.keras._impl.keras.backend import binary_crossentropy +from tensorflow.python.keras._impl.keras.backend import cast +from tensorflow.python.keras._impl.keras.backend import cast_to_floatx +from tensorflow.python.keras._impl.keras.backend import categorical_crossentropy +from tensorflow.python.keras._impl.keras.backend import clear_session +from tensorflow.python.keras._impl.keras.backend import clip +from tensorflow.python.keras._impl.keras.backend import concatenate +from tensorflow.python.keras._impl.keras.backend import constant +from tensorflow.python.keras._impl.keras.backend import conv1d +from tensorflow.python.keras._impl.keras.backend import conv2d +from tensorflow.python.keras._impl.keras.backend import conv2d_transpose +from tensorflow.python.keras._impl.keras.backend import conv3d +from tensorflow.python.keras._impl.keras.backend import cos +from tensorflow.python.keras._impl.keras.backend import count_params +from tensorflow.python.keras._impl.keras.backend import ctc_batch_cost +from tensorflow.python.keras._impl.keras.backend import ctc_decode +from tensorflow.python.keras._impl.keras.backend import ctc_label_dense_to_sparse +from tensorflow.python.keras._impl.keras.backend import dot +from tensorflow.python.keras._impl.keras.backend import dropout +from tensorflow.python.keras._impl.keras.backend import dtype +from tensorflow.python.keras._impl.keras.backend import elu +from tensorflow.python.keras._impl.keras.backend import epsilon +from tensorflow.python.keras._impl.keras.backend import equal +from tensorflow.python.keras._impl.keras.backend import eval +from tensorflow.python.keras._impl.keras.backend import exp +from tensorflow.python.keras._impl.keras.backend import expand_dims +from tensorflow.python.keras._impl.keras.backend import eye +from tensorflow.python.keras._impl.keras.backend import flatten +from tensorflow.python.keras._impl.keras.backend import floatx +from tensorflow.python.keras._impl.keras.backend import foldl +from tensorflow.python.keras._impl.keras.backend import foldr +from tensorflow.python.keras._impl.keras.backend import function +from tensorflow.python.keras._impl.keras.backend import gather +from tensorflow.python.keras._impl.keras.backend import get_session +from tensorflow.python.keras._impl.keras.backend import get_uid +from tensorflow.python.keras._impl.keras.backend import get_value +from tensorflow.python.keras._impl.keras.backend import gradients +from tensorflow.python.keras._impl.keras.backend import greater +from tensorflow.python.keras._impl.keras.backend import greater_equal +from tensorflow.python.keras._impl.keras.backend import hard_sigmoid +from tensorflow.python.keras._impl.keras.backend import image_data_format +from tensorflow.python.keras._impl.keras.backend import in_test_phase +from tensorflow.python.keras._impl.keras.backend import in_top_k +from tensorflow.python.keras._impl.keras.backend import in_train_phase +from tensorflow.python.keras._impl.keras.backend import int_shape +from tensorflow.python.keras._impl.keras.backend import is_sparse +from tensorflow.python.keras._impl.keras.backend import l2_normalize +from tensorflow.python.keras._impl.keras.backend import learning_phase +from tensorflow.python.keras._impl.keras.backend import less +from tensorflow.python.keras._impl.keras.backend import less_equal +from tensorflow.python.keras._impl.keras.backend import log +from tensorflow.python.keras._impl.keras.backend import manual_variable_initialization +from tensorflow.python.keras._impl.keras.backend import map_fn +from tensorflow.python.keras._impl.keras.backend import max +from tensorflow.python.keras._impl.keras.backend import maximum +from tensorflow.python.keras._impl.keras.backend import mean +from tensorflow.python.keras._impl.keras.backend import min +from tensorflow.python.keras._impl.keras.backend import minimum +from tensorflow.python.keras._impl.keras.backend import moving_average_update +from tensorflow.python.keras._impl.keras.backend import name_scope +from tensorflow.python.keras._impl.keras.backend import ndim +from tensorflow.python.keras._impl.keras.backend import normalize_batch_in_training +from tensorflow.python.keras._impl.keras.backend import not_equal +from tensorflow.python.keras._impl.keras.backend import one_hot +from tensorflow.python.keras._impl.keras.backend import ones +from tensorflow.python.keras._impl.keras.backend import ones_like +from tensorflow.python.keras._impl.keras.backend import permute_dimensions +from tensorflow.python.keras._impl.keras.backend import placeholder +from tensorflow.python.keras._impl.keras.backend import pool2d +from tensorflow.python.keras._impl.keras.backend import pool3d +from tensorflow.python.keras._impl.keras.backend import pow +from tensorflow.python.keras._impl.keras.backend import print_tensor +from tensorflow.python.keras._impl.keras.backend import prod +from tensorflow.python.keras._impl.keras.backend import random_binomial +from tensorflow.python.keras._impl.keras.backend import random_normal +from tensorflow.python.keras._impl.keras.backend import random_normal_variable +from tensorflow.python.keras._impl.keras.backend import random_uniform +from tensorflow.python.keras._impl.keras.backend import random_uniform_variable +from tensorflow.python.keras._impl.keras.backend import relu +from tensorflow.python.keras._impl.keras.backend import repeat +from tensorflow.python.keras._impl.keras.backend import repeat_elements +from tensorflow.python.keras._impl.keras.backend import reset_uids +from tensorflow.python.keras._impl.keras.backend import reshape +from tensorflow.python.keras._impl.keras.backend import resize_images +from tensorflow.python.keras._impl.keras.backend import resize_volumes +from tensorflow.python.keras._impl.keras.backend import reverse +from tensorflow.python.keras._impl.keras.backend import rnn +from tensorflow.python.keras._impl.keras.backend import round +from tensorflow.python.keras._impl.keras.backend import separable_conv2d +from tensorflow.python.keras._impl.keras.backend import set_epsilon +from tensorflow.python.keras._impl.keras.backend import set_floatx +from tensorflow.python.keras._impl.keras.backend import set_image_data_format +from tensorflow.python.keras._impl.keras.backend import set_learning_phase +from tensorflow.python.keras._impl.keras.backend import set_session +from tensorflow.python.keras._impl.keras.backend import set_value +from tensorflow.python.keras._impl.keras.backend import shape +from tensorflow.python.keras._impl.keras.backend import sigmoid +from tensorflow.python.keras._impl.keras.backend import sign +from tensorflow.python.keras._impl.keras.backend import sin +from tensorflow.python.keras._impl.keras.backend import softmax +from tensorflow.python.keras._impl.keras.backend import softplus +from tensorflow.python.keras._impl.keras.backend import softsign +from tensorflow.python.keras._impl.keras.backend import sparse_categorical_crossentropy +from tensorflow.python.keras._impl.keras.backend import spatial_2d_padding +from tensorflow.python.keras._impl.keras.backend import spatial_3d_padding +from tensorflow.python.keras._impl.keras.backend import sqrt +from tensorflow.python.keras._impl.keras.backend import square +from tensorflow.python.keras._impl.keras.backend import squeeze +from tensorflow.python.keras._impl.keras.backend import stack +from tensorflow.python.keras._impl.keras.backend import std +from tensorflow.python.keras._impl.keras.backend import stop_gradient +from tensorflow.python.keras._impl.keras.backend import sum +from tensorflow.python.keras._impl.keras.backend import switch +from tensorflow.python.keras._impl.keras.backend import tanh +from tensorflow.python.keras._impl.keras.backend import temporal_padding +from tensorflow.python.keras._impl.keras.backend import to_dense +from tensorflow.python.keras._impl.keras.backend import transpose +from tensorflow.python.keras._impl.keras.backend import truncated_normal +from tensorflow.python.keras._impl.keras.backend import update +from tensorflow.python.keras._impl.keras.backend import update_add +from tensorflow.python.keras._impl.keras.backend import update_sub +from tensorflow.python.keras._impl.keras.backend import var +from tensorflow.python.keras._impl.keras.backend import variable +from tensorflow.python.keras._impl.keras.backend import zeros +from tensorflow.python.keras._impl.keras.backend import zeros_like del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py index 3a970748573004cc43ae0d15d07576d678fee3e3..2d884790ddb9ccf49649c6af4cfd40cddbc38cb3 100644 --- a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py +++ b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py @@ -18,19 +18,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.callbacks import BaseLogger -from tensorflow.contrib.keras.python.keras.callbacks import Callback -from tensorflow.contrib.keras.python.keras.callbacks import CSVLogger -from tensorflow.contrib.keras.python.keras.callbacks import EarlyStopping -from tensorflow.contrib.keras.python.keras.callbacks import History -from tensorflow.contrib.keras.python.keras.callbacks import LambdaCallback -from tensorflow.contrib.keras.python.keras.callbacks import LearningRateScheduler -from tensorflow.contrib.keras.python.keras.callbacks import ModelCheckpoint -from tensorflow.contrib.keras.python.keras.callbacks import ProgbarLogger -from tensorflow.contrib.keras.python.keras.callbacks import ReduceLROnPlateau -from tensorflow.contrib.keras.python.keras.callbacks import RemoteMonitor -from tensorflow.contrib.keras.python.keras.callbacks import TensorBoard -from tensorflow.contrib.keras.python.keras.callbacks import TerminateOnNaN +from tensorflow.python.keras._impl.keras.callbacks import BaseLogger +from tensorflow.python.keras._impl.keras.callbacks import Callback +from tensorflow.python.keras._impl.keras.callbacks import CSVLogger +from tensorflow.python.keras._impl.keras.callbacks import EarlyStopping +from tensorflow.python.keras._impl.keras.callbacks import History +from tensorflow.python.keras._impl.keras.callbacks import LambdaCallback +from tensorflow.python.keras._impl.keras.callbacks import LearningRateScheduler +from tensorflow.python.keras._impl.keras.callbacks import ModelCheckpoint +from tensorflow.python.keras._impl.keras.callbacks import ProgbarLogger +from tensorflow.python.keras._impl.keras.callbacks import ReduceLROnPlateau +from tensorflow.python.keras._impl.keras.callbacks import RemoteMonitor +from tensorflow.python.keras._impl.keras.callbacks import TensorBoard +from tensorflow.python.keras._impl.keras.callbacks import TerminateOnNaN del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/constraints/__init__.py b/tensorflow/contrib/keras/api/keras/constraints/__init__.py index 6b9e3bf46e37faf1804fd9fa8202968da29d1fde..152606d8ebbcadf57d971d508e15283da65e4aa3 100644 --- a/tensorflow/contrib/keras/api/keras/constraints/__init__.py +++ b/tensorflow/contrib/keras/api/keras/constraints/__init__.py @@ -19,21 +19,21 @@ from __future__ import division from __future__ import print_function # Constraints functions / callable classes. -from tensorflow.contrib.keras.python.keras.constraints import Constraint -from tensorflow.contrib.keras.python.keras.constraints import max_norm -from tensorflow.contrib.keras.python.keras.constraints import MaxNorm -from tensorflow.contrib.keras.python.keras.constraints import min_max_norm -from tensorflow.contrib.keras.python.keras.constraints import MinMaxNorm -from tensorflow.contrib.keras.python.keras.constraints import non_neg -from tensorflow.contrib.keras.python.keras.constraints import NonNeg -from tensorflow.contrib.keras.python.keras.constraints import unit_norm -from tensorflow.contrib.keras.python.keras.constraints import UnitNorm +from tensorflow.python.keras._impl.keras.constraints import Constraint +from tensorflow.python.keras._impl.keras.constraints import max_norm +from tensorflow.python.keras._impl.keras.constraints import MaxNorm +from tensorflow.python.keras._impl.keras.constraints import min_max_norm +from tensorflow.python.keras._impl.keras.constraints import MinMaxNorm +from tensorflow.python.keras._impl.keras.constraints import non_neg +from tensorflow.python.keras._impl.keras.constraints import NonNeg +from tensorflow.python.keras._impl.keras.constraints import unit_norm +from tensorflow.python.keras._impl.keras.constraints import UnitNorm # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.constraints import deserialize -from tensorflow.contrib.keras.python.keras.constraints import serialize -from tensorflow.contrib.keras.python.keras.constraints import get +from tensorflow.python.keras._impl.keras.constraints import deserialize +from tensorflow.python.keras._impl.keras.constraints import serialize +from tensorflow.python.keras._impl.keras.constraints import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py index 0bfd3df5401794098a4da2bcb45446e4c4ba6397..b5371a03fd5f5755ba8844415276113c565f52db 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.datasets.boston_housing import load_data +from tensorflow.python.keras._impl.keras.datasets.boston_housing import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py index f5fac6982ac7ed6888f8126126a986d6ca81122a..68d3eb789ea2c410095c0c75e0b79a9b07d209a3 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.datasets.cifar10 import load_data +from tensorflow.python.keras._impl.keras.datasets.cifar10 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py index a7e69961363bbebc233da82f3998f88d18db198b..ca93742673341660ba69712feb59c5dd32ea3252 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.datasets.cifar100 import load_data +from tensorflow.python.keras._impl.keras.datasets.cifar100 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py index f141c8a8e981306bb7a9b0fdc031c1ca83370508..1c6396d2d32b88eaa900a5af4e62c7484fceab63 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.datasets.imdb import get_word_index -from tensorflow.contrib.keras.python.keras.datasets.imdb import load_data +from tensorflow.python.keras._impl.keras.datasets.imdb import get_word_index +from tensorflow.python.keras._impl.keras.datasets.imdb import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py index 50b74f149c10496e796c03485c86abc2dbe270d5..364255f3387b59a419c010db9b93cdfbcba36186 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.datasets.mnist import load_data +from tensorflow.python.keras._impl.keras.datasets.mnist import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py index fc7f1235a3aad78389147d6c574a72a00238beb7..bb6791a344ad0c372ac60cd4a332f5632841dd46 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.datasets.reuters import get_word_index -from tensorflow.contrib.keras.python.keras.datasets.reuters import load_data +from tensorflow.python.keras._impl.keras.datasets.reuters import get_word_index +from tensorflow.python.keras._impl.keras.datasets.reuters import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/initializers/__init__.py b/tensorflow/contrib/keras/api/keras/initializers/__init__.py index 9b58723ed5c93d441b2ae8976d5acaba2db3ad40..6b1fcfd2d9585d19ae3fd9705e128b19b1ec40e7 100644 --- a/tensorflow/contrib/keras/api/keras/initializers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/initializers/__init__.py @@ -19,30 +19,30 @@ from __future__ import division from __future__ import print_function # Initializer functions / callable classes. -from tensorflow.contrib.keras.python.keras.initializers import Constant -from tensorflow.contrib.keras.python.keras.initializers import Identity -from tensorflow.contrib.keras.python.keras.initializers import Initializer -from tensorflow.contrib.keras.python.keras.initializers import Ones -from tensorflow.contrib.keras.python.keras.initializers import Orthogonal -from tensorflow.contrib.keras.python.keras.initializers import RandomNormal -from tensorflow.contrib.keras.python.keras.initializers import RandomUniform -from tensorflow.contrib.keras.python.keras.initializers import TruncatedNormal -from tensorflow.contrib.keras.python.keras.initializers import VarianceScaling -from tensorflow.contrib.keras.python.keras.initializers import Zeros +from tensorflow.python.keras._impl.keras.initializers import Constant +from tensorflow.python.keras._impl.keras.initializers import Identity +from tensorflow.python.keras._impl.keras.initializers import Initializer +from tensorflow.python.keras._impl.keras.initializers import Ones +from tensorflow.python.keras._impl.keras.initializers import Orthogonal +from tensorflow.python.keras._impl.keras.initializers import RandomNormal +from tensorflow.python.keras._impl.keras.initializers import RandomUniform +from tensorflow.python.keras._impl.keras.initializers import TruncatedNormal +from tensorflow.python.keras._impl.keras.initializers import VarianceScaling +from tensorflow.python.keras._impl.keras.initializers import Zeros # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.initializers import glorot_normal -from tensorflow.contrib.keras.python.keras.initializers import glorot_uniform -from tensorflow.contrib.keras.python.keras.initializers import he_normal -from tensorflow.contrib.keras.python.keras.initializers import he_uniform -from tensorflow.contrib.keras.python.keras.initializers import lecun_normal -from tensorflow.contrib.keras.python.keras.initializers import lecun_uniform +from tensorflow.python.keras._impl.keras.initializers import glorot_normal +from tensorflow.python.keras._impl.keras.initializers import glorot_uniform +from tensorflow.python.keras._impl.keras.initializers import he_normal +from tensorflow.python.keras._impl.keras.initializers import he_uniform +from tensorflow.python.keras._impl.keras.initializers import lecun_normal +from tensorflow.python.keras._impl.keras.initializers import lecun_uniform # Auxiliary utils. -from tensorflow.contrib.keras.python.keras.initializers import deserialize -from tensorflow.contrib.keras.python.keras.initializers import serialize -from tensorflow.contrib.keras.python.keras.initializers import get +from tensorflow.python.keras._impl.keras.initializers import deserialize +from tensorflow.python.keras._impl.keras.initializers import serialize +from tensorflow.python.keras._impl.keras.initializers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/layers/__init__.py b/tensorflow/contrib/keras/api/keras/layers/__init__.py index aafd18921754657be4eb06de98dd52c6ca579564..acf0a5e1799b7c57dfd82861c9ccc1f132c34375 100644 --- a/tensorflow/contrib/keras/api/keras/layers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/layers/__init__.py @@ -20,128 +20,128 @@ from __future__ import print_function # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.engine import Input -from tensorflow.contrib.keras.python.keras.engine import InputLayer -from tensorflow.contrib.keras.python.keras.engine import InputSpec -from tensorflow.contrib.keras.python.keras.engine import Layer +from tensorflow.python.keras._impl.keras.engine import Input +from tensorflow.python.keras._impl.keras.engine import InputLayer +from tensorflow.python.keras._impl.keras.engine import InputSpec +from tensorflow.python.keras._impl.keras.engine import Layer # Advanced activations. -from tensorflow.contrib.keras.python.keras.layers.advanced_activations import LeakyReLU -from tensorflow.contrib.keras.python.keras.layers.advanced_activations import PReLU -from tensorflow.contrib.keras.python.keras.layers.advanced_activations import ELU -from tensorflow.contrib.keras.python.keras.layers.advanced_activations import ThresholdedReLU +from tensorflow.python.keras._impl.keras.layers.advanced_activations import LeakyReLU +from tensorflow.python.keras._impl.keras.layers.advanced_activations import PReLU +from tensorflow.python.keras._impl.keras.layers.advanced_activations import ELU +from tensorflow.python.keras._impl.keras.layers.advanced_activations import ThresholdedReLU # Convolution layers. -from tensorflow.contrib.keras.python.keras.layers.convolutional import Conv1D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Conv2D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Conv3D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Conv2DTranspose -from tensorflow.contrib.keras.python.keras.layers.convolutional import Conv3DTranspose -from tensorflow.contrib.keras.python.keras.layers.convolutional import SeparableConv2D +from tensorflow.python.keras._impl.keras.layers.convolutional import Conv1D +from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2D +from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3D +from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2DTranspose +from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3DTranspose +from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv2D # Convolution layer aliases. -from tensorflow.contrib.keras.python.keras.layers.convolutional import Convolution1D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Convolution2D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Convolution3D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Convolution2DTranspose -from tensorflow.contrib.keras.python.keras.layers.convolutional import Convolution3DTranspose -from tensorflow.contrib.keras.python.keras.layers.convolutional import SeparableConvolution2D +from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution1D +from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2D +from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3D +from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2DTranspose +from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose +from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D # Image processing layers. -from tensorflow.contrib.keras.python.keras.layers.convolutional import UpSampling1D -from tensorflow.contrib.keras.python.keras.layers.convolutional import UpSampling2D -from tensorflow.contrib.keras.python.keras.layers.convolutional import UpSampling3D -from tensorflow.contrib.keras.python.keras.layers.convolutional import ZeroPadding1D -from tensorflow.contrib.keras.python.keras.layers.convolutional import ZeroPadding2D -from tensorflow.contrib.keras.python.keras.layers.convolutional import ZeroPadding3D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Cropping1D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Cropping2D -from tensorflow.contrib.keras.python.keras.layers.convolutional import Cropping3D +from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling1D +from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling2D +from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling3D +from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding1D +from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding2D +from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding3D +from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping1D +from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping2D +from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping3D # Convolutional-recurrent layers. -from tensorflow.contrib.keras.python.keras.layers.convolutional_recurrent import ConvLSTM2D +from tensorflow.python.keras._impl.keras.layers.convolutional_recurrent import ConvLSTM2D # Core layers. -from tensorflow.contrib.keras.python.keras.layers.core import Masking -from tensorflow.contrib.keras.python.keras.layers.core import Dropout -from tensorflow.contrib.keras.python.keras.layers.core import SpatialDropout1D -from tensorflow.contrib.keras.python.keras.layers.core import SpatialDropout2D -from tensorflow.contrib.keras.python.keras.layers.core import SpatialDropout3D -from tensorflow.contrib.keras.python.keras.layers.core import Activation -from tensorflow.contrib.keras.python.keras.layers.core import Reshape -from tensorflow.contrib.keras.python.keras.layers.core import Permute -from tensorflow.contrib.keras.python.keras.layers.core import Flatten -from tensorflow.contrib.keras.python.keras.layers.core import RepeatVector -from tensorflow.contrib.keras.python.keras.layers.core import Lambda -from tensorflow.contrib.keras.python.keras.layers.core import Dense -from tensorflow.contrib.keras.python.keras.layers.core import ActivityRegularization +from tensorflow.python.keras._impl.keras.layers.core import Masking +from tensorflow.python.keras._impl.keras.layers.core import Dropout +from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout1D +from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout2D +from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout3D +from tensorflow.python.keras._impl.keras.layers.core import Activation +from tensorflow.python.keras._impl.keras.layers.core import Reshape +from tensorflow.python.keras._impl.keras.layers.core import Permute +from tensorflow.python.keras._impl.keras.layers.core import Flatten +from tensorflow.python.keras._impl.keras.layers.core import RepeatVector +from tensorflow.python.keras._impl.keras.layers.core import Lambda +from tensorflow.python.keras._impl.keras.layers.core import Dense +from tensorflow.python.keras._impl.keras.layers.core import ActivityRegularization # Embedding layers. -from tensorflow.contrib.keras.python.keras.layers.embeddings import Embedding +from tensorflow.python.keras._impl.keras.layers.embeddings import Embedding # Locally-connected layers. -from tensorflow.contrib.keras.python.keras.layers.local import LocallyConnected1D -from tensorflow.contrib.keras.python.keras.layers.local import LocallyConnected2D +from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected1D +from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected2D # Merge layers. -from tensorflow.contrib.keras.python.keras.layers.merge import Add -from tensorflow.contrib.keras.python.keras.layers.merge import Multiply -from tensorflow.contrib.keras.python.keras.layers.merge import Average -from tensorflow.contrib.keras.python.keras.layers.merge import Maximum -from tensorflow.contrib.keras.python.keras.layers.merge import Concatenate -from tensorflow.contrib.keras.python.keras.layers.merge import Dot -from tensorflow.contrib.keras.python.keras.layers.merge import add -from tensorflow.contrib.keras.python.keras.layers.merge import multiply -from tensorflow.contrib.keras.python.keras.layers.merge import average -from tensorflow.contrib.keras.python.keras.layers.merge import maximum -from tensorflow.contrib.keras.python.keras.layers.merge import concatenate -from tensorflow.contrib.keras.python.keras.layers.merge import dot +from tensorflow.python.keras._impl.keras.layers.merge import Add +from tensorflow.python.keras._impl.keras.layers.merge import Multiply +from tensorflow.python.keras._impl.keras.layers.merge import Average +from tensorflow.python.keras._impl.keras.layers.merge import Maximum +from tensorflow.python.keras._impl.keras.layers.merge import Concatenate +from tensorflow.python.keras._impl.keras.layers.merge import Dot +from tensorflow.python.keras._impl.keras.layers.merge import add +from tensorflow.python.keras._impl.keras.layers.merge import multiply +from tensorflow.python.keras._impl.keras.layers.merge import average +from tensorflow.python.keras._impl.keras.layers.merge import maximum +from tensorflow.python.keras._impl.keras.layers.merge import concatenate +from tensorflow.python.keras._impl.keras.layers.merge import dot # Noise layers. -from tensorflow.contrib.keras.python.keras.layers.noise import AlphaDropout -from tensorflow.contrib.keras.python.keras.layers.noise import GaussianNoise -from tensorflow.contrib.keras.python.keras.layers.noise import GaussianDropout +from tensorflow.python.keras._impl.keras.layers.noise import AlphaDropout +from tensorflow.python.keras._impl.keras.layers.noise import GaussianNoise +from tensorflow.python.keras._impl.keras.layers.noise import GaussianDropout # Normalization layers. -from tensorflow.contrib.keras.python.keras.layers.normalization import BatchNormalization +from tensorflow.python.keras._impl.keras.layers.normalization import BatchNormalization # Pooling layers. -from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPooling1D -from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPooling2D -from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPooling3D -from tensorflow.contrib.keras.python.keras.layers.pooling import AveragePooling1D -from tensorflow.contrib.keras.python.keras.layers.pooling import AveragePooling2D -from tensorflow.contrib.keras.python.keras.layers.pooling import AveragePooling3D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalAveragePooling1D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalAveragePooling2D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalAveragePooling3D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalMaxPooling1D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalMaxPooling2D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalMaxPooling3D +from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling1D +from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling2D +from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling3D +from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D +from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling2D +from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling3D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling1D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling2D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling3D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling1D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling2D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling3D # Pooling layer aliases. -from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPool1D -from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPool2D -from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPool3D -from tensorflow.contrib.keras.python.keras.layers.pooling import AvgPool1D -from tensorflow.contrib.keras.python.keras.layers.pooling import AvgPool2D -from tensorflow.contrib.keras.python.keras.layers.pooling import AvgPool3D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalAvgPool1D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalAvgPool2D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalAvgPool3D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalMaxPool1D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalMaxPool2D -from tensorflow.contrib.keras.python.keras.layers.pooling import GlobalMaxPool3D +from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool1D +from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool2D +from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool3D +from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool1D +from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool2D +from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool3D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool1D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool2D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool3D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool1D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool2D +from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool3D # Recurrent layers. -from tensorflow.contrib.keras.python.keras.layers.recurrent import SimpleRNN -from tensorflow.contrib.keras.python.keras.layers.recurrent import GRU -from tensorflow.contrib.keras.python.keras.layers.recurrent import LSTM +from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNN +from tensorflow.python.keras._impl.keras.layers.recurrent import GRU +from tensorflow.python.keras._impl.keras.layers.recurrent import LSTM # Wrapper functions -from tensorflow.contrib.keras.python.keras.layers.wrappers import Wrapper -from tensorflow.contrib.keras.python.keras.layers.wrappers import Bidirectional -from tensorflow.contrib.keras.python.keras.layers.wrappers import TimeDistributed +from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper +from tensorflow.python.keras._impl.keras.layers.wrappers import Bidirectional +from tensorflow.python.keras._impl.keras.layers.wrappers import TimeDistributed del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/losses/__init__.py b/tensorflow/contrib/keras/api/keras/losses/__init__.py index 06dd679f9cadedb93c87ee7b8210fb91d9a867c4..66721b694f5fd5fae7ca521ff56d4c6c6bce79b5 100644 --- a/tensorflow/contrib/keras/api/keras/losses/__init__.py +++ b/tensorflow/contrib/keras/api/keras/losses/__init__.py @@ -19,26 +19,26 @@ from __future__ import division from __future__ import print_function # Loss functions. -from tensorflow.contrib.keras.python.keras.losses import binary_crossentropy -from tensorflow.contrib.keras.python.keras.losses import categorical_crossentropy -from tensorflow.contrib.keras.python.keras.losses import categorical_hinge -from tensorflow.contrib.keras.python.keras.losses import cosine_proximity -from tensorflow.contrib.keras.python.keras.losses import hinge -from tensorflow.contrib.keras.python.keras.losses import kullback_leibler_divergence -from tensorflow.contrib.keras.python.keras.losses import logcosh -from tensorflow.contrib.keras.python.keras.losses import mean_absolute_error -from tensorflow.contrib.keras.python.keras.losses import mean_absolute_percentage_error -from tensorflow.contrib.keras.python.keras.losses import mean_squared_error -from tensorflow.contrib.keras.python.keras.losses import mean_squared_logarithmic_error -from tensorflow.contrib.keras.python.keras.losses import poisson -from tensorflow.contrib.keras.python.keras.losses import sparse_categorical_crossentropy -from tensorflow.contrib.keras.python.keras.losses import squared_hinge +from tensorflow.python.keras._impl.keras.losses import binary_crossentropy +from tensorflow.python.keras._impl.keras.losses import categorical_crossentropy +from tensorflow.python.keras._impl.keras.losses import categorical_hinge +from tensorflow.python.keras._impl.keras.losses import cosine_proximity +from tensorflow.python.keras._impl.keras.losses import hinge +from tensorflow.python.keras._impl.keras.losses import kullback_leibler_divergence +from tensorflow.python.keras._impl.keras.losses import logcosh +from tensorflow.python.keras._impl.keras.losses import mean_absolute_error +from tensorflow.python.keras._impl.keras.losses import mean_absolute_percentage_error +from tensorflow.python.keras._impl.keras.losses import mean_squared_error +from tensorflow.python.keras._impl.keras.losses import mean_squared_logarithmic_error +from tensorflow.python.keras._impl.keras.losses import poisson +from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy +from tensorflow.python.keras._impl.keras.losses import squared_hinge # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.losses import deserialize -from tensorflow.contrib.keras.python.keras.losses import serialize -from tensorflow.contrib.keras.python.keras.losses import get +from tensorflow.python.keras._impl.keras.losses import deserialize +from tensorflow.python.keras._impl.keras.losses import serialize +from tensorflow.python.keras._impl.keras.losses import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/metrics/__init__.py b/tensorflow/contrib/keras/api/keras/metrics/__init__.py index 99496edde2daf68e7bc28fa272a34bd295855d86..59faf037bce0f087d244a2faaeb52713bdc3b772 100644 --- a/tensorflow/contrib/keras/api/keras/metrics/__init__.py +++ b/tensorflow/contrib/keras/api/keras/metrics/__init__.py @@ -19,28 +19,28 @@ from __future__ import division from __future__ import print_function # Metrics functions. -from tensorflow.contrib.keras.python.keras.metrics import binary_accuracy -from tensorflow.contrib.keras.python.keras.metrics import binary_crossentropy -from tensorflow.contrib.keras.python.keras.metrics import categorical_accuracy -from tensorflow.contrib.keras.python.keras.metrics import categorical_crossentropy -from tensorflow.contrib.keras.python.keras.metrics import cosine_proximity -from tensorflow.contrib.keras.python.keras.metrics import hinge -from tensorflow.contrib.keras.python.keras.metrics import kullback_leibler_divergence -from tensorflow.contrib.keras.python.keras.metrics import mean_absolute_error -from tensorflow.contrib.keras.python.keras.metrics import mean_absolute_percentage_error -from tensorflow.contrib.keras.python.keras.metrics import mean_squared_error -from tensorflow.contrib.keras.python.keras.metrics import mean_squared_logarithmic_error -from tensorflow.contrib.keras.python.keras.metrics import poisson -from tensorflow.contrib.keras.python.keras.metrics import sparse_categorical_crossentropy -from tensorflow.contrib.keras.python.keras.metrics import sparse_top_k_categorical_accuracy -from tensorflow.contrib.keras.python.keras.metrics import squared_hinge -from tensorflow.contrib.keras.python.keras.metrics import top_k_categorical_accuracy +from tensorflow.python.keras._impl.keras.metrics import binary_accuracy +from tensorflow.python.keras._impl.keras.metrics import binary_crossentropy +from tensorflow.python.keras._impl.keras.metrics import categorical_accuracy +from tensorflow.python.keras._impl.keras.metrics import categorical_crossentropy +from tensorflow.python.keras._impl.keras.metrics import cosine_proximity +from tensorflow.python.keras._impl.keras.metrics import hinge +from tensorflow.python.keras._impl.keras.metrics import kullback_leibler_divergence +from tensorflow.python.keras._impl.keras.metrics import mean_absolute_error +from tensorflow.python.keras._impl.keras.metrics import mean_absolute_percentage_error +from tensorflow.python.keras._impl.keras.metrics import mean_squared_error +from tensorflow.python.keras._impl.keras.metrics import mean_squared_logarithmic_error +from tensorflow.python.keras._impl.keras.metrics import poisson +from tensorflow.python.keras._impl.keras.metrics import sparse_categorical_crossentropy +from tensorflow.python.keras._impl.keras.metrics import sparse_top_k_categorical_accuracy +from tensorflow.python.keras._impl.keras.metrics import squared_hinge +from tensorflow.python.keras._impl.keras.metrics import top_k_categorical_accuracy # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.metrics import deserialize -from tensorflow.contrib.keras.python.keras.metrics import serialize -from tensorflow.contrib.keras.python.keras.metrics import get +from tensorflow.python.keras._impl.keras.metrics import deserialize +from tensorflow.python.keras._impl.keras.metrics import serialize +from tensorflow.python.keras._impl.keras.metrics import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/models/__init__.py b/tensorflow/contrib/keras/api/keras/models/__init__.py index 4e5b2a1ed08f309ce0651ae541486d2312668a39..2fb4ac0960d38f28a1c9c897a0f1aedf57e048ac 100644 --- a/tensorflow/contrib/keras/api/keras/models/__init__.py +++ b/tensorflow/contrib/keras/api/keras/models/__init__.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.models import load_model -from tensorflow.contrib.keras.python.keras.models import Model -from tensorflow.contrib.keras.python.keras.models import model_from_config -from tensorflow.contrib.keras.python.keras.models import model_from_json -from tensorflow.contrib.keras.python.keras.models import model_from_yaml -from tensorflow.contrib.keras.python.keras.models import save_model -from tensorflow.contrib.keras.python.keras.models import Sequential +from tensorflow.python.keras._impl.keras.models import load_model +from tensorflow.python.keras._impl.keras.models import Model +from tensorflow.python.keras._impl.keras.models import model_from_config +from tensorflow.python.keras._impl.keras.models import model_from_json +from tensorflow.python.keras._impl.keras.models import model_from_yaml +from tensorflow.python.keras._impl.keras.models import save_model +from tensorflow.python.keras._impl.keras.models import Sequential del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py index b3531d7933f705499a9db6810e22330d3bc53200..44f47bc47f4a0e31aaf2ac8f67cfdbef410d8c44 100644 --- a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py @@ -19,20 +19,20 @@ from __future__ import division from __future__ import print_function # Optimizer classes. -from tensorflow.contrib.keras.python.keras.optimizers import Adadelta -from tensorflow.contrib.keras.python.keras.optimizers import Adagrad -from tensorflow.contrib.keras.python.keras.optimizers import Adam -from tensorflow.contrib.keras.python.keras.optimizers import Adamax -from tensorflow.contrib.keras.python.keras.optimizers import Nadam -from tensorflow.contrib.keras.python.keras.optimizers import Optimizer -from tensorflow.contrib.keras.python.keras.optimizers import RMSprop -from tensorflow.contrib.keras.python.keras.optimizers import SGD +from tensorflow.python.keras._impl.keras.optimizers import Adadelta +from tensorflow.python.keras._impl.keras.optimizers import Adagrad +from tensorflow.python.keras._impl.keras.optimizers import Adam +from tensorflow.python.keras._impl.keras.optimizers import Adamax +from tensorflow.python.keras._impl.keras.optimizers import Nadam +from tensorflow.python.keras._impl.keras.optimizers import Optimizer +from tensorflow.python.keras._impl.keras.optimizers import RMSprop +from tensorflow.python.keras._impl.keras.optimizers import SGD # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.optimizers import deserialize -from tensorflow.contrib.keras.python.keras.optimizers import serialize -from tensorflow.contrib.keras.python.keras.optimizers import get +from tensorflow.python.keras._impl.keras.optimizers import deserialize +from tensorflow.python.keras._impl.keras.optimizers import serialize +from tensorflow.python.keras._impl.keras.optimizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py index 18ce1becc2963bd0742b985be9710aca94c9dee1..b96e7675527041d3952b049f5f431d3df36eea4c 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py @@ -18,20 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.preprocessing.image import apply_transform -from tensorflow.contrib.keras.python.keras.preprocessing.image import array_to_img -from tensorflow.contrib.keras.python.keras.preprocessing.image import DirectoryIterator -from tensorflow.contrib.keras.python.keras.preprocessing.image import flip_axis -from tensorflow.contrib.keras.python.keras.preprocessing.image import ImageDataGenerator -from tensorflow.contrib.keras.python.keras.preprocessing.image import img_to_array -from tensorflow.contrib.keras.python.keras.preprocessing.image import Iterator -from tensorflow.contrib.keras.python.keras.preprocessing.image import load_img -from tensorflow.contrib.keras.python.keras.preprocessing.image import NumpyArrayIterator -from tensorflow.contrib.keras.python.keras.preprocessing.image import random_channel_shift -from tensorflow.contrib.keras.python.keras.preprocessing.image import random_rotation -from tensorflow.contrib.keras.python.keras.preprocessing.image import random_shear -from tensorflow.contrib.keras.python.keras.preprocessing.image import random_shift -from tensorflow.contrib.keras.python.keras.preprocessing.image import random_zoom +from tensorflow.python.keras._impl.keras.preprocessing.image import apply_transform +from tensorflow.python.keras._impl.keras.preprocessing.image import array_to_img +from tensorflow.python.keras._impl.keras.preprocessing.image import DirectoryIterator +from tensorflow.python.keras._impl.keras.preprocessing.image import flip_axis +from tensorflow.python.keras._impl.keras.preprocessing.image import ImageDataGenerator +from tensorflow.python.keras._impl.keras.preprocessing.image import img_to_array +from tensorflow.python.keras._impl.keras.preprocessing.image import Iterator +from tensorflow.python.keras._impl.keras.preprocessing.image import load_img +from tensorflow.python.keras._impl.keras.preprocessing.image import NumpyArrayIterator +from tensorflow.python.keras._impl.keras.preprocessing.image import random_channel_shift +from tensorflow.python.keras._impl.keras.preprocessing.image import random_rotation +from tensorflow.python.keras._impl.keras.preprocessing.image import random_shear +from tensorflow.python.keras._impl.keras.preprocessing.image import random_shift +from tensorflow.python.keras._impl.keras.preprocessing.image import random_zoom del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py index 2621e9bf53ee6a586e856cfddd5cd4bc5305f33b..112f6af5e588bcb2e85fdbecea86f402742d44e7 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.preprocessing.sequence import make_sampling_table -from tensorflow.contrib.keras.python.keras.preprocessing.sequence import pad_sequences -from tensorflow.contrib.keras.python.keras.preprocessing.sequence import skipgrams +from tensorflow.python.keras._impl.keras.preprocessing.sequence import make_sampling_table +from tensorflow.python.keras._impl.keras.preprocessing.sequence import pad_sequences +from tensorflow.python.keras._impl.keras.preprocessing.sequence import skipgrams del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py index a6b68c3ba68103b4560a8b0cc487c0e50d766949..5bf1a2fb21dc27f7aa10cd08b1496e3991c61d2f 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.preprocessing.text import one_hot -from tensorflow.contrib.keras.python.keras.preprocessing.text import text_to_word_sequence -from tensorflow.contrib.keras.python.keras.preprocessing.text import Tokenizer +from tensorflow.python.keras._impl.keras.preprocessing.text import one_hot +from tensorflow.python.keras._impl.keras.preprocessing.text import text_to_word_sequence +from tensorflow.python.keras._impl.keras.preprocessing.text import Tokenizer del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py index a3b0062d5c8ba8c70cd5fe9001f7d21e377928fd..3e707ccab577b5e28febd83d91f84d7b1c0d5d82 100644 --- a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py @@ -19,19 +19,19 @@ from __future__ import division from __future__ import print_function # Regularizer functions / callable classes. -from tensorflow.contrib.keras.python.keras.regularizers import L1L2 -from tensorflow.contrib.keras.python.keras.regularizers import Regularizer +from tensorflow.python.keras._impl.keras.regularizers import L1L2 +from tensorflow.python.keras._impl.keras.regularizers import Regularizer # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.contrib.keras.python.keras.regularizers import l1 -from tensorflow.contrib.keras.python.keras.regularizers import l2 -from tensorflow.contrib.keras.python.keras.regularizers import l1_l2 +from tensorflow.python.keras._impl.keras.regularizers import l1 +from tensorflow.python.keras._impl.keras.regularizers import l2 +from tensorflow.python.keras._impl.keras.regularizers import l1_l2 # Auxiliary utils. -from tensorflow.contrib.keras.python.keras.regularizers import deserialize -from tensorflow.contrib.keras.python.keras.regularizers import serialize -from tensorflow.contrib.keras.python.keras.regularizers import get +from tensorflow.python.keras._impl.keras.regularizers import deserialize +from tensorflow.python.keras._impl.keras.regularizers import serialize +from tensorflow.python.keras._impl.keras.regularizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/utils/__init__.py b/tensorflow/contrib/keras/api/keras/utils/__init__.py index d6d70f79d5fae12f624cca17d8496af3340f572f..a7c2179fe7ad434356921a5fb8709aa5b1f33498 100644 --- a/tensorflow/contrib/keras/api/keras/utils/__init__.py +++ b/tensorflow/contrib/keras/api/keras/utils/__init__.py @@ -18,21 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.utils.data_utils import GeneratorEnqueuer -from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file -from tensorflow.contrib.keras.python.keras.utils.data_utils import Sequence -from tensorflow.contrib.keras.python.keras.utils.data_utils import SequenceEnqueuer -from tensorflow.contrib.keras.python.keras.utils.generic_utils import custom_object_scope -from tensorflow.contrib.keras.python.keras.utils.generic_utils import CustomObjectScope -from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.contrib.keras.python.keras.utils.generic_utils import get_custom_objects -from tensorflow.contrib.keras.python.keras.utils.generic_utils import Progbar -from tensorflow.contrib.keras.python.keras.utils.generic_utils import serialize_keras_object -from tensorflow.contrib.keras.python.keras.utils.io_utils import HDF5Matrix -from tensorflow.contrib.keras.python.keras.utils.layer_utils import convert_all_kernels_in_model -from tensorflow.contrib.keras.python.keras.utils.np_utils import normalize -from tensorflow.contrib.keras.python.keras.utils.np_utils import to_categorical -from tensorflow.contrib.keras.python.keras.utils.vis_utils import plot_model +from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer +from tensorflow.python.keras._impl.keras.utils.data_utils import get_file +from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence +from tensorflow.python.keras._impl.keras.utils.data_utils import SequenceEnqueuer +from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope +from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope +from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras._impl.keras.utils.generic_utils import get_custom_objects +from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar +from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix +from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras._impl.keras.utils.np_utils import normalize +from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical +from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py index ba1d28c5c6871b33722dfdcb467165a70fbb4b06..a46f859273ea0117e29a403057f9f81bc758dd52 100644 --- a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py +++ b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.keras.python.keras.wrappers.scikit_learn import KerasClassifier -from tensorflow.contrib.keras.python.keras.wrappers.scikit_learn import KerasRegressor +from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasClassifier +from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasRegressor del absolute_import del division diff --git a/tensorflow/contrib/keras/python/keras/__init__.py b/tensorflow/contrib/keras/python/keras/__init__.py deleted file mode 100644 index 19380bc8c5aaec057b0280822263bde33ed92e15..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/keras/python/keras/__init__.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""The Keras API. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.keras.python.keras import activations -from tensorflow.contrib.keras.python.keras import applications -from tensorflow.contrib.keras.python.keras import backend -from tensorflow.contrib.keras.python.keras import callbacks -from tensorflow.contrib.keras.python.keras import constraints -from tensorflow.contrib.keras.python.keras import datasets -from tensorflow.contrib.keras.python.keras import engine -from tensorflow.contrib.keras.python.keras import initializers -from tensorflow.contrib.keras.python.keras import layers -from tensorflow.contrib.keras.python.keras import losses -from tensorflow.contrib.keras.python.keras import metrics -from tensorflow.contrib.keras.python.keras import models -from tensorflow.contrib.keras.python.keras import optimizers -from tensorflow.contrib.keras.python.keras import preprocessing -from tensorflow.contrib.keras.python.keras import regularizers -from tensorflow.contrib.keras.python.keras import utils -from tensorflow.contrib.keras.python.keras import wrappers -from tensorflow.contrib.keras.python.keras.layers import Input - -__version__ = '2.0.6-tf' diff --git a/tensorflow/contrib/keras/python/keras/layers/__init__.py b/tensorflow/contrib/keras/python/keras/layers/__init__.py deleted file mode 100644 index 9a428f311415bb89ec8994f787308b8b2aabdefa..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/keras/python/keras/layers/__init__.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras layers module. -""" -# pylint: disable=wildcard-import -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.keras.python.keras.engine import Input -from tensorflow.contrib.keras.python.keras.engine import InputLayer -from tensorflow.contrib.keras.python.keras.engine import InputSpec -from tensorflow.contrib.keras.python.keras.engine import Layer -from tensorflow.contrib.keras.python.keras.layers.advanced_activations import * -from tensorflow.contrib.keras.python.keras.layers.convolutional import * -from tensorflow.contrib.keras.python.keras.layers.convolutional_recurrent import * -from tensorflow.contrib.keras.python.keras.layers.core import * -from tensorflow.contrib.keras.python.keras.layers.embeddings import * -from tensorflow.contrib.keras.python.keras.layers.local import * -from tensorflow.contrib.keras.python.keras.layers.merge import * -from tensorflow.contrib.keras.python.keras.layers.noise import * -from tensorflow.contrib.keras.python.keras.layers.normalization import * -from tensorflow.contrib.keras.python.keras.layers.pooling import * -from tensorflow.contrib.keras.python.keras.layers.recurrent import * -from tensorflow.contrib.keras.python.keras.layers.serialization import deserialize -from tensorflow.contrib.keras.python.keras.layers.serialization import serialize -from tensorflow.contrib.keras.python.keras.layers.wrappers import * - diff --git a/tensorflow/contrib/keras/python/keras/utils/__init__.py b/tensorflow/contrib/keras/python/keras/utils/__init__.py deleted file mode 100644 index 3b197653f382278afffe2a4f26d73be0fc8ab495..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/keras/python/keras/utils/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras utilities. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.keras.python.keras.utils import conv_utils -from tensorflow.contrib.keras.python.keras.utils import data_utils -from tensorflow.contrib.keras.python.keras.utils import generic_utils -from tensorflow.contrib.keras.python.keras.utils import io_utils -from tensorflow.contrib.keras.python.keras.utils import np_utils -from tensorflow.contrib.keras.python.keras.utils.data_utils import GeneratorEnqueuer -from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file -from tensorflow.contrib.keras.python.keras.utils.data_utils import OrderedEnqueuer -from tensorflow.contrib.keras.python.keras.utils.data_utils import Sequence -from tensorflow.contrib.keras.python.keras.utils.generic_utils import custom_object_scope -from tensorflow.contrib.keras.python.keras.utils.generic_utils import CustomObjectScope -from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.contrib.keras.python.keras.utils.generic_utils import get_custom_objects -from tensorflow.contrib.keras.python.keras.utils.generic_utils import Progbar -from tensorflow.contrib.keras.python.keras.utils.generic_utils import serialize_keras_object -from tensorflow.contrib.keras.python.keras.utils.io_utils import HDF5Matrix -from tensorflow.contrib.keras.python.keras.utils.layer_utils import convert_all_kernels_in_model -from tensorflow.contrib.keras.python.keras.utils.np_utils import normalize -from tensorflow.contrib.keras.python.keras.utils.np_utils import to_categorical -from tensorflow.contrib.keras.python.keras.utils.vis_utils import plot_model - - -# Globally-importable utils. diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 598d9aee02a00d6547203d5588f8c94339f5418d..da16bf6ce64000c8f4fa971dc3702ae0e928806f 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -2559,10 +2559,10 @@ def _create_sequence_feature_spec_for_parsing(sequence_feature_columns, feature_spec = create_feature_spec_for_parsing(sequence_feature_columns) sequence_feature_spec = {} for key, feature in feature_spec.items(): - if (isinstance(feature, parsing_ops.VarLenFeature) or - isinstance(feature, parsing_ops.FixedLenSequenceFeature)): + if isinstance(feature, parsing_ops.VarLenFeature): sequence_feature = feature - elif isinstance(feature, parsing_ops.FixedLenFeature): + elif (isinstance(feature, parsing_ops.FixedLenFeature) or + isinstance(feature, parsing_ops.FixedLenSequenceFeature)): default_is_set = feature.default_value is not None if default_is_set: logging.warning( diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py index 21ab9867102ca41bd745ca5f5712d0fb3278e0de..ab65e47af8899b0fae22c4fbdbc5f577f72d4528 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py @@ -912,8 +912,7 @@ class FeatureColumnTest(test.TestCase): parsing_ops.VarLenFeature(dtype=dtypes.float32), "real_valued_var_len_dense_column": parsing_ops.FixedLenSequenceFeature( - shape=[], dtype=dtypes.float32, allow_missing=True, - default_value=4.0), + shape=[], dtype=dtypes.float32, allow_missing=True), } self.assertDictEqual(expected_feature_spec, feature_spec) diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py index 7eb410b4c72b3f5c20d0f7d94e8f983d6a8e89dc..33db93b9704eb3c81d042e2636f916d5f685ad97 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers.py +++ b/tensorflow/contrib/layers/python/layers/optimizers.py @@ -156,9 +156,9 @@ def optimize_loss(loss, loss = ops.convert_to_tensor(loss) contrib_framework.assert_scalar(loss) if global_step is None: - global_step = contrib_framework.get_global_step() + global_step = train.get_global_step() else: - contrib_framework.assert_global_step(global_step) + train.assert_global_step(global_step) with vs.variable_scope(name, "OptimizeLoss", [loss, global_step]): # Update ops take UPDATE_OPS collection if not provided. if update_ops is None: diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index c2e74d1cc2e41fad34b2be1187863a65639599f4..d35b5556fc819cdc0daded761ae87bf16c44d012 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -36,6 +36,7 @@ py_library( "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/contrib/session_bundle:exporter", "//tensorflow/contrib/session_bundle:gc", + "//tensorflow/contrib/tpu:tpu_estimator", "//tensorflow/contrib/training:training_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -410,33 +411,6 @@ py_test( ], ) -py_test( - name = "dnn_linear_combined_benchmark_test", - size = "medium", - srcs = ["python/learn/estimators/dnn_linear_combined_benchmark_test.py"], - srcs_version = "PY2AND3", - tags = [ - "guitar", - "local", - "manual", - "notap", - ], - visibility = [ - "//learning/brain/google/guitar:__subpackages__", - "//tensorflow:__subpackages__", - ], - deps = [ - ":learn", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/contrib/learn/python/learn/datasets", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - ], -) - py_test( name = "kmeans_test", size = "medium", @@ -458,36 +432,11 @@ py_test( ], ) -py_test( - name = "dnn_benchmark_test", - size = "medium", - srcs = ["python/learn/estimators/dnn_benchmark_test.py"], - srcs_version = "PY2AND3", - tags = [ - "guitar", - "local", - "manual", - "notap", - ], - visibility = [ - "//learning/brain/google/guitar:__subpackages__", - "//tensorflow:__subpackages__", - ], - deps = [ - ":learn", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", - "//third_party/py/numpy", - ], -) - py_test( name = "dynamic_rnn_estimator_test", size = "medium", srcs = ["python/learn/estimators/dynamic_rnn_estimator_test.py"], + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":learn", diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_benchmark_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_benchmark_test.py deleted file mode 100644 index 86b3eee6ad140c5a30a1a11b7ac31c7f9cd00d54..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_benchmark_test.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Regression test for DNNEstimator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools -import numpy as np -from tensorflow.contrib.layers.python.layers import feature_column -from tensorflow.contrib.learn.python.learn.estimators import dnn -from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils -from tensorflow.contrib.learn.python.learn.estimators import run_config -from tensorflow.contrib.learn.python.learn.estimators import test_data -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.platform import test -from tensorflow.python.training import input as input_lib - - -_METRIC_KEYS = { - 'accuracy', - 'auc', - 'accuracy/threshold_0.500000_mean', - 'loss', - 'precision/positive_threshold_0.500000_mean', - 'recall/positive_threshold_0.500000_mean', -} - - -class DNNClassifierBenchmark(test.Benchmark): - - def _report_metrics(self, metrics): - self.report_benchmark( - iters=metrics['global_step'], - extras={k: v - for k, v in metrics.items() if k in _METRIC_KEYS}) - - def _report_predictions(self, - benchmark_name_override, - classifier, - input_fn, - iters, - n_examples, - n_classes, - expected_probabilities=None, - expected_classes=None): - probabilities = classifier.predict_proba( - input_fn=input_fn, as_iterable=False) - if expected_probabilities is not None: - np.testing.assert_allclose( - expected_probabilities, tuple(probabilities), atol=0.2) - - classes = classifier.predict(input_fn=input_fn, as_iterable=False) - if expected_classes is not None: - np.testing.assert_array_equal(expected_classes, classes) - - self.report_benchmark( - iters=iters, - extras={ - 'inference.example%d_class%d_prob' % (i, j): probabilities[i][j] - for j in range(n_classes) for i in range(n_examples) - }.update({ - 'inference.example%d_class' % i: classes[i] - for i in range(n_examples) - }), - name=benchmark_name_override) - - def benchmarkLogisticMatrixData(self): - classifier = dnn.DNNClassifier( - feature_columns=(feature_column.real_valued_column( - 'feature', dimension=4),), - hidden_units=(3, 3), - config=run_config.RunConfig(tf_random_seed=1)) - input_fn = test_data.iris_input_logistic_fn - steps = 400 - metrics = classifier.fit(input_fn=input_fn, steps=steps).evaluate( - input_fn=input_fn, steps=1) - estimator_test_utils.assert_in_range(steps, steps + 5, 'global_step', - metrics) - estimator_test_utils.assert_in_range(0.9, 1.0, 'accuracy', metrics) - estimator_test_utils.assert_in_range(0.0, 0.3, 'loss', metrics) - - self._report_metrics(metrics) - - def benchmarkLogisticMatrixDataLabels1D(self): - - def _input_fn(): - iris = test_data.prepare_iris_data_for_logistic_regression() - return { - 'feature': constant_op.constant( - iris.data, dtype=dtypes.float32) - }, constant_op.constant( - iris.target, shape=(100,), dtype=dtypes.int32) - - classifier = dnn.DNNClassifier( - feature_columns=(feature_column.real_valued_column( - 'feature', dimension=4),), - hidden_units=(3, 3), - config=run_config.RunConfig(tf_random_seed=1)) - steps = 1000 - metrics = classifier.fit(input_fn=_input_fn, steps=steps).evaluate( - input_fn=_input_fn, steps=1) - estimator_test_utils.assert_in_range(steps, steps + 5, 'global_step', - metrics) - estimator_test_utils.assert_in_range(0.9, 1.0, 'accuracy', metrics) - - self._report_metrics(metrics) - - def benchmarkLogisticNpMatrixData(self): - classifier = dnn.DNNClassifier( - feature_columns=(feature_column.real_valued_column( - '', dimension=4),), - hidden_units=(3, 3), - config=run_config.RunConfig(tf_random_seed=1)) - iris = test_data.prepare_iris_data_for_logistic_regression() - train_x = iris.data - train_y = iris.target - steps = 100 - metrics = classifier.fit(x=train_x, y=train_y, steps=steps).evaluate( - x=train_x, y=train_y, steps=1) - estimator_test_utils.assert_in_range(steps, steps + 5, 'global_step', - metrics) - estimator_test_utils.assert_in_range(0.8, 1.0, 'accuracy', metrics) - - self._report_metrics(metrics) - - def benchmarkLogisticTensorData(self): - - def _input_fn(num_epochs=None): - features = { - 'age': - input_lib.limit_epochs( - constant_op.constant(((.8,), (0.2,), (.1,))), - num_epochs=num_epochs), - 'language': - sparse_tensor.SparseTensor( - values=input_lib.limit_epochs( - ('en', 'fr', 'zh'), num_epochs=num_epochs), - indices=((0, 0), (0, 1), (2, 0)), - dense_shape=(3, 2)) - } - return features, constant_op.constant( - ((1,), (0,), (0,)), dtype=dtypes.int32) - - lang_column = feature_column.sparse_column_with_hash_bucket( - 'language', hash_bucket_size=20) - classifier = dnn.DNNClassifier( - feature_columns=(feature_column.embedding_column( - lang_column, dimension=1), - feature_column.real_valued_column('age')), - hidden_units=(3, 3), - config=run_config.RunConfig(tf_random_seed=1)) - steps = 100 - metrics = classifier.fit(input_fn=_input_fn, steps=steps).evaluate( - input_fn=_input_fn, steps=1) - estimator_test_utils.assert_in_range(steps, steps + 5, 'global_step', - metrics) - estimator_test_utils.assert_in_range(0.9, 1.0, 'accuracy', metrics) - estimator_test_utils.assert_in_range(0.0, 0.3, 'loss', metrics) - - self._report_metrics(metrics) - self._report_predictions( - classifier=classifier, - input_fn=functools.partial(_input_fn, num_epochs=1), - iters=metrics['global_step'], - n_examples=3, - n_classes=2, - expected_classes=(1, 0, 0), - benchmark_name_override=( - 'DNNClassifierBenchmark.benchmarkLogisticTensorData_predictions')) - - def benchmarkLogisticFloatLabel(self): - - def _input_fn(num_epochs=None): - features = { - 'age': - input_lib.limit_epochs( - constant_op.constant(((50,), (20,), (10,))), - num_epochs=num_epochs), - 'language': - sparse_tensor.SparseTensor( - values=input_lib.limit_epochs( - ('en', 'fr', 'zh'), num_epochs=num_epochs), - indices=((0, 0), (0, 1), (2, 0)), - dense_shape=(3, 2)) - } - return features, constant_op.constant( - ((0.8,), (0.,), (0.2,)), dtype=dtypes.float32) - - lang_column = feature_column.sparse_column_with_hash_bucket( - 'language', hash_bucket_size=20) - n_classes = 2 - classifier = dnn.DNNClassifier( - n_classes=n_classes, - feature_columns=(feature_column.embedding_column( - lang_column, dimension=1), - feature_column.real_valued_column('age')), - hidden_units=(3, 3), - config=run_config.RunConfig(tf_random_seed=1)) - steps = 1000 - metrics = classifier.fit(input_fn=_input_fn, steps=steps).evaluate( - input_fn=_input_fn, steps=1) - estimator_test_utils.assert_in_range(steps, steps + 5, 'global_step', - metrics) - - # Prediction probabilities mirror the labels column, which proves that the - # classifier learns from float input. - self._report_metrics(metrics) - self._report_predictions( - classifier=classifier, - input_fn=functools.partial(_input_fn, num_epochs=1), - iters=metrics['global_step'], - n_examples=3, - n_classes=n_classes, - expected_probabilities=((0.2, 0.8), (1., 0.), (0.8, 0.2)), - expected_classes=(1, 0, 0), - benchmark_name_override=( - 'DNNClassifierBenchmark.benchmarkLogisticFloatLabel_predictions')) - - def benchmarkMultiClassMatrixData(self): - """Tests multi-class classification using matrix data as input.""" - classifier = dnn.DNNClassifier( - n_classes=3, - feature_columns=(feature_column.real_valued_column( - 'feature', dimension=4),), - hidden_units=(3, 3), - config=run_config.RunConfig(tf_random_seed=1)) - - input_fn = test_data.iris_input_multiclass_fn - steps = 500 - metrics = classifier.fit(input_fn=input_fn, steps=steps).evaluate( - input_fn=input_fn, steps=1) - estimator_test_utils.assert_in_range(steps, steps + 5, 'global_step', - metrics) - estimator_test_utils.assert_in_range(0.9, 1.0, 'accuracy', metrics) - estimator_test_utils.assert_in_range(0.0, 0.4, 'loss', metrics) - - self._report_metrics(metrics) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_benchmark_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_benchmark_test.py deleted file mode 100644 index 98b7c7e95c59a7695760013b1e1d62d99b79b1ca..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_benchmark_test.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Regression test for DNNLinearCombinedEstimator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import json -import tempfile -from tensorflow.contrib.layers.python.layers import feature_column -from tensorflow.contrib.learn.python.learn.datasets import base -from tensorflow.contrib.learn.python.learn.estimators import dnn_linear_combined -from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils -from tensorflow.contrib.learn.python.learn.estimators import run_config -from tensorflow.contrib.learn.python.learn.estimators import test_data -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import test -from tensorflow.python.training import adagrad -from tensorflow.python.training import ftrl -from tensorflow.python.training import server_lib - - -# Desired training steps, reported in benchmark. Actual steps might be slightly -# more than this since supervisor training runs for a non-detrministic number of -# steps. -_ITERS = 100 - -_METRIC_KEYS = { - 'accuracy', - 'auc', - 'accuracy/threshold_0.500000_mean', - 'loss', - 'precision/positive_threshold_0.500000_mean', - 'recall/positive_threshold_0.500000_mean', -} - - -class DNNLinearCombinedClassifierBenchmark(test.Benchmark): - - def _assertSingleClassMetrics(self, metrics): - estimator_test_utils.assert_in_range(0.9, 1.0, 'auc', metrics) - estimator_test_utils.assert_in_range(0.9, 1.0, - 'accuracy/threshold_0.500000_mean', - metrics) - estimator_test_utils.assert_in_range( - 0.9, 1.0, 'precision/positive_threshold_0.500000_mean', metrics) - estimator_test_utils.assert_in_range( - 0.9, 1.0, 'recall/positive_threshold_0.500000_mean', metrics) - self._assertCommonMetrics(metrics) - - def _assertCommonMetrics(self, metrics): - estimator_test_utils.assert_in_range(_ITERS, _ITERS + 5, 'global_step', - metrics) - estimator_test_utils.assert_in_range(0.9, 1.0, 'accuracy', metrics) - estimator_test_utils.assert_in_range(0.0, 0.2, 'loss', metrics) - self.report_benchmark( - iters=metrics['global_step'], - extras={k: v - for k, v in metrics.items() if k in _METRIC_KEYS}) - - def benchmarkMatrixData(self): - iris = test_data.prepare_iris_data_for_logistic_regression() - cont_feature = feature_column.real_valued_column('feature', dimension=4) - bucketized_feature = feature_column.bucketized_column( - cont_feature, test_data.get_quantile_based_buckets(iris.data, 10)) - - classifier = dnn_linear_combined.DNNLinearCombinedClassifier( - model_dir=tempfile.mkdtemp(), - linear_feature_columns=(bucketized_feature,), - dnn_feature_columns=(cont_feature,), - dnn_hidden_units=(3, 3)) - - input_fn = test_data.iris_input_logistic_fn - metrics = classifier.fit(input_fn=input_fn, steps=_ITERS).evaluate( - input_fn=input_fn, steps=100) - self._assertSingleClassMetrics(metrics) - - def benchmarkTensorData(self): - - def _input_fn(): - iris = test_data.prepare_iris_data_for_logistic_regression() - features = {} - for i in range(4): - # The following shows how to provide the Tensor data for - # RealValuedColumns. - features.update({ - str(i): - array_ops.reshape( - constant_op.constant( - iris.data[:, i], dtype=dtypes.float32), (-1, 1)) - }) - # The following shows how to provide the SparseTensor data for - # a SparseColumn. - features['dummy_sparse_column'] = sparse_tensor.SparseTensor( - values=('en', 'fr', 'zh'), - indices=((0, 0), (0, 1), (60, 0)), - dense_shape=(len(iris.target), 2)) - labels = array_ops.reshape( - constant_op.constant( - iris.target, dtype=dtypes.int32), (-1, 1)) - return features, labels - - iris = test_data.prepare_iris_data_for_logistic_regression() - cont_features = [ - feature_column.real_valued_column(str(i)) for i in range(4) - ] - linear_features = [ - feature_column.bucketized_column( - cont_features[i], - test_data.get_quantile_based_buckets(iris.data[:, i], 10)) - for i in range(4) - ] - linear_features.append( - feature_column.sparse_column_with_hash_bucket( - 'dummy_sparse_column', hash_bucket_size=100)) - - classifier = dnn_linear_combined.DNNLinearCombinedClassifier( - model_dir=tempfile.mkdtemp(), - linear_feature_columns=linear_features, - dnn_feature_columns=cont_features, - dnn_hidden_units=(3, 3)) - - metrics = classifier.fit(input_fn=_input_fn, steps=_ITERS).evaluate( - input_fn=_input_fn, steps=100) - self._assertSingleClassMetrics(metrics) - - def benchmarkCustomOptimizer(self): - iris = test_data.prepare_iris_data_for_logistic_regression() - cont_feature = feature_column.real_valued_column('feature', dimension=4) - bucketized_feature = feature_column.bucketized_column( - cont_feature, test_data.get_quantile_based_buckets(iris.data, 10)) - - classifier = dnn_linear_combined.DNNLinearCombinedClassifier( - model_dir=tempfile.mkdtemp(), - linear_feature_columns=(bucketized_feature,), - linear_optimizer=ftrl.FtrlOptimizer(learning_rate=0.1), - dnn_feature_columns=(cont_feature,), - dnn_hidden_units=(3, 3), - dnn_optimizer=adagrad.AdagradOptimizer(learning_rate=0.1)) - - input_fn = test_data.iris_input_logistic_fn - metrics = classifier.fit(input_fn=input_fn, steps=_ITERS).evaluate( - input_fn=input_fn, steps=100) - self._assertSingleClassMetrics(metrics) - - def benchmarkMultiClass(self): - iris = base.load_iris() - cont_feature = feature_column.real_valued_column('feature', dimension=4) - bucketized_feature = feature_column.bucketized_column( - cont_feature, test_data.get_quantile_based_buckets(iris.data, 10)) - - classifier = dnn_linear_combined.DNNLinearCombinedClassifier( - n_classes=3, - linear_feature_columns=(bucketized_feature,), - dnn_feature_columns=(cont_feature,), - dnn_hidden_units=(3, 3)) - - input_fn = test_data.iris_input_multiclass_fn - metrics = classifier.fit(input_fn=input_fn, steps=_ITERS).evaluate( - input_fn=input_fn, steps=100) - self._assertCommonMetrics(metrics) - - def benchmarkPartitionedVariables(self): - - def _input_fn(): - features = { - 'language': - sparse_tensor.SparseTensor( - values=('en', 'fr', 'zh'), - indices=((0, 0), (0, 1), (2, 0)), - dense_shape=(3, 2)) - } - labels = constant_op.constant(((1,), (0,), (0,))) - return features, labels - - # The given hash_bucket_size results in variables larger than the - # default min_slice_size attribute, so the variables are partitioned. - sparse_feature = feature_column.sparse_column_with_hash_bucket( - 'language', hash_bucket_size=2e7) - embedding_feature = feature_column.embedding_column( - sparse_feature, dimension=1) - - tf_config = { - 'cluster': { - run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] - } - } - with test.mock.patch.dict('os.environ', - {'TF_CONFIG': json.dumps(tf_config)}): - config = run_config.RunConfig() - # Because we did not start a distributed cluster, we need to pass an - # empty ClusterSpec, otherwise the device_setter will look for - # distributed jobs, such as "/job:ps" which are not present. - config._cluster_spec = server_lib.ClusterSpec({}) - - classifier = dnn_linear_combined.DNNLinearCombinedClassifier( - linear_feature_columns=(sparse_feature,), - dnn_feature_columns=(embedding_feature,), - dnn_hidden_units=(3, 3), - config=config) - - metrics = classifier.fit(input_fn=_input_fn, steps=_ITERS).evaluate( - input_fn=_input_fn, steps=100) - self._assertCommonMetrics(metrics) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index c31d5d2d47dfd67506cb459fd88b88a9bb90db9b..861db1f89ef54e524f4ebf9ee81670fa998cbb98 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -24,7 +24,6 @@ import six from tensorflow.contrib import framework as framework_lib from tensorflow.contrib import layers as layers_lib -from tensorflow.contrib import lookup as lookup_lib from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.estimators import prediction_key @@ -35,6 +34,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import nn @@ -1070,9 +1070,8 @@ class _MultiClassHead(_SingleHead): labels_tensor = _to_labels_tensor(labels, self._label_name) _check_no_sparse_tensor(labels_tensor) if self._label_keys: - table = lookup_lib.string_to_index_table_from_tensor( - mapping=self._label_keys, - name="label_id_lookup") + table = lookup_ops.index_table_from_tensor( + self._label_keys, name="label_id_lookup") return { "labels": labels_tensor, "label_ids": table.lookup(labels_tensor), @@ -1106,9 +1105,8 @@ class _MultiClassHead(_SingleHead): class_ids = math_ops.argmax( logits, 1, name=prediction_key.PredictionKey.CLASSES) if self._label_keys: - table = lookup_lib.index_to_string_table_from_tensor( - mapping=self._label_keys, - name="class_string_lookup") + table = lookup_ops.index_to_string_table_from_tensor( + self._label_keys, name="class_string_lookup") classes = table.lookup(class_ids) else: classes = class_ids diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py index 0f09b111bd8dee03633402fbda7654bc4dcdbddc..896b668d4e2f129f443e3d7be39a476b56f7c9fe 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py +++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py @@ -178,7 +178,7 @@ def select_last_activations(activations, sequence_lengths): """Selects the nth set of activations for each n in `sequence_length`. Reuturns a `Tensor` of shape `[batch_size, k]`. If `sequence_length` is not - `None`, then `output[i, :] = activations[i, sequence_length[i], :]`. If + `None`, then `output[i, :] = activations[i, sequence_length[i] - 1, :]`. If `sequence_length` is `None`, then `output[i, :] = activations[i, -1, :]`. Args: diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index 626c34014bb6aaf71a22104b0c6a6a684cfd02a3..eae35d59aca760aeca7d11821a010a0f54b8d26b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -335,7 +335,9 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): # For class instance without __repr__, some special cares are required. # Otherwise, the object address will be used. if '_cluster_spec' in ordered_state: - ordered_state['_cluster_spec'] = ordered_state['_cluster_spec'].as_dict() + ordered_state['_cluster_spec'] = collections.OrderedDict( + sorted(ordered_state['_cluster_spec'].as_dict().items(), + key=lambda t: t[0])) return ', '.join( '%s=%r' % (k, v) for (k, v) in six.iteritems(ordered_state)) diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py b/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py index 25ef1705c202ca5a4f4e49dc57148182a9f23f3f..e559612f7b1c110bed88548ecb7ee12bc35e0b26 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py @@ -358,6 +358,25 @@ class RunConfigTest(test.TestCase): uid_2 = _create_run_config_with_cluster_spec(tf_config_2_str).uid() self.assertEqual(uid_1, uid_2) + def test_uid_for_different_cluster_specs(self): + tf_config_1 = { + "cluster": { + run_config_lib.TaskType.PS: ["host1:1", "host2:2"], + run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"] + }, + } + + tf_config_2 = { + "cluster": { + run_config_lib.TaskType.PS: ["host1:1"], + run_config_lib.TaskType.WORKER: ["host3:3", "host4:4", "host5:5"] + }, + } + + uid_1 = _create_run_config_with_cluster_spec(json.dumps(tf_config_1)).uid() + uid_2 = _create_run_config_with_cluster_spec(json.dumps(tf_config_2)).uid() + self.assertNotEqual(uid_1, uid_2) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index c35a493086e10bfa8aa67f63c66f0f3618df8905..627d4991f036c440c832a9a9eab9fa69e6efbd3e 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -33,6 +33,7 @@ from tensorflow.contrib.learn.python.learn import export_strategy from tensorflow.contrib.learn.python.learn import monitors from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import run_config +from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging @@ -221,6 +222,14 @@ class Experiment(object): "`estimator` must implement `tf.contrib.learn.Trainable`" "or `tf.estimator.`Estimator`.") + if isinstance(estimator, tpu_estimator.TPUEstimator): + raise ValueError( + "`Experiment` class cannot work with `tf.contrib.tpu.TPUEstimator`. " + "Please call `TPUEstimator` train/evaluate directly. \n" + "Details: `Experiment` class is designed for between-graph " + "distributed training, while `TPUEstimator` is working in in-graph " + "distributed mode.") + super(Experiment, self).__init__() # Immutable fields. self._estimator = estimator diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index fe40d27c445d4f560c96fc9b50ceb0daed30ee93..2c68edbb34b1eff2e3ea1bb0379a23989c80e578 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -32,6 +32,8 @@ from tensorflow.contrib.learn.python.learn.estimators import dnn from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib from tensorflow.contrib.learn.python.learn.estimators import test_data from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils +from tensorflow.contrib.tpu.python.tpu import tpu_config +from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.estimator import estimator as core_estimator @@ -935,6 +937,20 @@ class ExperimentTest(test.TestCase): self.assertEqual(ex._maybe_export.call_count, 4) self.assertEqual(ex._call_evaluate.call_count, 4) + def test_fail_with_tpu_estimator(self): + def dummy_model_fn(features, labels): + del features, labels # unused + + with self.assertRaisesRegexp( + ValueError, + '`Experiment` class cannot work with `tf.contrib.tpu.TPUEstimator`'): + experiment.Experiment( + tpu_estimator.TPUEstimator(model_fn=dummy_model_fn, + config=tpu_config.RunConfig(), + train_batch_size=256), + train_input_fn='train_input', + eval_input_fn='eval_input') + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index bafde464afb819730109353a6cc9d0e005a15151..4c50d40aaa9b3c5d94d0a66d08e8ab6173db427a 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -28,7 +28,6 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging @@ -44,7 +43,7 @@ def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size=None): x_is_dict, y_is_dict = isinstance( x_shape, dict), y_shape is not None and isinstance(y_shape, dict) if y_is_dict and n_classes is not None: - assert (isinstance(n_classes, dict)) + assert isinstance(n_classes, dict) if batch_size is None: batch_size = list(x_shape.values())[0][0] if x_is_dict else x_shape[0] @@ -322,10 +321,12 @@ class DataFeeder(object): self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items()) ]) if x_is_dict else check_array(x, x.dtype) - self._y = None if y is None else \ - dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) if y_is_dict else check_array(y, y.dtype) + self._y = None if y is None else ( + dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) + if y_is_dict else check_array(y, y.dtype)) - # self.n_classes is not None means we're converting raw target indices to one-hot. + # self.n_classes is not None means we're converting raw target indices + # to one-hot. if n_classes is not None: if not y_is_dict: y_dtype = (np.int64 @@ -344,12 +345,15 @@ class DataFeeder(object): x_shape, y_shape, n_classes, batch_size) # Input dtype matches dtype of x. - self._input_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())]) if x_is_dict \ - else _check_dtype(self._x.dtype) - - # note: self._output_dtype = np.float32 when y is None - self._output_dtype = dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())]) if y_is_dict \ - else _check_dtype(self._y.dtype) if y is not None else np.float32 + self._input_dtype = ( + dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())]) + if x_is_dict else _check_dtype(self._x.dtype)) + + # self._output_dtype == np.float32 when y is None + self._output_dtype = ( + dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())]) + if y_is_dict else ( + _check_dtype(self._y.dtype) if y is not None else np.float32)) # self.n_classes is None means we're passing in raw target indices if n_classes is not None and y_is_dict: diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 1e68a3ef66039f4831a5061506b54665d2c0a566..676e1f2b51c0a0a48b84f4e1d3d8ad9ae2521f9b 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Utilities supporting export to SavedModel. Some contents of this file are moved to tensorflow/python/estimator/export.py: @@ -39,6 +38,7 @@ import time from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn import export_strategy from tensorflow.contrib.learn.python.learn.estimators import constants +from tensorflow.contrib.learn.python.learn.estimators import metric_key from tensorflow.contrib.learn.python.learn.estimators import prediction_key from tensorflow.contrib.learn.python.learn.utils import gc from tensorflow.contrib.learn.python.learn.utils import input_fn_utils @@ -75,8 +75,8 @@ FEATURES_INPUT_ALTERNATIVE_KEY = 'features_input_alternative' _FALLBACK_DEFAULT_OUTPUT_ALTERNATIVE_KEY = 'default_output_alternative' -def build_standardized_signature_def( - input_tensors, output_tensors, problem_type): +def build_standardized_signature_def(input_tensors, output_tensors, + problem_type): """Build a SignatureDef using problem type and input and output Tensors. Note that this delegates the actual creation of the signatures to methods in @@ -116,8 +116,8 @@ def build_standardized_signature_def( (_, predictions), = output_tensors.items() return signature_def_utils.regression_signature_def(examples, predictions) else: - return signature_def_utils.predict_signature_def( - input_tensors, output_tensors) + return signature_def_utils.predict_signature_def(input_tensors, + output_tensors) def _get_classification_scores(output_tensors): @@ -139,17 +139,15 @@ def _is_classification_problem(problem_type, input_tensors, output_tensors): classes = _get_classification_classes(output_tensors) scores = _get_classification_scores(output_tensors) return ((problem_type == constants.ProblemType.CLASSIFICATION or - problem_type == constants.ProblemType.LOGISTIC_REGRESSION) - and len(input_tensors) == 1 - and (classes is not None or - scores is not None or - len(output_tensors) == 1)) + problem_type == constants.ProblemType.LOGISTIC_REGRESSION) and + len(input_tensors) == 1 and + (classes is not None or scores is not None or + len(output_tensors) == 1)) def _is_regression_problem(problem_type, input_tensors, output_tensors): - return (problem_type == constants.ProblemType.LINEAR_REGRESSION - and len(input_tensors) == 1 - and len(output_tensors) == 1) + return (problem_type == constants.ProblemType.LINEAR_REGRESSION and + len(input_tensors) == 1 and len(output_tensors) == 1) def get_input_alternatives(input_ops): @@ -177,9 +175,7 @@ def get_input_alternatives(input_ops): return input_alternatives, features -def get_output_alternatives( - model_fn_ops, - default_output_alternative_key=None): +def get_output_alternatives(model_fn_ops, default_output_alternative_key=None): """Obtain all output alternatives using the model_fn output and heuristics. Args: @@ -218,8 +214,10 @@ def get_output_alternatives( default_outputs = {prediction_key.PredictionKey.GENERIC: default_outputs} actual_default_output_alternative_key = ( _FALLBACK_DEFAULT_OUTPUT_ALTERNATIVE_KEY) - output_alternatives = {actual_default_output_alternative_key: - (default_problem_type, default_outputs)} + output_alternatives = { + actual_default_output_alternative_key: (default_problem_type, + default_outputs) + } return output_alternatives, actual_default_output_alternative_key if default_output_alternative_key: @@ -246,13 +244,12 @@ def build_all_signature_defs(input_alternatives, output_alternatives, actual_default_output_alternative_key): """Build `SignatureDef`s from all pairs of input and output alternatives.""" - signature_def_map = { - ('%s:%s' % (input_key, output_key or 'None')): - build_standardized_signature_def( - inputs, outputs, problem_type) - for input_key, inputs in input_alternatives.items() - for output_key, (problem_type, outputs) - in output_alternatives.items()} + signature_def_map = {('%s:%s' % (input_key, output_key or 'None')): + build_standardized_signature_def(inputs, outputs, + problem_type) + for input_key, inputs in input_alternatives.items() + for output_key, (problem_type, + outputs) in output_alternatives.items()} # Add the default SignatureDef default_inputs = input_alternatives.get(DEFAULT_INPUT_ALTERNATIVE_KEY) @@ -263,8 +260,8 @@ def build_all_signature_defs(input_alternatives, output_alternatives, (default_problem_type, default_outputs) = ( output_alternatives[actual_default_output_alternative_key]) signature_def_map[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( - build_standardized_signature_def( - default_inputs, default_outputs, default_problem_type)) + build_standardized_signature_def(default_inputs, default_outputs, + default_problem_type)) return signature_def_map @@ -308,9 +305,8 @@ def get_timestamped_export_dir(export_dir_base): return export_dir time.sleep(1) attempts += 1 - logging.warn( - 'Export directory {} already exists; retrying (attempt {}/{})'.format( - export_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)) + logging.warn('Export directory {} already exists; retrying (attempt {}/{})'. + format(export_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)) raise RuntimeError('Failed to obtain a unique export directory name after ' '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS)) @@ -330,8 +326,7 @@ def get_temp_export_dir(timestamped_export_dir): """ (dirname, basename) = os.path.split(timestamped_export_dir) temp_export_dir = os.path.join( - compat.as_bytes(dirname), - compat.as_bytes('temp-{}'.format(basename))) + compat.as_bytes(dirname), compat.as_bytes('temp-{}'.format(basename))) return temp_export_dir @@ -357,8 +352,8 @@ def get_most_recent_export(export_dir_base): A gc.Path, with is just a namedtuple of (path, export_version). """ select_filter = gc.largest_export_versions(1) - results = select_filter(gc.get_paths(export_dir_base, - parser=_export_version_parser)) + results = select_filter( + gc.get_paths(export_dir_base, parser=_export_version_parser)) return next(iter(results or []), None) @@ -378,8 +373,8 @@ def garbage_collect_exports(export_dir_base, exports_to_keep): keep_filter = gc.largest_export_versions(exports_to_keep) delete_filter = gc.negation(keep_filter) - for p in delete_filter(gc.get_paths(export_dir_base, - parser=_export_version_parser)): + for p in delete_filter( + gc.get_paths(export_dir_base, parser=_export_version_parser)): try: gfile.DeleteRecursively(p.path) except errors_impl.NotFoundError as e: @@ -416,10 +411,7 @@ def make_export_strategy(serving_input_fn, An ExportStrategy that can be passed to the Experiment constructor. """ - def export_fn(estimator, - export_dir_base, - checkpoint_path=None - ): + def export_fn(estimator, export_dir_base, checkpoint_path=None): """Exports the given Estimator as a SavedModel. Args: @@ -512,3 +504,128 @@ def make_parsing_export_strategy(feature_columns, assets_extra=assets_extra, as_text=as_text, exports_to_keep=exports_to_keep) + + +def _default_compare_fn(curr_best_eval_result, cand_eval_result): + """Compares two evaluation results and returns true if the 2nd one is better. + + Both evaluation results should have the values for MetricKey.LOSS, which are + used for comparison. + + Args: + curr_best_eval_result: current best eval metrics. + cand_eval_result: candidate eval metrics. + + Returns: + True if cand_eval_result is better. + + Raises: + ValueError: If input eval result is None or no loss is available. + """ + default_key = metric_key.MetricKey.LOSS + if not curr_best_eval_result or default_key not in curr_best_eval_result: + raise ValueError( + 'curr_best_eval_result cannot be empty or no loss is found in it.') + + if not cand_eval_result or default_key not in cand_eval_result: + raise ValueError( + 'cand_eval_result cannot be empty or no loss is found in it.') + + return curr_best_eval_result[default_key] > cand_eval_result[default_key] + + +class BestModelSelector(object): + """A helper that keeps track of export selection candidates.""" + + def __init__(self, compare_fn=None): + """Constructor of this class. + + Args: + compare_fn: a function that returns true if the candidate is better than + the current best model. + """ + self._best_eval_result = None + self._compare_fn = compare_fn or _default_compare_fn + + def update(self, checkpoint_path, eval_result): + """Records a given checkpoint and exports if this is the best model. + + Args: + checkpoint_path: the checkpoint path to export. + eval_result: a dictionary which is usually generated in evaluation runs. + By default, eval_results contains 'loss' field. + + Returns: + A string representing the path to the checkpoint to be exported. + A dictionary of the same type of eval_result. + + Raises: + ValueError: if checkpoint path is empty. + ValueError: if eval_results is None object. + """ + if not checkpoint_path: + raise ValueError('Checkpoint path is empty.') + if eval_result is None: + raise ValueError('%s has empty evaluation results.', checkpoint_path) + + if (self._best_eval_result is None or + self._compare_fn(self._best_eval_result, eval_result)): + self._best_eval_result = eval_result + return checkpoint_path, eval_result + else: + return '', None + + +def make_best_model_export_strategy(serving_input_fn, + exports_to_keep=1, + compare_fn=None, + default_output_alternative_key=None): + """Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment. + + Args: + serving_input_fn: a function that takes no arguments and returns an + `InputFnOps`. + exports_to_keep: an integer indicating how many historical best models need + to be preserved. + compare_fn: a function that select the 'best' candidate from a dictionary + of evaluation result keyed by corresponding checkpoint path. + default_output_alternative_key: the key for default serving signature for + multi-headed inference graphs. + + Returns: + An ExportStrategy that can be passed to the Experiment constructor. + """ + best_model_export_strategy = make_export_strategy( + serving_input_fn, + exports_to_keep=exports_to_keep, + default_output_alternative_key=default_output_alternative_key) + + best_model_selector = BestModelSelector(compare_fn) + + def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None): + """Exports the given Estimator as a SavedModel. + + Args: + estimator: the Estimator to export. + export_dir_base: A string containing a directory to write the exported + graph and checkpoints. + checkpoint_path: The checkpoint path to export. If None (the default), + the most recent checkpoint found within the model directory is chosen. + eval_result: placehold args matching the call signature of ExportStrategy. + + Returns: + The string path to the exported directory. + """ + + export_checkpoint_path, export_eval_result = best_model_selector.update( + checkpoint_path, eval_result) + + if export_checkpoint_path and export_eval_result is not None: + checkpoint_base = os.path.basename(export_checkpoint_path) + export_dir = os.path.join(export_dir_base, checkpoint_base) + return best_model_export_strategy.export( + estimator, export_dir, export_checkpoint_path, export_eval_result) + else: + return '' + + return export_strategy.ExportStrategy('best_model', export_fn) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py index 9e778ab72ad2095edc3fcb45af624bb9b09ca5f1..66bca9c0f533dc97c682caf2befd33197eb0a733 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py @@ -24,6 +24,7 @@ import time from tensorflow.contrib.layers.python.layers import feature_column as fc from tensorflow.contrib.learn.python.learn import export_strategy as export_strategy_lib from tensorflow.contrib.learn.python.learn.estimators import constants +from tensorflow.contrib.learn.python.learn.estimators import estimator as core_estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.utils import input_fn_utils from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils @@ -40,18 +41,43 @@ from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.util import compat +class TestEstimator(core_estimator.Estimator): + + def __init__(self, *args, **kwargs): + super(TestEstimator, self).__init__(*args, **kwargs) + self.last_exported_checkpoint = "" + self.last_exported_dir = "" + + # @Override + def export_savedmodel(self, + export_dir, + serving_input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + checkpoint_path=None): + + if not os.path.exists(export_dir): + os.makedirs(export_dir) + + open(os.path.join(export_dir, "placeholder.txt"), "a").close() + + self.last_exported_checkpoint = checkpoint_path + self.last_exported_dir = export_dir + + return export_dir + + class SavedModelExportUtilsTest(test.TestCase): def test_build_standardized_signature_def_regression(self): input_tensors = { "input-1": - array_ops.placeholder( - dtypes.float32, 1, name="input-tensor-1") + array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1") } output_tensors = { "output-1": - array_ops.placeholder( - dtypes.float32, 1, name="output-tensor-1") + array_ops.placeholder(dtypes.float32, 1, name="output-tensor-1") } problem_type = constants.ProblemType.LINEAR_REGRESSION actual_signature_def = ( @@ -61,10 +87,9 @@ class SavedModelExportUtilsTest(test.TestCase): shape = tensor_shape_pb2.TensorShapeProto( dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) dtype = types_pb2.DataType.Value("DT_FLOAT") - expected_signature_def.inputs[ - signature_constants.REGRESS_INPUTS].CopyFrom( - meta_graph_pb2.TensorInfo( - name="input-tensor-1:0", dtype=dtype, tensor_shape=shape)) + expected_signature_def.inputs[signature_constants.REGRESS_INPUTS].CopyFrom( + meta_graph_pb2.TensorInfo( + name="input-tensor-1:0", dtype=dtype, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.REGRESS_OUTPUTS].CopyFrom( meta_graph_pb2.TensorInfo( @@ -77,13 +102,11 @@ class SavedModelExportUtilsTest(test.TestCase): """Tests classification with one output tensor.""" input_tensors = { "input-1": - array_ops.placeholder( - dtypes.float32, 1, name="input-tensor-1") + array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1") } output_tensors = { "output-1": - array_ops.placeholder( - dtypes.string, 1, name="output-tensor-1") + array_ops.placeholder(dtypes.string, 1, name="output-tensor-1") } problem_type = constants.ProblemType.CLASSIFICATION actual_signature_def = ( @@ -94,14 +117,14 @@ class SavedModelExportUtilsTest(test.TestCase): dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) dtype_float = types_pb2.DataType.Value("DT_FLOAT") dtype_string = types_pb2.DataType.Value("DT_STRING") - expected_signature_def.inputs[ - signature_constants.CLASSIFY_INPUTS].CopyFrom( - meta_graph_pb2.TensorInfo( - name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) + expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom( + meta_graph_pb2.TensorInfo( + name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom( meta_graph_pb2.TensorInfo( - name="output-tensor-1:0", dtype=dtype_string, + name="output-tensor-1:0", + dtype=dtype_string, tensor_shape=shape)) expected_signature_def.method_name = ( @@ -112,8 +135,7 @@ class SavedModelExportUtilsTest(test.TestCase): """Tests multiple output tensors that include classes and probabilities.""" input_tensors = { "input-1": - array_ops.placeholder( - dtypes.float32, 1, name="input-tensor-1") + array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1") } output_tensors = { "classes": @@ -136,19 +158,20 @@ class SavedModelExportUtilsTest(test.TestCase): dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) dtype_float = types_pb2.DataType.Value("DT_FLOAT") dtype_string = types_pb2.DataType.Value("DT_STRING") - expected_signature_def.inputs[ - signature_constants.CLASSIFY_INPUTS].CopyFrom( - meta_graph_pb2.TensorInfo( - name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) + expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom( + meta_graph_pb2.TensorInfo( + name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom( meta_graph_pb2.TensorInfo( - name="output-tensor-classes:0", dtype=dtype_string, + name="output-tensor-classes:0", + dtype=dtype_string, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom( meta_graph_pb2.TensorInfo( - name="output-tensor-proba:0", dtype=dtype_float, + name="output-tensor-proba:0", + dtype=dtype_float, tensor_shape=shape)) expected_signature_def.method_name = ( @@ -159,8 +182,7 @@ class SavedModelExportUtilsTest(test.TestCase): """Tests multiple output tensors that include classes and scores.""" input_tensors = { "input-1": - array_ops.placeholder( - dtypes.float32, 1, name="input-tensor-1") + array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1") } output_tensors = { "classes": @@ -182,19 +204,20 @@ class SavedModelExportUtilsTest(test.TestCase): dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) dtype_float = types_pb2.DataType.Value("DT_FLOAT") dtype_string = types_pb2.DataType.Value("DT_STRING") - expected_signature_def.inputs[ - signature_constants.CLASSIFY_INPUTS].CopyFrom( - meta_graph_pb2.TensorInfo( - name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) + expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom( + meta_graph_pb2.TensorInfo( + name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom( meta_graph_pb2.TensorInfo( - name="output-tensor-classes:0", dtype=dtype_string, + name="output-tensor-classes:0", + dtype=dtype_string, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom( meta_graph_pb2.TensorInfo( - name="output-tensor-scores:0", dtype=dtype_float, + name="output-tensor-scores:0", + dtype=dtype_float, tensor_shape=shape)) expected_signature_def.method_name = ( @@ -205,8 +228,7 @@ class SavedModelExportUtilsTest(test.TestCase): """Tests classification without classes tensor.""" input_tensors = { "input-1": - array_ops.placeholder( - dtypes.float32, 1, name="input-tensor-1") + array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1") } output_tensors = { "probabilities": @@ -224,14 +246,14 @@ class SavedModelExportUtilsTest(test.TestCase): shape = tensor_shape_pb2.TensorShapeProto( dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) dtype_float = types_pb2.DataType.Value("DT_FLOAT") - expected_signature_def.inputs[ - signature_constants.CLASSIFY_INPUTS].CopyFrom( - meta_graph_pb2.TensorInfo( - name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) + expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom( + meta_graph_pb2.TensorInfo( + name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom( meta_graph_pb2.TensorInfo( - name="output-tensor-proba:0", dtype=dtype_float, + name="output-tensor-proba:0", + dtype=dtype_float, tensor_shape=shape)) expected_signature_def.method_name = ( @@ -246,8 +268,7 @@ class SavedModelExportUtilsTest(test.TestCase): """ input_tensors = { "input-1": - array_ops.placeholder( - dtypes.float32, 1, name="input-tensor-1") + array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1") } output_tensors = { "classes": @@ -268,14 +289,14 @@ class SavedModelExportUtilsTest(test.TestCase): shape = tensor_shape_pb2.TensorShapeProto( dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) dtype_float = types_pb2.DataType.Value("DT_FLOAT") - expected_signature_def.inputs[ - signature_constants.CLASSIFY_INPUTS].CopyFrom( - meta_graph_pb2.TensorInfo( - name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) + expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom( + meta_graph_pb2.TensorInfo( + name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom( meta_graph_pb2.TensorInfo( - name="output-tensor-scores:0", dtype=dtype_float, + name="output-tensor-scores:0", + dtype=dtype_float, tensor_shape=shape)) expected_signature_def.method_name = ( @@ -290,8 +311,7 @@ class SavedModelExportUtilsTest(test.TestCase): """ input_tensors = { "input-1": - array_ops.placeholder( - dtypes.float32, 1, name="input-tensor-1") + array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1") } output_tensors = { "classes": @@ -310,17 +330,18 @@ class SavedModelExportUtilsTest(test.TestCase): dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) dtype_int64 = types_pb2.DataType.Value("DT_INT64") dtype_float = types_pb2.DataType.Value("DT_FLOAT") - expected_signature_def.inputs[ - "input-1"].CopyFrom( - meta_graph_pb2.TensorInfo( - name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) + expected_signature_def.inputs["input-1"].CopyFrom( + meta_graph_pb2.TensorInfo( + name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) expected_signature_def.outputs["classes"].CopyFrom( meta_graph_pb2.TensorInfo( - name="output-tensor-classes:0", dtype=dtype_int64, + name="output-tensor-classes:0", + dtype=dtype_int64, tensor_shape=shape)) expected_signature_def.outputs["logits"].CopyFrom( meta_graph_pb2.TensorInfo( - name="output-tensor-logits:0", dtype=dtype_float, + name="output-tensor-logits:0", + dtype=dtype_float, tensor_shape=shape)) expected_signature_def.method_name = ( @@ -379,8 +400,9 @@ class SavedModelExportUtilsTest(test.TestCase): def test_get_output_alternatives_single_no_default(self): prediction_tensor = constant_op.constant(["bogus"]) provided_output_alternatives = { - "head-1": (constants.ProblemType.LINEAR_REGRESSION, - {"output": prediction_tensor}), + "head-1": (constants.ProblemType.LINEAR_REGRESSION, { + "output": prediction_tensor + }), } model_fn_ops = model_fn.ModelFnOps( model_fn.ModeKeys.INFER, @@ -390,10 +412,11 @@ class SavedModelExportUtilsTest(test.TestCase): output_alternatives, _ = saved_model_export_utils.get_output_alternatives( model_fn_ops) - self.assertEqual({"head-1": - (constants.ProblemType.LINEAR_REGRESSION, - {"output": prediction_tensor})}, - output_alternatives) + self.assertEqual({ + "head-1": (constants.ProblemType.LINEAR_REGRESSION, { + "output": prediction_tensor + }) + }, output_alternatives) def test_get_output_alternatives_multi_no_default(self): provided_output_alternatives = { @@ -424,10 +447,11 @@ class SavedModelExportUtilsTest(test.TestCase): output_alternatives, _ = saved_model_export_utils.get_output_alternatives( model_fn_ops) - self.assertEqual( - {"default_output_alternative": (constants.ProblemType.UNSPECIFIED, { - "some_output": prediction_tensor})}, - output_alternatives) + self.assertEqual({ + "default_output_alternative": (constants.ProblemType.UNSPECIFIED, { + "some_output": prediction_tensor + }) + }, output_alternatives) def test_get_output_alternatives_empty_provided_with_default(self): prediction_tensor = constant_op.constant(["bogus"]) @@ -452,10 +476,11 @@ class SavedModelExportUtilsTest(test.TestCase): output_alternatives, _ = saved_model_export_utils.get_output_alternatives( model_fn_ops) - self.assertEqual( - {"default_output_alternative": (constants.ProblemType.UNSPECIFIED, { - "some_output": prediction_tensor})}, - output_alternatives) + self.assertEqual({ + "default_output_alternative": (constants.ProblemType.UNSPECIFIED, { + "some_output": prediction_tensor + }) + }, output_alternatives) def test_get_output_alternatives_implicit_single(self): prediction_tensor = constant_op.constant(["bogus"]) @@ -506,14 +531,14 @@ class SavedModelExportUtilsTest(test.TestCase): expected_signature_defs = { "serving_default": - signature_def_utils.regression_signature_def(input_example, - output_1), + signature_def_utils.regression_signature_def( + input_example, output_1), "default_input_alternative:head-1": - signature_def_utils.regression_signature_def(input_example, - output_1), + signature_def_utils.regression_signature_def( + input_example, output_1), "default_input_alternative:head-2": - signature_def_utils.classification_signature_def(input_example, - output_2, None), + signature_def_utils.classification_signature_def( + input_example, output_2, None), "default_input_alternative:head-3": signature_def_utils.predict_signature_def({ "default input": input_example @@ -624,17 +649,20 @@ class SavedModelExportUtilsTest(test.TestCase): (most_recent_export_dir, most_recent_export_version) = ( saved_model_export_utils.get_most_recent_export(export_dir_base)) - self.assertEqual(compat.as_bytes(export_dir_4), - compat.as_bytes(most_recent_export_dir)) - self.assertEqual(compat.as_bytes(export_dir_4), - os.path.join(compat.as_bytes(export_dir_base), - compat.as_bytes( - str(most_recent_export_version)))) + self.assertEqual( + compat.as_bytes(export_dir_4), compat.as_bytes(most_recent_export_dir)) + self.assertEqual( + compat.as_bytes(export_dir_4), + os.path.join( + compat.as_bytes(export_dir_base), + compat.as_bytes(str(most_recent_export_version)))) def test_make_export_strategy(self): """Only tests that an ExportStrategy instance is created.""" + def _serving_input_fn(): return array_ops.constant([1]), None + export_strategy = saved_model_export_utils.make_export_strategy( serving_input_fn=_serving_input_fn, default_output_alternative_key="default", @@ -655,14 +683,61 @@ class SavedModelExportUtilsTest(test.TestCase): real_valued_col1 = fc.real_valued_column("real_valued_column1") bucketized_col1 = fc.bucketized_column( fc.real_valued_column("real_valued_column_for_bucketization1"), [0, 4]) - feature_columns = [sparse_col, embedding_col, real_valued_col1, - bucketized_col1] + feature_columns = [ + sparse_col, embedding_col, real_valued_col1, bucketized_col1 + ] export_strategy = saved_model_export_utils.make_parsing_export_strategy( feature_columns=feature_columns) self.assertTrue( isinstance(export_strategy, export_strategy_lib.ExportStrategy)) + def test_make_best_model_export_strategy(self): + export_dir_base = tempfile.mkdtemp() + "export/" + gfile.MkDir(export_dir_base) + + test_estimator = TestEstimator() + export_strategy = saved_model_export_utils.make_best_model_export_strategy( + serving_input_fn=None, exports_to_keep=3, compare_fn=None) + + self.assertNotEqual("", + export_strategy.export(test_estimator, export_dir_base, + "fake_ckpt_0", {"loss": 100})) + self.assertNotEqual("", test_estimator.last_exported_dir) + self.assertNotEqual("", test_estimator.last_exported_checkpoint) + + self.assertEqual("", + export_strategy.export(test_estimator, export_dir_base, + "fake_ckpt_1", {"loss": 101})) + self.assertEqual(test_estimator.last_exported_dir, + os.path.join(export_dir_base, "fake_ckpt_0")) + + self.assertNotEqual("", + export_strategy.export(test_estimator, export_dir_base, + "fake_ckpt_2", {"loss": 10})) + self.assertEqual(test_estimator.last_exported_dir, + os.path.join(export_dir_base, "fake_ckpt_2")) + + self.assertEqual("", + export_strategy.export(test_estimator, export_dir_base, + "fake_ckpt_3", {"loss": 20})) + self.assertEqual(test_estimator.last_exported_dir, + os.path.join(export_dir_base, "fake_ckpt_2")) + + def test_make_best_model_export_strategy_exceptions(self): + export_dir_base = tempfile.mkdtemp() + "export/" + + test_estimator = TestEstimator() + export_strategy = saved_model_export_utils.make_best_model_export_strategy( + serving_input_fn=None, exports_to_keep=3, compare_fn=None) + + with self.assertRaises(ValueError): + export_strategy.export(test_estimator, export_dir_base, "", {"loss": 200}) + + with self.assertRaises(ValueError): + export_strategy.export(test_estimator, export_dir_base, "fake_ckpt_1", + None) + def _create_test_export_dir(export_dir_base): export_dir = saved_model_export_utils.get_timestamped_export_dir( diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md index 835d68489eb717db47fe76309f1609da8296d112..715eb5157762a3a08079d0845682f55dc05d7b76 100644 --- a/tensorflow/contrib/makefile/README.md +++ b/tensorflow/contrib/makefile/README.md @@ -201,7 +201,8 @@ tensorflow/contrib/makefile/compile_ios_protobuf.sh Then, you will need to compile the nsync library for iOS: -```export HOST_NSYNC_LIB=`tensorflow/contrib/makefile/compile_nsync.sh` +```bash +export HOST_NSYNC_LIB=`tensorflow/contrib/makefile/compile_nsync.sh` export TARGET_NSYNC_LIB=`tensorflow/contrib/makefile/compile_nsync.sh -t ios` ``` diff --git a/tensorflow/contrib/makefile/compile_nsync.sh b/tensorflow/contrib/makefile/compile_nsync.sh index 207661ee46d27966ba2875aec27ec152093e7b57..729ab6b0373257f1d28c4d75df63db64e6006466 100755 --- a/tensorflow/contrib/makefile/compile_nsync.sh +++ b/tensorflow/contrib/makefile/compile_nsync.sh @@ -215,12 +215,12 @@ for arch in $archs; do armeabi-v7a) toolchain="arm-linux-androideabi-4.9" sysroot_arch="arm" bin_prefix="arm-linux-androideabi" - march_option="-march=armv7-a" + march_option="-march=armv7-a -mfloat-abi=softfp -mfpu=neon" ;; armeabi-v7a-hard) toolchain="arm-linux-androideabi-4.9" sysroot_arch="arm" bin_prefix="arm-linux-androideabi" - march_option="-march=armv7-a" + march_option="-march=armv7-a -mfpu=neon" ;; mips) toolchain="mipsel-linux-android-4.9" sysroot_arch="mips" @@ -266,8 +266,7 @@ for arch in $archs; do -I$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/libs/'"$arch"'/include \ -I../../platform/c++11 -I../../platform/gcc \ -I../../platform/posix -pthread - PLATFORM_CFLAGS=-std=c++11 -Wno-narrowing '"$march_option"' \ - -mfloat-abi=softfp -mfpu=neon -fPIE + PLATFORM_CFLAGS=-std=c++11 -Wno-narrowing '"$march_option"' -fPIE PLATFORM_LDFLAGS=-pthread MKDEP=${CC} -M -std=c++11 PLATFORM_C=../../platform/c++11/src/nsync_semaphore_mutex.cc \ diff --git a/tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in b/tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in index 631d52235a48f6bbb64389481fa7c88bb586608b..26c1ad4947363e98d9bb8e400f40290fb87b2e4e 100644 --- a/tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in +++ b/tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in @@ -52,7 +52,9 @@ $(INFERENCE_SO_PATH): $(LIB_OBJS) $(INFERENCE_OBJS) @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) \ -o $@ $(INFERENCE_OBJS) $(LIB_OBJS) \ - $(LIBFLAGS) $(LDFLAGS) -shared $(LIBS) + $(LIBFLAGS) $(LDFLAGS) \ + -shared -Wl,-soname,$(INFERENCE_SO_NAME) \ + $(LIBS) $(INFERENCE_SO_NAME): $(INFERENCE_SO_PATH) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 00cde08bff1157dd44adad3e1bdeff674fb0a444..9b959b43a9db8baac5b37524e81bfbb11d6ad868 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -1496,6 +1496,15 @@ class StreamingAUCTest(test.TestCase): for _ in range(10): self.assertAlmostEqual(initial_auc, auc.eval(), 5) + def testPredictionsOutOfRange(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [1, -1, 1, -1], shape=(1, 4), dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + _, update_op = metrics.streaming_auc(predictions, labels) + sess.run(variables.local_variables_initializer()) + self.assertRaises(errors_impl.InvalidArgumentError, update_op.eval) + def testAllCorrect(self): self.allCorrectAsExpected('ROC') diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD index 338181e4cacab6e849050c670c433510ea75df8d..d6508362b8bf01468a43b26d6a0d0c9807b5967e 100644 --- a/tensorflow/contrib/nccl/BUILD +++ b/tensorflow/contrib/nccl/BUILD @@ -48,6 +48,8 @@ tf_cuda_cc_test( # Disabled on jenkins until errors finding nvmlShutdown are found. tags = [ "manual", + "no_oss", + "noguitar", # note: is run manually there "notap", ], deps = if_cuda( @@ -112,25 +114,26 @@ tf_custom_op_py_library( ], ) -# http://b/62064807 -# cuda_py_test( -# name = "nccl_ops_test", -# size = "small", -# srcs = ["python/ops/nccl_ops_test.py"], -# additional_deps = [ -# ":nccl_py", -# "//tensorflow/python:array_ops", -# "//tensorflow/python:client_testlib", -# "//tensorflow/python:framework_for_generated_wrappers", -# "//tensorflow/python:framework_test_lib", -# "//tensorflow/python:platform_test", -# ], -# # Disabled on jenkins until errors finding nvmlShutdown are found. -# tags = [ -# "manual", -# "notap", -# ], -# ) +cuda_py_test( + name = "nccl_ops_test", + size = "small", + srcs = ["python/ops/nccl_ops_test.py"], + additional_deps = [ + ":nccl_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + # Disabled on jenkins until errors finding nvmlShutdown are found. + tags = [ + "manual", + "no_oss", + "noguitar", # note: is run manually there + "notap", + ], +) filegroup( name = "all_files", diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.h b/tensorflow/contrib/nccl/kernels/nccl_manager.h index 1a661e8f7f777365e91bbade09435f911623e1d4..6e2f8e953a5f0094a291e07367e71043925e0a2b 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager.h +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.h @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "external/nccl_archive/src/nccl.h" +#include "src/nccl.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/mutex.h" diff --git a/tensorflow/contrib/nccl/kernels/nccl_ops.cc b/tensorflow/contrib/nccl/kernels/nccl_ops.cc index 3c532e3d73121777facadde4a1c19a7669bc1ee9..d4455483f77f170c30ce070d47b7812ad0b44612 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_ops.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "external/nccl_archive/src/nccl.h" +#include "src/nccl.h" #include "tensorflow/contrib/nccl/kernels/nccl_manager.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py index ae658e732278ab85b97b76290ab5c3d183c11434..1621e9f28e37d5c31cd3818ec75d1aaff41e77b5 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py @@ -43,7 +43,8 @@ class AllReduceTest(test.TestCase): self._testSingleAllReduce(sess, dtype, nccl.all_max, np.maximum) def _testSingleAllReduce(self, sess, np_type, nccl_fn, numpy_accumulation_fn): - for devices in [['/device:GPU:0', '/device:GPU:0', '/device:GPU:0'], ['/device:GPU:0', '/device:GPU:0']]: + for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'], + ['/device:GPU:1', '/device:GPU:0']]: shape = (3, 4) np_ans = None tensors = [] @@ -84,7 +85,8 @@ class BroadcastTest(test.TestCase): # Create session inside outer loop to test use of # same communicator across multiple sessions. with self.test_session(use_gpu=True) as sess: - for devices in [['/device:GPU:0', '/device:GPU:0', '/device:GPU:0'], ['/device:GPU:0', '/device:GPU:0']]: + for devices in [['/device:GPU:1', '/device:GPU:0', '/device:GPU:2'], + ['/device:GPU:1', '/device:GPU:0']]: shape = (3, 4) sender = np.random.randint(0, len(devices) - 1) with ops.device(devices[sender]): @@ -115,7 +117,8 @@ class CombinedTest(test.TestCase): # Create session inside outer loop to test use of # same communicator across multiple sessions. with self.test_session(use_gpu=True) as sess: - for devices in [['/device:GPU:0', '/device:GPU:0', '/device:GPU:0'], ['/device:GPU:0', '/device:GPU:0']]: + for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'], + ['/device:GPU:0', '/device:GPU:1']]: shape = (3, 4) # all-reduce diff --git a/tensorflow/contrib/receptive_field/BUILD b/tensorflow/contrib/receptive_field/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..ed2f3af08cbbd8ae5da2a87f4a7dd9854493c346 --- /dev/null +++ b/tensorflow/contrib/receptive_field/BUILD @@ -0,0 +1,71 @@ +# Description: +# Contains modules to compute receptive field parameters for CNN models. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +# Transitive dependencies of this target will be included in the pip package. +py_library( + name = "receptive_field_pip", + deps = [ + ":graph_compute_order_py", + ":receptive_field_py", + ], +) + +py_library( + name = "graph_compute_order_py", + srcs = [ + "__init__.py", + "python/util/graph_compute_order.py", + ], + srcs_version = "PY2AND3", +) + +py_library( + name = "receptive_field_py", + srcs = [ + "__init__.py", + "python/util/receptive_field.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":graph_compute_order_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + ], +) + +py_test( + name = "receptive_field_test", + srcs = ["python/util/receptive_field_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":receptive_field_py", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/slim", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:nn", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/receptive_field/README.md b/tensorflow/contrib/receptive_field/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b150b903b23581abfb69a59f5ce4cdaa41ab08b9 --- /dev/null +++ b/tensorflow/contrib/receptive_field/README.md @@ -0,0 +1,165 @@ +# Receptive field computation for convnets + +This library enables you to easily compute the receptive field parameters of +your favorite convnet. You can use it to understand how big of an input image +region your output features depend on. Better yet, using the parameters computed +by the library, you can easily find the exact image region which is used to +compute each convnet feature. + +## Basic usage + +The main function to be called is `compute_receptive_field_from_graph_def`, +which will return the receptive field, effective stride and effective padding +for both horizontal and vertical directions. + +For example, if your model is constructed using the function +`my_model_construction()`, you can use the library as follows: + +```python +import tensorflow as tf +from tensorflow.contrib import receptive_field + +# Construct graph. +g = tf.Graph() +with g.as_default(): + images = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='input_image') + my_model_construction(images) + +# Compute receptive field parameters. +rf_x, rf_y, eff_stride_x, eff_stride_y, eff_pad_x, eff_pad_y = \ + receptive_field.compute_receptive_field_from_graph_def( \ + g.as_graph_def(), 'input_image', 'my_output_endpoint') +``` + +Here's a simple example of computing the receptive field parameters for +Inception-Resnet-v2. To get this to work, be sure to checkout +[tensorflow/models](https://github.com/tensorflow/models), so that the Inception +models are available to you. This can be done in three simple commands: + +```sh +git clone https://github.com/tensorflow/models +cd models/slim +sudo python setup.py install_lib +``` + +You can then compute the receptive field parameters for Inception-Resnet-v2 as: + +```python +from nets import inception +import tensorflow as tf +from tensorflow.contrib import receptive_field + +# Construct graph. +g = tf.Graph() +with g.as_default(): + images = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='input_image') + inception.inception_resnet_v2_base(images) + +# Compute receptive field parameters. +rf_x, rf_y, eff_stride_x, eff_stride_y, eff_pad_x, eff_pad_y = \ + receptive_field.compute_receptive_field_from_graph_def( \ + g.as_graph_def(), 'input_image', 'InceptionResnetV2/Conv2d_7b_1x1/Relu') +``` + +This will give you `rf_x = rf_y = 3039`, `eff_stride_x = eff_stride_y = 32`, and +`eff_pad_x = eff_pad_y = 1482`. This means that each feature that is output at +the node `'InceptionResnetV2/Conv2d_7b_1x1/Relu'` is computed from a region +which is of size `3039x3039`. Further, by using the expressions + +```python +center_x = -eff_pad_x + feature_x*eff_stride_x + (rf_x - 1)/2 +center_y = -eff_pad_y + feature_y*eff_stride_y + (rf_y - 1)/2 +``` + +one can compute the center of the region in the input image that is used to +compute the output feature at position `[feature_x, feature_y]`. For example, +the feature at position `[0, 2]` at the output of the layer +`'InceptionResnetV2/Conv2d_7b_1x1/Relu'` is centered in the original image in +the position `[37, 101]`. + +TODO: include link to derivations and definitions of different parameters. + +## Receptive field benchmark + +As you might expect, it is straightforward to run this library on the popular +convnets, and gather their receptive fields. We provide a python script which +does exactly that, available under `python/util/examples/rf_benchmark.py`. + +To get this to work, be sure to checkout +[tensorflow/models](https://github.com/tensorflow/models) (see the 3-command +instructions for this above). Then, simply: + +```sh +cd python/util/examples +python rf_benchmark.py --csv_path /tmp/rf_benchmark_results.csv +``` + +The script will write to stdout the receptive field parameters for many variants +of several popular convnets: AlexNet, VGG, ResNet, Inception, Mobilenet. They +are also written to the file `/tmp/rf_benchmark_results.csv`. + +TODO: include here a plot for receptive field sizes of different convnets. + +TODO: include table/link to pre-computed RF parameters. + +## Compute RF parameters from a graph pbtxt + +We also provide a utility to compute the receptive field parameters directly +from a graph protobuf file. + +Have a `graph.pbtxt` file and want to compute its receptive field parameters? We +got you covered. The only prerequisite is to install +[google/protobuf](https://github.com/google/protobuf), which you probably +already have if you're using tensorflow (otherwise, follow installation +instructions [here](https://github.com/google/protobuf/tree/master/python)). + +This should work: + +```sh +cd python/util/examples +python compute_rf.py \ + --graph_path /path/to/graph.pbtxt \ + --output_path /path/to/output/rf_info.txt \ + --input_node my_input_node \ + --output_node my_output_node +``` + +Don't know how to generate a graph protobuf file? Take a look at the +`write_inception_resnet_v2_graph.py` script, which shows how to save it for the +Inception-Resnet-v2 model: + +```sh +cd python/util/examples +python write_inception_resnet_v2_graph.py --graph_dir /tmp --graph_filename graph.pbtxt +``` + +This will write the Inception-Resnet-v2 graph protobuf to `/tmp/graph.pbtxt`. + +For completeness, here's how you would use this file to get the receptive field +parameters of the Inception-Resnet-v2 model: + +```sh +cd python/util/examples +python compute_rf.py \ + --graph_path /tmp/graph.pbtxt \ + --output_path /tmp/rf_info.txt \ + --input_node input_image \ + --output_node InceptionResnetV2/Conv2d_7b_1x1/Relu +``` + +This will write the receptive field parameters of the model to +`/tmp/rf_info.txt`, which will look like: + +```sh +Receptive field size (horizontal) = 3039 +Receptive field size (vertical) = 3039 +Effective stride (horizontal) = 32 +Effective stride (vertical) = 32 +Effective padding (horizontal) = 1482 +Effective padding (vertical) = 1482 +``` + +## Authors + +André Araujo (github id: andrefaraujo) and Mark Sandler (github id: +marksandler) diff --git a/tensorflow/contrib/receptive_field/__init__.py b/tensorflow/contrib/receptive_field/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10745a6a53d5b3ef9521b2313ddc28799ee8b886 --- /dev/null +++ b/tensorflow/contrib/receptive_field/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module to compute receptive field parameters for CNN tensorflow models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import +from tensorflow.contrib.receptive_field.python.util.graph_compute_order import get_compute_order +from tensorflow.contrib.receptive_field.python.util.receptive_field import compute_receptive_field_from_graph_def +# pylint: enable=unused-import diff --git a/tensorflow/contrib/receptive_field/python/__init__.py b/tensorflow/contrib/receptive_field/python/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..217047f92d33448440c7acd1535fbdbf80bfe011 --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module to compute receptive field parameters for CNN tensorflow models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf978b90a3661a075130790d82a499da4d8a0cc --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/examples/compute_rf.py @@ -0,0 +1,94 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Computes Receptive Field (RF) information given a graph protobuf. + +For an example of usage, see accompanying file compute_rf.sh +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +from google.protobuf import text_format + +from tensorflow.contrib import receptive_field +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.platform import app +from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging as logging + +cmd_args = None + + +def _load_graphdef(path): + """Helper function to load GraphDef from file. + + Args: + path: Path to pbtxt file. + + Returns: + graph_def: A GraphDef object. + """ + graph_def = graph_pb2.GraphDef() + pbstr = gfile.Open(path).read() + text_format.Parse(pbstr, graph_def) + return graph_def + + +def main(unused_argv): + + graph_def = _load_graphdef(cmd_args.graph_path) + + (receptive_field_x, receptive_field_y, effective_stride_x, effective_stride_y, + effective_padding_x, effective_padding_y + ) = receptive_field.compute_receptive_field_from_graph_def( + graph_def, cmd_args.input_node, cmd_args.output_node) + + logging.info('Receptive field size (horizontal) = %s', receptive_field_x) + logging.info('Receptive field size (vertical) = %s', receptive_field_y) + logging.info('Effective stride (horizontal) = %s', effective_stride_x) + logging.info('Effective stride (vertical) = %s', effective_stride_y) + logging.info('Effective padding (horizontal) = %s', effective_padding_x) + logging.info('Effective padding (vertical) = %s', effective_padding_y) + + f = gfile.GFile('%s' % cmd_args.output_path, 'w') + f.write('Receptive field size (horizontal) = %s\n' % receptive_field_x) + f.write('Receptive field size (vertical) = %s\n' % receptive_field_y) + f.write('Effective stride (horizontal) = %s\n' % effective_stride_x) + f.write('Effective stride (vertical) = %s\n' % effective_stride_y) + f.write('Effective padding (horizontal) = %s\n' % effective_padding_x) + f.write('Effective padding (vertical) = %s\n' % effective_padding_y) + f.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--graph_path', type=str, default='', help='Graph path (pbtxt format).') + parser.add_argument( + '--output_path', + type=str, + default='', + help='Path to output text file where RF information will be written to.') + parser.add_argument( + '--input_node', type=str, default='', help='Name of input node.') + parser.add_argument( + '--output_node', type=str, default='', help='Name of output node.') + cmd_args, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..94228dfa61b1de617f131611173fda7c3917d250 --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py @@ -0,0 +1,460 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Computes Receptive Field (RF) information for different models. + +The receptive field (and related parameters) for the different models are +printed to stdout, and may also optionally be written to a CSV file. + +For an example of usage, see rf_benchmark.sh +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import csv +import sys + +from nets import alexnet +from nets import inception +from nets import mobilenet_v1 +from nets import resnet_v1 +from nets import resnet_v2 +from nets import vgg +from tensorflow.contrib import framework +from tensorflow.contrib import receptive_field +from tensorflow.contrib import slim +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import app + +cmd_args = None + +# Input node name for all architectures. +_INPUT_NODE = 'input_image' + +# Variants of different network architectures. + +# - resnet: different versions and sizes. +_SUPPORTED_RESNET_VARIANTS = [ + 'resnet_v1_50', 'resnet_v1_101', 'resnet_v1_152', 'resnet_v1_200', + 'resnet_v2_50', 'resnet_v2_101', 'resnet_v2_152', 'resnet_v2_200' +] + +# - inception_resnet_v2: default, and version with SAME padding. +_SUPPORTED_INCEPTIONRESNETV2_VARIANTS = [ + 'inception_resnet_v2', 'inception_resnet_v2-same' +] + +# - inception_v2: default, and version with no separable conv. +_SUPPORTED_INCEPTIONV2_VARIANTS = [ + 'inception_v2', 'inception_v2-no-separable-conv' +] + +# - inception_v3: default version. +_SUPPORTED_INCEPTIONV3_VARIANTS = ['inception_v3'] + +# - inception_v4: default version. +_SUPPORTED_INCEPTIONV4_VARIANTS = ['inception_v4'] + +# - alexnet_v2: default version. +_SUPPORTED_ALEXNETV2_VARIANTS = ['alexnet_v2'] + +# - vgg: vgg_a (with 11 layers) and vgg_16 (version D). +_SUPPORTED_VGG_VARIANTS = ['vgg_a', 'vgg_16'] + +# - mobilenet_v1: 100% and 75%. +_SUPPORTED_MOBILENETV1_VARIANTS = ['mobilenet_v1', 'mobilenet_v1_075'] + + +def _construct_model(model_type='resnet_v1_50'): + """Constructs model for the desired type of CNN. + + Args: + model_type: Type of model to be used. + + Returns: + end_points: A dictionary from components of the network to the corresponding + activations. + + Raises: + ValueError: If the model_type is not supported. + """ + # Placeholder input. + images = array_ops.placeholder( + dtypes.float32, shape=(1, None, None, 3), name=_INPUT_NODE) + + # Construct model. + if model_type == 'inception_resnet_v2': + _, end_points = inception.inception_resnet_v2_base(images) + elif model_type == 'inception_resnet_v2-same': + _, end_points = inception.inception_resnet_v2_base( + images, align_feature_maps=True) + elif model_type == 'inception_v2': + _, end_points = inception.inception_v2_base(images) + elif model_type == 'inception_v2-no-separable-conv': + _, end_points = inception.inception_v2_base( + images, use_separable_conv=False) + elif model_type == 'inception_v3': + _, end_points = inception.inception_v3_base(images) + elif model_type == 'inception_v4': + _, end_points = inception.inception_v4_base(images) + elif model_type == 'alexnet_v2': + _, end_points = alexnet.alexnet_v2(images) + elif model_type == 'vgg_a': + _, end_points = vgg.vgg_a(images) + elif model_type == 'vgg_16': + _, end_points = vgg.vgg_16(images) + elif model_type == 'mobilenet_v1': + _, end_points = mobilenet_v1.mobilenet_v1_base(images) + elif model_type == 'mobilenet_v1_075': + _, end_points = mobilenet_v1.mobilenet_v1_base( + images, depth_multiplier=0.75) + elif model_type == 'resnet_v1_50': + _, end_points = resnet_v1.resnet_v1_50( + images, num_classes=None, is_training=False, global_pool=False) + elif model_type == 'resnet_v1_101': + _, end_points = resnet_v1.resnet_v1_101( + images, num_classes=None, is_training=False, global_pool=False) + elif model_type == 'resnet_v1_152': + _, end_points = resnet_v1.resnet_v1_152( + images, num_classes=None, is_training=False, global_pool=False) + elif model_type == 'resnet_v1_200': + _, end_points = resnet_v1.resnet_v1_200( + images, num_classes=None, is_training=False, global_pool=False) + elif model_type == 'resnet_v2_50': + _, end_points = resnet_v2.resnet_v2_50( + images, num_classes=None, is_training=False, global_pool=False) + elif model_type == 'resnet_v2_101': + _, end_points = resnet_v2.resnet_v2_101( + images, num_classes=None, is_training=False, global_pool=False) + elif model_type == 'resnet_v2_152': + _, end_points = resnet_v2.resnet_v2_152( + images, num_classes=None, is_training=False, global_pool=False) + elif model_type == 'resnet_v2_200': + _, end_points = resnet_v2.resnet_v2_200( + images, num_classes=None, is_training=False, global_pool=False) + else: + raise ValueError('Unsupported model_type %s.' % model_type) + + return end_points + + +def _get_desired_end_point_keys(model_type='resnet_v1_50'): + """Gets list of desired end point keys for a type of CNN. + + Args: + model_type: Type of model to be used. + + Returns: + desired_end_point_types: A list containing the desired end-points. + + Raises: + ValueError: If the model_type is not supported. + """ + if model_type in _SUPPORTED_RESNET_VARIANTS: + blocks = ['block1', 'block2', 'block3', 'block4'] + desired_end_point_keys = ['%s/%s' % (model_type, i) for i in blocks] + elif model_type in _SUPPORTED_INCEPTIONRESNETV2_VARIANTS: + desired_end_point_keys = [ + 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'MaxPool_3a_3x3', + 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b', + 'Mixed_6a', 'PreAuxLogits', 'Mixed_7a', 'Conv2d_7b_1x1' + ] + elif model_type in _SUPPORTED_INCEPTIONV2_VARIANTS: + desired_end_point_keys = [ + 'Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 'Conv2d_2c_3x3', + 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 'Mixed_4a', 'Mixed_4b', + 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c' + ] + elif model_type in _SUPPORTED_INCEPTIONV3_VARIANTS: + desired_end_point_keys = [ + 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'MaxPool_3a_3x3', + 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 'Mixed_5b', + 'Mixed_5c', 'Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', + 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c' + ] + elif model_type in _SUPPORTED_INCEPTIONV4_VARIANTS: + desired_end_point_keys = [ + 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a', + 'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 'Mixed_5e', + 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 'Mixed_6f', + 'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c', 'Mixed_7d' + ] + elif model_type in _SUPPORTED_ALEXNETV2_VARIANTS: + ep = ['conv1', 'pool1', 'conv2', 'conv3', 'conv4', 'conv5', 'pool5'] + desired_end_point_keys = ['%s/%s' % (model_type, i) for i in ep] + elif model_type in _SUPPORTED_VGG_VARIANTS: + ep = [ + 'conv1/conv1_1', 'pool1', 'conv2/conv2_1', 'pool2', 'conv3/conv3_1', + 'conv3/conv3_2', 'pool3', 'conv4/conv4_1', 'conv4/conv4_2', 'pool4', + 'conv5/conv5_1', 'conv5/conv5_2', 'pool5' + ] + desired_end_point_keys = ['%s/%s' % (model_type, i) for i in ep] + elif model_type in _SUPPORTED_MOBILENETV1_VARIANTS: + desired_end_point_keys = [ + 'Conv2d_0', 'Conv2d_1_pointwise', 'Conv2d_2_pointwise', + 'Conv2d_3_pointwise', 'Conv2d_4_pointwise', 'Conv2d_5_pointwise', + 'Conv2d_6_pointwise', 'Conv2d_7_pointwise', 'Conv2d_8_pointwise', + 'Conv2d_9_pointwise', 'Conv2d_10_pointwise', 'Conv2d_11_pointwise', + 'Conv2d_12_pointwise', 'Conv2d_13_pointwise' + ] + else: + raise ValueError('Unsupported model_type %s.' % model_type) + + return desired_end_point_keys + + +def _model_graph_def(model_type='resnet_v1_50', arg_sc=None): + """Constructs a model graph, returning GraphDef and end-points. + + Args: + model_type: Type of model to be used. + arg_sc: Optional arg scope to use in constructing the graph. + + Returns: + graph_def: GraphDef of constructed graph. + end_points: A dictionary from components of the network to the corresponding + activations. + """ + if arg_sc is None: + arg_sc = {} + g = ops.Graph() + with g.as_default(): + with framework.arg_scope(arg_sc): + end_points = _construct_model(model_type) + + return g.as_graph_def(), end_points + + +def _model_rf(graphdef, + end_points, + desired_end_point_keys, + model_type='resnet_v1_50', + csv_writer=None): + """Computes receptive field information for a given CNN model. + + The information will be printed to stdout. If the RF parameters are the same + for the horizontal and vertical directions, it will be printed only once. + Otherwise, they are printed once for the horizontal and once for the vertical + directions. + + Args: + graphdef: GraphDef of given model. + end_points: A dictionary from components of the model to the corresponding + activations. + desired_end_point_keys: List of desired end points for which receptive field + information will be computed. + model_type: Type of model to be used, used only for printing purposes. + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for desired_end_point_key in desired_end_point_keys: + print('- %s:' % desired_end_point_key) + output_node_with_colon = end_points[desired_end_point_key].name + pos = output_node_with_colon.rfind(':') + output_node = output_node_with_colon[:pos] + (receptive_field_x, receptive_field_y, effective_stride_x, + effective_stride_y, effective_padding_x, effective_padding_y + ) = receptive_field.compute_receptive_field_from_graph_def( + graphdef, _INPUT_NODE, output_node) + # If values are the same in horizontal/vertical directions, just report one + # of them. Otherwise, report both. + if (receptive_field_x == receptive_field_y) and ( + effective_stride_x == effective_stride_y) and ( + effective_padding_x == effective_padding_y): + print('Receptive field size = %5s, effective stride = %5s, effective ' + 'padding = %5s' % (str(receptive_field_x), str(effective_stride_x), + str(effective_padding_x))) + else: + print('Receptive field size: horizontal = %5s, vertical = %5s. ' + 'Effective stride: horizontal = %5s, vertical = %5s. Effective ' + 'padding: horizontal = %5s, vertical = %5s' % + (str(receptive_field_x), str(receptive_field_y), + str(effective_stride_x), str(effective_stride_y), + str(effective_padding_x), str(effective_padding_y))) + if csv_writer is not None: + csv_writer.writerow({ + 'CNN': model_type, + 'end_point': desired_end_point_key, + 'RF size hor': str(receptive_field_x), + 'RF size ver': str(receptive_field_y), + 'effective stride hor': str(effective_stride_x), + 'effective stride ver': str(effective_stride_y), + 'effective padding hor': str(effective_padding_x), + 'effective padding ver': str(effective_padding_y) + }) + + +def _process_model_rf(model_type='resnet_v1_50', csv_writer=None, arg_sc=None): + """Contructs model graph and desired end-points, and compute RF. + + The computed RF parameters are printed to stdout by the _model_rf function. + + Args: + model_type: Type of model to be used. + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + arg_sc: Optional arg scope to use in constructing the graph. + + """ + print('********************%s' % model_type) + graphdef, end_points = _model_graph_def(model_type, arg_sc) + desired_end_point_keys = _get_desired_end_point_keys(model_type) + _model_rf(graphdef, end_points, desired_end_point_keys, model_type, + csv_writer) + + +def _resnet_rf(csv_writer=None): + """Computes RF and associated parameters for resnet models. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_RESNET_VARIANTS: + arg_sc = resnet_v1.resnet_arg_scope() + _process_model_rf(model_type, csv_writer, arg_sc) + + +def _inception_resnet_v2_rf(csv_writer=None): + """Computes RF and associated parameters for the inception_resnet_v2 model. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_INCEPTIONRESNETV2_VARIANTS: + _process_model_rf(model_type, csv_writer) + + +def _inception_v2_rf(csv_writer=None): + """Computes RF and associated parameters for the inception_v2 model. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_INCEPTIONV2_VARIANTS: + _process_model_rf(model_type, csv_writer) + + +def _inception_v3_rf(csv_writer=None): + """Computes RF and associated parameters for the inception_v3 model. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_INCEPTIONV3_VARIANTS: + _process_model_rf(model_type, csv_writer) + + +def _inception_v4_rf(csv_writer=None): + """Computes RF and associated parameters for the inception_v4 model. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_INCEPTIONV4_VARIANTS: + _process_model_rf(model_type, csv_writer) + + +def _alexnet_v2_rf(csv_writer=None): + """Computes RF and associated parameters for the alexnet_v2 model. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_ALEXNETV2_VARIANTS: + _process_model_rf(model_type, csv_writer) + + +def _vgg_rf(csv_writer=None): + """Computes RF and associated parameters for the vgg model. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_VGG_VARIANTS: + _process_model_rf(model_type, csv_writer) + + +def _mobilenet_v1_rf(csv_writer=None): + """Computes RF and associated parameters for the mobilenet_v1 model. + + The computed values are written to stdout. + + Args: + csv_writer: A CSV writer for RF parameters, which is used if it is not None. + """ + for model_type in _SUPPORTED_MOBILENETV1_VARIANTS: + with slim.arg_scope( + [slim.batch_norm, slim.dropout], is_training=False) as arg_sc: + _process_model_rf(model_type, csv_writer, arg_sc) + + +def main(unused_argv): + # Configure CSV file which will be written, if desired. + if cmd_args.csv_path: + csv_file = open(cmd_args.csv_path, 'w') + field_names = [ + 'CNN', 'end_point', 'RF size hor', 'RF size ver', + 'effective stride hor', 'effective stride ver', 'effective padding hor', + 'effective padding ver' + ] + rf_writer = csv.DictWriter(csv_file, fieldnames=field_names) + rf_writer.writeheader() + else: + rf_writer = None + + # Compute RF parameters for each network architecture. + _alexnet_v2_rf(rf_writer) + _vgg_rf(rf_writer) + _inception_v2_rf(rf_writer) + _inception_v3_rf(rf_writer) + _inception_v4_rf(rf_writer) + _inception_resnet_v2_rf(rf_writer) + _mobilenet_v1_rf(rf_writer) + _resnet_rf(rf_writer) + + # Close CSV file, if it was opened. + if cmd_args.csv_path: + csv_file.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--csv_path', + type=str, + default='', + help="""\ + Path to CSV file that will be written with RF parameters.If empty, no + file will be written.\ + """) + cmd_args, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/receptive_field/python/util/examples/write_inception_resnet_v2_graph.py b/tensorflow/contrib/receptive_field/python/util/examples/write_inception_resnet_v2_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..793ae163d807fdda62c2025cb8176b96832cb61a --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/examples/write_inception_resnet_v2_graph.py @@ -0,0 +1,61 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Simple script to write Inception-ResNet-v2 model to graph file. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +from nets import inception +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import graph_io +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import app + +cmd_args = None + + +def main(unused_argv): + # Model definition. + g = ops.Graph() + with g.as_default(): + images = array_ops.placeholder( + dtypes.float32, shape=(1, None, None, 3), name='input_image') + inception.inception_resnet_v2_base(images) + + graph_io.write_graph(g.as_graph_def(), cmd_args.graph_dir, + cmd_args.graph_filename) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--graph_dir', + type=str, + default='/tmp', + help='Directory where graph will be saved.') + parser.add_argument( + '--graph_filename', + type=str, + default='graph.pbtxt', + help='Filename of graph that will be saved.') + cmd_args, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py b/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py new file mode 100644 index 0000000000000000000000000000000000000000..8af4be16d6c17286287713a1fb6f5017355e3b32 --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/graph_compute_order.py @@ -0,0 +1,88 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Library to compute order of computations in a graph. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + + +class GraphDefHelper(object): + """Helper class to collect node names and definitions. + + Example: + b = GraphDefHelper(graph_def) + # Prints node that produces given output. + print b.output_of['conv/foo/bar'] + """ + + def __init__(self, gd): + self.output_of = {} + for each in gd.node: + self.output_of[each.name] = each + + +# pylint: disable=invalid-name +_NodeEntry = collections.namedtuple('NodeEntry', field_names=['order', 'node']) + + +def _get_computed_nodes(g, output, seen): + """Traverses the graph in topological order. + + Args: + g: GraphDefHelper object. + output: current node. + seen: map of nodes we've already traversed. + Returns: + order in topological sort for 'output'. + """ + if output in seen: + return seen[output].order + node_def = g.output_of.get(output, None) + if node_def is None: + seen[output] = _NodeEntry(0, None) + return 0 + + r = 0 + for each in node_def.input: + # Parses name of input node. + if each.startswith('^'): + each = each[1:] + each = each.split(':')[0] + # Recursively computes ordering. + new_v = _get_computed_nodes(g, each, seen) + r = max(r, new_v + 1) + + seen[output] = _NodeEntry(r, node_def) + + return seen[output].order + + +def get_compute_order(graph_def): + """Computes order of computation for a given graph. + + Args: + graph_def: GraphDef object. + Returns: + map: name -> {order, node} + """ + helper = GraphDefHelper(graph_def) + seen = collections.defaultdict(_NodeEntry) + for each in graph_def.node: + _get_computed_nodes(helper, each.name, seen) + return seen diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field.py b/tensorflow/contrib/receptive_field/python/util/receptive_field.py new file mode 100644 index 0000000000000000000000000000000000000000..db190a1a41668bff3f6db1c674192980db068838 --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field.py @@ -0,0 +1,485 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions to compute receptive field of a fully-convolutional network. + +Please refer to the following g3doc for detailed explanation on how this +computation is performed, and why it is important: +g3doc/photos/vision/features/delf/g3doc/rf_computation.md +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +from tensorflow.contrib.receptive_field.python.util import graph_compute_order +from tensorflow.contrib.util import make_ndarray +from tensorflow.python.platform import tf_logging as logging + +# White-listed layer operations, which do not affect the receptive field +# computation. +_UNCHANGED_RF_LAYER_OPS = [ + "Softplus", "Relu", "BiasAdd", "Mul", "Add", "Const", "Identity", + "VariableV2", "Sub", "Rsqrt", "ConcatV2" +] + +# Different ways in which padding modes may be spelled. +_VALID_PADDING = ["VALID", b"VALID"] +_SAME_PADDING = ["SAME", b"SAME"] + + +def _stride_size(node): + """Computes stride size given a TF node. + + Args: + node: Tensorflow node (NodeDef proto). + + Returns: + stride_x: Stride size for horizontal direction (integer). + stride_y: Stride size for vertical direction (integer). + """ + strides_attr = node.attr["strides"] + logging.vlog(4, "strides_attr = %s", strides_attr) + stride_y = strides_attr.list.i[1] + stride_x = strides_attr.list.i[2] + return stride_x, stride_y + + +def _conv_kernel_size(node, name_to_order_node): + """Computes kernel size given a TF convolution or pooling node. + + Args: + node: Tensorflow node (NodeDef proto). + name_to_order_node: Map from name to {order, node}. Output of + graph_compute_order.get_compute_order(). + + Returns: + kernel_size_x: Kernel size for horizontal direction (integer). + kernel_size_y: Kernel size for vertical direction (integer). + + Raises: + ValueError: If the weight layer node is invalid. + """ + weights_layer_read_name = node.input[1] + if not weights_layer_read_name.endswith("/read"): + raise ValueError( + "Weight layer's name input to conv layer does not end with '/read'") + weights_layer_param_name = weights_layer_read_name[:-5] + weights_node = name_to_order_node[weights_layer_param_name].node + if weights_node.op != "VariableV2": + raise ValueError("Weight layer is not of type VariableV2") + shape = weights_node.attr["shape"] + logging.vlog(4, "weight shape = %s", shape) + kernel_size_y = shape.shape.dim[0].size + kernel_size_x = shape.shape.dim[1].size + return kernel_size_x, kernel_size_y + + +def _padding_size_conv_pool(node, kernel_size, stride): + """Computes padding size given a TF convolution or pooling node. + + Args: + node: Tensorflow node (NodeDef proto). + kernel_size: Kernel size of node (integer). + stride: Stride size of node (integer). + + Returns: + padding: Padding size (integer). + + Raises: + ValueError: If padding is invalid. + """ + # In this case, we need to carefully consider the different TF padding modes. + # The padding depends on kernel size, and may depend on input size. If it + # depends on input size, we raise an exception. + padding_attr = node.attr["padding"] + logging.vlog(4, "padding_attr = %s", padding_attr) + if padding_attr.s in _VALID_PADDING: + padding = 0 + elif padding_attr.s in _SAME_PADDING: + if kernel_size == 1: + padding = 0 + elif stride == 1: + padding = int(math.floor((float(kernel_size) - 1) / 2)) + elif stride == 2 and kernel_size % 2 == 0: + padding = int(math.floor((float(kernel_size) - 1) / 2)) + else: + padding = None + logging.warning( + "Padding depends on input size, which means that the effective " + "padding may be different depending on the input image " + "dimensionality. In this case, alignment check will be skipped.") + else: + raise ValueError("Invalid padding operation %s" % padding_attr.s) + return padding + + +def _pool_kernel_size(node): + """Computes kernel size given a TF pooling node. + + Args: + node: Tensorflow node (NodeDef proto). + + Returns: + kernel_size_x: Kernel size for horizontal direction (integer). + kernel_size_y: Kernel size for vertical direction (integer). + + Raises: + ValueError: If pooling is invalid. + """ + ksize = node.attr["ksize"] + kernel_size_y = ksize.list.i[1] + kernel_size_x = ksize.list.i[2] + if ksize.list.i[0] != 1: + raise ValueError("pool ksize for first dim is not 1") + if ksize.list.i[3] != 1: + raise ValueError("pool ksize for last dim is not 1") + return kernel_size_x, kernel_size_y + + +def _padding_size_pad_layer(node, name_to_order_node): + """Computes padding size given a TF padding node. + + Args: + node: Tensorflow node (NodeDef proto). + name_to_order_node: Map from name to {order, node}. Output of + graph_compute_order.get_compute_order(). + + Returns: + padding_x: Padding size for horizontal direction (integer). + padding_y: Padding size for vertical direction (integer). + + Raises: + ValueError: If padding layer is invalid. + """ + paddings_layer_name = node.input[1] + if not paddings_layer_name.endswith("/paddings"): + raise ValueError("Padding layer name does not end with '/paddings'") + paddings_node = name_to_order_node[paddings_layer_name].node + if paddings_node.op != "Const": + raise ValueError("Padding op is not Const") + value = paddings_node.attr["value"] + t = make_ndarray(value.tensor) + padding_y = t[1][0] + padding_x = t[2][0] + if t[0][0] != 0: + raise ValueError("padding is not zero for first tensor dim") + if t[3][0] != 0: + raise ValueError("padding is not zero for last tensor dim") + return padding_x, padding_y + + +def _get_layer_params(node, name_to_order_node): + """Gets layer parameters relevant for RF computation. + + Currently, only these nodes are supported: + - Conv2D + - DepthwiseConv2dNative + - Pad + - MaxPool + - AvgPool + - all nodes listed in _UNCHANGED_RF_LAYER_OPS + + Args: + node: Tensorflow node (NodeDef proto). + name_to_order_node: Map from name to {order, node}. Output of + graph_compute_order.get_compute_order(). + + Returns: + kernel_size_x: Kernel size for horizontal direction (integer). + kernel_size_y: Kernel size for vertical direction (integer). + stride_x: Stride size for horizontal direction (integer). + stride_y: Stride size for vertical direction (integer). + padding_x: Padding size for horizontal direction (integer). + padding_y: Padding size for vertical direction (integer). + + Raises: + ValueError: If layer op is unknown. + """ + logging.vlog(3, "node.op = %s", node.op) + logging.vlog(4, "node = %s", node) + if node.op == "Conv2D" or node.op == "DepthwiseConv2dNative": + stride_x, stride_y = _stride_size(node) + kernel_size_x, kernel_size_y = _conv_kernel_size(node, name_to_order_node) + # Compute the padding for this node separately for each direction. + padding_x = _padding_size_conv_pool(node, kernel_size_x, stride_x) + padding_y = _padding_size_conv_pool(node, kernel_size_y, stride_y) + elif node.op == "Pad": + # Kernel and stride are simply 1 in this case. + kernel_size_x = 1 + kernel_size_y = 1 + stride_x = 1 + stride_y = 1 + padding_x, padding_y = _padding_size_pad_layer(node, name_to_order_node) + elif node.op == "MaxPool" or node.op == "AvgPool": + stride_x, stride_y = _stride_size(node) + kernel_size_x, kernel_size_y = _pool_kernel_size(node) + # Compute the padding for this node separately for each direction. + padding_x = _padding_size_conv_pool(node, kernel_size_x, stride_x) + padding_y = _padding_size_conv_pool(node, kernel_size_y, stride_y) + elif node.op in _UNCHANGED_RF_LAYER_OPS: + # These nodes do not modify the RF parameters. + kernel_size_x = 1 + kernel_size_y = 1 + stride_x = 1 + stride_y = 1 + padding_x = 0 + padding_y = 0 + else: + raise ValueError("Unknown layer op: %s" % node.op) + return kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y + + +def _reverse_sort_by_order(name_to_order_node): + """Sorts map of name_to_order_node nodes in reverse order. + + The output is such that the nodes in name_to_order_node are sorted in + descending order of the "order" field. + + Args: + name_to_order_node: Map from name to {order, node}. Output of + graph_compute_order.get_compute_order(). + + Returns: + sorted_name_to_order_node: Sorted version of the input, in descending order. + """ + return sorted(name_to_order_node.items(), key=lambda x: -x[1].order) + + +def _get_rf_size_node_input(stride, kernel_size, rf_size_output): + """Computes RF size at the input of a given layer. + + Args: + stride: Stride of given layer (integer). + kernel_size: Kernel size of given layer (integer). + rf_size_output: RF size at output of given layer (integer). + + Returns: + rf_size_input: RF size at input of given layer (integer). + """ + return stride * rf_size_output + kernel_size - stride + + +def _get_effective_stride_node_input(stride, effective_stride_output): + """Computes effective stride at the input of a given layer. + + Args: + stride: Stride of given layer (integer). + effective_stride_output: Effective stride at output of given layer + (integer). + + Returns: + effective_stride_input: Effective stride at input of given layer + (integer). + """ + return stride * effective_stride_output + + +def _get_effective_padding_node_input(stride, padding, + effective_padding_output): + """Computes effective padding at the input of a given layer. + + Args: + stride: Stride of given layer (integer). + padding: Padding of given layer (integer). + effective_padding_output: Effective padding at output of given layer + (integer). + + Returns: + effective_padding_input: Effective padding at input of given layer + (integer). + """ + return stride * effective_padding_output + padding + + +def compute_receptive_field_from_graph_def(graph_def, input_node, output_node): + """Computes receptive field (RF) parameters from a GraphDef object. + + Args: + graph_def: GraphDef object. + input_node: Name of the input node from graph. + output_node: Name of the output node from graph. + + Returns: + rf_size_x: Receptive field size of network in the horizontal direction, with + respect to specified input and output. + rf_size_y: Receptive field size of network in the vertical direction, with + respect to specified input and output. + effective_stride_x: Effective stride of network in the horizontal direction, + with respect to specified input and output. + effective_stride_y: Effective stride of network in the vertical direction, + with respect to specified input and output. + effective_padding_x: Effective padding of network in the horizontal + direction, with respect to specified input and output. + effective_padding_y: Effective padding of network in the vertical + direction, with respect to specified input and output. + + Raises: + ValueError: If network is not aligned or if either input or output nodes + cannot be found. For network criterion alignment, see + photos/vision/features/delf/g3doc/rf_computation.md + """ + # Computes order of computation for a given graph. + name_to_order_node = graph_compute_order.get_compute_order( + graph_def=graph_def) + + # Sort in reverse topological order. + order = _reverse_sort_by_order(name_to_order_node) + + # Dictionaries to keep track of receptive field, effective stride and + # effective padding of different nodes. + rf_sizes_x = {} + rf_sizes_y = {} + effective_strides_x = {} + effective_strides_y = {} + effective_paddings_x = {} + effective_paddings_y = {} + + # Initialize dicts for output_node. + rf_sizes_x[output_node] = 1 + rf_sizes_y[output_node] = 1 + effective_strides_x[output_node] = 1 + effective_strides_y[output_node] = 1 + effective_paddings_x[output_node] = 0 + effective_paddings_y[output_node] = 0 + + # Flag to denote if we found output node yet. If we have not, we skip nodes + # until the output node is found. + found_output_node = False + + # Flag to denote if padding is undefined. This happens when SAME padding mode + # is used in conjunction with stride and kernel sizes which make it such that + # the padding to be applied would depend on the input size. In this case, + # alignment checks are skipped, and the effective padding is None. + undefined_padding = False + + for _, (o, node) in order: + if node: + logging.vlog(3, "%10d %-100s %-20s" % (o, node.name[:90], node.op)) + else: + continue + + # When we find input node, we can stop. + if node.name == input_node: + break + + # Loop until we find the output node. All nodes before finding the output + # one are irrelevant, so they can be skipped. + if not found_output_node: + if node.name == output_node: + found_output_node = True + + if found_output_node: + if node.name not in rf_sizes_x: + assert node.name not in rf_sizes_y, ("Node %s is in rf_sizes_y, but " + "not in rf_sizes_x" % node.name) + # In this case, node is not relevant since it's not part of the + # computation we're interested in. + logging.vlog(3, "Irrelevant node %s, skipping it...", node.name) + continue + + # Get params for this layer. + kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y = ( + _get_layer_params(node, name_to_order_node)) + logging.vlog(3, "kernel_size_x = %s, kernel_size_y = %s, " + "stride_x = %s, stride_y = %s, " + "padding_x = %s, padding_y = %s" % + (kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, + padding_y)) + if padding_x is None or padding_y is None: + undefined_padding = True + + # Get parameters at input of this layer which may or may not be propagated + # to the input layers. + rf_size_input_x = _get_rf_size_node_input(stride_x, kernel_size_x, + rf_sizes_x[node.name]) + rf_size_input_y = _get_rf_size_node_input(stride_y, kernel_size_y, + rf_sizes_y[node.name]) + effective_stride_input_x = _get_effective_stride_node_input( + stride_x, effective_strides_x[node.name]) + effective_stride_input_y = _get_effective_stride_node_input( + stride_y, effective_strides_y[node.name]) + if not undefined_padding: + effective_padding_input_x = _get_effective_padding_node_input( + stride_x, padding_x, effective_paddings_x[node.name]) + effective_padding_input_y = _get_effective_padding_node_input( + stride_y, padding_y, effective_paddings_y[node.name]) + else: + effective_padding_input_x = None + effective_padding_input_y = None + + # Loop over this node's inputs and potentially propagate information down. + for inp_name in node.input: + logging.vlog(4, "inp_name = %s", inp_name) + inp_node = name_to_order_node[inp_name].node + logging.vlog(4, "inp_node = \n%s", inp_node) + if inp_node.name in rf_sizes_x: + assert inp_node.name in rf_sizes_y, ( + "Node %s is in rf_sizes_x, but " + "not in rf_sizes_y" % inp_node.name) + # This node was already discovered through a previous path, so we need + # to make sure that graph is aligned. This alignment check is skipped + # if the padding is not defined, since in this case alignment cannot + # be checked. + if not undefined_padding: + if effective_strides_x[inp_node.name] != effective_stride_input_x: + raise ValueError( + "Graph is not aligned since effective stride from different " + "paths is different in horizontal direction") + if effective_strides_y[inp_node.name] != effective_stride_input_y: + raise ValueError( + "Graph is not aligned since effective stride from different " + "paths is different in vertical direction") + if (rf_sizes_x[inp_node.name] - 1 + ) / 2 - effective_paddings_x[inp_node.name] != ( + rf_size_input_x - 1) / 2 - effective_padding_input_x: + raise ValueError( + "Graph is not aligned since center shift from different " + "paths is different in horizontal direction") + if (rf_sizes_y[inp_node.name] - 1 + ) / 2 - effective_paddings_y[inp_node.name] != ( + rf_size_input_y - 1) / 2 - effective_padding_input_y: + raise ValueError( + "Graph is not aligned since center shift from different " + "paths is different in vertical direction") + # Keep track of path with largest RF, for both directions. + if rf_sizes_x[inp_node.name] < rf_size_input_x: + rf_sizes_x[inp_node.name] = rf_size_input_x + effective_strides_x[inp_node.name] = effective_stride_input_x + effective_paddings_x[inp_node.name] = effective_padding_input_x + if rf_sizes_y[inp_node.name] < rf_size_input_y: + rf_sizes_y[inp_node.name] = rf_size_input_y + effective_strides_y[inp_node.name] = effective_stride_input_y + effective_paddings_y[inp_node.name] = effective_padding_input_y + else: + assert inp_node.name not in rf_sizes_y, ( + "Node %s is in rf_sizes_y, but " + "not in rf_sizes_x" % inp_node.name) + # In this case, it is the first time we encounter this node. So we + # propagate the RF parameters. + rf_sizes_x[inp_node.name] = rf_size_input_x + rf_sizes_y[inp_node.name] = rf_size_input_y + effective_strides_x[inp_node.name] = effective_stride_input_x + effective_strides_y[inp_node.name] = effective_stride_input_y + effective_paddings_x[inp_node.name] = effective_padding_input_x + effective_paddings_y[inp_node.name] = effective_padding_input_y + + if not found_output_node: + raise ValueError("Output node was not found") + if input_node not in rf_sizes_x: + raise ValueError("Input node was not found") + return (rf_sizes_x[input_node], rf_sizes_y[input_node], + effective_strides_x[input_node], effective_strides_y[input_node], + effective_paddings_x[input_node], effective_paddings_y[input_node]) diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2771389250b1518f33ebadf3f1cfd23e653dab93 --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py @@ -0,0 +1,225 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for receptive_fields module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import slim +from tensorflow.contrib.receptive_field.python.util import receptive_field +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +def create_test_network_1(): + """Aligned network for test. + + The graph corresponds to the example from the second figure in + go/cnn-rf-computation#arbitrary-computation-graphs + + Returns: + g: Tensorflow graph object (Graph proto). + """ + g = ops.Graph() + with g.as_default(): + # An 8x8 test image. + x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image') + # Left branch. + l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID') + # Right branch. + l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]]) + l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID') + l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID') + # Addition. + nn.relu(l1 + l3, name='output') + return g + + +def create_test_network_2(): + """Aligned network for test. + + The graph corresponds to a variation to the example from the second figure in + go/cnn-rf-computation#arbitrary-computation-graphs. Layers 2 and 3 are changed + to max-pooling operations. Since the functionality is the same as convolution, + the network is aligned and the receptive field size is the same as from the + network created using create_test_network_1(). + + Returns: + g: Tensorflow graph object (Graph proto). + """ + g = ops.Graph() + with g.as_default(): + # An 8x8 test image. + x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image') + # Left branch. + l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID') + # Right branch. + l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]]) + l2 = slim.max_pool2d(l2_pad, [3, 3], stride=2, scope='L2', padding='VALID') + l3 = slim.max_pool2d(l2, [1, 1], stride=2, scope='L3', padding='VALID') + # Addition. + nn.relu(l1 + l3, name='output') + return g + + +def create_test_network_3(): + """Misaligned network for test. + + The graph corresponds to the example from the first figure in + go/cnn-rf-computation#arbitrary-computation-graphs + + Returns: + g: Tensorflow graph object (Graph proto). + """ + g = ops.Graph() + with g.as_default(): + # An 8x8 test image. + x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image') + # Left branch. + l1_pad = array_ops.pad(x, [[0, 0], [2, 1], [2, 1], [0, 0]]) + l1 = slim.conv2d(l1_pad, 1, [5, 5], stride=2, scope='L1', padding='VALID') + # Right branch. + l2 = slim.conv2d(x, 1, [3, 3], stride=1, scope='L2', padding='VALID') + l3 = slim.conv2d(l2, 1, [3, 3], stride=1, scope='L3', padding='VALID') + # Addition. + nn.relu(l1 + l3, name='output') + return g + + +def create_test_network_4(): + """Misaligned network for test. + + The graph corresponds to a variation from the example from the second figure + in go/cnn-rf-computation#arbitrary-computation-graphs. Layer 2 uses 'SAME' + padding, which makes its padding dependent on the input image dimensionality. + In this case, the effective padding will be undetermined, and the utility is + not able to check the network alignment. + + Returns: + g: Tensorflow graph object (Graph proto). + """ + g = ops.Graph() + with g.as_default(): + # An 8x8 test image. + x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image') + # Left branch. + l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID') + # Right branch. + l2 = slim.conv2d(x, 1, [3, 3], stride=2, scope='L2', padding='SAME') + l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID') + # Addition. + nn.relu(l1 + l3, name='output') + return g + + +def create_test_network_5(): + """Single-path network for testing non-square kernels. + + The graph is similar to the right branch of the graph from + create_test_network_1(), except that the kernel sizes are changed to be + non-square. + + Returns: + g: Tensorflow graph object (Graph proto). + """ + g = ops.Graph() + with g.as_default(): + # An 8x8 test image. + x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image') + # Two convolutional layers, where the first one has non-square kernel. + l1 = slim.conv2d(x, 1, [3, 5], stride=2, scope='L1', padding='VALID') + l2 = slim.conv2d(l1, 1, [3, 1], stride=2, scope='L2', padding='VALID') + # ReLU. + nn.relu(l2, name='output') + return g + + +class RfUtilsTest(test.TestCase): + + def testComputeRFFromGraphDefAligned(self): + graph_def = create_test_network_1().as_graph_def() + input_node = 'input_image' + output_node = 'output' + (receptive_field_x, receptive_field_y, effective_stride_x, + effective_stride_y, effective_padding_x, effective_padding_y) = ( + receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node)) + self.assertEqual(receptive_field_x, 3) + self.assertEqual(receptive_field_y, 3) + self.assertEqual(effective_stride_x, 4) + self.assertEqual(effective_stride_y, 4) + self.assertEqual(effective_padding_x, 1) + self.assertEqual(effective_padding_y, 1) + + def testComputeRFFromGraphDefAligned2(self): + graph_def = create_test_network_2().as_graph_def() + input_node = 'input_image' + output_node = 'output' + (receptive_field_x, receptive_field_y, effective_stride_x, + effective_stride_y, effective_padding_x, effective_padding_y) = ( + receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node)) + self.assertEqual(receptive_field_x, 3) + self.assertEqual(receptive_field_y, 3) + self.assertEqual(effective_stride_x, 4) + self.assertEqual(effective_stride_y, 4) + self.assertEqual(effective_padding_x, 1) + self.assertEqual(effective_padding_y, 1) + + def testComputeRFFromGraphDefUnaligned(self): + graph_def = create_test_network_3().as_graph_def() + input_node = 'input_image' + output_node = 'output' + with self.assertRaises(ValueError): + receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node) + + def testComputeRFFromGraphDefUnaligned2(self): + graph_def = create_test_network_4().as_graph_def() + input_node = 'input_image' + output_node = 'output' + (receptive_field_x, receptive_field_y, effective_stride_x, + effective_stride_y, effective_padding_x, effective_padding_y) = ( + receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node)) + self.assertEqual(receptive_field_x, 3) + self.assertEqual(receptive_field_y, 3) + self.assertEqual(effective_stride_x, 4) + self.assertEqual(effective_stride_y, 4) + self.assertEqual(effective_padding_x, None) + self.assertEqual(effective_padding_y, None) + + def testComputeRFFromGraphDefNonSquareRF(self): + graph_def = create_test_network_5().as_graph_def() + input_node = 'input_image' + output_node = 'output' + (receptive_field_x, receptive_field_y, effective_stride_x, + effective_stride_y, effective_padding_x, effective_padding_y) = ( + receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node)) + self.assertEqual(receptive_field_x, 5) + self.assertEqual(receptive_field_y, 7) + self.assertEqual(effective_stride_x, 4) + self.assertEqual(effective_stride_y, 4) + self.assertEqual(effective_padding_x, 0) + self.assertEqual(effective_padding_y, 0) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py index 9aa1e0562844cf8fed0eadb038599f02d94d0cd6..6253f96315b74dda59e1e0460e62481fe1c590a1 100644 --- a/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py +++ b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py @@ -163,7 +163,7 @@ class ResamplerTest(test.TestCase): data_channels = 3 warp_width = 2 warp_height = 6 - batch_size = 10 + batch_size = 3 warp = _make_warp(batch_size, warp_height, warp_width, dtype.as_numpy_dtype) data_shape = (batch_size, data_height, data_width, data_channels) diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 48c2c5a724b2c0f8a0ad6d8f38f672258d06dc48..f591f7c84e50660ccddbe13e31a32f6bc273c460 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -342,7 +342,8 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): num_units, forget_bias=1.0, clip_cell=True, - use_peephole=False): + use_peephole=False, + reuse=None): """Initialize the basic LSTM cell. Args: @@ -351,10 +352,14 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): clip_cell: boolean, whether to apply cell clipping. See `_lstm_block_cell()` for details. use_peephole: Whether to use peephole connections or not. + reuse: (optional) boolean describing whether to reuse variables in an + existing scope. If not `True`, and the existing scope already has the + given variables, an error is raised. When restoring from CudnnLSTM-trained checkpoints, must use CudnnCompatibleLSTMBlockCell instead. """ + super(LSTMBlockCell, self).__init__(_reuse=reuse) self._num_units = num_units self._forget_bias = forget_bias self._use_peephole = use_peephole diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD index 596c4f351ce252fb99d152fc018c9449cb051209..ebb7a218562dba74d8fffa3c438e388d7d79951f 100644 --- a/tensorflow/contrib/session_bundle/BUILD +++ b/tensorflow/contrib/session_bundle/BUILD @@ -234,7 +234,7 @@ cc_library( cc_test( name = "session_bundle_test", - size = "small", + size = "medium", srcs = ["session_bundle_test.cc"], data = [":session_bundle_half_plus_two"], # Link in all registered kernels. diff --git a/tensorflow/contrib/session_bundle/session_bundle_test.cc b/tensorflow/contrib/session_bundle/session_bundle_test.cc index eb36d79e0f4ae183ed069e7fe26c4133855b96f4..6d997bac9ee8e0fe242455686cc00a016d9bd768 100644 --- a/tensorflow/contrib/session_bundle/session_bundle_test.cc +++ b/tensorflow/contrib/session_bundle/session_bundle_test.cc @@ -171,7 +171,8 @@ void BasicTest(const string& export_path) { // SessionBundles. Concurrent with adding this test, we had a leak where the // TensorFlow Session was not being closed, which leaked memory. // TODO(b/31711147): Increase the SessionBundle ResourceLeakTest iterations and -// move outside of the test suite. +// move outside of the test suite; decrease test size back to small at the same +// time. TEST(LoadSessionBundleFromPath, ResourceLeakTest) { const string export_path = test_util::TestSrcDirPath(kExportPath); for (int i = 0; i < 100; i++) { diff --git a/tensorflow/contrib/slim/BUILD b/tensorflow/contrib/slim/BUILD index a76e037c5f440d8ca16bea2d5a153f39488b5cfb..8f920c9b0310a3e7e826915f86d2dba53b719086 100644 --- a/tensorflow/contrib/slim/BUILD +++ b/tensorflow/contrib/slim/BUILD @@ -88,6 +88,8 @@ py_test( "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variables", + "//tensorflow/python/debug:debug_data", + "//tensorflow/python/debug:dumping_wrapper", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py index c7614fd426202800e3c369a5e6fb0b79e43bc45c..5ee014a1f11a6b0d11857d209f27b134b737275d 100644 --- a/tensorflow/contrib/slim/python/slim/learning.py +++ b/tensorflow/contrib/slim/python/slim/learning.py @@ -551,6 +551,7 @@ def train(train_op, save_interval_secs=600, sync_optimizer=None, session_config=None, + session_wrapper=None, trace_every_n_steps=None): """Runs a training loop using a TensorFlow supervisor. @@ -607,6 +608,10 @@ def train(train_op, If left as `None`, gradient updates will be asynchronous. session_config: An instance of `tf.ConfigProto` that will be used to configure the `Session`. If left as `None`, the default will be used. + session_wrapper: A function that takes a `tf.Session` object as the only + argument and returns a wrapped session object that has the same methods + that the original object has, or `None`. Iff not `None`, the wrapped + object will be used for training. trace_every_n_steps: produce and save a `Timeline` in Chrome trace format and add it to the summaries every `trace_every_n_steps`. If None, no trace information will be produced or saved. @@ -736,6 +741,10 @@ def train(train_op, with sv.managed_session( master, start_standard_services=False, config=session_config) as sess: logging.info('Starting Session.') + if session_wrapper is not None: + logging.info( + 'Wrapping session with wrapper function: %s', session_wrapper) + sess = session_wrapper(sess) if is_chief: if logdir: sv.start_standard_services(sess) diff --git a/tensorflow/contrib/slim/python/slim/learning_test.py b/tensorflow/contrib/slim/python/slim/learning_test.py index 3ee2434c02d5d69835ed44576bff367c917b8964..4e816f9b11be2986d042f336bdc320ff47d8cc49 100644 --- a/tensorflow/contrib/slim/python/slim/learning_test.py +++ b/tensorflow/contrib/slim/python/slim/learning_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import glob import os import tempfile @@ -30,6 +31,8 @@ from tensorflow.contrib.losses.python.losses import loss_ops from tensorflow.contrib.slim.python.slim import learning from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.debug.lib import debug_data +from tensorflow.python.debug.wrappers import dumping_wrapper as dumping_wrapper_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -43,6 +46,7 @@ from tensorflow.python.training import gradient_descent from tensorflow.python.training import input as input_lib from tensorflow.python.training import saver as saver_lib + class ClipGradientNormsTest(test.TestCase): def clip_values(self, arr): @@ -489,6 +493,43 @@ class TrainTest(test.TestCase): self.assertIsNotNone(loss) self.assertLess(loss, .015) + def testTrainWithSessionWrapper(self): + """Test that slim.learning.train can take `session_wrapper` args. + + One of the applications of `session_wrapper` is the wrappers of TensorFlow + Debugger (tfdbg), which intercept methods calls to `tf.Session` (e.g., run) + to achieve debugging. `DumpingDebugWrapperSession` is used here for testing + purpose. + """ + dump_root = tempfile.mkdtemp() + def dumping_wrapper(sess): # pylint: disable=invalid-name + return dumping_wrapper_lib.DumpingDebugWrapperSession(sess, dump_root) + + with ops.Graph().as_default(): + random_seed.set_random_seed(0) + tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) + tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) + + tf_predictions = LogisticClassifier(tf_inputs) + loss_ops.log_loss(tf_predictions, tf_labels) + total_loss = loss_ops.get_total_loss() + + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) + + train_op = learning.create_train_op(total_loss, optimizer) + + loss = learning.train( + train_op, + None, + number_of_steps=1, + session_wrapper=dumping_wrapper) + self.assertIsNotNone(loss) + + run_root = glob.glob(os.path.join(dump_root, 'run_*'))[-1] + dump = debug_data.DebugDumpDir(run_root) + self.assertAllEqual( + 0, dump.get_tensors('global_step', 0, 'DebugIdentity')[0]) + def testTrainWithTrace(self): logdir = os.path.join( tempfile.mkdtemp(prefix=self.get_temp_dir()), 'tmp_logs') diff --git a/tensorflow/contrib/stateless/BUILD b/tensorflow/contrib/stateless/BUILD index 598e6513aebe54224409fbdf0a6077c03ee3d2d1..865fb72a55b9a83b8354a100af843abaefc79980 100644 --- a/tensorflow/contrib/stateless/BUILD +++ b/tensorflow/contrib/stateless/BUILD @@ -21,6 +21,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":stateless_random_ops", + "//tensorflow/python:framework", "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/stateless/__init__.py b/tensorflow/contrib/stateless/__init__.py index 82e5d36ce44cbc9dc1133867f0396f6c2f0f9855..ca937546f50df46b7e5b1144dcbdc380cb04ca9b 100644 --- a/tensorflow/contrib/stateless/__init__.py +++ b/tensorflow/contrib/stateless/__init__.py @@ -34,5 +34,11 @@ from __future__ import print_function # pylint: disable=wildcard-import from tensorflow.contrib.stateless.gen_stateless_random_ops import * +from tensorflow.python.framework import ops from tensorflow.python.util.all_util import remove_undocumented + +ops.NotDifferentiable("StatelessRandomNormal") +ops.NotDifferentiable("StatelessRandomUniform") +ops.NotDifferentiable("StatelessTruncatedNormal") + remove_undocumented(__name__) diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..527deab86a6ba1e5ccfe6aceb6d73d20aee3ebc2 --- /dev/null +++ b/tensorflow/contrib/summary/BUILD @@ -0,0 +1,62 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files([ + "LICENSE", +]) + +load( + "//tensorflow:tensorflow.bzl", + "py_test", + "tf_gen_op_wrapper_py", +) + +tf_gen_op_wrapper_py( + name = "gen_summary_ops", + out = "gen_summary_ops.py", + deps = ["//tensorflow/core:summary_ops_op_lib"], +) + +py_test( + name = "summary_ops_test", + srcs = ["summary_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":summary_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:lib", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python/eager:function", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "summary_ops", + srcs = ["summary_ops.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + ":gen_summary_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:summary_op_util", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ceaf83b70a76e8a1195b4c177f4764dc7ab792f2 --- /dev/null +++ b/tensorflow/contrib/summary/summary_ops.py @@ -0,0 +1,171 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Operations to emit summaries.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.summary import gen_summary_ops +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import summary_op_util +from tensorflow.python.training import training_util + + +# Name for a collection which is expected to have at most a single boolean +# Tensor. If this tensor is True the summary ops will record summaries. +_SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries" + + +def should_record_summaries(): + """Returns boolean Tensor which is true if summaries should be recorded.""" + should_record_collection = ops.get_collection(_SHOULD_RECORD_SUMMARIES_NAME) + if not should_record_collection: + return constant_op.constant(False) + if len(should_record_collection) != 1: + raise ValueError( + "More than one tensor specified for whether summaries " + "should be recorded: %s" % should_record_collection) + return should_record_collection[0] + + +# TODO(apassos) consider how to handle local step here. +def record_summaries_every_n_global_steps(n): + """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" + collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) + collection_ref[:] = [training_util.get_global_step() % n == 0] + + +def always_record_summaries(): + """Sets the should_record_summaries Tensor to always true.""" + collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) + collection_ref[:] = [constant_op.constant(True)] + + +def never_record_summaries(): + """Sets the should_record_summaries Tensor to always false.""" + collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) + collection_ref[:] = [constant_op.constant(False)] + + +def create_summary_file_writer(logdir, + max_queue=None, + flush_secs=None, + filename_suffix=None, + name=None): + """Creates a summary file writer in the current context.""" + if max_queue is None: + max_queue = constant_op.constant(10) + if flush_secs is None: + flush_secs = constant_op.constant(120) + if filename_suffix is None: + filename_suffix = constant_op.constant("") + resource = gen_summary_ops.summary_writer(shared_name=name) + gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue, + flush_secs, filename_suffix) + context.context().summary_writer_resource = resource + + +def _nothing(): + """Convenient else branch for when summaries do not record.""" + return False + + +def summary_writer_function(name, tensor, function, family=None): + """Helper function to write summaries. + + Args: + name: name of the summary + tensor: main tensor to form the summary + function: function taking a tag and a scope which writes the summary + family: optional, the summary's family + + Returns: + The result of writing the summary. + """ + def record(): + with summary_op_util.summary_scope( + name, family, values=[tensor]) as (tag, scope): + function(tag, scope) + return True + + return control_flow_ops.cond(should_record_summaries(), record, _nothing) + + +def generic(name, tensor, metadata, family=None): + """Writes a tensor summary if possible.""" + + def function(tag, scope): + gen_summary_ops.write_summary(context.context().summary_writer_resource, + training_util.get_global_step(), tensor, + tag, metadata, name=scope) + return summary_writer_function(name, tensor, function, family=family) + + +def scalar(name, tensor, family=None): + """Writes a scalar summary if possible.""" + + def function(tag, scope): + gen_summary_ops.write_scalar_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), tag, tensor, name=scope) + + return summary_writer_function(name, tensor, function, family=family) + + +def histogram(name, tensor, family=None): + """Writes a histogram summary if possible.""" + + def function(tag, scope): + gen_summary_ops.write_histogram_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), tag, tensor, name=scope) + + return summary_writer_function(name, tensor, function, family=family) + + +def image(name, tensor, bad_color=None, max_images=3, family=None): + """Writes an image summary if possible.""" + + def function(tag, scope): + if bad_color is None: + bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) + gen_summary_ops.write_image_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), tag, tensor, bad_color_, max_images, + name=scope) + + return summary_writer_function(name, tensor, function, family=family) + + +def audio(name, tensor, sample_rate, max_outputs, family=None): + """Writes an audio summary if possible.""" + + def function(tag, scope): + gen_summary_ops.write_audio_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), + tag, + tensor, + sample_rate=sample_rate, + max_outputs=max_outputs, + name=scope) + + return summary_writer_function(name, tensor, function, family=family) diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4b1f60ce4ebf8fadd99e01fa92a23e336f7badfd --- /dev/null +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -0,0 +1,77 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from tensorflow.contrib.summary import summary_ops +from tensorflow.core.util import event_pb2 +from tensorflow.python.eager import function +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.lib.io import tf_record +from tensorflow.python.platform import gfile +from tensorflow.python.training import training_util + + +class TargetTest(test_util.TensorFlowTestCase): + + def testShouldRecordSummary(self): + self.assertFalse(summary_ops.should_record_summaries().numpy()) + summary_ops.always_record_summaries() + self.assertTrue(summary_ops.should_record_summaries().numpy()) + + def testSummaryOps(self): + training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0') + summary_ops.always_record_summaries() + summary_ops.generic('tensor', 1, '') + summary_ops.scalar('scalar', 2.0) + summary_ops.histogram('histogram', [1.0]) + summary_ops.image('image', [[[[1.0]]]]) + summary_ops.audio('audio', [[1.0]], 1.0, 1) + # The working condition of the ops is tested in the C++ test so we just + # test here that we're calling them correctly. + self.assertTrue(gfile.Exists(logdir)) + + def testDefunSummarys(self): + training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t1') + summary_ops.always_record_summaries() + + @function.defun + def write(): + summary_ops.scalar('scalar', 2.0) + + write() + + self.assertTrue(gfile.Exists(logdir)) + files = gfile.ListDirectory(logdir) + self.assertEqual(len(files), 1) + records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) + self.assertEqual(len(records), 2) + event = event_pb2.Event() + event.ParseFromString(records[1]) + self.assertEqual(event.summary.value[0].simple_value, 2.0) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc index ccc412600c760e6d453d896121e99916129c196c..e5d1beae7f99cb1ef7a449aad81158156312f928 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc @@ -96,7 +96,12 @@ void SplitCollectionOperator::AddExample( } bool SplitCollectionOperator::IsInitialized(int32 node_id) const { - return stats_.at(node_id)->IsInitialized(); + auto it = stats_.find(node_id); + if (it == stats_.end()) { + LOG(WARNING) << "IsInitialized called with unknown node_id = " << node_id; + return false; + } + return it->second->IsInitialized(); } void SplitCollectionOperator::CreateAndInitializeCandidateWithExample( diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace.py b/tensorflow/contrib/tensorboard/plugins/trace/trace.py index 57f95dfce7269be54f69db84dc76f14719344869..07e5316b8b3f3bc8229eca65487796a41eeb0af0 100644 --- a/tensorflow/contrib/tensorboard/plugins/trace/trace.py +++ b/tensorflow/contrib/tensorboard/plugins/trace/trace.py @@ -38,7 +38,7 @@ TOKENS = LEFT_TOKENS + RIGHT_TOKENS def store_trace_info(output_file_path, - graph=ops.get_default_graph(), + graph=None, ignore_regex_fpaths=None): """Collects and stores trace information for a TensorFlow model. @@ -51,6 +51,8 @@ def store_trace_info(output_file_path, in this list will be ignored. Defaults to patterns that match the core tensorflow python library. """ + graph = graph or ops.get_default_graph() + if not ignore_regex_fpaths: ignore_regex_fpaths = TF_LIB_REGEX_FPATHS diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops.py b/tensorflow/contrib/text/python/ops/skip_gram_ops.py index 410ee517e03d5e9c973cbecebd54246b7c6163e6..7ed45031a3fa405c0f076d35e4c9e10562e23399 100644 --- a/tensorflow/contrib/text/python/ops/skip_gram_ops.py +++ b/tensorflow/contrib/text/python/ops/skip_gram_ops.py @@ -216,6 +216,7 @@ def skip_gram_sample_with_text_vocab(input_tensor, vocab_delimiter=",", vocab_min_count=0, vocab_subsampling=None, + corpus_size=None, min_skips=1, max_skips=5, start=0, @@ -267,6 +268,18 @@ def skip_gram_sample_with_text_vocab(input_tensor, frequently will be randomly down-sampled. Reasonable starting values may be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546 for more details. + corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the + total number of tokens in the corpus (e.g., sum of all the frequency + counts of `vocab_freq_file`). Used with `vocab_subsampling` for + down-sampling frequently occurring tokens. If this is specified, + `vocab_freq_file` and `vocab_subsampling` must also be specified. + If `corpus_size` is needed but not supplied, then it will be calculated + from `vocab_freq_file`. You might want to supply your own value if you + have already eliminated infrequent tokens from your vocabulary files + (where frequency < vocab_min_count) to save memory in the internal token + lookup table. Otherwise, the unused tokens' variables will waste memory. + The user-supplied `corpus_size` value must be greater than or equal to the + sum of all the frequency counts of `vocab_freq_file`. min_skips: `int` or scalar `Tensor` specifying the minimum window size to randomly use for each token. Must be >= 0 and <= `max_skips`. If `min_skips` and `max_skips` are both 0, the only label outputted will be @@ -316,7 +329,7 @@ def skip_gram_sample_with_text_vocab(input_tensor, # Iterates through the vocab file and calculates the number of vocab terms as # well as the total corpus size (by summing the frequency counts of all the # vocab terms). - corpus_size = 0.0 + calculated_corpus_size = 0.0 vocab_size = 0 with gfile.GFile(vocab_freq_file, mode="r") as f: reader = csv.reader(f, delimiter=vocab_delimiter) @@ -334,7 +347,15 @@ def skip_gram_sample_with_text_vocab(input_tensor, format(freq, row)) # Note: tokens whose frequencies are below vocab_min_count will still # contribute to the total corpus size used for vocab subsampling. - corpus_size += freq + calculated_corpus_size += freq + + if not corpus_size: + corpus_size = calculated_corpus_size + elif calculated_corpus_size - corpus_size > 1e-6: + raise ValueError( + "`corpus_size`={} must be greater than or equal to the sum of all the " + "frequency counts ({}) of `vocab_freq_file` ({}).".format( + corpus_size, calculated_corpus_size, vocab_freq_file)) vocab_freq_table = lookup.HashTable( lookup.TextFileInitializer( diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py index d989942f732164baac19ef577e298dcca7a3e076..84e36146d5ac18a239f98a827aeb7676cd38d23f 100644 --- a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py +++ b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py @@ -470,7 +470,7 @@ class SkipGramOpsTest(test.TestCase): self.assertAllEqual(expected_labels, labels.eval()) def _text_vocab_subsample_vocab_helper(self, vocab_freq_file, vocab_min_count, - vocab_freq_dtype): + vocab_freq_dtype, corpus_size=None): # The outputs are non-deterministic, so set random seed to help ensure that # the outputs remain constant for testing. random_seed.set_random_seed(42) @@ -499,6 +499,7 @@ class SkipGramOpsTest(test.TestCase): vocab_freq_dtype=vocab_freq_dtype, vocab_min_count=vocab_min_count, vocab_subsampling=0.05, + corpus_size=corpus_size, min_skips=1, max_skips=1, seed=123) @@ -523,10 +524,27 @@ class SkipGramOpsTest(test.TestCase): # the: 30 # to: 20 # universe: 2 + # + # corpus_size for the above vocab is 40+8+30+20+2 = 100. + text_vocab_freq_file = self._make_text_vocab_freq_file() self._text_vocab_subsample_vocab_helper( - vocab_freq_file=self._make_text_vocab_freq_file(), + vocab_freq_file=text_vocab_freq_file, vocab_min_count=3, vocab_freq_dtype=dtypes.int64) + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=text_vocab_freq_file, + vocab_min_count=3, + vocab_freq_dtype=dtypes.int64, + corpus_size=100) + + # The user-supplied corpus_size should not be less than the sum of all + # the frequency counts of vocab_freq_file, which is 100. + with self.assertRaises(ValueError): + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=text_vocab_freq_file, + vocab_min_count=3, + vocab_freq_dtype=dtypes.int64, + corpus_size=99) def test_skip_gram_sample_with_text_vocab_subsample_vocab_float(self): """Tests skip-gram sampling with text vocab and subsampling with floats.""" @@ -536,10 +554,27 @@ class SkipGramOpsTest(test.TestCase): # the: 0.3 # to: 0.2 # universe: 0.02 + # + # corpus_size for the above vocab is 0.4+0.08+0.3+0.2+0.02 = 1. + text_vocab_float_file = self._make_text_vocab_float_file() self._text_vocab_subsample_vocab_helper( - vocab_freq_file=self._make_text_vocab_float_file(), + vocab_freq_file=text_vocab_float_file, vocab_min_count=0.03, vocab_freq_dtype=dtypes.float32) + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=text_vocab_float_file, + vocab_min_count=0.03, + vocab_freq_dtype=dtypes.float32, + corpus_size=1.0) + + # The user-supplied corpus_size should not be less than the sum of all + # the frequency counts of vocab_freq_file, which is 1. + with self.assertRaises(ValueError): + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=text_vocab_float_file, + vocab_min_count=0.03, + vocab_freq_dtype=dtypes.float32, + corpus_size=0.99) def test_skip_gram_sample_with_text_vocab_errors(self): """Tests various errors raised by skip_gram_sample_with_text_vocab().""" diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index 7d1325e0466ece5e0c6a44b32fb792d6bc472ef0..c952288704ac8c8f7b437e3f97f7bdcc42273fd4 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -39,7 +39,6 @@ py_library( deps = [ ":tpu_lib", ":tpu_py", - "//tensorflow/contrib/learn", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index 0c860ad4d76a33b4c84fd40a999e6d971c77a232..a567d1bbb08377c518e029898e3d843adc3e6350 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -14,20 +14,32 @@ tf_proto_library_cc( visibility = ["//visibility:public"], ) -cc_binary( - name = "capture_tpu_profile", - srcs = ["capture_tpu_profile.cc"], - visibility = ["//tensorflow/contrib/tpu/profiler:__subpackages__"], +cc_library( + name = "dump_tpu_profile", + srcs = ["dump_tpu_profile.cc"], + hdrs = ["dump_tpu_profile.h"], deps = [ ":op_profile_proto_cc", ":tpu_profiler_proto_cc", ":trace_events_proto_cc", ":trace_events_to_json", "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + ], +) + +cc_binary( + name = "capture_tpu_profile", + srcs = ["capture_tpu_profile.cc"], + visibility = ["//tensorflow/contrib/tpu/profiler:__subpackages__"], + deps = [ + ":dump_tpu_profile", + ":tpu_profiler_proto_cc", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", + "//tensorflow/core/platform/cloud:gcs_file_system", "@grpc//:grpc++_unsecure", ], ) diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index a0dc15249f796dc0c74384e023ae7c507df96e93..5b51a72ece848f0efcd5ace57fe0201a86e311a3 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -24,22 +24,12 @@ limitations under the License. #include #include -#include "tensorflow/contrib/tpu/profiler/op_profile.pb.h" +#include "tensorflow/contrib/tpu/profiler/dump_tpu_profile.h" #include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h" -#include "tensorflow/contrib/tpu/profiler/trace_events.pb.h" -#include "tensorflow/contrib/tpu/profiler/trace_events_to_json.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" -#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/io/compression.h" -#include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/util/command_line_flags.h" -#include "tensorflow/core/util/events_writer.h" namespace tensorflow { namespace tpu { @@ -47,16 +37,6 @@ namespace { using ::tensorflow::TPUProfiler; -using ::grpc::ClientContext; -using ::tensorflow::io::JoinPath; -using ::tensorflow::protobuf::util::JsonOptions; -using ::tensorflow::protobuf::util::MessageToJsonString; - -constexpr char kProfilePluginDirectory[] = "plugins/profile/"; -constexpr char kJsonOpProfileFileName[] = "op_profile.json"; -constexpr char kProtoTraceFileName[] = "trace"; -constexpr char kJsonTraceFileName[] = "trace.json.gz"; -constexpr char kGraphRunPrefix[] = "tpu_profiler.hlo_graph."; constexpr uint64 kMaxEvents = 1000000; string GetCurrentTimeStampAsString() { @@ -66,65 +46,13 @@ string GetCurrentTimeStampAsString() { return s; } -Status WriteGzippedDataToFile(const string& filename, const string& data) { - std::unique_ptr file; - TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(filename, &file)); - io::ZlibCompressionOptions options = io::ZlibCompressionOptions::GZIP(); - io::ZlibOutputBuffer buffer(file.get(), options.input_buffer_size, - options.output_buffer_size, options); - TF_RETURN_IF_ERROR(buffer.Init()); - TF_RETURN_IF_ERROR(buffer.Append(data)); - TF_RETURN_IF_ERROR(buffer.Close()); - TF_RETURN_IF_ERROR(file->Close()); - return Status::OK(); -} - -// Dumps profile data to /plugins/profile//. -inline string CreateProfileRunDirectory(const string& logdir, - const string& run) { - string run_dir = JoinPath(logdir, kProfilePluginDirectory, run); - TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(run_dir)); - return run_dir; -} - -void DumpTraceToLogDirectory(StringPiece run_dir, const string& encoded_trace) { - string proto_path = JoinPath(run_dir, kProtoTraceFileName); - TF_CHECK_OK(WriteStringToFile(Env::Default(), proto_path, encoded_trace)); - LOG(INFO) << "Dumped raw-proto trace data to " << proto_path; - - string json_path = JoinPath(run_dir, kJsonTraceFileName); - Trace trace; - trace.ParseFromString(encoded_trace); - std::cout << "Trace contains " << trace.trace_events_size() << " events." - << std::endl; - TF_CHECK_OK(WriteGzippedDataToFile(json_path, TraceEventsToJson(trace))); - std::cout << "Dumped JSON trace data to " << json_path << std::endl; -} - -void DumpOpProfileToLogDirectory(StringPiece run_dir, - const tpu::op_profile::Profile& profile) { - string path = JoinPath(run_dir, kJsonOpProfileFileName); - string json; - JsonOptions options; - options.always_print_primitive_fields = true; - auto status = MessageToJsonString(profile, &json, options); - if (!status.ok()) { - std::cerr << "Failed to convert op profile to json. Skipping... " - << status.error_message() << std::endl; - return; - } - TF_CHECK_OK(WriteStringToFile(Env::Default(), path, json)); - std::cout << "Dumped json op profile data to " << path << std::endl; -} - ProfileResponse Profile(const string& service_addr, int duration_ms) { ProfileRequest request; request.set_duration_ms(duration_ms); request.set_max_events(kMaxEvents); std::cout << "Limiting the number of trace events to " << kMaxEvents << std::endl; - ProfileResponse response; - ClientContext context; + ::grpc::ClientContext context; ::grpc::ChannelArguments channel_args; // TODO(ioeric): use `SetMaxReceiveMessageSize` instead once it's available. channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, @@ -132,20 +60,11 @@ ProfileResponse Profile(const string& service_addr, int duration_ms) { std::unique_ptr stub = TPUProfiler::NewStub(::grpc::CreateCustomChannel( service_addr, ::grpc::InsecureChannelCredentials(), channel_args)); + ProfileResponse response; TF_QCHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response))); return response; } -void DumpGraph(StringPiece logdir, StringPiece run, const string& graph_def) { - // The graph plugin expects the graph in //. - string run_dir = JoinPath(logdir, strings::StrCat(kGraphRunPrefix, run)); - TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(run_dir)); - EventsWriter event_writer(JoinPath(run_dir, "events")); - Event event; - event.set_graph_def(graph_def); - event_writer.WriteEvent(event); -} - } // namespace } // namespace tpu } // namespace tensorflow @@ -176,35 +95,8 @@ int main(int argc, char** argv) { tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms); // Use the current timestamp as the run name. tensorflow::string run = tensorflow::tpu::GetCurrentTimeStampAsString(); - tensorflow::string run_dir = - tensorflow::tpu::CreateProfileRunDirectory(FLAGS_logdir, run); - // Ignore computation_graph for now. - if (response.encoded_trace().empty()) { - std::cout << "No trace event is collected during the " << duration_ms - << "ms interval." << std::endl; - } else { - LOG(INFO) << "Converting trace events to TraceViewer JSON."; - tensorflow::tpu::DumpTraceToLogDirectory(run_dir, response.encoded_trace()); - } - int num_graphs = response.computation_graph_size(); - if (num_graphs > 0) { - // The server might generates multiple graphs for one program; we simply - // pick the first one. - if (num_graphs > 1) { - std::cout << num_graphs - << " TPU program variants observed over the profiling period. " - << "One computation graph will be chosen arbitrarily." - << std::endl; - } - tensorflow::tpu::DumpGraph( - FLAGS_logdir, run, response.computation_graph(0).SerializeAsString()); - } - if (response.has_op_profile() && - (response.op_profile().has_by_program_structure() || - response.op_profile().has_by_category())) { - tensorflow::tpu::DumpOpProfileToLogDirectory(run_dir, - response.op_profile()); - } + TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile( + FLAGS_logdir, run, response, &std::cout)); // Print this at the end so that it's not buried in irrelevant LOG messages. std::cout << "NOTE: using the trace duration " << duration_ms << "ms." << std::endl diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc new file mode 100644 index 0000000000000000000000000000000000000000..7541a5291d123256e7f1d83cb6f6ef72a78ad99d --- /dev/null +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc @@ -0,0 +1,164 @@ +/* Copyright 2017 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tpu/profiler/dump_tpu_profile.h" + +#include +#include +#include + +#include "tensorflow/contrib/tpu/profiler/op_profile.pb.h" +#include "tensorflow/contrib/tpu/profiler/trace_events.pb.h" +#include "tensorflow/contrib/tpu/profiler/trace_events_to_json.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/compression.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" + +namespace tensorflow { +namespace tpu { +namespace { + +using ::tensorflow::io::JoinPath; +using ::tensorflow::protobuf::util::JsonOptions; +using ::tensorflow::protobuf::util::MessageToJsonString; + +constexpr char kProfilePluginDirectory[] = "plugins/profile/"; +constexpr char kJsonOpProfileFileName[] = "op_profile.json"; +constexpr char kProtoTraceFileName[] = "trace"; +constexpr char kJsonTraceFileName[] = "trace.json.gz"; +constexpr char kGraphRunPrefix[] = "tpu_profiler.hlo_graph."; + +Status WriteGzippedDataToFile(const string& filename, const string& data) { + std::unique_ptr file; + TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(filename, &file)); + io::ZlibCompressionOptions options = io::ZlibCompressionOptions::GZIP(); + io::ZlibOutputBuffer buffer(file.get(), options.input_buffer_size, + options.output_buffer_size, options); + TF_RETURN_IF_ERROR(buffer.Init()); + TF_RETURN_IF_ERROR(buffer.Append(data)); + TF_RETURN_IF_ERROR(buffer.Close()); + TF_RETURN_IF_ERROR(file->Close()); + return Status::OK(); +} + +Status DumpTraceToLogDirectory(StringPiece run_dir, const string& encoded_trace, + std::ostream* os) { + string proto_path = JoinPath(run_dir, kProtoTraceFileName); + TF_RETURN_IF_ERROR( + WriteStringToFile(Env::Default(), proto_path, encoded_trace)); + LOG(INFO) << "Dumped raw-proto trace data to " << proto_path; + + string json_path = JoinPath(run_dir, kJsonTraceFileName); + Trace trace; + trace.ParseFromString(encoded_trace); + *os << "Trace contains " << trace.trace_events_size() << " events." + << std::endl; + TF_RETURN_IF_ERROR( + WriteGzippedDataToFile(json_path, TraceEventsToJson(trace))); + *os << "Dumped JSON trace data to " << json_path << std::endl; + return Status::OK(); +} + +Status DumpOpProfileToLogDirectory(StringPiece run_dir, + const tpu::op_profile::Profile& profile, + std::ostream* os) { + string path = JoinPath(run_dir, kJsonOpProfileFileName); + string json; + JsonOptions options; + options.always_print_primitive_fields = true; + auto status = MessageToJsonString(profile, &json, options); + if (!status.ok()) { + return errors::Internal( + "Failed to convert op profile to json. Skipping... ", + string(status.error_message())); + } + TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, json)); + *os << "Dumped json op profile data to " << path << std::endl; + return Status::OK(); +} + +Status DumpGraphEvents(const string& logdir, const string& run, + const ProfileResponse& response, std::ostream* os) { + int num_graphs = response.computation_graph_size(); + if (response.computation_graph_size() == 0) return Status::OK(); + // The server might generates multiple graphs for one program; we simply + // pick the first one. + if (num_graphs > 1) { + *os << num_graphs + << " TPU program variants observed over the profiling period. " + << "One computation graph will be chosen arbitrarily." << std::endl; + } + // The graph plugin expects the graph in //. + string run_dir = JoinPath(logdir, strings::StrCat(kGraphRunPrefix, run)); + TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(run_dir)); + EventsWriter event_writer(JoinPath(run_dir, "events")); + Event event; + // Add the computation graph. + event.set_graph_def(response.computation_graph(0).SerializeAsString()); + event_writer.WriteEvent(event); + *os << "Wrote a HLO graph to " << event_writer.FileName() << std::endl; + + if (response.has_hlo_metadata()) { + tensorflow::TaggedRunMetadata tagged_run_metadata; + tagged_run_metadata.set_tag(run); + tagged_run_metadata.set_run_metadata( + response.hlo_metadata().SerializeAsString()); + tensorflow::Event meta_event; + *meta_event.mutable_tagged_run_metadata() = tagged_run_metadata; + event_writer.WriteEvent(meta_event); + *os << "Wrote HLO ops run metadata to " << event_writer.FileName() + << std::endl; + } + return Status::OK(); +} + +} // namespace + +Status WriteTensorboardTPUProfile(const string& logdir, const string& run, + const ProfileResponse& response, + std::ostream* os) { + // Dumps profile data to /plugins/profile//. + string profile_run_dir = JoinPath(logdir, kProfilePluginDirectory, run); + TF_RETURN_IF_ERROR(Env::Default()->RecursivelyCreateDir(profile_run_dir)); + // Ignore computation_graph for now. + if (response.encoded_trace().empty()) { + *os << "No trace event is collected." << std::endl; + } else { + LOG(INFO) << "Converting trace events to TraceViewer JSON."; + TF_RETURN_IF_ERROR( + DumpTraceToLogDirectory(profile_run_dir, response.encoded_trace(), os)); + } + if (response.has_op_profile() && + (response.op_profile().has_by_program_structure() || + response.op_profile().has_by_category())) { + TF_RETURN_IF_ERROR(DumpOpProfileToLogDirectory(profile_run_dir, + response.op_profile(), os)); + } + + TF_RETURN_IF_ERROR(DumpGraphEvents(logdir, run, response, os)); + + return Status::OK(); +} + +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h new file mode 100644 index 0000000000000000000000000000000000000000..65b92aa41867ed9e2e8b06c9e34dd99068bb459c --- /dev/null +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_ + +#include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tpu { + +// Dumps all profiling tool data in a TPU profile to a TensorBoard log directory +// with the given run name. This writes user-facing log messages to `os`. +// The following tools are supported: +// - Trace viewer +// - Op profile +// - HLO computation graph +Status WriteTensorboardTPUProfile(const string& logdir, const string& run, + const ProfileResponse& response, + std::ostream* os); + +} // namespace tpu +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TPU_PROFILER_DUMP_TPU_PROFILE_H_ diff --git a/tensorflow/contrib/tpu/profiler/op_profile.proto b/tensorflow/contrib/tpu/profiler/op_profile.proto index 6911b649a04a914d752e5090048b77885340f8ca..840a43913ba0f159d3c495553ebdff79c0448e73 100644 --- a/tensorflow/contrib/tpu/profiler/op_profile.proto +++ b/tensorflow/contrib/tpu/profiler/op_profile.proto @@ -32,6 +32,18 @@ message Node { string expression = 2; // %multiply = [shape]multiply(operand1, operand2) string provenance = 3; // Typically the TensorFlow operation name. string category = 4; + // Describes the physical memory layout of the instruction's primary input. + // e.g. for a convolution, this analyzes the image and ignores the kernel. + LayoutAnalysis layout = 5; + message LayoutAnalysis { + // The physical data layout, from most-minor to most-major dimensions. + repeated Dimension dimensions = 1; + message Dimension { + int32 size = 1; // Size of the data in this dimension. + int32 alignment = 2; // Data must be padded to a multiple of alignment. + string semantics = 3; // What the dimension represents, e.g. "spatial". + } + } } } diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto index d0a27f1a3d51ef036d02062c0af7d3076f9aaf3c..88e86eca3b63da4bf1d2f9340707dc4a50d28b16 100644 --- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package tensorflow; import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/protobuf/config.proto"; import "tensorflow/contrib/tpu/profiler/op_profile.proto"; // The TPUProfiler service retrieves performance information about @@ -31,6 +32,10 @@ message ProfileResponse { // Graphs of programs executed on TPUs during the profiling period. repeated GraphDef computation_graph = 2; + // Performance profile that can be used to annotate HLO operations in the + // computation graph. + RunMetadata hlo_metadata = 5; + // Encoded Trace proto message that contains metadata about the trace captured // during the profiling period. Describes the devices and resources that // 'trace_events' refers to. @@ -40,4 +45,5 @@ message ProfileResponse { // If the trace covers multiple programs, the longest-running one is analyzed. // See op_profile.proto for the detailed semantics of the returned profile. tpu.op_profile.Profile op_profile = 4; + // next-field: 6 } diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 2cbb118ed6364ff172267008c7e19eb71298aca8..6748a765623a043c53ce34b7b5a3113823eb4149 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -55,6 +55,7 @@ from tensorflow.python.training import training _INITIAL_LOSS = 1e7 _ZERO_LOSS = 0. _BATCH_SIZE_KEY = 'batch_size' +_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] @@ -101,10 +102,12 @@ def _increase_eval_step_op(iterations_per_loop): use_locking=True) -def _tpu_job(run_config): +def _tpu_job(run_config, mode): # The tpu job is determined by the run_config. Right now, this method is # required as tpu_config is not part of the RunConfig. - return None if run_config.master in ['', 'local'] else 'tpu_worker' + master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL + else run_config.master) + return None if master in ['', 'local'] else 'tpu_worker' def _is_running_on_cpu(use_tpu, mode, eval_batch_size): @@ -264,9 +267,9 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): dequeue. """ - def __init__(self, run_config, enqueue_fn, dequeue_ops=None): + def __init__(self, run_config, mode, enqueue_fn, dequeue_ops=None): self._iterations = run_config.tpu_config.iterations_per_loop - self._tpu_job = _tpu_job(run_config) + self._tpu_job = _tpu_job(run_config, mode) self._enqueue_fn = enqueue_fn self._dequeue_ops = dequeue_ops @@ -898,7 +901,7 @@ class _EvalMetrics(object): """ num_shards = run_config.tpu_config.num_shards - job = _tpu_job(run_config) + job = _tpu_job(run_config, model_fn_lib.ModeKeys.EVAL) job_device = '' if job is None else ('/job:%s' % job) # For each i, dequeue_ops[i] is a list containing the tensors from all @@ -977,18 +980,20 @@ class TPUEstimator(estimator_lib.Estimator): Example (MNIST): ``` + # The metric Fn which runs on CPU. + def metric_fn(labels, logits): + predictions = tf.argmax(logits, 1) + return { + 'accuracy': tf.metrics.precision( + labels=labels, predictions=predictions), + } + + # Your model Fn which runs on TPU. def model_fn(features, labels, mode, config, params): ... logits = ... if mode = tf.estimator.ModeKeys.EVAL: - def metric_fn(labels, logits): - predictions = tf.argmax(logits, 1) - return { - 'precision': tf.metrics.precision( - labels=labels, predictions=predictions), - } - return tpu_estimator.TPUEstimatorSpec( mode=mode, loss=loss, @@ -1161,7 +1166,7 @@ class TPUEstimator(estimator_lib.Estimator): with ops.device('/device:CPU:0'): return input_fn(**kwargs) - job = _tpu_job(config) + job = _tpu_job(config, mode) def placement_function(index): if job is None: return '/replica:0/task:0/device:CPU:0' @@ -1189,13 +1194,14 @@ class TPUEstimator(estimator_lib.Estimator): # TODO(b/64607814): Ensure batch_axis works with nested structures. def _create_infeed_enqueue_ops_and_dequeue_fn(inputs_holder, run_config, - batch_axis): + batch_axis, mode): """Utility to convert input_fn to enqueue and dequeue fns for TPU. Args: inputs_holder: An `_InputsHolder` holding features and labels. run_config: A `RunConfig` instance. batch_axis: A python list of batch dimensions. + mode: ModeKeys Returns: A tuple of (dequeue_fn, enqueue_fn) @@ -1238,7 +1244,7 @@ def _create_infeed_enqueue_ops_and_dequeue_fn(inputs_holder, run_config, return infeed_queue.generate_enqueue_ops( sharded_inputs, tpu_ordinal_function=tpu_ordinal_function) else: - job = _tpu_job(run_config) + job = _tpu_job(run_config, mode) def placement_function(index): if job is None: return '/replica:0/task:0/device:CPU:0' @@ -1270,12 +1276,12 @@ def _augment_model_fn(model_fn, train_batch_size, eval_batch_size, use_tpu, num_shards=config.tpu_config.num_shards) dequeue_fn, enqueue_fn = _create_infeed_enqueue_ops_and_dequeue_fn( - inputs, config, batch_axis) + inputs, config, batch_axis, mode) if mode == model_fn_lib.ModeKeys.TRAIN: loss = _train_on_tpu_system(model_fn_wrapper, dequeue_fn) hooks = [ - TPUInfeedOutfeedSessionHook(config, enqueue_fn), + TPUInfeedOutfeedSessionHook(config, mode, enqueue_fn), training.LoggingTensorHook( {'loss': array_ops.identity(loss), 'step': training.get_global_step()}, @@ -1284,6 +1290,10 @@ def _augment_model_fn(model_fn, train_batch_size, eval_batch_size, use_tpu, summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) with ops.control_dependencies([loss]): update_ops = _sync_variables_ops() + + # Validate the TPU training graph to catch basic errors + _validate_tpu_training_graph() + return model_fn_lib.EstimatorSpec( mode, loss=loss, @@ -1313,7 +1323,7 @@ def _augment_model_fn(model_fn, train_batch_size, eval_batch_size, use_tpu, eval_metric_ops.to_metric_metric_ops_for_tpu( config, dummy_update_op)) hooks = [ - TPUInfeedOutfeedSessionHook(config, enqueue_fn, eval_update_ops), + TPUInfeedOutfeedSessionHook(config, mode, enqueue_fn, eval_update_ops), ] return model_fn_lib.EstimatorSpec( @@ -1353,10 +1363,28 @@ def _train_on_tpu_system(model_fn_wrapper, dequeue_fn): dequeue_fn) multi_tpu_train_steps_on_single_shard = (lambda: training_loop.repeat( # pylint: disable=g-long-lambda - iterations_per_loop, single_tpu_train_step, [_INITIAL_LOSS], name='loop')) + iterations_per_loop, single_tpu_train_step, [_INITIAL_LOSS], + name=b'loop')) (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard, inputs=[], num_shards=num_shards, outputs_from_all_shards=False) return loss + + +def _validate_tpu_training_graph(): + """Validate graph before running distributed training. + + Raises: + ValueError: If the graph seems invalid for running on device + """ + operations = ops.get_default_graph().get_operations() + + # Check if there is atleast one CrossReplicaSum operation in the graph + # This should be introduced by using the CrossShardOptimizer wrapper + cross_replica_sum_ops = [o for o in operations + if o.type == _CROSS_REPLICA_SUM_OP] + if not cross_replica_sum_ops: + raise ValueError( + 'CrossShardOptimizer must be used for model training on TPUs.') diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index e8c6c349c8c0444d10881a940b3ba2c1baff8cbd..8e3d869a51c440e00059851f05f6ed2fe5558416 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -23,7 +23,6 @@ py_library( "python/training/evaluation.py", "python/training/feeding_queue_runner.py", "python/training/hparam.py", - "python/training/python_input.py", "python/training/resample.py", "python/training/sampling_ops.py", "python/training/sequence_queueing_state_saver.py", @@ -226,23 +225,6 @@ py_test( ], ) -py_test( - name = "python_input_test", - size = "medium", - srcs = ["python/training/python_input_test.py"], - srcs_version = "PY2AND3", - tags = ["manual"], - deps = [ - ":training_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:training", - "//third_party/py/numpy", - ], -) - py_test( name = "evaluation_test", size = "small", diff --git a/tensorflow/contrib/training/__init__.py b/tensorflow/contrib/training/__init__.py index 87a70e6d164b56bee3f6f14392cf548f4b2b4802..da2de3e421b841937e4125168ea1ecea066ff841 100644 --- a/tensorflow/contrib/training/__init__.py +++ b/tensorflow/contrib/training/__init__.py @@ -36,7 +36,6 @@ See @{$python/contrib.training} guide. @@HParams @@HParamDef @@parse_values -@@python_input """ from __future__ import absolute_import @@ -55,7 +54,6 @@ from tensorflow.contrib.training.python.training.evaluation import SummaryAtEndH from tensorflow.contrib.training.python.training.evaluation import wait_for_new_checkpoint from tensorflow.contrib.training.python.training.feeding_queue_runner import FeedingQueueRunner from tensorflow.contrib.training.python.training.hparam import * -from tensorflow.contrib.training.python.training.python_input import python_input from tensorflow.contrib.training.python.training.resample import * from tensorflow.contrib.training.python.training.sampling_ops import * from tensorflow.contrib.training.python.training.sequence_queueing_state_saver import * diff --git a/tensorflow/contrib/training/python/training/python_input.py b/tensorflow/contrib/training/python/training/python_input.py deleted file mode 100644 index 7f5420a98a1afd3b417b302be04ec6f6747445cd..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/training/python/training/python_input.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Operations for asynchronously reading data from python into queues. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import threading - -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import script_ops - - -def _process_yielded_dict(feature_values, keys, features, dtypes, shapes): - """Read feature_values from the generator and emit a proper output dict.""" - if not isinstance(feature_values, dict): - raise TypeError("generator must return dict, saw: %s" % feature_values) - - processed_values = {} - for pk in keys: - if feature_values.get(pk, None) is not None: - processed_values[pk] = np.asarray( - feature_values[pk], dtype=dtypes[pk].as_numpy_dtype) - check_shape = tensor_shape.TensorShape(processed_values[pk].shape) - if not shapes[pk].is_compatible_with(check_shape): - raise ValueError( - "Feature '%s' has shape %s that is incompatible with declared " - "shape: %s" % (pk, shapes[pk], check_shape)) - continue - if isinstance(features[pk], parsing_ops.FixedLenFeature): - if features[pk].default_value is not None: - processed_values[pk] = np.asarray( - features[pk].default_value, dtype=dtypes[pk].as_numpy_dtype) - elif isinstance(features[pk], parsing_ops.FixedLenSequenceFeature): - processed_values[pk] = np.empty( - [0] + features[pk].shape.aslist(), dtype=dtypes[pk].as_numpy_dtype) - else: - raise ValueError( - "Expected generator to return key '%s' with non-empty value" % pk) - - return processed_values - - -def python_input(generator, features, name=None): - """Easily feed data from a python generator into TensorFlow queues. - - Example usage: - - ```python - def generator(): - for i in range(3): - yield {"value": i} - - features = { - "value": tf.FixedLenFeature(shape=[], dtype=dtypes.int32) - } - - tensor_dict = tf.contrib.training.python_input(generator, features) - batched_dict = tf.train.batch( - tensor_dict, batch_size=2, allow_smaller_final_batch=True) - - s = tf.Session() - tf.train.start_queue_runners() - - batch1 = s.run(batched_dict) # returns {"value": np.array([0, 1])} - batch2 = s.run(batched_dict) # returns {"value": np.array([2])} - s.run(batched_dict) # error: Queue is closed (generator finished at i==3) - ``` - - Args: - generator: A python generator that takes no arguments, and yields dicts - containing a single minibatch entry one at a time. - features: A python `dict` mapping keys expected from the generator to - instances of `tf.FixedLenFeature`, or `tf.FixedLenSequenceFeature`. - name: (Optional) A name for the operations. - - Returns: - A dict mapping keys of the `features` dict to `Tensor` objects. - These `Tensor` objects are outputs of a queue that is fed by `generator`. - - Raises: - TypeError: If generator is not callable or features is not a dict. - TypeError: If any of features' values are not a Feature object. - NotImplementedError: If any of features' values are instances of - `SparseFeature` or `VarLenFeature` (these are not currently supported). - ValueError: If any FixedLenSequenceFeatures contain a default value - (this field is not supported). - ValueError: if any FixedLenSequenceFeatures have allow_missing=False - (this field is not supported). - """ - if not callable(generator): - raise TypeError("generator must be callable, saw: %s" % generator) - if not isinstance(features, dict): - raise TypeError("features must be a dict, saw: %s" - % type(features).__name__) - - with ops.name_scope(name, "python_input"): - shapes = {} - dtypes = {} - for k, v in features.items(): - if isinstance(v, parsing_ops.FixedLenFeature): - if v.default_value is not None: - value = ops.convert_to_tensor(v.default_value, dtype=v.dtype, name=k) - shapes[k] = value.shape - dtypes[k] = value.dtype - else: - tensor_shape.TensorShape(v.shape).assert_is_fully_defined() - shapes[k] = tensor_shape.TensorShape(v.shape) - dtypes[k] = v.dtype - elif isinstance(v, parsing_ops.VarLenFeature): - raise NotImplementedError("VarLenFeature not supported") - elif isinstance(v, parsing_ops.SparseFeature): - raise NotImplementedError("SparseFeature not supported") - elif isinstance(v, parsing_ops.FixedLenSequenceFeature): - if v.default_value is not None: - raise ValueError("FixedLenSequenceFeature with default value not " - "supported") - if not v.allow_missing: - raise ValueError("FixedLenSequenceFeature with allow_missing=False " - "not supported") - tensor_shape.TensorShape(v.shape).assert_is_fully_defined() - shapes[k] = tensor_shape.TensorShape([None]).concatenate(v.shape) - dtypes[k] = v.dtype - else: - raise TypeError( - "Expected value for features key '%s' to be one of " - "FixedLenFeature, VarLenFeature, SparseFeature, or " - "FixedLenSequenceFeature. Got: %s" % (k, v)) - - keys = list(shapes.keys()) - dtypes_list = [dtypes[pk] for pk in keys] - - counter = [0] - lock = threading.Lock() - iterator = iter(generator()) - - def generator_iter(): - """Iterate through generator output and return np.arrays to py_func.""" - with lock: - try: - feature_values = next(iterator) - counter[0] += 1 - except StopIteration as e: - raise StopIteration("Iteration finished. Processed %d entries (%s)" - % (counter[0], e)) - - processed_dict = _process_yielded_dict( - feature_values, keys, features, dtypes, shapes) - return [processed_dict[pk] for pk in keys] - - generator_pyfunc_values = script_ops.py_func( - generator_iter, inp=[], Tout=dtypes_list, stateful=True) - - pyfunc_input = {k: v for (k, v) in zip(keys, generator_pyfunc_values)} - for k, v in shapes.items(): - pyfunc_input[k].set_shape(v) - - return pyfunc_input - - -__all__ = ["python_input"] diff --git a/tensorflow/contrib/training/python/training/python_input_test.py b/tensorflow/contrib/training/python/training/python_input_test.py deleted file mode 100644 index afd0f38c2cd3b2ae915f1e860a277018aaeb9cfd..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/training/python/training/python_input_test.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.training.python_input.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -from tensorflow.contrib.training.python.training import bucket_ops -from tensorflow.contrib.training.python.training import python_input -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.ops import parsing_ops -from tensorflow.python.platform import test -from tensorflow.python.training import coordinator -from tensorflow.python.training import input as core_input -from tensorflow.python.training import queue_runner_impl - - -class PythonInputTest(test.TestCase): - - def testGenerator(self): - def simple_generator(): - for i in range(2): - yield {"value": i, "ignored": 3} - - simple_features = { - "value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32) - } - tensors = python_input.python_input(simple_generator, simple_features) - self.assertEqual(["value"], tensors.keys()) - self.assertEqual(dtypes.int32, tensors["value"].dtype) - self.assertEqual((), tensors["value"].shape) - - with self.test_session() as sess: - self.assertEqual({"value": 0}, sess.run(tensors)) - self.assertEqual({"value": 1}, sess.run(tensors)) - with self.assertRaisesOpError("Iteration finished"): - sess.run(tensors) - - def testInvalidGenerator(self): - generator1 = lambda: iter([{"value": "a"}]) - int_features = { - "value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32) - } - tensors1 = python_input.python_input(generator1, int_features) - - with self.test_session() as sess: - with self.assertRaisesOpError("invalid literal"): - # Can't convert a string to an integer - sess.run(tensors1) - - generator2 = lambda: iter([None]) - tensors2 = python_input.python_input(generator2, int_features) - - with self.test_session() as sess: - with self.assertRaisesOpError("generator must return dict"): - sess.run(tensors2) - - generator3 = lambda: iter([{"value": [1, 2]}]) - tensors3 = python_input.python_input(generator3, int_features) - - with self.test_session() as sess: - with self.assertRaisesOpError("incompatible with declared shape"): - sess.run(tensors3) - - def testGeneratorWorksWithBatching(self): - def simple_generator(): - for i in range(5): - yield {"value": i, "ignored": 3} - - simple_features = { - "value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32) - } - tensors = python_input.python_input(simple_generator, simple_features) - - # Request batches of size 4 at a time, the final batch may be smaller. - batched_tensors = core_input.batch(tensors, batch_size=4, - allow_smaller_final_batch=True) - - self.assertEqual(["value"], batched_tensors.keys()) - self.assertEqual(dtypes.int32, batched_tensors["value"].dtype) - self.assertEqual([None], batched_tensors["value"].shape.as_list()) - - with self.test_session() as sess: - # The generator emits 5 items total. The first 4 are returned in - # the first session run; the final one is returned in the - # second. This works because allow_smaller_final_batch=True. - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) - r1 = sess.run(batched_tensors) - r2 = sess.run(batched_tensors) - self.assertAllEqual([0, 1, 2, 3], r1["value"]) - self.assertEqual([4], r2["value"]) - with self.assertRaisesOpError("Iteration finished"): - sess.run(tensors) - coord.request_stop() - for thread in threads: - thread.join() - - def testGeneratorWorksWithManyBatchingThreads(self): - def simple_generator(): - for i in range(5000): - yield {"value": i, "ignored": 3} - - simple_features = { - "value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32) - } - tensors = python_input.python_input(simple_generator, simple_features) - - # Request batches of size 20 at a time, the final batch may be smaller. - _, batched_tensors = bucket_ops.bucket( - tensors, which_bucket=tensors["value"] % 5, - batch_size=20, num_buckets=5, num_threads=7, capacity=17, - allow_smaller_final_batch=True) - - self.assertEqual(["value"], batched_tensors.keys()) - self.assertEqual(dtypes.int32, batched_tensors["value"].dtype) - self.assertEqual([None], batched_tensors["value"].shape.as_list()) - - with self.test_session() as sess: - # The generator emits 5 items total. The first 4 are returned in - # the first session run; the final one is returned in the - # second. This works because allow_smaller_final_batch=True. - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) - results = [] - while True: - try: - r = sess.run(batched_tensors) - results.extend(r["value"].tolist()) - except errors.OutOfRangeError: - break - coord.request_stop() - for thread in threads: - thread.join() - self.assertEqual(sorted(results), - list(range(5000))) - - def testVaryingFieldsInGenerator(self): - def simple_generator(): - for i in range(2): - yield {"value": i, - "seqlen_value": np.ones((i, 1))} - - simple_features = { - "value": parsing_ops.FixedLenFeature(shape=[], dtype=dtypes.int32), - "seqlen_value": parsing_ops.FixedLenSequenceFeature( - shape=[1], dtype=dtypes.float32, allow_missing=True), - "empty_value": parsing_ops.FixedLenFeature( - default_value=[-1, -2], dtype=dtypes.int32, shape=[2]) - } - tensors = python_input.python_input(simple_generator, simple_features) - self.assertEqual( - set(["value", "seqlen_value", "empty_value"]), set(tensors.keys())) - self.assertEqual(dtypes.int32, tensors["value"].dtype) - self.assertEqual((), tensors["value"].shape) - self.assertEqual(dtypes.float32, tensors["seqlen_value"].dtype) - self.assertEqual([None, 1], tensors["seqlen_value"].shape.as_list()) - self.assertEqual(dtypes.int32, tensors["empty_value"].dtype) - self.assertEqual([2], tensors["empty_value"].shape) - - with self.test_session() as sess: - r1 = sess.run(tensors) - self.assertAllEqual(0, r1["value"]) - self.assertAllEqual(np.ones((0, 1)), r1["seqlen_value"]) - self.assertAllEqual([-1, -2], r1["empty_value"]) - - r2 = sess.run(tensors) - self.assertAllEqual(1, r2["value"]) - self.assertAllEqual([[1]], r2["seqlen_value"]) - self.assertAllEqual([-1, -2], r2["empty_value"]) - - with self.assertRaisesOpError("Iteration finished"): - sess.run(tensors) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index bf0c968ae25caf411444c6c278b15080a6e01e55..21952cfa59affa82b7f448b91b450ca5c089d9c8 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -122,6 +122,7 @@ load( "tf_additional_gpu_tracer_cuda_deps", "tf_pyclif_proto_library", "tf_jspb_proto_library", + "tf_nano_proto_library", ) load( "//tensorflow/core:platform/default/build_config_root.bzl", @@ -212,6 +213,15 @@ tf_jspb_proto_library( deps = [":protos_all_cc"], ) +tf_nano_proto_library( + name = "protos_all_nano_proto", + field_style = "accessors", + generate_equals = 1, + generate_intdefs = 1, + visibility = ["//visibility:public"], + deps = [":protos_all_cc"], +) + exports_files([ "framework/types.proto", ]) @@ -556,6 +566,7 @@ tf_gen_op_libs( "state_ops", "stateless_random_ops", "string_ops", + "summary_ops", "training_ops", ], ) @@ -766,6 +777,7 @@ cc_library( "//tensorflow/core/kernels:state", "//tensorflow/core/kernels:stateless_random_ops", "//tensorflow/core/kernels:string", + "//tensorflow/core/kernels:summary_kernels", "//tensorflow/core/kernels:training_ops", "//tensorflow/core/kernels:word2vec_kernels", ] + tf_additional_cloud_kernel_deps() + if_not_windows([ @@ -777,8 +789,10 @@ cc_library( ]) + if_mkl([ "//tensorflow/core/kernels:mkl_concat_op", "//tensorflow/core/kernels:mkl_conv_op", + "//tensorflow/core/kernels:mkl_cwise_ops_common", "//tensorflow/core/kernels:mkl_fused_batch_norm_op", "//tensorflow/core/kernels:mkl_identity_op", + "//tensorflow/core/kernels:mkl_input_conversion_op", "//tensorflow/core/kernels:mkl_lrn_op", "//tensorflow/core/kernels:mkl_pooling_ops", "//tensorflow/core/kernels:mkl_relu_op", @@ -1927,6 +1941,7 @@ tf_cuda_library( name = "gpu_runtime", srcs = [ "common_runtime/gpu/gpu_bfc_allocator.cc", + "common_runtime/gpu/gpu_cudamalloc_allocator.cc", "common_runtime/gpu/gpu_debug_allocator.cc", "common_runtime/gpu/gpu_device.cc", "common_runtime/gpu/gpu_device_factory.cc", @@ -1939,6 +1954,7 @@ tf_cuda_library( ], hdrs = [ "common_runtime/gpu/gpu_bfc_allocator.h", + "common_runtime/gpu/gpu_cudamalloc_allocator.h", "common_runtime/gpu/gpu_debug_allocator.h", "common_runtime/gpu/gpu_device.h", "common_runtime/gpu/gpu_init.h", @@ -2143,8 +2159,6 @@ tf_cc_tests( "platform/port_test.cc", "platform/profile_utils/cpu_utils_test.cc", "platform/subprocess_test.cc", - "platform/vmodule_benchmark_test.cc", - "platform/vmodule_test.cc", ], deps = [ ":lib", @@ -2369,6 +2383,7 @@ tf_cc_tests( "util/semver_test.cc", "util/sparse/sparse_tensor_test.cc", "util/stat_summarizer_test.cc", + "util/tensor_format_test.cc", "util/tensor_slice_reader_test.cc", "util/tensor_slice_set_test.cc", "util/tensor_slice_util_test.cc", @@ -2468,8 +2483,10 @@ tf_cc_test_mkl( "//tensorflow/core/kernels:mkl_aggregate_ops", "//tensorflow/core/kernels:mkl_concat_op", "//tensorflow/core/kernels:mkl_conv_op", + "//tensorflow/core/kernels:mkl_cwise_ops_common", "//tensorflow/core/kernels:mkl_fused_batch_norm_op", "//tensorflow/core/kernels:mkl_identity_op", + "//tensorflow/core/kernels:mkl_input_conversion_op", "//tensorflow/core/kernels:mkl_lrn_op", "//tensorflow/core/kernels:mkl_pooling_ops", "//tensorflow/core/kernels:mkl_relu_op", @@ -3079,7 +3096,6 @@ cc_test( srcs = ["example/example_parser_configuration_test.cc"], data = [":example_parser_configuration_testdata"], deps = [ - ":core", ":core_cpu", ":core_cpu_internal", ":direct_session_internal", diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc new file mode 100644 index 0000000000000000000000000000000000000000..70c2d96763e72909bd1d58ae637d8393f1368197 --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef GOOGLE_CUDA +#include "cuda/include/cuda.h" +#include "tensorflow/stream_executor/cuda/cuda_activation.h" +#endif // GOOGLE_CUDA + +#include "tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h" + +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/core/platform/stream_executor.h" + +namespace gpu = ::perftools::gputools; + +namespace tensorflow { + +GPUcudaMallocAllocator::GPUcudaMallocAllocator(VisitableAllocator* allocator, + int device_id) + : base_allocator_(allocator) { + stream_exec_ = GPUMachineManager()->ExecutorForDevice(device_id).ValueOrDie(); +} + +GPUcudaMallocAllocator::~GPUcudaMallocAllocator() { delete base_allocator_; } + +void* GPUcudaMallocAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { +#ifdef GOOGLE_CUDA + // allocate with cudaMalloc + gpu::cuda::ScopedActivateExecutorContext scoped_activation{stream_exec_}; + CUdeviceptr rv = 0; + CUresult res = cuMemAlloc(&rv, num_bytes); + if (res != CUDA_SUCCESS) { + LOG(ERROR) << "cuMemAlloc failed to allocate " << num_bytes; + return nullptr; + } + return reinterpret_cast(rv); +#else + return nullptr; +#endif // GOOGLE_CUDA +} +void GPUcudaMallocAllocator::DeallocateRaw(void* ptr) { +#ifdef GOOGLE_CUDA + // free with cudaFree + CUresult res = cuMemFree(reinterpret_cast(ptr)); + if (res != CUDA_SUCCESS) { + LOG(ERROR) << "cuMemFree failed to free " << ptr; + } +#endif // GOOGLE_CUDA +} + +void GPUcudaMallocAllocator::AddAllocVisitor(Visitor visitor) { + return base_allocator_->AddAllocVisitor(visitor); +} + +void GPUcudaMallocAllocator::AddFreeVisitor(Visitor visitor) { + return base_allocator_->AddFreeVisitor(visitor); +} + +bool GPUcudaMallocAllocator::TracksAllocationSizes() { return false; } + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..23552b809a8a735aaeb8ac9643eccd0b0542f03b --- /dev/null +++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_CUDA_MALLOC_ALLOCATOR_H_ +#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_CUDA_MALLOC_ALLOCATOR_H_ + +#include + +#include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// An allocator that wraps a GPU allocator and adds debugging +// functionality that verifies that users do not write outside their +// allocated memory. +class GPUcudaMallocAllocator : public VisitableAllocator { + public: + explicit GPUcudaMallocAllocator(VisitableAllocator* allocator, int device_id); + ~GPUcudaMallocAllocator() override; + string Name() override { return "gpu_debug"; } + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; + void AddAllocVisitor(Visitor visitor) override; + void AddFreeVisitor(Visitor visitor) override; + bool TracksAllocationSizes() override; + + private: + VisitableAllocator* base_allocator_ = nullptr; // owned + + perftools::gputools::StreamExecutor* stream_exec_; // Not owned. + + TF_DISALLOW_COPY_AND_ASSIGN(GPUcudaMallocAllocator); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_CUDAMALLOC_ALLOCATOR_H_ diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc index 6b3c58ac9c51ebc3a6a77185550076d29fb9a282..0675dbf3fcdc772f4d45025d296eaddbf4397271 100644 --- a/tensorflow/core/common_runtime/gpu/process_state.cc +++ b/tensorflow/core/common_runtime/gpu/process_state.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h" +#include "tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h" #include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h" #include "tensorflow/core/common_runtime/gpu/gpu_init.h" #include "tensorflow/core/common_runtime/gpu/pool_allocator.h" @@ -48,6 +49,27 @@ namespace gpu = ::perftools::gputools; namespace tensorflow { +namespace { +bool useCudaMallocAllocator() { + const char* debug_allocator_str = std::getenv("TF_GPU_ALLOCATOR"); + if (debug_allocator_str != nullptr && + strcmp(debug_allocator_str, "cuda_malloc") == 0) + return true; + else + return false; +} + +bool useCudaMemoryGuardAllocator() { + const char* debug_allocator_str = std::getenv("TF_GPU_ALLOCATOR"); + if (debug_allocator_str != nullptr && + strcmp(debug_allocator_str, "memory_guard") == 0) + return true; + else + return false; +} + +} // namespace + ProcessState* ProcessState::instance_ = nullptr; /*static*/ ProcessState* ProcessState::singleton() { @@ -114,10 +136,14 @@ Allocator* ProcessState::GetGPUAllocator(const GPUOptions& options, int gpu_id, // If true, checks for memory overwrites by writing // distinctive patterns on both ends of allocated memory. - static const bool kGPUDebug = false; - if (kGPUDebug) { + if (useCudaMemoryGuardAllocator()) { gpu_allocator = new GPUDebugAllocator(gpu_allocator, gpu_id); gpu_allocator = new GPUNanResetAllocator(gpu_allocator, gpu_id); + } else if (useCudaMallocAllocator()) { + // If true, passes all allocation requests through to cudaMalloc + // useful for doing memory debugging with tools like cuda-memcheck + // **WARNING** probably will not work in a multi-gpu scenario + gpu_allocator = new GPUcudaMallocAllocator(gpu_allocator, gpu_id); } gpu_allocators_[gpu_id] = gpu_allocator; diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index 005aabf9b822ce93dc292004e239eec8c37b7f08..f16da10d7afb6587618043a590243cfe973738c8 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -75,12 +75,12 @@ class MklCPUAllocator : public Allocator { // Hooks provided by this allocator for memory allocation routines from MKL static inline void* MallocHook(size_t size) { - VLOG(2) << "MklCPUAllocator: In MallocHook"; + VLOG(3) << "MklCPUAllocator: In MallocHook"; return cpu_allocator()->AllocateRaw(kAlignment, size); } static inline void FreeHook(void* ptr) { - VLOG(2) << "MklCPUAllocator: In FreeHook"; + VLOG(3) << "MklCPUAllocator: In FreeHook"; cpu_allocator()->DeallocateRaw(ptr); } diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc index 2a974d184028a416396f3d3acfb93f7f7fb3792a..363d3a0c9d387c60781d205c3f3e2e413b1dc98c 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/util.h" @@ -54,7 +53,6 @@ SimpleGraphExecutionState::SimpleGraphExecutionState( : stateful_placements_(options.stateful_placements), device_set_(options.device_set), session_options_(options.session_options), - costs_(true /*is_global*/), flib_def_(new FunctionLibraryDefinition(OpRegistry::Global(), graph_def->library())), graph_(nullptr) { @@ -258,19 +256,11 @@ Status SimpleGraphExecutionState::InitBaseGraph( // Save stateful placements before placing. RestoreStatefulNodes(new_graph.get()); - CostModel costs(true /*is_global*/); - { - mutex_lock l(mu_); - costs_.InitFromGraph(*new_graph); - costs.MergeFromGlobal(costs_); - } - GraphOptimizationPassOptions optimization_options; optimization_options.session_options = session_options_; optimization_options.graph = &new_graph; optimization_options.flib_def = flib_def_.get(); optimization_options.device_set = device_set_; - optimization_options.cost_model = &costs; TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( OptimizationPassRegistry::PRE_PLACEMENT, optimization_options)); @@ -420,14 +410,11 @@ Status SimpleGraphExecutionState::BuildGraph( new FunctionLibraryDefinition(*flib_def_)); // TODO(andydavis): Clarify optimization pass requirements around CostModel. - CostModel costs(true /*is_global*/); - costs.MergeFromGlobal(costs_); GraphOptimizationPassOptions optimization_options; optimization_options.session_options = session_options_; optimization_options.graph = &ng; optimization_options.flib_def = flib.get(); optimization_options.device_set = device_set_; - optimization_options.cost_model = &costs; TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options)); diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.h b/tensorflow/core/common_runtime/simple_graph_execution_state.h index c7f34a42d61689ea90da5a1fef84f2a56f535fd4..53eef8a07d532fa5f23cfe031e26288d1c078671 100644 --- a/tensorflow/core/common_runtime/simple_graph_execution_state.h +++ b/tensorflow/core/common_runtime/simple_graph_execution_state.h @@ -25,19 +25,14 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/graph/costmodel.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { struct SessionOptions; -class StepStats; -class Timeline; namespace subgraph { struct RewriteGraphMetadata; @@ -167,7 +162,6 @@ class SimpleGraphExecutionState { // Returns the map of stateful placements as a map of // node name to placement string. std::unordered_map GetStatefulPlacements() const { - mutex_lock l(mu_); return stateful_placements_; } @@ -193,9 +187,6 @@ class SimpleGraphExecutionState { const DeviceSet* device_set_; // Not owned const SessionOptions* session_options_; // Not owned - mutable mutex mu_; - CostModel costs_ GUARDED_BY(mu_); - // Map from name to Node for the full graph in placed_. NodeNameToCostIdMap node_name_to_cost_id_map_; diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index d410a164eac00a01f6a6c967e4cf4637a31a3d37..ee126240747b76288e8fcbe9fb90a1c6ac623aab 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -40,8 +40,8 @@ static int ExtractGpuWithStreamAll(string device_name) { scanner.OneLiteral("lla:maerts/"); // Capture the digits if present scanner.RestartCapture().Many(strings::Scanner::DIGIT).StopCapture(); - // Check that the digits are preceded by the 'gpu:' string - scanner.OneLiteral(":upg"); + // Check that the digits are preceded by the 'device:GPU:' string + scanner.OneLiteral(":UPG:ecived"); StringPiece capture; bool matched = scanner.GetResult(nullptr, &capture); @@ -69,8 +69,8 @@ static int ExtractGpuWithoutStream(string device_name) { strings::Scanner scanner(device_name); // Capture the trailing digits if present scanner.RestartCapture().Many(strings::Scanner::DIGIT).StopCapture(); - // Check that the digits are preceded by the 'gpu:' string - scanner.OneLiteral(":upg"); + // Check that the digits are preceded by the 'device:GPU:' string + scanner.OneLiteral(":UPG:ecived"); StringPiece capture; bool matched = scanner.GetResult(nullptr, &capture); diff --git a/tensorflow/core/example/feature_util.cc b/tensorflow/core/example/feature_util.cc index 6f3cc6c6c5d451a3a0f18d6db7054b247e82e93b..f0593ede82fd987f6958d4d450dfcaf7d5ce273a 100644 --- a/tensorflow/core/example/feature_util.cc +++ b/tensorflow/core/example/feature_util.cc @@ -18,77 +18,129 @@ limitations under the License. namespace tensorflow { namespace internal { - -::tensorflow::Feature& ExampleFeature(const string& name, - ::tensorflow::Example* example) { - ::tensorflow::Features* features = example->mutable_features(); - return (*features->mutable_feature())[name]; +Feature& ExampleFeature(const string& name, Example* example) { + return *GetFeature(name, example); } -} // namespace internal +} // namespace internal template <> -bool ExampleHasFeature(const string& name, - const Example& example) { - auto it = example.features().feature().find(name); - return (it != example.features().feature().end()) && +bool HasFeature<>(const string& key, const Features& features) { + return (features.feature().find(key) != features.feature().end()); +} + +template <> +bool HasFeature(const string& key, const Features& features) { + auto it = features.feature().find(key); + return (it != features.feature().end()) && (it->second.kind_case() == Feature::KindCase::kInt64List); } template <> -bool ExampleHasFeature(const string& name, const Example& example) { - auto it = example.features().feature().find(name); - return (it != example.features().feature().end()) && +bool HasFeature(const string& key, const Features& features) { + auto it = features.feature().find(key); + return (it != features.feature().end()) && (it->second.kind_case() == Feature::KindCase::kFloatList); } template <> -bool ExampleHasFeature(const string& name, const Example& example) { - auto it = example.features().feature().find(name); - return (it != example.features().feature().end()) && +bool HasFeature(const string& key, const Features& features) { + auto it = features.feature().find(key); + return (it != features.feature().end()) && (it->second.kind_case() == Feature::KindCase::kBytesList); } +bool HasFeatureList(const string& key, + const SequenceExample& sequence_example) { + auto& feature_list = sequence_example.feature_lists().feature_list(); + return (feature_list.find(key) != feature_list.end()); +} + template <> const protobuf::RepeatedField& GetFeatureValues( - const string& name, const Example& example) { - return example.features().feature().at(name).int64_list().value(); + const Feature& feature) { + return feature.int64_list().value(); } template <> protobuf::RepeatedField* GetFeatureValues( - const string& name, Example* example) { - return internal::ExampleFeature(name, example) - .mutable_int64_list() - ->mutable_value(); + Feature* feature) { + return feature->mutable_int64_list()->mutable_value(); } template <> const protobuf::RepeatedField& GetFeatureValues( - const string& name, const Example& example) { - return example.features().feature().at(name).float_list().value(); + const Feature& feature) { + return feature.float_list().value(); } template <> -protobuf::RepeatedField* GetFeatureValues(const string& name, - Example* example) { - return internal::ExampleFeature(name, example) - .mutable_float_list() - ->mutable_value(); +protobuf::RepeatedField* GetFeatureValues(Feature* feature) { + return feature->mutable_float_list()->mutable_value(); } template <> const protobuf::RepeatedPtrField& GetFeatureValues( - const string& name, const Example& example) { - return example.features().feature().at(name).bytes_list().value(); + const Feature& feature) { + return feature.bytes_list().value(); +} + +template <> +protobuf::RepeatedPtrField* GetFeatureValues(Feature* feature) { + return feature->mutable_bytes_list()->mutable_value(); +} + +const protobuf::RepeatedPtrField& GetFeatureList( + const string& key, const SequenceExample& sequence_example) { + return sequence_example.feature_lists().feature_list().at(key).feature(); +} + +protobuf::RepeatedPtrField* GetFeatureList( + const string& feature_list_key, SequenceExample* sequence_example) { + return (*sequence_example->mutable_feature_lists() + ->mutable_feature_list())[feature_list_key] + .mutable_feature(); +} + +template <> +Features* GetFeatures(Features* proto) { + return proto; +} + +template <> +Features* GetFeatures(Example* proto) { + return proto->mutable_features(); } template <> -protobuf::RepeatedPtrField* GetFeatureValues(const string& name, - Example* example) { - return internal::ExampleFeature(name, example) - .mutable_bytes_list() - ->mutable_value(); +const Features& GetFeatures(const Features& proto) { + return proto; } +template <> +const Features& GetFeatures(const Example& proto) { + return proto.features(); +} + +template <> +const protobuf::RepeatedField& GetFeatureValues( + const Feature& feature); + +template <> +protobuf::RepeatedField* GetFeatureValues( + Feature* feature); + +template <> +const protobuf::RepeatedField& GetFeatureValues( + const Feature& feature); + +template <> +protobuf::RepeatedField* GetFeatureValues(Feature* feature); + +template <> +const protobuf::RepeatedPtrField& GetFeatureValues( + const Feature& feature); + +template <> +protobuf::RepeatedPtrField* GetFeatureValues(Feature* feature); } // namespace tensorflow diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h index 4004411cb17ee5308dc6f7f6f7e875e451d0d1ed..a87c2c9a57c7c80692359dc88be3aca2ce7779b6 100644 --- a/tensorflow/core/example/feature_util.h +++ b/tensorflow/core/example/feature_util.h @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// A set of lightweight wrappers which simplify access to Example features. +// A set of lightweight wrappers which simplify access to Feature protos. // // TensorFlow Example proto uses associative maps on top of oneof fields. +// SequenceExample proto uses associative map of FeatureList. // So accessing feature values is not very convenient. // // For example, to read a first value of integer feature "tag": @@ -42,9 +43,59 @@ limitations under the License. // (RepeatedPtrField for byte list). So refer to its documentation of // RepeatedField for full list of supported methods. // -// NOTE: It is also important to mention that due to the nature of oneof proto -// fields setting a feature of one type automatically clears all values stored -// as another type with the same feature name. +// NOTE: Due to the nature of oneof proto fields setting a feature of one type +// automatically clears all values stored as another type with the same feature +// key. +// +// This library also has tools to work with SequenceExample protos. +// +// To get a value from SequenceExample.context: +// int id = GetFeatureValues("tag", se.context()).Get(0); +// To add a value to the context: +// GetFeatureValues("tag", se.mutable_context())->Add(42); +// +// To add values to feature_lists: +// AppendFeatureValues({4.0}, +// GetFeatureList("movie_ratings", &se)->Add()); +// AppendFeatureValues({5.0, 3.0}, +// GetFeatureList("movie_ratings", &se)->Add()); +// This will create a feature list keyed as "images" with two features: +// feature_lists { +// feature_list { +// key: "images" +// value { +// feature { float_list { value: [4.0] } } +// feature { float_list { value: [5.0, 3.0] } } +// } +// } } +// +// Functions exposed by this library: +// HasFeature<[FeatureType]>(key, proto) -> bool +// Returns true if a feature with the specified key, and optionally +// FeatureType, belongs to the Features or Example proto. +// HasFeatureList(key, sequence_example) -> bool +// Returns true if SequenceExample has a feature_list with the key. +// GetFeatureValues(key, proto) -> RepeatedField +// Returns values for the specified key and the FeatureType. +// Supported types for the proto: Example, Features. +// GetFeatureList(key, sequence_example) -> RepeatedPtrField +// Returns Feature protos associated with a key. +// AppendFeatureValues(begin, end, feature) +// AppendFeatureValues(container or initializer_list, feature) +// Copies values into a Feature. +// AppendFeatureValues(begin, end, key, proto) +// AppendFeatureValues(container or initializer_list, key, proto) +// Copies values into Features and Example protos with the specified key. +// +// Auxiliary functions, it is unlikely you'll need to use them directly: +// GetFeatures(proto) -> Features +// A convenience function to get Features proto. +// Supported types for the proto: Example, Features. +// GetFeature(key, proto) -> Feature* +// Returns a Feature proto for the specified key, creates a new if +// necessary. Supported types for the proto: Example, Features. +// GetFeatureValues(feature) -> RepeatedField +// Returns values of the feature for the FeatureType. #ifndef TENSORFLOW_EXAMPLE_FEATURE_H_ #define TENSORFLOW_EXAMPLE_FEATURE_H_ @@ -62,10 +113,11 @@ namespace tensorflow { namespace internal { +// DEPRECATED: Use GetFeature instead. +// TODO(gorban): Update all clients in a followup CL. // Returns a reference to a feature corresponding to the name. // Note: it will create a new Feature if it is missing in the example. -::tensorflow::Feature& ExampleFeature(const string& name, - ::tensorflow::Example* example); +Feature& ExampleFeature(const string& name, Example* example); // Specializations of RepeatedFieldTrait define a type of RepeatedField // corresponding to a selected feature type. @@ -127,89 +179,135 @@ struct FeatureTrait< } // namespace internal -// Returns true if feature with the specified name belongs to the example proto. -// Doesn't check feature type. Note that specialized versions return false if -// the feature has a wrong type. -template -bool ExampleHasFeature(const string& name, const Example& example) { - return example.features().feature().find(name) != - example.features().feature().end(); -} +// Returns true if sequence_example has a feature_list with the specified key. +bool HasFeatureList(const string& key, const SequenceExample& sequence_example); + +// A family of template functions to return mutable Features proto from a +// container proto. Supported ProtoTypes: Example, Features. +template +Features* GetFeatures(ProtoType* proto); + +template +const Features& GetFeatures(const ProtoType& proto); // Base declaration of a family of template functions to return a read only -// repeated field corresponding to a feature with the specified name. +// repeated field of feature values. template const typename internal::RepeatedFieldTrait::Type& -GetFeatureValues(const string& name, const Example& example); +GetFeatureValues(const Feature& feature); + +// Returns a read only repeated field corresponding to a feature with the +// specified name and FeatureType. Supported ProtoTypes: Example, Features. +template +const typename internal::RepeatedFieldTrait::Type& +GetFeatureValues(const string& key, const ProtoType& proto) { + return GetFeatureValues(GetFeatures(proto).feature().at(key)); +} -// Base declaration of a family of template functions to return a mutable -// repeated field corresponding to a feature with the specified name. +// Returns a mutable repeated field of a feature values. template typename internal::RepeatedFieldTrait::Type* GetFeatureValues( - const string& name, Example* example); + Feature* feature); + +// Returns a mutable repeated field corresponding to a feature with the +// specified name and FeatureType. Supported ProtoTypes: Example, Features. +template +typename internal::RepeatedFieldTrait::Type* GetFeatureValues( + const string& key, ProtoType* proto) { + ::tensorflow::Feature& feature = + (*GetFeatures(proto)->mutable_feature())[key]; + return GetFeatureValues(&feature); +} + +// Returns a Feature proto for the specified key, creates a new if necessary. +// Supported types for the proto: Example, Features. +template +Feature* GetFeature(const string& key, ProtoType* proto) { + return &(*GetFeatures(proto)->mutable_feature())[key]; +} + +// Returns a repeated field with features corresponding to a feature_list key. +const protobuf::RepeatedPtrField& GetFeatureList( + const string& key, const SequenceExample& sequence_example); + +// Returns a mutable repeated field with features corresponding to a +// feature_list key. It will create a new FeatureList if necessary. +protobuf::RepeatedPtrField* GetFeatureList( + const string& feature_list_key, SequenceExample* sequence_example); -// Copies elements from the range, defined by [first, last) into a feature. template void AppendFeatureValues(IteratorType first, IteratorType last, - const string& name, Example* example) { + Feature* feature) { using FeatureType = typename internal::FeatureTrait< typename std::iterator_traits::value_type>::Type; - std::copy(first, last, protobuf::RepeatedFieldBackInserter( - GetFeatureValues(name, example))); + std::copy(first, last, + protobuf::RepeatedFieldBackInserter( + GetFeatureValues(feature))); +} + +template +void AppendFeatureValues(std::initializer_list container, + Feature* feature) { + AppendFeatureValues(container.begin(), container.end(), feature); } -// Copies all elements from the container into a feature. template -void AppendFeatureValues(const ContainerType& container, const string& name, - Example* example) { +void AppendFeatureValues(const ContainerType& container, Feature* feature) { using IteratorType = typename ContainerType::const_iterator; - AppendFeatureValues(container.begin(), container.end(), name, - example); + AppendFeatureValues(container.begin(), container.end(), + feature); } -// Copies all elements from the initializer list into a feature. -template +// Copies elements from the range, defined by [first, last) into the feature +// obtainable from the (proto, key) combination. +template +void AppendFeatureValues(IteratorType first, IteratorType last, + const string& key, ProtoType* proto) { + AppendFeatureValues(first, last, GetFeature(key, GetFeatures(proto))); +} + +// Copies all elements from the container into a feature. +template +void AppendFeatureValues(const ContainerType& container, const string& key, + ProtoType* proto) { + using IteratorType = typename ContainerType::const_iterator; + AppendFeatureValues(container.begin(), container.end(), key, + proto); +} + +// Copies all elements from the initializer list into a Feature contained by +// Features or Example proto. +template void AppendFeatureValues(std::initializer_list container, - const string& name, Example* example) { + const string& key, ProtoType* proto) { using IteratorType = typename std::initializer_list::const_iterator; - AppendFeatureValues(container.begin(), container.end(), name, - example); + AppendFeatureValues(container.begin(), container.end(), key, + proto); } -template <> -bool ExampleHasFeature(const string& name, - const Example& example); - -template <> -bool ExampleHasFeature(const string& name, const Example& example); - -template <> -bool ExampleHasFeature(const string& name, const Example& example); - -template <> -const protobuf::RepeatedField& GetFeatureValues( - const string& name, const Example& example); - -template <> -protobuf::RepeatedField* GetFeatureValues( - const string& name, Example* example); - -template <> -const protobuf::RepeatedField& GetFeatureValues( - const string& name, const Example& example); - -template <> -protobuf::RepeatedField* GetFeatureValues(const string& name, - Example* example); - -template <> -const protobuf::RepeatedPtrField& GetFeatureValues( - const string& name, const Example& example); +// Returns true if a feature with the specified key belongs to the Features. +// The template parameter pack accepts zero or one template argument - which +// is FeatureType. If the FeatureType not specified (zero template arguments) +// the function will not check the feature type. Otherwise it will return false +// if the feature has a wrong type. +template +bool HasFeature(const string& key, const Features& features); + +// Returns true if a feature with the specified key belongs to the Example. +// Doesn't check feature type if used without FeatureType, otherwise the +// specialized versions return false if the feature has a wrong type. +template +bool HasFeature(const string& key, const Example& example) { + return HasFeature(key, GetFeatures(example)); +}; -template <> -protobuf::RepeatedPtrField* GetFeatureValues(const string& name, - Example* example); +// DEPRECATED: use HasFeature instead. +// TODO(gorban): update all clients in a followup CL. +template +bool ExampleHasFeature(const string& key, const Example& example) { + return HasFeature(key, example); +} } // namespace tensorflow #endif // TENSORFLOW_EXAMPLE_FEATURE_H_ diff --git a/tensorflow/core/example/feature_util_test.cc b/tensorflow/core/example/feature_util_test.cc index eb7b90af1b2de8955413de8c4cee30168312f606..cd32dee306d1b9d589bcf490c32fbd69f86ed32a 100644 --- a/tensorflow/core/example/feature_util_test.cc +++ b/tensorflow/core/example/feature_util_test.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tensorflow/core/example/feature_util.h" #include @@ -38,6 +37,16 @@ TEST(GetFeatureValuesInt64Test, ReadsASingleValue) { EXPECT_EQ(42, tag.Get(0)); } +TEST(GetFeatureValuesInt64Test, ReadsASingleValueFromFeature) { + Feature feature; + feature.mutable_int64_list()->add_value(42); + + auto values = GetFeatureValues(feature); + + ASSERT_EQ(1, values.size()); + EXPECT_EQ(42, values.Get(0)); +} + TEST(GetFeatureValuesInt64Test, WritesASingleValue) { Example example; @@ -48,25 +57,33 @@ TEST(GetFeatureValuesInt64Test, WritesASingleValue) { EXPECT_EQ(42, example.features().feature().at("tag").int64_list().value(0)); } +TEST(GetFeatureValuesInt64Test, WritesASingleValueToFeature) { + Feature feature; + + GetFeatureValues(&feature)->Add(42); + + ASSERT_EQ(1, feature.int64_list().value_size()); + EXPECT_EQ(42, feature.int64_list().value(0)); +} + TEST(GetFeatureValuesInt64Test, CheckUntypedFieldExistence) { Example example; - - EXPECT_FALSE(ExampleHasFeature("tag", example)); + ASSERT_FALSE(HasFeature("tag", example)); GetFeatureValues("tag", &example)->Add(0); - EXPECT_TRUE(ExampleHasFeature("tag", example)); + EXPECT_TRUE(HasFeature("tag", example)); } TEST(GetFeatureValuesInt64Test, CheckTypedFieldExistence) { Example example; GetFeatureValues("tag", &example)->Add(3.14); - ASSERT_FALSE(ExampleHasFeature("tag", example)); + ASSERT_FALSE(HasFeature("tag", example)); GetFeatureValues("tag", &example)->Add(42); - EXPECT_TRUE(ExampleHasFeature("tag", example)); + EXPECT_TRUE(HasFeature("tag", example)); auto tag_ro = GetFeatureValues("tag", example); ASSERT_EQ(1, tag_ro.size()); EXPECT_EQ(42, tag_ro.Get(0)); @@ -87,6 +104,16 @@ TEST(GetFeatureValuesInt64Test, CopyIterableToAField) { EXPECT_EQ(3, tag_ro.Get(2)); } +TEST(GetFeatureValuesFloatTest, ReadsASingleValueFromFeature) { + Feature feature; + feature.mutable_float_list()->add_value(3.14); + + auto values = GetFeatureValues(feature); + + ASSERT_EQ(1, values.size()); + EXPECT_NEAR(3.14, values.Get(0), kTolerance); +} + TEST(GetFeatureValuesFloatTest, ReadsASingleValue) { Example example; (*example.mutable_features()->mutable_feature())["tag"] @@ -99,6 +126,15 @@ TEST(GetFeatureValuesFloatTest, ReadsASingleValue) { EXPECT_NEAR(3.14, tag.Get(0), kTolerance); } +TEST(GetFeatureValuesFloatTest, WritesASingleValueToFeature) { + Feature feature; + + GetFeatureValues(&feature)->Add(3.14); + + ASSERT_EQ(1, feature.float_list().value_size()); + EXPECT_NEAR(3.14, feature.float_list().value(0), kTolerance); +} + TEST(GetFeatureValuesFloatTest, WritesASingleValue) { Example example; @@ -114,6 +150,20 @@ TEST(GetFeatureValuesFloatTest, WritesASingleValue) { TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistence) { Example example; + GetFeatureValues("tag", &example)->Add(42); + ASSERT_FALSE(HasFeature("tag", example)); + + GetFeatureValues("tag", &example)->Add(3.14); + + EXPECT_TRUE(HasFeature("tag", example)); + auto tag_ro = GetFeatureValues("tag", example); + ASSERT_EQ(1, tag_ro.size()); + EXPECT_NEAR(3.14, tag_ro.Get(0), kTolerance); +} + +TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistenceForDeprecatedMethod) { + Example example; + GetFeatureValues("tag", &example)->Add(42); ASSERT_FALSE(ExampleHasFeature("tag", example)); @@ -125,6 +175,16 @@ TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistence) { EXPECT_NEAR(3.14, tag_ro.Get(0), kTolerance); } +TEST(GetFeatureValuesStringTest, ReadsASingleValueFromFeature) { + Feature feature; + feature.mutable_bytes_list()->add_value("FOO"); + + auto values = GetFeatureValues(feature); + + ASSERT_EQ(1, values.size()); + EXPECT_EQ("FOO", values.Get(0)); +} + TEST(GetFeatureValuesStringTest, ReadsASingleValue) { Example example; (*example.mutable_features()->mutable_feature())["tag"] @@ -137,6 +197,15 @@ TEST(GetFeatureValuesStringTest, ReadsASingleValue) { EXPECT_EQ("FOO", tag.Get(0)); } +TEST(GetFeatureValuesStringTest, WritesASingleValueToFeature) { + Feature feature; + + *GetFeatureValues(&feature)->Add() = "FOO"; + + ASSERT_EQ(1, feature.bytes_list().value_size()); + EXPECT_EQ("FOO", feature.bytes_list().value(0)); +} + TEST(GetFeatureValuesStringTest, WritesASingleValue) { Example example; @@ -148,15 +217,15 @@ TEST(GetFeatureValuesStringTest, WritesASingleValue) { example.features().feature().at("tag").bytes_list().value(0)); } -TEST(GetFeatureValuesBytesTest, CheckTypedFieldExistence) { +TEST(GetFeatureValuesStringTest, CheckTypedFieldExistence) { Example example; GetFeatureValues("tag", &example)->Add(42); - ASSERT_FALSE(ExampleHasFeature("tag", example)); + ASSERT_FALSE(HasFeature("tag", example)); *GetFeatureValues("tag", &example)->Add() = "FOO"; - EXPECT_TRUE(ExampleHasFeature("tag", example)); + EXPECT_TRUE(HasFeature("tag", example)); auto tag_ro = GetFeatureValues("tag", example); ASSERT_EQ(1, tag_ro.size()); EXPECT_EQ("FOO", tag_ro.Get(0)); @@ -228,5 +297,146 @@ TEST(AppendFeatureValuesTest, StringVariablesUsingInitializerList) { EXPECT_EQ("BAZ", tag_ro.Get(2)); } +TEST(SequenceExampleTest, ReadsASingleValueFromContext) { + SequenceExample se; + (*se.mutable_context()->mutable_feature())["tag"] + .mutable_int64_list() + ->add_value(42); + + auto values = GetFeatureValues("tag", se.context()); + + ASSERT_EQ(1, values.size()); + EXPECT_EQ(42, values.Get(0)); +} + +TEST(SequenceExampleTest, WritesASingleValueToContext) { + SequenceExample se; + + GetFeatureValues("tag", se.mutable_context())->Add(42); + + ASSERT_EQ(1, se.context().feature().at("tag").int64_list().value_size()); + EXPECT_EQ(42, se.context().feature().at("tag").int64_list().value(0)); +} + +TEST(SequenceExampleTest, AppendFeatureValuesToContextSingleArg) { + SequenceExample se; + + AppendFeatureValues({1.1, 2.2, 3.3}, "tag", se.mutable_context()); + + auto tag_ro = GetFeatureValues("tag", se.context()); + ASSERT_EQ(3, tag_ro.size()); + EXPECT_NEAR(1.1, tag_ro.Get(0), kTolerance); + EXPECT_NEAR(2.2, tag_ro.Get(1), kTolerance); + EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance); +} + +TEST(SequenceExampleTest, CheckTypedFieldExistence) { + SequenceExample se; + + GetFeatureValues("tag", se.mutable_context())->Add(3.14); + ASSERT_FALSE(HasFeature("tag", se.context())); + + GetFeatureValues("tag", se.mutable_context())->Add(42); + + EXPECT_TRUE(HasFeature("tag", se.context())); + auto tag_ro = GetFeatureValues("tag", se.context()); + ASSERT_EQ(1, tag_ro.size()); + EXPECT_EQ(42, tag_ro.Get(0)); +} + +TEST(SequenceExampleTest, ReturnsExistingFeatureLists) { + SequenceExample se; + (*se.mutable_feature_lists()->mutable_feature_list())["tag"] + .mutable_feature() + ->Add(); + + auto feature = GetFeatureList("tag", se); + + ASSERT_EQ(1, feature.size()); +} + +TEST(SequenceExampleTest, CreatesNewFeatureLists) { + SequenceExample se; + + GetFeatureList("tag", &se)->Add(); + + EXPECT_EQ(1, se.feature_lists().feature_list().at("tag").feature_size()); +} + +TEST(SequenceExampleTest, CheckFeatureListExistence) { + SequenceExample se; + ASSERT_FALSE(HasFeatureList("tag", se)); + + GetFeatureList("tag", &se)->Add(); + + ASSERT_TRUE(HasFeatureList("tag", se)); +} + +TEST(SequenceExampleTest, AppendFeatureValuesWithInitializerList) { + SequenceExample se; + + AppendFeatureValues({1, 2, 3}, "ids", se.mutable_context()); + AppendFeatureValues({"cam1-0", "cam2-0"}, + GetFeatureList("images", &se)->Add()); + AppendFeatureValues({"cam1-1", "cam2-2"}, + GetFeatureList("images", &se)->Add()); + + EXPECT_EQ(se.DebugString(), + "context {\n" + " feature {\n" + " key: \"ids\"\n" + " value {\n" + " int64_list {\n" + " value: 1\n" + " value: 2\n" + " value: 3\n" + " }\n" + " }\n" + " }\n" + "}\n" + "feature_lists {\n" + " feature_list {\n" + " key: \"images\"\n" + " value {\n" + " feature {\n" + " bytes_list {\n" + " value: \"cam1-0\"\n" + " value: \"cam2-0\"\n" + " }\n" + " }\n" + " feature {\n" + " bytes_list {\n" + " value: \"cam1-1\"\n" + " value: \"cam2-2\"\n" + " }\n" + " }\n" + " }\n" + " }\n" + "}\n"); +} + +TEST(SequenceExampleTest, AppendFeatureValuesWithVectors) { + SequenceExample se; + + std::vector readings{1.0, 2.5, 5.0}; + AppendFeatureValues(readings, GetFeatureList("movie_ratings", &se)->Add()); + + EXPECT_EQ(se.DebugString(), + "feature_lists {\n" + " feature_list {\n" + " key: \"movie_ratings\"\n" + " value {\n" + " feature {\n" + " float_list {\n" + " value: 1\n" + " value: 2.5\n" + " value: 5\n" + " }\n" + " }\n" + " }\n" + " }\n" + "}\n"); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc index e7092f549b21e9a3c72950bfb637b966936ef5ab..f5dadf76daf8d351e509c4ae538b31abf00d9566 100644 --- a/tensorflow/core/framework/allocator.cc +++ b/tensorflow/core/framework/allocator.cc @@ -117,16 +117,6 @@ class CPUAllocator : public Allocator { TF_DISALLOW_COPY_AND_ASSIGN(CPUAllocator); }; -namespace { -Allocator* MakeCpuAllocator() { - Allocator* allocator = new CPUAllocator; - if (cpu_allocator_collect_full_stats || LogMemory::IsEnabled()) { - allocator = new TrackingAllocator(allocator, true); - } - return allocator; -} -} // namespace - Allocator* cpu_allocator() { static Allocator* cpu_alloc = AllocatorRegistry::Global()->GetAllocator(); if (cpu_allocator_collect_full_stats && !cpu_alloc->TracksAllocationSizes()) { diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc index 1cbed62939fc45b02e5cc76a92d852cc26c2fb3f..9da4828bbad7b6333336dd1215441f5c5f62151a 100644 --- a/tensorflow/core/framework/cancellation.cc +++ b/tensorflow/core/framework/cancellation.cc @@ -23,7 +23,9 @@ namespace tensorflow { const CancellationToken CancellationManager::kInvalidToken = -1; CancellationManager::CancellationManager() - : is_cancelling_(false), is_cancelled_(0), next_cancellation_token_(0) {} + : is_cancelling_(false), + is_cancelled_(false), + next_cancellation_token_(0) {} void CancellationManager::StartCancel() { gtl::FlatMap callbacks_to_run; diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 4c65c86e5e28777c8f8b2b29db27bc72e344cc16..ab21f47282e2951bded1e0d5932385c148998e2f 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -206,15 +206,28 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) { Status FusedConvBiasActivationShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(Conv2DShape(c)); - ShapeHandle bias_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &bias_shape)); - DimensionHandle bias_dim = c->Dim(bias_shape, 0); + string data_format_str, filter_format_str; + TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); + TF_RETURN_IF_ERROR(c->GetAttr("filter_format", &filter_format_str)); + TensorFormat data_format; + FormatFromString(data_format_str, &data_format); + FilterTensorFormat filter_format; + FilterFormatFromString(filter_format_str, &filter_format); + + constexpr int num_spatial_dims = 2; + const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); ShapeHandle filter_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape)); - DimensionHandle output_depth_dim = c->Dim(filter_shape, 3); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape)); + DimensionHandle output_depth_dim = c->Dim( + filter_shape, GetFilterDimIndex(filter_format, 'O')); int64 output_depth_dim_val = c->Value(output_depth_dim); + + ShapeHandle bias_shape; + // Bias should be a 1-D tensor. + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bias_shape)); + DimensionHandle bias_dim = c->Dim(bias_shape, 0); int64 bias_dim_val = c->Value(bias_dim); if (output_depth_dim_val != bias_dim_val) { @@ -223,6 +236,54 @@ Status FusedConvBiasActivationShape(shape_inference::InferenceContext* c) { ") and bias dimension (", bias_dim_val, ") do not match."); } + // Check side input shape matches the output shape. + ShapeHandle side_input_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &side_input_shape)); + if (c->Rank(side_input_shape) > 1) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->Merge(side_input_shape, c->output(0), &unused)); + } + + return Status::OK(); +} + +Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format, + const ShapeHandle shape_handle, + const string& tensor_name, + shape_inference::InferenceContext* c) { + if (tensor_format == FORMAT_NCHW_VECT_C) { + // Check that the vect dim has size 4. + const int num_dims = c->Rank(shape_handle); + DimensionHandle vect_dim = c->Dim( + shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format)); + DimensionHandle unused_vect_dim; + TF_RETURN_IF_ERROR(c->WithValue(vect_dim, 4, &unused_vect_dim)); + } + + return Status::OK(); +} + +// Returns a new shape with the specified dims arranged in the specified +// format. The returned value is owned by this context. +// Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth. +Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, + const std::vector& spatial, + DimensionOrConstant C, ShapeHandle* out, + shape_inference::InferenceContext* c) { + const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format); + std::vector dims_actual(num_dims); + dims_actual[GetTensorBatchDimIndex(num_dims, format)] = c->MakeDim(N); + int outer_c_index = GetTensorFeatureDimIndex(num_dims, format); + dims_actual[outer_c_index] = c->MakeDim(C); + if (format == FORMAT_NCHW_VECT_C) { + dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] = + c->MakeDim(4); + } + for (int spatial_dim = 0; spatial_dim < spatial.size(); spatial_dim++) { + dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] = + c->MakeDim(spatial[spatial_dim]); + } + *out = c->MakeShape(dims_actual); return Status::OK(); } @@ -283,24 +344,38 @@ Status ShapeFromDimensions(DimensionHandle batch_dim, } Status Conv2DShape(shape_inference::InferenceContext* c) { - string data_format_str; - Status s = c->GetAttr("data_format", &data_format_str); - if (!s.ok()) { + string data_format_str, filter_format_str; + if (!c->GetAttr("data_format", &data_format_str).ok()) { data_format_str = "NHWC"; } + if (!c->GetAttr("filter_format", &filter_format_str).ok()) { + filter_format_str = "HWIO"; + } TensorFormat data_format; if (!FormatFromString(data_format_str, &data_format)) { return errors::InvalidArgument("Invalid data format string: ", data_format_str); } + FilterTensorFormat filter_format; + if (!FilterFormatFromString(filter_format_str, &filter_format)) { + return errors::InvalidArgument("Invalid filter format string: ", + filter_format_str); + } + + constexpr int num_spatial_dims = 2; + const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); + ShapeHandle conv_input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape)); + TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape( + data_format, conv_input_shape, "conv_input", c)); - const int rank = GetTensorDimsFromSpatialDims(2, data_format); - ShapeHandle input_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C). ShapeHandle filter_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape)); + TF_RETURN_IF_ERROR( + CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c)); + std::vector strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); @@ -312,38 +387,33 @@ Status Conv2DShape(shape_inference::InferenceContext* c) { strides.size()); } - int32 stride_rows, stride_cols; - if (data_format == FORMAT_NCHW || data_format == FORMAT_NCHW_VECT_C) { - stride_rows = strides[2]; - stride_cols = strides[3]; - } else { - stride_rows = strides[1]; - stride_cols = strides[2]; - } + const int32 stride_rows = GetTensorDim(strides, data_format, 'H'); + const int32 stride_cols = GetTensorDim(strides, data_format, 'W'); DimensionHandle batch_size_dim; DimensionHandle input_depth_dim; gtl::InlinedVector input_spatial_dims(2); - TF_RETURN_IF_ERROR(DimensionsFromShape(input_shape, data_format, + TF_RETURN_IF_ERROR(DimensionsFromShape(conv_input_shape, data_format, &batch_size_dim, &input_spatial_dims, &input_depth_dim, c)); - DimensionHandle output_depth_dim, filter_rows_dim, filter_cols_dim, - filter_input_depth_dim; - // If the input format is NCHW_VECT_C, the filter format is assumed to be - // OIHW_VECT_I, otherwise it is assumed to be HWIO. - if (data_format == FORMAT_NCHW_VECT_C) { - output_depth_dim = c->Dim(filter_shape, 0); - TF_RETURN_IF_ERROR(c->Multiply(c->Dim(filter_shape, 1), - c->Dim(filter_shape, 4), - &filter_input_depth_dim)); - filter_rows_dim = c->Dim(filter_shape, 2); - filter_cols_dim = c->Dim(filter_shape, 3); + DimensionHandle output_depth_dim = c->Dim( + filter_shape, GetFilterDimIndex(filter_format, 'O')); + DimensionHandle filter_rows_dim = c->Dim( + filter_shape, GetFilterDimIndex(filter_format, 'H')); + DimensionHandle filter_cols_dim = c->Dim( + filter_shape, GetFilterDimIndex(filter_format, 'W')); + DimensionHandle filter_input_depth_dim; + if (filter_format == FORMAT_OIHW_VECT_I) { + TF_RETURN_IF_ERROR(c->Multiply( + c->Dim(filter_shape, + GetFilterDimIndex(filter_format, 'I')), + c->Dim(filter_shape, + GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)), + &filter_input_depth_dim)); } else { - filter_rows_dim = c->Dim(filter_shape, 0); - filter_cols_dim = c->Dim(filter_shape, 1); - filter_input_depth_dim = c->Dim(filter_shape, 2); - output_depth_dim = c->Dim(filter_shape, 3); + filter_input_depth_dim = c->Dim( + filter_shape, GetFilterDimIndex(filter_format, 'I')); } // Check that the input tensor and the filter tensor agree on the input @@ -519,18 +589,27 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) { } Status AvgPoolShape(shape_inference::InferenceContext* c) { + string data_format_str; + TensorFormat data_format; + Status s = c->GetAttr("data_format", &data_format_str); + if (s.ok()) { + FormatFromString(data_format_str, &data_format); + } else { + data_format = FORMAT_NHWC; + } + + const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; ShapeHandle input_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); - string data_format; - Status s = c->GetAttr("data_format", &data_format); + TF_RETURN_IF_ERROR( + CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); std::vector strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); if (strides.size() != 4) { return errors::InvalidArgument( - "AvgPool requires the stride attribute to contain 4 values, but " - "got: ", + "AvgPool requires the stride attribute to contain 4 values, but got: ", strides.size()); } @@ -542,31 +621,20 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) { kernel_sizes.size()); } - int32 stride_rows, stride_cols; - int32 kernel_rows, kernel_cols; - - if (s.ok() && data_format == "NCHW") { - // Canonicalize input shape to NHWC so the shape inference code below can - // process it. - auto dim = [&](char dimension) { - return c->Dim(input_shape, GetTensorDimIndex<2>(FORMAT_NCHW, dimension)); - }; - input_shape = c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('C')}}); - stride_rows = strides[2]; - stride_cols = strides[3]; - kernel_rows = kernel_sizes[2]; - kernel_cols = kernel_sizes[3]; - } else { - stride_rows = strides[1]; - stride_cols = strides[2]; - kernel_rows = kernel_sizes[1]; - kernel_cols = kernel_sizes[2]; - } - - DimensionHandle batch_size_dim = c->Dim(input_shape, 0); - DimensionHandle in_rows_dim = c->Dim(input_shape, 1); - DimensionHandle in_cols_dim = c->Dim(input_shape, 2); - DimensionHandle output_depth_dim = c->Dim(input_shape, 3); + int32 stride_rows = GetTensorDim(strides, data_format, 'H'); + int32 stride_cols = GetTensorDim(strides, data_format, 'W'); + int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); + int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); + + constexpr int num_spatial_dims = 2; + DimensionHandle batch_size_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'N')); + DimensionHandle in_rows_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'H')); + DimensionHandle in_cols_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'W')); + DimensionHandle depth_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'C')); Padding padding; TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); @@ -582,31 +650,35 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) { c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); ShapeHandle output_shape; - if (data_format == "NCHW") { - output_shape = c->MakeShape( - {batch_size_dim, output_depth_dim, output_rows, output_cols}); - } else { - output_shape = c->MakeShape( - {batch_size_dim, output_rows, output_cols, output_depth_dim}); - } - + TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, + {output_rows, output_cols}, depth_dim, + &output_shape, c)); c->set_output(0, output_shape); return Status::OK(); } Status MaxPoolShape(shape_inference::InferenceContext* c) { + string data_format_str; + TensorFormat data_format; + Status s = c->GetAttr("data_format", &data_format_str); + if (s.ok()) { + FormatFromString(data_format_str, &data_format); + } else { + data_format = FORMAT_NHWC; + } + + const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; ShapeHandle input_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); - string data_format; - Status s = c->GetAttr("data_format", &data_format); + TF_RETURN_IF_ERROR( + CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); std::vector strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); if (strides.size() != 4) { return errors::InvalidArgument( - "MaxPool requires the stride attribute to contain 4 values, but " - "got: ", + "MaxPool requires the stride attribute to contain 4 values, but got: ", strides.size()); } @@ -618,35 +690,22 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) { kernel_sizes.size()); } - int32 stride_rows, stride_cols, stride_depth; - int32 kernel_rows, kernel_cols, kernel_depth; - - if (s.ok() && data_format == "NCHW") { - // Canonicalize input shape to NHWC so the shape inference code below can - // process it. - auto dim = [&](char dimension) { - return c->Dim(input_shape, GetTensorDimIndex<2>(FORMAT_NCHW, dimension)); - }; - input_shape = c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('C')}}); - stride_depth = strides[1]; - stride_rows = strides[2]; - stride_cols = strides[3]; - kernel_depth = kernel_sizes[1]; - kernel_rows = kernel_sizes[2]; - kernel_cols = kernel_sizes[3]; - } else { - stride_rows = strides[1]; - stride_cols = strides[2]; - stride_depth = strides[3]; - kernel_rows = kernel_sizes[1]; - kernel_cols = kernel_sizes[2]; - kernel_depth = kernel_sizes[3]; - } - - DimensionHandle batch_size_dim = c->Dim(input_shape, 0); - DimensionHandle in_rows_dim = c->Dim(input_shape, 1); - DimensionHandle in_cols_dim = c->Dim(input_shape, 2); - DimensionHandle in_depth_dim = c->Dim(input_shape, 3); + int32 stride_depth = GetTensorDim(strides, data_format, 'C'); + int32 stride_rows = GetTensorDim(strides, data_format, 'H'); + int32 stride_cols = GetTensorDim(strides, data_format, 'W'); + int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C'); + int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); + int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); + + constexpr int num_spatial_dims = 2; + DimensionHandle batch_size_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'N')); + DimensionHandle in_rows_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'H')); + DimensionHandle in_cols_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'W')); + DimensionHandle in_depth_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'C')); Padding padding; TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); @@ -660,26 +719,30 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth)); - output_shape = - c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth}); - if (data_format == "NCHW") { - // Convert output shape back to expected NCHW data format. - auto dim = [&](char dimension) { - return c->Dim(output_shape, GetTensorDimIndex<2>(FORMAT_NHWC, dimension)); - }; - output_shape = c->MakeShape({{dim('N'), dim('C'), dim('0'), dim('1')}}); - } + TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, + {output_rows, output_cols}, + output_depth, &output_shape, c)); c->set_output(0, output_shape); return Status::OK(); } Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { + string data_format_str; + TensorFormat data_format; + Status s = c->GetAttr("data_format", &data_format_str); + if (s.ok()) { + FormatFromString(data_format_str, &data_format); + } else { + data_format = FORMAT_NHWC; + } + + const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; ShapeHandle input_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); - string data_format; - Status s = c->GetAttr("data_format", &data_format); + TF_RETURN_IF_ERROR( + CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); std::vector kernel_sizes; std::vector strides; @@ -704,7 +767,8 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { } kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements()); auto kernel_sizes_vec = kernel_sizes_tensor->flat(); - std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), kernel_sizes.begin()); + std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), + kernel_sizes.begin()); const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1); if (strides_tensor == nullptr) { @@ -728,35 +792,22 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { kernel_sizes.size()); } - int32 stride_rows, stride_cols, stride_depth; - int32 kernel_rows, kernel_cols, kernel_depth; - - if (s.ok() && data_format == "NCHW") { - // Canonicalize input shape to NHWC so the shape inference code below can - // process it. - auto dim = [&](char dimension) { - return c->Dim(input_shape, GetTensorDimIndex<2>(FORMAT_NCHW, dimension)); - }; - input_shape = c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('C')}}); - stride_depth = strides[1]; - stride_rows = strides[2]; - stride_cols = strides[3]; - kernel_depth = kernel_sizes[1]; - kernel_rows = kernel_sizes[2]; - kernel_cols = kernel_sizes[3]; - } else { - stride_rows = strides[1]; - stride_cols = strides[2]; - stride_depth = strides[3]; - kernel_rows = kernel_sizes[1]; - kernel_cols = kernel_sizes[2]; - kernel_depth = kernel_sizes[3]; - } - - DimensionHandle batch_size_dim = c->Dim(input_shape, 0); - DimensionHandle in_rows_dim = c->Dim(input_shape, 1); - DimensionHandle in_cols_dim = c->Dim(input_shape, 2); - DimensionHandle in_depth_dim = c->Dim(input_shape, 3); + int32 stride_depth = GetTensorDim(strides, data_format, 'C'); + int32 stride_rows = GetTensorDim(strides, data_format, 'H'); + int32 stride_cols = GetTensorDim(strides, data_format, 'W'); + int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C'); + int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); + int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); + + constexpr int num_spatial_dims = 2; + DimensionHandle batch_size_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'N')); + DimensionHandle in_rows_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'H')); + DimensionHandle in_cols_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'W')); + DimensionHandle in_depth_dim = c->Dim( + input_shape, GetTensorDimIndex(data_format, 'C')); Padding padding; TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); @@ -770,15 +821,9 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth)); - output_shape = - c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth}); - if (data_format == "NCHW") { - // Convert output shape back to expected NCHW data format. - auto dim = [&](char dimension) { - return c->Dim(output_shape, GetTensorDimIndex<2>(FORMAT_NHWC, dimension)); - }; - output_shape = c->MakeShape({{dim('N'), dim('C'), dim('0'), dim('1')}}); - } + TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, + {output_rows, output_cols}, + output_depth, &output_shape, c)); c->set_output(0, output_shape); return Status::OK(); diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index 416478f8542741f5be1baf8cd5e8d20ecc6a9fbc..ec9746b2af1ed0da348fbe7459c5d93d842b25d9 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/shape_inference_testutil.h" @@ -411,34 +412,35 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) { TEST(CommonShapeFnsTest, Conv2DShapeTest) { ShapeInferenceTestOp op("Conv2D"); auto set_op = [&op](const std::vector& strides, const string& padding, - const string& data_format) { + const string& data_format, const string& filter_format) { TF_CHECK_OK(NodeDefBuilder("test", "Conv2D") .Input("input", 0, DT_FLOAT) .Input("filter", 0, DT_FLOAT) .Attr("strides", strides) .Attr("padding", padding) .Attr("data_format", data_format) + .Attr("filter_format", filter_format) .Finalize(&op.node_def)); }; // 1x1 filter - set_op({{1, 1, 1, 1}}, "VALID", "NHWC"); + set_op({{1, 1, 1, 1}}, "VALID", "NHWC", "HWIO"); INFER_OK(op, "[1,2,2,1];[1,1,1,1]", "[d0_0,2,2,d1_3]"); // 2x2 filter - set_op({{1, 1, 1, 1}}, "VALID", "NHWC"); + set_op({{1, 1, 1, 1}}, "VALID", "NHWC", "HWIO"); INFER_OK(op, "[1,2,2,1];[2,2,1,1]", "[d0_0,1,1,d1_3]"); // 3x3 input, 1x1 filter, 2x2 stride - set_op({{1, 2, 2, 1}}, "VALID", "NHWC"); + set_op({{1, 2, 2, 1}}, "VALID", "NHWC", "HWIO"); INFER_OK(op, "[1,3,3,1];[1,1,1,1]", "[d0_0,2,2,d1_3]"); // 3x3 input, 1x1 filter, 2x1 stride - set_op({{1, 2, 1, 1}}, "VALID", "NHWC"); + set_op({{1, 2, 1, 1}}, "VALID", "NHWC", "HWIO"); INFER_OK(op, "[1,3,3,1];[1,1,1,1]", "[d0_0,2,3,d1_3]"); // 4x4 input, 2x1 filter, 1x2 stride - set_op({{1, 1, 2, 1}}, "VALID", "NHWC"); + set_op({{1, 1, 2, 1}}, "VALID", "NHWC", "HWIO"); INFER_OK(op, "[1,4,4,1];[2,1,1,1]", "[d0_0,3,2,d1_3]"); // Invalid rank for input @@ -460,77 +462,76 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) { // Tests for NCHW // 1x1 filter - set_op({{1, 1, 1, 1}}, "VALID", "NCHW"); + set_op({{1, 1, 1, 1}}, "VALID", "NCHW", "HWIO"); INFER_OK(op, "[1,1,2,2];[1,1,1,1]", "[d0_0,d1_3,2,2]"); // 2x2 filter - set_op({{1, 1, 1, 1}}, "VALID", "NCHW"); + set_op({{1, 1, 1, 1}}, "VALID", "NCHW", "HWIO"); INFER_OK(op, "[1,1,2,2];[2,2,1,1]", "[d0_0,d1_3,1,1]"); // 3x3 input, 1x1 filter, 2x2 stride - set_op({{1, 1, 2, 2}}, "VALID", "NCHW"); + set_op({{1, 1, 2, 2}}, "VALID", "NCHW", "HWIO"); INFER_OK(op, "[1,1,3,3];[1,1,1,1]", "[d0_0,d1_3,2,2]"); // 3x3 input, 1x1 filter, 2x1 stride - set_op({{1, 1, 2, 1}}, "VALID", "NCHW"); + set_op({{1, 1, 2, 1}}, "VALID", "NCHW", "HWIO"); INFER_OK(op, "[1,1,3,3];[1,1,1,1]", "[d0_0,d1_3,2,3]"); // 4x4 input, 2x1 filter, 1x2 stride - set_op({{1, 1, 1, 2}}, "VALID", "NCHW"); + set_op({{1, 1, 1, 2}}, "VALID", "NCHW", "HWIO"); INFER_OK(op, "[1,1,4,4];[2,1,1,1]", "[d0_0,d1_3,3,2]"); // Tests for NCHW_VECT_C // 1x1 filter - set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C"); + set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I"); INFER_OK(op, "[1,1,2,2,4];[4,1,1,1,4]", "[d0_0,1,2,2,4]"); // 2x2 filter - set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C"); + set_op({{1, 1, 1, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I"); INFER_OK(op, "[1,1,2,2,4];[4,1,2,2,4]", "[d0_0,1,1,1,4]"); // 3x3 input, 1x1 filter, 2x2 stride - set_op({{1, 1, 2, 2}}, "VALID", "NCHW_VECT_C"); + set_op({{1, 1, 2, 2}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I"); INFER_OK(op, "[1,1,3,3,4];[8,1,1,1,4]", "[d0_0,2,2,2,4]"); // 3x3 input, 1x1 filter, 2x1 stride - set_op({{1, 1, 2, 1}}, "VALID", "NCHW_VECT_C"); + set_op({{1, 1, 2, 1}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I"); INFER_OK(op, "[1,1,3,3,4];[4,1,1,1,4]", "[d0_0,1,2,3,4]"); // 4x4 input, 2x1 filter, 1x2 stride - set_op({{1, 1, 1, 2}}, "VALID", "NCHW_VECT_C"); + set_op({{1, 1, 1, 2}}, "VALID", "NCHW_VECT_C", "OIHW_VECT_I"); INFER_OK(op, "[1,1,4,4,4];[4,1,2,1,4]", "[d0_0,1,3,2,4]"); // Some tests for "SAME" padding // 4x4 input, 1x1 filter, 1x1 stride - set_op({{1, 1, 1, 1}}, "SAME", "NHWC"); + set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO"); INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,d0_1,d0_2,d1_3]"); // 3x3 input, 2x2 filter, 1x1 stride - set_op({{1, 1, 1, 1}}, "SAME", "NHWC"); + set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO"); INFER_OK(op, "[1,3,3,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]"); // 4x4 input, 2x2 filter, 2x2 stride - set_op({{1, 2, 2, 1}}, "SAME", "NHWC"); + set_op({{1, 2, 2, 1}}, "SAME", "NHWC", "HWIO"); INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,2,2,d1_3]"); // 4x4 input, 2x2 filter, 1x1 stride - set_op({{1, 1, 1, 1}}, "SAME", "NHWC"); + set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO"); INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]"); // With stride 1x1 and SAME, unknown dims don't matter - filter dims except // for output channels are ignored for output, so all inputs are carried // through to output. - set_op({{1, 1, 1, 1}}, "SAME", "NHWC"); + set_op({{1, 1, 1, 1}}, "SAME", "NHWC", "HWIO"); INFER_OK(op, "[1,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); INFER_OK(op, "[1,?,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); INFER_OK(op, "[1,4,?,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); INFER_OK(op, "[1,4,4,?];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); - INFER_OK(op, "[1,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); - INFER_OK(op, "[1,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); + INFER_OK(op, "[?,4,4,1];[?,?,?,?]", "[d0_0,d0_1,d0_2,d1_3]"); // With stride != 1, the input HW dims are divided to produce output dims. - set_op({{1, 2, 2, 1}}, "SAME", "NHWC"); + set_op({{1, 2, 2, 1}}, "SAME", "NHWC", "HWIO"); INFER_OK(op, "[?,4,4,1];[?,?,?,?]", "[d0_0,2,2,d1_3]"); INFER_OK(op, "[1,?,4,1];[?,?,?,?]", "[d0_0,?,2,d1_3]"); INFER_OK(op, "[1,4,?,1];[?,?,?,?]", "[d0_0,2,?,d1_3]"); @@ -696,8 +697,15 @@ TEST(CommonShapeFnsTest, AvgPool2DShapeTest) { set_op({{1, 1, 1, 2}}, {1, 1, 2, 1}, "VALID", "NCHW"); INFER_OK(op, "[1,1,4,4]", "[d0_0,d0_1,3,2]"); + // 5x7 input, 2x2 ksize, 1x1 stride, NCHW_VECT_C test + set_op({{1, 1, 1, 1}}, {1, 1, 2, 2}, "VALID", "NCHW_VECT_C"); + INFER_OK(op, "[2,3,5,7,4]", "[d0_0,d0_1,4,6,4]"); + INFER_OK(op, "[5,7,?,?,4]", "[d0_0,d0_1,?,?,4]"); + INFER_OK(op, "[?,?,?,?,4]", "[d0_0,d0_1,?,?,4]"); + INFER_ERROR("Dimension must be 4 but is 3", op, "[2,5,7,11,3]"); + // Invalid rank for input - INFER_ERROR("must be rank 4", op, "[4,4]"); + INFER_ERROR("Shape must be rank", op, "[4,4]"); } TEST(CommonShapeFnsTest, MaxPool2DShapeTest) { @@ -725,6 +733,55 @@ TEST(CommonShapeFnsTest, MaxPool2DShapeTest) { // depth 3 stride, 1x1x1 filter, NCHW set_op({1, 3, 1, 1}, {1, 1, 1, 1}, "VALID", "NCHW"); INFER_OK(op, "[1,7,5,5]", "[d0_0,3,5,5]"); + + // 5x7 input, 2x2 ksize, 1x1 stride, NCHW_VECT_C tests + set_op({{1, 1, 1, 1}}, {1, 1, 2, 2}, "SAME", "NCHW_VECT_C"); + INFER_OK(op, "[2,3,5,7,4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_OK(op, "[5,7,?,?,4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_OK(op, "[?,?,?,?,4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_ERROR("Dimension must be 4 but is 8", op, "[2,3,5,7,8]"); +} + +TEST(CommonShapeFnsTest, MaxPoolV22DShapeTest) { + ShapeInferenceTestOp op("MaxPoolV2"); + Tensor ksizes_tensor, strides_tensor; + auto set_op = [&op, &ksizes_tensor, &strides_tensor]( + const std::vector& strides, + const std::vector& ksizes, const string& padding, + const string& data_format) { + TF_CHECK_OK(NodeDefBuilder("test", "MaxPoolV2") + .Input("input", 0, DT_FLOAT) + .Input("ksize", 1, DT_INT32) + .Input("strides", 2, DT_INT32) + .Attr("padding", padding) + .Attr("data_format", data_format) + .Finalize(&op.node_def)); + ksizes_tensor = test::AsTensor(ksizes); + op.input_tensors.resize(3); + op.input_tensors[0] = nullptr; + op.input_tensors[1] = &ksizes_tensor; + strides_tensor = test::AsTensor(strides); + op.input_tensors[2] = &strides_tensor; + }; + + // Most of the functionality is tested by conv-like shapes, + // so we check the very-specific maxpooling features here, + // namely depthwise kernel and striding. + + // all 1 strides, depth 2 filter + set_op({1, 1, 1, 1}, {1, 1, 1, 2}, "VALID", "NHWC"); + INFER_OK(op, "[1,2,2,2];[4];[4]", "[d0_0,2,2,1]"); + + // depth 3 stride, 1x1x1 filter, NCHW + set_op({1, 3, 1, 1}, {1, 1, 1, 1}, "VALID", "NCHW"); + INFER_OK(op, "[1,7,5,5];[4];[4]", "[d0_0,3,5,5]"); + + // 5x7 input, 2x2 ksize, 1x1 stride, NCHW_VECT_C tests + set_op({{1, 1, 1, 1}}, {1, 1, 2, 2}, "SAME", "NCHW_VECT_C"); + INFER_OK(op, "[2,3,5,7,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_OK(op, "[5,7,?,?,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_OK(op, "[?,?,?,?,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_ERROR("Dimension must be 4 but is 8", op, "[2,3,5,7,8];[4];[4]"); } TEST(CommonShapeFnsTest, Pool3DShapeTest) { diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 0fdc2c820c77dd9eb78aa167137dfda435ccae87..b788d6b77785a55fc8ecfe6f1efb8bc963a0c960 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" +#include #include #include #include @@ -271,12 +272,17 @@ class FunctionInstantiationHelper { int nid = -1; const string node_name = input.substr(1); const string node_colon = node_name + ":"; - for (const auto& p : index_) { - if (p.first == node_name || - tensorflow::StringPiece(p.first).starts_with(node_colon)) { - nid = p.second.nid; + const string node_colon_bound = node_name + ";"; + // index_ is a map sorted lexicographically, so the key we are looking for + // must lie in the range [node_name, node_colon_bound). + auto it = index_.lower_bound(node_name); + while (it != index_.end() && it->first <= node_colon_bound) { + if (it->first == node_name || + tensorflow::StringPiece(it->first).starts_with(node_colon)) { + nid = it->second.nid; break; } + ++it; } if (nid == -1) { return errors::InvalidArgument("input[", i, "] == '", input, @@ -421,7 +427,7 @@ class FunctionInstantiationHelper { GetFunctionSignature get_function_; InstantiationResult& result_; // A small index for all names that can be used as a node's input arguments. - std::unordered_map index_; + std::map index_; // This contains information about a node in the new graph including the node // names and input nodes' indexes. struct NodeInfo { @@ -908,6 +914,13 @@ const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const { } Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) { + bool added; + return AddFunctionDefHelper(fdef, &added); +} + +Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef, + bool* added) { + *added = false; std::unique_ptr* entry = &function_defs_[fdef.signature().name()]; if (*entry != nullptr) { @@ -927,10 +940,18 @@ Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) { "' because an op with the same name already exists."); } entry->reset(new FunctionDefAndOpRegistration(fdef)); + *added = true; return Status::OK(); } Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) { + bool added; + return AddGradientDefHelper(grad, &added); +} + +Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad, + bool* added) { + *added = false; string* entry = &func_grad_[grad.function_name()]; if (!entry->empty()) { if (*entry != grad.gradient_func()) { @@ -943,35 +964,98 @@ Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) { return Status::OK(); } *entry = grad.gradient_func(); + *added = true; return Status::OK(); } -// TODO(skyewm): don't modify FunctionLibraryDefinition in case of error Status FunctionLibraryDefinition::AddLibrary( const FunctionLibraryDefinition& other) { + // Remember the funcs and grads that we added successfully so that + // we can roll them back on error. + std::vector funcs; + std::vector funcs_with_grads; + Status s; + bool added; for (auto iter : other.function_defs_) { - TF_RETURN_IF_ERROR(AddFunctionDef(iter.second->fdef)); + s = AddFunctionDefHelper(iter.second->fdef, &added); + if (!s.ok()) { + Remove(funcs, funcs_with_grads); + return s; + } + if (added) { + funcs.push_back(iter.second->fdef.signature().name()); + } } for (auto iter : other.func_grad_) { GradientDef grad; grad.set_function_name(iter.first); grad.set_gradient_func(iter.second); - TF_RETURN_IF_ERROR(AddGradientDef(grad)); + s = AddGradientDefHelper(grad, &added); + if (!s.ok()) { + Remove(funcs, funcs_with_grads); + return s; + } + if (added) { + funcs_with_grads.push_back(grad.function_name()); + } } return Status::OK(); } Status FunctionLibraryDefinition::AddLibrary( const FunctionDefLibrary& lib_def) { + // Remember the funcs and grads that we added successfully so that + // we can roll them back on error. + std::vector funcs; + std::vector funcs_with_grads; + Status s; + bool added; for (const FunctionDef& fdef : lib_def.function()) { - TF_RETURN_IF_ERROR(AddFunctionDef(fdef)); + s = AddFunctionDefHelper(fdef, &added); + if (!s.ok()) { + Remove(funcs, funcs_with_grads); + return s; + } + if (added) { + funcs.push_back(fdef.signature().name()); + } } for (const GradientDef& grad : lib_def.gradient()) { - TF_RETURN_IF_ERROR(AddGradientDef(grad)); + s = AddGradientDefHelper(grad, &added); + if (!s.ok()) { + Remove(funcs, funcs_with_grads); + return s; + } + if (added) { + funcs_with_grads.push_back(grad.function_name()); + } } return Status::OK(); } +void FunctionLibraryDefinition::RemoveFunction(const string& func) { + const auto& i = function_defs_.find(func); + DCHECK(i != function_defs_.end()); + function_defs_.erase(i); +} + +void FunctionLibraryDefinition::RemoveGradient(const string& func) { + const auto& i = func_grad_.find(func); + DCHECK(i != func_grad_.end()); + func_grad_.erase(i); +} + +void FunctionLibraryDefinition::Remove( + const std::vector& funcs, + const std::vector& funcs_with_grads) { + for (const string& f : funcs) { + RemoveFunction(f); + } + for (const string& f : funcs_with_grads) { + RemoveGradient(f); + } +} + string FunctionLibraryDefinition::FindGradient(const string& func) const { return gtl::FindWithDefault(func_grad_, func, ""); } diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 87bc5520f46bf21556d6713e4ff5bbcde3018041..317707644b3f724cc34c44e24e46670391d05835 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -292,20 +292,24 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // 'fdef' already exists in this function library. // If 'fdef' is successfully added to the library, it will be accessible // from 'LookUp' and included in the proto returned by 'ToProto'. + // This operation is atomic. Status AddFunctionDef(const FunctionDef& fdef); // Adds gradient definition 'grad' to this function library. // This is a no-op if 'grad' already exists in this function library. // If 'grad' is successfully added, it will be accessible via 'FindGradient' // and included in the proto returned by 'ToProto'. + // This operation is atomic. Status AddGradientDef(const GradientDef& grad); // Adds the functions and gradients in 'other' to this function library. // Duplicate functions and gradients are ignored. + // This operation is atomic. Status AddLibrary(const FunctionLibraryDefinition& other); // Adds the functions and gradients in 'lib_def' to this function library. // Duplicate functions and gradients are ignored. + // This operation is atomic. Status AddLibrary(const FunctionDefLibrary& lib_def); // If the gradient function for 'func' is specified explicitly in @@ -353,6 +357,11 @@ class FunctionLibraryDefinition : public OpRegistryInterface { OpRegistrationData op_registration_data; }; + // Same as AddFunctionDef/AddGradientDef except these methods set + // `added` to true if the `fdef`/`grad` were actually added to this. + Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added); + Status AddGradientDefHelper(const GradientDef& grad, bool* added); + const OpRegistryInterface* const default_registry_; gtl::FlatMap> function_defs_; @@ -361,6 +370,18 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // Helper function for GetAttr. Returns the FunctionDef* to get the // attr from. const FunctionDef* GetAttrImpl(const NodeDef& ndef) const; + + // Remove function `func` from the library. `func` must be in the library. + void RemoveFunction(const string& func); + + // Remove gradient of function `func` from the library. `func` must have + // a gradient. + void RemoveGradient(const string& func); + + // Remove all functions in `funcs` and all gradients of + // functions in `funcs_with_grads` from this library. + void Remove(const std::vector& funcs, + const std::vector& funcs_with_grads); }; // Forward declare. Defined in common_runtime/function.h diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index 8e15bf04ab9c02bfb0f28a16f37b3040e762fb3f..13955addb5ea25df112c94c4694e2b518d72cb73 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -1054,6 +1054,123 @@ TEST(FunctionLibraryDefinitionTest, AddLibrary) { TF_EXPECT_OK(lib_def.AddLibrary(lib_def)); } +GradientDef MakeGradDef(const string& f, const string& g) { + GradientDef grad; + grad.set_function_name(f); + grad.set_gradient_func(g); + return grad; +} + +TEST(FunctionLibraryDefinitionTest, AddLibrary_Atomic) { + // Create lib def containing two functions with equal names + FunctionDefLibrary proto; + const string x2_name = test::function::XTimesTwo().signature().name(); + const string x4_name = test::function::XTimesFour().signature().name(); + *proto.add_function() = test::function::XTimesTwo(); + FunctionDef fdef = test::function::XTimesFour(); + fdef.mutable_signature()->set_name(x2_name); + *proto.add_function() = fdef; + FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary()); + + // Try adding the two functions to lib_def + Status s = lib_def.AddLibrary(proto); + EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); + EXPECT_EQ( + "Cannot add function 'XTimesTwo' because a different function with " + "the same name already exists.", + s.error_message()); + + // Verify that none of the functions are added + EXPECT_TRUE(lib_def.Find(x2_name) == nullptr); + + // Fix the name in proto but add two gradient names for it + proto.mutable_function(1)->mutable_signature()->set_name(x4_name); + *proto.add_gradient() = MakeGradDef(x2_name, x4_name); + *proto.add_gradient() = MakeGradDef(x2_name, "SecondGradName"); + + // Try adding the library and check that nothing was added + s = lib_def.AddLibrary(proto); + EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); + EXPECT_EQ(s.error_message(), + "Cannot assign gradient function 'SecondGradName' to 'XTimesTwo' " + "because it already has gradient function 'XTimesFour'"); + EXPECT_TRUE(lib_def.Find(x2_name) == nullptr); + EXPECT_EQ(0, lib_def.ToProto().function_size()); + EXPECT_EQ(0, lib_def.ToProto().gradient_size()); +} + +TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_FuncConflict) { + const string x2_name = test::function::XTimesTwo().signature().name(); + const string x4_name = test::function::XTimesFour().signature().name(); + const string wx_name = test::function::WXPlusB().signature().name(); + + // Create FunctionLibraryDefinition with + // (func = XTimesTwo, grad = XTimesFour) + FunctionDefLibrary proto; + *proto.add_function() = test::function::XTimesTwo(); + *proto.add_gradient() = MakeGradDef(x2_name, x4_name); + FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); + EXPECT_EQ(1, lib_def.ToProto().function_size()); + EXPECT_EQ(1, lib_def.ToProto().gradient_size()); + + // Create FunctionLibraryDefinition with (func = WXPlusB, grad = XTimesTwo) + // and function (name = XTimesTwo, body = XTimeFour) + FunctionDefLibrary proto2; + *proto2.add_function() = test::function::WXPlusB(); + *proto2.add_gradient() = MakeGradDef(wx_name, x2_name); + *proto2.add_function() = test::function::XTimesFour(); + proto2.mutable_function(1)->mutable_signature()->set_name(x2_name); + FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2); + + // Verify that adding lib_def2 will fail because of function conflict + // and WXPlusB is not added. + Status s = lib_def.AddLibrary(lib_def2); + EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); + EXPECT_EQ( + "Cannot add function 'XTimesTwo' because a different function " + "with the same name already exists.", + s.error_message()); + EXPECT_TRUE(lib_def.Find(wx_name) == nullptr); + EXPECT_EQ(1, lib_def.ToProto().function_size()); + EXPECT_EQ(1, lib_def.ToProto().gradient_size()); +} + +TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_GradConflict) { + const string x2_name = test::function::XTimesTwo().signature().name(); + const string x4_name = test::function::XTimesFour().signature().name(); + const string wx_name = test::function::WXPlusB().signature().name(); + + // Create FunctionLibraryDefinition with + // (func = XTimesTwo, grad = XTimesFour) + FunctionDefLibrary proto; + *proto.add_function() = test::function::XTimesTwo(); + *proto.add_gradient() = MakeGradDef(x2_name, x4_name); + FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto); + EXPECT_EQ(1, lib_def.ToProto().function_size()); + EXPECT_EQ(1, lib_def.ToProto().gradient_size()); + + // Create FunctionLibraryDefinition with (func = WXPlusB, grad = XTimesTwo) + // and (func = XTimesTwo, grad = WXPlusB) + FunctionDefLibrary proto2; + *proto2.add_function() = test::function::WXPlusB(); + *proto2.add_gradient() = MakeGradDef(wx_name, x2_name); + *proto2.add_function() = test::function::XTimesTwo(); + *proto2.add_gradient() = MakeGradDef(x2_name, wx_name); + FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto2); + + // Verify that adding lib_def2 will fail because of gradient conflict + // and WXPlusB is not added. + Status s = lib_def.AddLibrary(lib_def2); + EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); + EXPECT_EQ( + "Cannot assign gradient function 'WXPlusB' to 'XTimesTwo'" + " because it already has gradient function 'XTimesFour'", + s.error_message()); + EXPECT_TRUE(lib_def.Find(wx_name) == nullptr); + EXPECT_EQ(1, lib_def.ToProto().function_size()); + EXPECT_EQ(1, lib_def.ToProto().gradient_size()); +} + TEST(FunctionLibraryDefinitionTest, ToProto) { FunctionDefLibrary proto1; *proto1.add_function() = test::function::XTimesTwo(); diff --git a/tensorflow/core/framework/summary.proto b/tensorflow/core/framework/summary.proto index ba490333310f6df511f134af6a050a231fe148f8..55879f87831eb968ee900e01697fbb99ba4cfe99 100644 --- a/tensorflow/core/framework/summary.proto +++ b/tensorflow/core/framework/summary.proto @@ -42,7 +42,7 @@ message SummaryMetadata { // The content to store for the plugin. The best practice is for this to be // a binary serialized protocol buffer. - string content = 2; + bytes content = 2; } // Data that associates a summary with a certain plugin. diff --git a/tensorflow/core/framework/tensor_testutil.h b/tensorflow/core/framework/tensor_testutil.h index ab224aa7188699bb3c459ef27268e5dbeff0d5f5..4c216a84f04389f9a2ef761aa6b6cec2c20a0be8 100644 --- a/tensorflow/core/framework/tensor_testutil.h +++ b/tensorflow/core/framework/tensor_testutil.h @@ -166,10 +166,11 @@ struct Expector { static void Equal(const Tensor& x, const Tensor& y) { ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); AssertSameTypeDims(x, y); - auto a = x.flat(); - auto b = y.flat(); - for (int i = 0; i < a.size(); ++i) { - ExpectEqual(a(i), b(i)); + const auto size = x.NumElements(); + const T* a = x.flat().data(); + const T* b = y.flat().data(); + for (int i = 0; i < size; ++i) { + ExpectEqual(a[i], b[i]); } } }; @@ -182,10 +183,11 @@ struct Expector { static void Equal(const Tensor& x, const Tensor& y) { ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); AssertSameTypeDims(x, y); - auto a = x.flat(); - auto b = y.flat(); - for (int i = 0; i < a.size(); ++i) { - ExpectEqual(a(i), b(i)); + const auto size = x.NumElements(); + const T* a = x.flat().data(); + const T* b = y.flat().data(); + for (int i = 0; i < size; ++i) { + ExpectEqual(a[i], b[i]); } } @@ -199,10 +201,11 @@ struct Expector { static void Near(const Tensor& x, const Tensor& y, const double abs_err) { ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); AssertSameTypeDims(x, y); - auto a = x.flat(); - auto b = y.flat(); - for (int i = 0; i < a.size(); ++i) { - Near(a(i), b(i), abs_err, i); + const auto size = x.NumElements(); + const T* a = x.flat().data(); + const T* b = y.flat().data(); + for (int i = 0; i < size; ++i) { + Near(a[i], b[i], abs_err, i); } } }; diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc index 11756c356aa1b390bccd578eb962cfd1330b4724..9cc7530459eac9e59406e5d780308b1032d0671c 100644 --- a/tensorflow/core/framework/variant_op_registry.cc +++ b/tensorflow/core/framework/variant_op_registry.cc @@ -88,7 +88,17 @@ bool DecodeUnaryVariant(Variant* variant) { if (decode_fn == nullptr) { return false; } - return (*decode_fn)(variant); + const string type_name = variant->TypeName(); + bool decoded = (*decode_fn)(variant); + if (!decoded) return false; + if (variant->TypeName() != type_name) { + LOG(ERROR) << "DecodeUnaryVariant: Variant type_name before decoding was: " + << type_name + << " but after decoding was: " << variant->TypeName() + << ". Treating this as a failure."; + return false; + } + return true; } // Add some basic registrations for use by others, e.g., for testing. @@ -101,15 +111,59 @@ string MaybeRemoveTFPrefix(const StringPiece& str) { } // namespace #define REGISTER_VARIANT_DECODE_TYPE(T) \ - REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, MaybeRemoveTFPrefix(TF_STR(T))); + REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T)); // No encode/decode registered for std::complex<> and Eigen::half // objects yet. -TF_CALL_INTEGRAL_TYPES(REGISTER_VARIANT_DECODE_TYPE); -TF_CALL_float(REGISTER_VARIANT_DECODE_TYPE); -TF_CALL_double(REGISTER_VARIANT_DECODE_TYPE); -TF_CALL_bool(REGISTER_VARIANT_DECODE_TYPE); +REGISTER_VARIANT_DECODE_TYPE(int); +REGISTER_VARIANT_DECODE_TYPE(float); +REGISTER_VARIANT_DECODE_TYPE(bool); +REGISTER_VARIANT_DECODE_TYPE(double); #undef REGISTER_VARIANT_DECODE_TYPE +// Special casing ZerosLikeFn per device. +UnaryVariantOpRegistry::VariantZerosLikeFn* +UnaryVariantOpRegistry::GetZerosLikeFn(const string& device, + const string& type_name) { + auto found = zeros_like_fns.find(std::make_pair(device, type_name)); + if (found == zeros_like_fns.end()) return nullptr; + return &found->second; +} + +void UnaryVariantOpRegistry::RegisterZerosLikeFn( + const string& device, const string& type_name, + const VariantZerosLikeFn& zeros_like_fn) { + CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantZerosLike"; + VariantZerosLikeFn* existing = GetZerosLikeFn(device, type_name); + CHECK_EQ(existing, nullptr) + << "Unary VariantZerosLikeFn for type_name: " << type_name + << " already registered for device type: " << device; + zeros_like_fns.insert( + std::pair, VariantZerosLikeFn>( + std::make_pair(device, type_name), zeros_like_fn)); +} + +namespace { + +template +Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t, + T* t_out) { + *t_out = T(0); + return Status::OK(); +} +} // namespace + +#define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \ + REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION( \ + DEVICE_CPU, T, TF_STR(T), ZerosLikeVariantPrimitiveType); + +// No zeros_like registered for std::complex<> or Eigen::half objects yet. +REGISTER_VARIANT_ZEROS_LIKE_TYPE(int); +REGISTER_VARIANT_ZEROS_LIKE_TYPE(float); +REGISTER_VARIANT_ZEROS_LIKE_TYPE(double); +REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool); + +#undef REGISTER_VARIANT_ZEROS_LIKE_TYPE + } // namespace tensorflow diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h index 389b049fa0194de4e14ed51abfab9da9e275b4cb..37e54f82c0ff4d7da15d1754c78f8ed4d18a347d 100644 --- a/tensorflow/core/framework/variant_op_registry.h +++ b/tensorflow/core/framework/variant_op_registry.h @@ -19,11 +19,13 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" namespace tensorflow { +class OpKernelContext; // A global UnaryVariantOpRegistry is used to hold callback functions // for different variant types. To be used by ShapeOp, RankOp, and // SizeOp, decoding, etc. @@ -32,6 +34,8 @@ class UnaryVariantOpRegistry { public: typedef std::function VariantShapeFn; typedef std::function VariantDecodeFn; + typedef std::function + VariantZerosLikeFn; // Add a shape lookup function to the registry. void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn); @@ -46,11 +50,29 @@ class UnaryVariantOpRegistry { // Returns nullptr if no decode function was found for the given TypeName. VariantDecodeFn* GetDecodeFn(const string& type_name); + // Add a zeros-like function to the registry. + void RegisterZerosLikeFn(const string& device, const string& type_name, + const VariantZerosLikeFn& zeros_like_fn); + + // Returns nullptr if no zeros-like function was found for the given + // device and TypeName. + VariantZerosLikeFn* GetZerosLikeFn(const string& device, + const string& type_name); + static UnaryVariantOpRegistry* Global(); private: std::unordered_map shape_fns; std::unordered_map decode_fns; + // Map std::pair to function. + struct PairHash { + template + std::size_t operator()(const std::pair& x) const { + return std::hash()(x.first) ^ std::hash()(x.second); + } + }; + std::unordered_map, VariantZerosLikeFn, PairHash> + zeros_like_fns; }; // Gets a TensorShape from a Tensor containing a scalar Variant. @@ -72,6 +94,28 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape); // bool DecodeUnaryVariant(Variant* variant); +// Sets *z_out = zeros_like(v). The variant v must have a registered +// ZerosLike function for the given Device. Returns an Internal error +// if v does not have a registered zeros_like function for this device, or if +// ZerosLike fails. +// +// REQUIRES: +// v_out is not null. +// +template +Status CreateZerosLikeVariant(OpKernelContext* ctx, const Variant& v, + Variant* v_out) { + const string& device = DeviceName::value; + UnaryVariantOpRegistry::VariantZerosLikeFn* zeros_like_fn = + UnaryVariantOpRegistry::Global()->GetZerosLikeFn(device, v.TypeName()); + if (zeros_like_fn == nullptr) { + return errors::Internal( + "No unary variant zeros_like function found for Variant type_name: ", + v.TypeName(), " for device type: ", device); + } + return (*zeros_like_fn)(ctx, v, v_out); +} + namespace variant_op_registry_fn_registration { template @@ -120,6 +164,34 @@ class UnaryVariantDecodeRegistration { } }; +template +class UnaryVariantZerosLikeRegistration { + typedef std::function + LocalVariantZerosLikeFn; + + public: + UnaryVariantZerosLikeRegistration( + const string& device, const string& type_name, + const LocalVariantZerosLikeFn& zeros_like_fn) { + auto wrapped_fn = [type_name, zeros_like_fn](OpKernelContext* ctx, + const Variant& v, + Variant* v_out) -> Status { + CHECK_NOTNULL(v_out); + *v_out = T(); + if (v.get() == nullptr) { + return errors::Internal( + "VariantZerosLikeFn: Could not access object, type_name: ", + type_name); + } + const T& t = *v.get(); + T* t_out = v_out->get(); + return zeros_like_fn(ctx, t, t_out); + }; + UnaryVariantOpRegistry::Global()->RegisterZerosLikeFn(device, type_name, + wrapped_fn); + } +}; + }; // namespace variant_op_registry_fn_registration // Register a unary shape variant function with the signature: @@ -151,6 +223,26 @@ class UnaryVariantDecodeRegistration { T> \ register_unary_variant_op_decoder_fn_##ctr(type_name) +// Register a unary zeros_like variant function with the signature: +// Status ZerosLikeFn(OpKernelContext* ctx, const T& t, T* t_out); +// to Variants having TypeName type_name, for device string device. +#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(device, T, type_name, \ + zeros_like_function) \ + REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, device, T, type_name, zeros_like_function) + +#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \ + ctr, device, T, type_name, zeros_like_function) \ + REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ(ctr, device, T, type_name, \ + zeros_like_function) + +#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ( \ + ctr, device, T, type_name, zeros_like_function) \ + static variant_op_registry_fn_registration:: \ + UnaryVariantZerosLikeRegistration \ + register_unary_variant_op_decoder_fn_##ctr(device, type_name, \ + zeros_like_function) + } // end namespace tensorflow #endif // TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_ diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc index 86fef53dbe69a103146640ed5ee705085afd383e..4e79180217a58e95591120a3e316d80d5cbd8b40 100644 --- a/tensorflow/core/framework/variant_op_registry_test.cc +++ b/tensorflow/core/framework/variant_op_registry_test.cc @@ -15,13 +15,25 @@ limitations under the License. #include +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif + #include "tensorflow/core/framework/variant_op_registry.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + namespace { struct VariantValue { @@ -33,7 +45,24 @@ struct VariantValue { *s = TensorShape({-0xdeadbeef}); return Status::OK(); } + static Status CPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v, + VariantValue* v_out) { + if (v.early_exit) { + return errors::InvalidArgument("early exit zeros_like!"); + } + v_out->zeros_like_set = 1; // CPU + return Status::OK(); + } + static Status GPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v, + VariantValue* v_out) { + if (v.early_exit) { + return errors::InvalidArgument("early exit zeros_like!"); + } + v_out->zeros_like_set = 2; // GPU + return Status::OK(); + } bool early_exit; + int zeros_like_set; }; REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue", @@ -41,6 +70,14 @@ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue", REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue"); +REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(DEVICE_CPU, VariantValue, + "TEST VariantValue", + VariantValue::CPUZerosLikeFn); + +REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(DEVICE_GPU, VariantValue, + "TEST VariantValue", + VariantValue::GPUZerosLikeFn); + } // namespace TEST(VariantOpShapeRegistryTest, TestBasic) { @@ -101,4 +138,67 @@ TEST(VariantOpDecodeRegistryTest, TestDuplicate) { "fjfjfj already registered"); } +TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetZerosLikeFn( + DEVICE_CPU, "YOU SHALL NOT PASS"), + nullptr); + + VariantValue vv_early_exit{true /* early_exit */, 0 /* zeros_like_set */}; + Variant v = vv_early_exit; + Variant v_out = VariantValue(); + + OpKernelContext* null_context_pointer = nullptr; + Status s0 = + CreateZerosLikeVariant(null_context_pointer, v, &v_out); + EXPECT_FALSE(s0.ok()); + EXPECT_TRUE( + StringPiece(s0.error_message()).contains("early exit zeros_like")); + + VariantValue vv_ok{false /* early_exit */, 0 /* zeros_like_set */}; + v = vv_ok; + TF_EXPECT_OK( + CreateZerosLikeVariant(null_context_pointer, v, &v_out)); + VariantValue* vv_out = CHECK_NOTNULL(v_out.get()); + EXPECT_EQ(vv_out->zeros_like_set, 1); // CPU +} + +#if GOOGLE_CUDA +TEST(VariantOpZerosLikeRegistryTest, TestBasicGPU) { + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetZerosLikeFn( + DEVICE_GPU, "YOU SHALL NOT PASS"), + nullptr); + + VariantValue vv_early_exit{true /* early_exit */, 0 /* zeros_like_set */}; + Variant v = vv_early_exit; + Variant v_out = VariantValue(); + + OpKernelContext* null_context_pointer = nullptr; + Status s0 = + CreateZerosLikeVariant(null_context_pointer, v, &v_out); + EXPECT_FALSE(s0.ok()); + EXPECT_TRUE( + StringPiece(s0.error_message()).contains("early exit zeros_like")); + + VariantValue vv_ok{false /* early_exit */, 0 /* zeros_like_set */}; + v = vv_ok; + TF_EXPECT_OK( + CreateZerosLikeVariant(null_context_pointer, v, &v_out)); + VariantValue* vv_out = CHECK_NOTNULL(v_out.get()); + EXPECT_EQ(vv_out->zeros_like_set, 2); // GPU +} +#endif // GOOGLE_CUDA + +TEST(VariantOpZerosLikeRegistryTest, TestDuplicate) { + UnaryVariantOpRegistry registry; + UnaryVariantOpRegistry::VariantZerosLikeFn f; + + registry.RegisterZerosLikeFn(DEVICE_CPU, "fjfjfj", f); + EXPECT_DEATH(registry.RegisterZerosLikeFn(DEVICE_CPU, "fjfjfj", f), + "fjfjfj already registered"); + + registry.RegisterZerosLikeFn(DEVICE_GPU, "fjfjfj", f); + EXPECT_DEATH(registry.RegisterZerosLikeFn(DEVICE_GPU, "fjfjfj", f), + "fjfjfj already registered"); +} + } // namespace tensorflow diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 7d938365c5a823482c1f097943f73647b51a3b19..a274c79970497cbe85ca3edc1099de8b72ba8eae 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -523,6 +523,17 @@ Status Graph::IsValidNode(const Node* node) const { return Status::OK(); } +Status Graph::IsValidOutputTensor(const Node* node, int idx) const { + TF_RETURN_IF_ERROR(IsValidNode(node)); + if (idx >= node->num_outputs()) { + return errors::InvalidArgument("Node '", node->name(), "' (type: '", + node->op_def().name(), + "', num of outputs: ", node->num_outputs(), + ") does not have ", "output ", idx); + } + return Status::OK(); +} + Node* Graph::AllocateNode(std::shared_ptr props, const Node* cost_node) { Node* node = nullptr; @@ -572,7 +583,7 @@ int Graph::InternDeviceName(const string& device_name) { } string Edge::DebugString() const { - return strings::Printf("Edge %d %s:%d -> %s:%d", id_, src_->name().c_str(), + return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(), src_output_, dst_->name().c_str(), dst_input_); } diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index bd388d90651c4ab25590e841dc4f3a56e4c939ac..25875185e4780e6bdcd4c62b11a70d553af34edf 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -261,10 +261,10 @@ class Node { // that a single `OutputTensor` can correspond to multiple `Edge`s if the output // is consumed by multiple destination nodes. struct OutputTensor { - Node* node; + const Node* node; int index; - OutputTensor(Node* n, int i) : node(n), index(i) {} + OutputTensor(const Node* n, int i) : node(n), index(i) {} OutputTensor() : node(nullptr), index(0) {} }; @@ -519,6 +519,10 @@ class Graph { // Returns OK if `node` is non-null and belongs to this graph Status IsValidNode(const Node* node) const; + // Returns OK if IsValidNode(`node`) and `idx` is less than + // node->num_outputs() + Status IsValidOutputTensor(const Node* node, int idx) const; + // TODO(josh11b): uint64 hash() const; private: diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index fb0661ce5cd5b8ca7e8e3ef8a273c22fc8c7dda7..9e58993fbb277d5949d771b10854cc578d9a1937 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -280,19 +280,31 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; csinfo_.mkl_conv2d_with_bias_backprop_bias = "_MklConv2DWithBiasBackpropBias"; - csinfo_.relu = "Relu"; - csinfo_.relu_grad = "ReluGrad"; - csinfo_.reshape = "Reshape"; - csinfo_.split = "Split"; + csinfo_.relu = "Relu"; + csinfo_.relu_grad = "ReluGrad"; + csinfo_.reshape = "Reshape"; + csinfo_.split = "Split"; + // Element-wise ops. Ensure you also add any new ops to IsOpElementWise + // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the + // MklInputConversion op is added before it. + csinfo_.add = "Add"; + csinfo_.maximum = "Maximum"; + csinfo_.mul = "Mul"; + csinfo_.squared_difference = "SquaredDifference"; + csinfo_.sub = "Sub"; + // End - element-wise ops. See note above. // NOTE: names are alphabetically sorted. rinfo_.push_back({csinfo_.addn, GetMklOpName(csinfo_.addn), CopyAttrsAddN, AddNRewrite, nullptr}); + rinfo_.push_back({csinfo_.add, + mkl_op_registry::GetMklOpName(csinfo_.add), + CopyAttrsDataType, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.avg_pool, - GetMklOpName(csinfo_.avg_pool), + mkl_op_registry::GetMklOpName(csinfo_.avg_pool), CopyAttrsPooling, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.avg_pool_grad, - GetMklOpName(csinfo_.avg_pool_grad), + mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad), CopyAttrsPooling, AlwaysRewrite, nullptr}); // BiasAddGrad gets written into Conv2DWithBiasBackpropBias depending // on if context contains Conv2D. @@ -306,50 +318,62 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CopyAttrsBiasAddGrad, ContextMatchRewrite, &biasaddgrad_matmul_context_}); rinfo_.push_back({csinfo_.concat, - GetMklOpName(csinfo_.concat), + mkl_op_registry::GetMklOpName(csinfo_.concat), CopyAttrsConcat, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.concatv2, - GetMklOpName(csinfo_.concatv2), + mkl_op_registry::GetMklOpName(csinfo_.concatv2), CopyAttrsConcatV2, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.conv2d, - GetMklOpName(csinfo_.conv2d), + mkl_op_registry::GetMklOpName(csinfo_.conv2d), CopyAttrsConv2D, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.conv2d_grad_filter, - GetMklOpName(csinfo_.conv2d_grad_filter), + mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), CopyAttrsConv2D, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.conv2d_grad_input, - GetMklOpName(csinfo_.conv2d_grad_input), + mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), CopyAttrsConv2D, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.fused_batch_norm, - GetMklOpName(csinfo_.fused_batch_norm), + mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.fused_batch_norm_grad, - GetMklOpName(csinfo_.fused_batch_norm_grad), + mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.identity, - GetMklOpName(csinfo_.identity), + mkl_op_registry::GetMklOpName(csinfo_.identity), CopyAttrsIdentity, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.lrn, - GetMklOpName(csinfo_.lrn), + mkl_op_registry::GetMklOpName(csinfo_.lrn), CopyAttrsLRN, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.lrn_grad, - GetMklOpName(csinfo_.lrn_grad), + mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), CopyAttrsLRN, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.max_pool, - GetMklOpName(csinfo_.max_pool), + mkl_op_registry::GetMklOpName(csinfo_.max_pool), CopyAttrsPooling, NonDepthBatchWisePoolRewrite, nullptr}); rinfo_.push_back({csinfo_.max_pool_grad, - GetMklOpName(csinfo_.max_pool_grad), + mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad), CopyAttrsPooling, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.maximum, + mkl_op_registry::GetMklOpName(csinfo_.maximum), + CopyAttrsDataType, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.mul, + mkl_op_registry::GetMklOpName(csinfo_.mul), + CopyAttrsDataType, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.relu, - GetMklOpName(csinfo_.relu), - CopyAttrsRelu, AlwaysRewrite, nullptr}); + mkl_op_registry::GetMklOpName(csinfo_.relu), + CopyAttrsDataType, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.relu_grad, - GetMklOpName(csinfo_.relu_grad), - CopyAttrsRelu, AlwaysRewrite, nullptr}); + mkl_op_registry::GetMklOpName(csinfo_.relu_grad), + CopyAttrsDataType, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.reshape, - GetMklOpName(csinfo_.reshape), + mkl_op_registry::GetMklOpName(csinfo_.reshape), CopyAttrsReshape, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.squared_difference, + mkl_op_registry::GetMklOpName(csinfo_.squared_difference), + CopyAttrsDataType, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.sub, + mkl_op_registry::GetMklOpName(csinfo_.sub), + CopyAttrsDataType, AlwaysRewrite, nullptr}); // Add info about which ops to add workspace edge to and the slots. wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); @@ -433,6 +457,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { /// NOTE: names are alphabetically sorted. typedef struct { string addn; + string add; string avg_pool; string avg_pool_grad; string bias_add; @@ -450,15 +475,19 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string matmul; string max_pool; string max_pool_grad; + string maximum; string mkl_conv2d; string mkl_conv2d_grad_input; string mkl_conv2d_grad_filter; string mkl_conv2d_with_bias; string mkl_conv2d_with_bias_backprop_bias; + string mul; string relu; string relu_grad; string reshape; string split; + string squared_difference; + string sub; } ConstStringsInfo; private: @@ -506,15 +535,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return N; } - // Get the name of Mkl op from original TensorFlow op - // We prefix 'Mkl' to the original op to get Mkl op. - // TODO(nhasabni) We should move this to mkl_util.h. - inline string GetMklOpName(const string& name) const { - // Prefix that we add to Tensorflow op name to construct Mkl op name. - const char* const kMklOpPrefix = "_Mkl"; - return string(kMklOpPrefix) + name; - } - // Can op represented by node 'n' run on DEVICE_CPU? // Op can run on CPU with MKL if the runtime assigned device or the // user requested device contains device CPU, or both are empty. @@ -929,11 +949,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsIdentity(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb); - static void CopyAttrsRelu(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb); @@ -1117,6 +1137,44 @@ int MklLayoutRewritePass::SetUpContiguousInputs( CHECK_NOTNULL(workspace_tensors); CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + // TODO(nhasabni): Temporary solution to connect filter input of + // BackpropInput with the converted filter from Conv2D. + bool do_connect_conv2d_backprop_input_filter = false; + Node* conv2d_node = nullptr; + // Filter node is 2nd input (slot index 1) of Conv2D. + int kConv2DFilterInputSlotIdx = 1; + int kConv2DBackpropInputFilterInputSlotIdx = 1; + int kConv2DFilterOutputSlotIdx = 1; + if (old_node->type_string() == csinfo_.conv2d_grad_input) { + // We need to find Conv2D node from Conv2DBackpropInput. + // For that let's first find filter node that is 2nd input (slot 1) + // of BackpropInput. + Node* filter_node = nullptr; + old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx, &filter_node); + CHECK_NOTNULL(filter_node); + + // Now check which nodes receive from filter_node. Filter feeds as + // 2nd input (slot 1) of _MklConv2D and _MklConv2DWithBias. + for (const Edge* e : filter_node->out_edges()) { + if (e->dst()->type_string() == csinfo_.mkl_conv2d && + e->dst_input() == kConv2DFilterInputSlotIdx + /* filter is 2nd input of Conv2D and _MklConv2D. */) { + if (conv2d_node != nullptr) { + VLOG(1) << "MklLayoutRewritePass: unusual case of same filter" + << " feeding multiple Conv2D nodes: " + << filter_node->DebugString(); + // We will not connect filter input of Conv2DBackpropInput + // to be safe here. + do_connect_conv2d_backprop_input_filter = false; + break; + } else { + conv2d_node = e->dst(); + do_connect_conv2d_backprop_input_filter = true; + } + } + } + } + // Number of input slots to original op // Input slots are represented by .Input() calls in REGISTER_OP. int old_node_input_slots = old_node->op_def().input_arg_size(); @@ -1140,7 +1198,13 @@ int MklLayoutRewritePass::SetUpContiguousInputs( nb->Input(new_node_inputs); nn_slot_idx++; } else { - nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second); + // Special case for connecting filter input of Conv2DBackpropInput + if (do_connect_conv2d_backprop_input_filter && + iidx == kConv2DBackpropInputFilterInputSlotIdx) { + nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx); + } else { + nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second); + } iidx++; nn_slot_idx++; } @@ -1175,9 +1239,17 @@ int MklLayoutRewritePass::SetUpContiguousInputs( } else { Node* mkl_node = nullptr; int mkl_node_output_slot = 0; - GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first, - old_node_inputs[iidx].second, - &mkl_node, &mkl_node_output_slot); + // Special case for connecting filter input of Conv2DBackpropInput + if (do_connect_conv2d_backprop_input_filter && + iidx == kConv2DBackpropInputFilterInputSlotIdx) { + GetNodeProducingMklTensor(g, old_node, conv2d_node, + kConv2DFilterOutputSlotIdx, &mkl_node, + &mkl_node_output_slot); + } else { + GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first, + old_node_inputs[iidx].second, &mkl_node, + &mkl_node_output_slot); + } nb->Input(mkl_node, mkl_node_output_slot); iidx++; nn_slot_idx++; @@ -1300,7 +1372,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); for (auto ws : wsinfo_) { if (orig_node->type_string() == ws.fwd_op && - mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()), T)) { + mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { // If this op is a fwd op, then we need to check if there is an // edge from this node's fwd_slot to bwdop's bwd_slot. If there is // an edge, then we just add an attribute on this node for setting @@ -1326,7 +1398,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( nb->Attr("workspace_enabled", false); } } else if (orig_node->type_string() == ws.bwd_op && - mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()), + mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { // If this op is a bwd op, then we need to add workspace edge and // it's Mkl tensor edge between its corresponding fwd op and this @@ -1342,7 +1414,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( if (e->src_output() == ws.fwd_slot && // We would have rewritten the forward op, so we need to use // GetMklOpName call to get its Mkl name. - e->src()->type_string() == GetMklOpName(ws.fwd_op) && + e->src()->type_string() == mkl_op_registry::GetMklOpName(ws.fwd_op) && e->dst_input() == ws.bwd_slot) { nb->Attr("workspace_enabled", true); CHECK_NOTNULL(ws_tensors); @@ -1507,8 +1579,8 @@ void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node, nb->Attr("data_format", data_format); } -void MklLayoutRewritePass::CopyAttrsRelu(const Node* orig_node, - NodeBuilder* nb) { +void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node, + NodeBuilder* nb) { DataType T; // Get all attributes from old node. @@ -1874,7 +1946,15 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr* g, } // Get all inputs. - const int num_inputs = orig_node->in_edges().size(); + int num_inputs = orig_node->in_edges().size(); + + // Drop count for control edges from inputs + for (const Edge* e : orig_node->in_edges()) { + if (e->IsControlEdge()) { + num_inputs--; + } + } + gtl::InlinedVector control_edges; gtl::InlinedVector, 4> inputs(num_inputs); FillInputs(orig_node, &control_edges, &inputs); @@ -1988,7 +2068,34 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { // BiasAddGrad is not an Mkl layer, so we make an exception for it. if (n->type_string() != csinfo_.bias_add_grad) { - if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) { + if (!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), T)) { + return nullptr; + } + } + + // For elementwise node, we reuse the Eigen implementation and pass the MKL + // metadata tensor through so we can avoid conversions. However, if all + // incoming edges are in TF format, we don't need all this overhead, so + // replace the elementwise node only if at least one of its parents is a MKL + // node. + // + // TODO(vrane): Add implementation for element-wise ops that doesn't reuse + // eigen code to reduce cross-library dependency. + if (mkl_op_registry::IsMklElementWiseOp( + mkl_op_registry::GetMklOpName(n->type_string()), T)) { + bool incoming_mkl_edge = false; + for (auto parent : n->in_edges()) { + if (mkl_op_registry::IsMklOp( + mkl_op_registry::GetMklOpName(parent->src()->type_string()), T)) { + incoming_mkl_edge = true; + break; + } else { + VLOG(1) << "Non-MKL parent is: " << parent->src()->type_string(); + } + } + if (incoming_mkl_edge == false) { + VLOG(1) << "Skipping replacement of elementwise node which has no MKL " + "parents."; return nullptr; } } diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index 482e339802f0ba82bf0352007cf946d359900371..6a41e3965abc226c55b19b3183ba3cf8bd2ecaa8 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -133,19 +133,19 @@ TEST_F(MklLayoutPassTest, Basic) { InitGraph( "node { name: 'A' op: 'Input'}" "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }" - "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Mul);D(Mul)|" + "A(Input);B(Input);C(Zeta);D(Zeta)|" "A->C;A->D;B->C:1;B->D:1"); } // Test set 1: Conv2D + AddBias -// C=_MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved ordering) -// C=_MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous ordering) +// C=_MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Zeta(E,Y) (for interleaved ordering) +// C=_MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Zeta(E,Y) (for contiguous ordering) TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) { CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); InitGraph( @@ -166,18 +166,18 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) { " attr { key: 'data_format' value { s: 'NCHW' } }" " input: ['C', 'D'] }" "node { name: 'Y' op: 'Input'}" - "node { name: 'Z' op: 'Sub'" + "node { name: 'Z' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['E', 'Y']}"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);" - "M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;" + "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->E;" "A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;M->E:3;" "N->E:4;Y->Z:1"); } -// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved) -// C=_MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous) +// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Zeta(E,Y) (for interleaved) +// C=_MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Zeta(E,Y) (for contiguous) // Test for correct output slots selected TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) { CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); @@ -199,17 +199,17 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) { " attr { key: 'data_format' value { s: 'NCHW' } }" " input: ['C', 'D'] }" "node { name: 'Y' op: 'Input'}" - "node { name: 'Z' op: 'Sub'" + "node { name: 'Z' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['E', 'Y']}"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);" - "M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;" + "M(_MklInput2);N(_MklInput2);Y(Input);Z(Zeta)|A->E;" "A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;" "M:1->E:3;N:1->E:4;Y->Z:1"); } -// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y); +// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Zeta(E,Y); // This is a case of node rewrite followed by node merge. // We will first rewrite Conv2D to _MklConv2D, and then merge _MklConv2D // with BiasAdd to produce _MklConv2DWithBias. @@ -231,12 +231,12 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) { " attr { key: 'data_format' value { s: 'NCHW' } }" " input: ['C', 'D'] }" "node { name: 'Y' op: 'Input'}" - "node { name: 'Z' op: 'Sub'" + "node { name: 'Z' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['E', 'Y']}"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Sub)|" + "DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Zeta)|" "A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" "A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;" "DMT/_2->E:5;E->Z;Y->Z:1"); @@ -286,7 +286,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) { "M(_MklInput);N(_MklInput)|A->C;B->C:1;D->F;E->F:1;M->C:2;N->C:3"); } -// _MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add). +// _MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Zeta). // Merge should not be done in such case. TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) { InitGraph( @@ -308,12 +308,12 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) { " attr { key: 'data_format' value { s: 'NCHW' } }" " input: ['D', 'E'] }" // Conv2D has two outputs. // No merge should happen. - "node { name: 'G' op: 'Add'" + "node { name: 'G' op: 'Zeta'" " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'E'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);" - "G(Add);M(_MklInput);N(_MklInput)|A->C;B->C:1;C->G;D->F;" + "G(Zeta);M(_MklInput);N(_MklInput)|A->C;B->C:1;C->G;D->F;" "E->F:1;E->G:1;M->C:2;N->C:3"); } @@ -362,7 +362,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B', 'C', 'M', 'N', 'O']}" - "node { name: 'E' op: 'Sub'" + "node { name: 'E' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['D', 'A']}" "node { name: 'F' op: 'Int32Input'}" @@ -387,7 +387,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) { " input: ['E'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);" - "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);" + "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);" "I(_MklConv2DBackpropInput);J(_MklConv2DWithBiasBackpropBias);" "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G;B->D:1;" "B->I:1;C->D:2;D->E;DMT/_0->J:1;E->G:2;E->I:2;E->J;" @@ -413,7 +413,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative1) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B', 'C', 'M', 'N', 'O']}" - "node { name: 'E' op: 'Sub'" + "node { name: 'E' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['D', 'A']}" "node { name: 'F' op: 'Int32Input'}" @@ -438,7 +438,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative1) { " input: ['E'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" - "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);" + "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);" "I(_MklConv2DBackpropInput);J(BiasAddGrad);" "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;" "B->I:1;C->D:2;D->E;E->G;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;" @@ -463,7 +463,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['B', 'A', 'C', 'M', 'N', 'O']}" - "node { name: 'E' op: 'Sub'" + "node { name: 'E' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['D', 'A']}" "node { name: 'F' op: 'Int32Input'}" @@ -488,7 +488,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) { " input: ['E'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" - "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);" + "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);" "I(_MklConv2DBackpropInput);J(BiasAddGrad);" "M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;" "B->I:1;C->D:2;D->E;E->G:2;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;" @@ -512,7 +512,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B', 'C', 'M', 'N', 'O']}" - "node { name: 'E' op: 'Sub'" + "node { name: 'E' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['D', 'A']}" "node { name: 'F' op: 'Int32Input'}" @@ -529,7 +529,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) { " input: ['E'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);" - "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);" + "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);" "H(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);" "O(_MklInput)|A->D;A->E:1;A->G;B->D:1;C->D:2;D->E;DMT/_0->H:1;" "E->G:2;E->H;E:control->DMT/_0:control;F->G:1;M->D:3;M->G:3;" @@ -553,7 +553,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative1) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B', 'C', 'M', 'N', 'O']}" - "node { name: 'E' op: 'Sub'" + "node { name: 'E' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['D', 'A']}" "node { name: 'F' op: 'Int32Input'}" @@ -570,7 +570,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative1) { " input: ['E'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" - "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);" + "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);" "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;" "C->D:2;D->E;E->G;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;" "O->G:5"); @@ -593,7 +593,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative2) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['B', 'A', 'C', 'M', 'N', 'O']}" - "node { name: 'E' op: 'Sub'" + "node { name: 'E' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['D', 'A']}" "node { name: 'F' op: 'Int32Input'}" @@ -610,7 +610,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative2) { " input: ['E'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" - "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);" + "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);" "M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;" "C->D:2;D->E;E->G:2;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;" "O->G:5"); @@ -618,8 +618,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative2) { // No _MklConv2DWithBias in context, but _MklConv2D in context. // No rewrite for BiasAddGrad should happen. -// C=_MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved) -// C=_MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous) +// C=_MklConv2D(A,M,B,N); D=Zeta(C,A); E=BiasAddGrad(D) (for interleaved) +// C=_MklConv2D(A,B,M,N); D=Zeta(C,A); E=BiasAddGrad(D) (for contiguous) TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) { InitGraph( "node { name: 'A' op: 'Input'}" @@ -633,7 +633,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B', 'M', 'N']}" - "node { name: 'D' op: 'Sub'" + "node { name: 'D' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'A']}" "node { name: 'E' op: 'BiasAddGrad'" @@ -641,21 +641,21 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) { " attr { key: 'data_format' value { s: 'NCHW' } }" " input: ['D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(_MklConv2D);D(Sub);E(BiasAddGrad);" + "A(Input);B(Input);C(_MklConv2D);D(Zeta);E(BiasAddGrad);" "M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;" "M->C:2;N->C:3"); } // No Conv2D in the context for BiasAddGrad. No rewrite should happen. -// C=Add(A,B); D=Sub(C,A); E=BiasAddGrad(D) +// C=Polygamma(A,B); D=Zeta(C,A); E=BiasAddGrad(D) TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D) { InitGraph( "node { name: 'A' op: 'Input'}" "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Add'" + "node { name: 'C' op: 'Polygamma'" " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B']}" - "node { name: 'D' op: 'Sub'" + "node { name: 'D' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'A']}" "node { name: 'E' op: 'BiasAddGrad'" @@ -663,13 +663,13 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D) { " attr { key: 'data_format' value { s: 'NCHW' } }" " input: ['D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|" + "A(Input);B(Input);C(Polygamma);D(Zeta);E(BiasAddGrad)|" "A->C;A->D:1;B->C:1;C->D;D->E"); } // No Conv2D in the context for BiasAddGrad, but MatMul in context. // Rewrite should happen, but name of BiasAddGrad does not change. -// C=MatMul(A,B); D=Sub(C,A); E=BiasAddGrad(D) +// C=MatMul(A,B); D=Zeta(C,A); E=BiasAddGrad(D) TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D_MatMul) { InitGraph( "node { name: 'A' op: 'Input'}" @@ -679,7 +679,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D_MatMul) { " attr { key: 'transpose_a' value { b: false } }" " attr { key: 'transpose_b' value { b: false } }" " input: ['A', 'B']}" - "node { name: 'D' op: 'Sub'" + "node { name: 'D' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'A']}" "node { name: 'E' op: 'BiasAddGrad'" @@ -687,12 +687,12 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D_MatMul) { " attr { key: 'data_format' value { s: 'NCHW' } }" " input: ['D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|" + "A(Input);B(Input);C(MatMul);D(Zeta);E(BiasAddGrad)|" "A->C;A->D:1;B->C:1;C->D;D->E"); } // Test set 3: MatMul..BiasAddGrad -> BiasAddGrad rewrite tests -// C=MatMul(A,B); D=Sub(C,A); E=BiasAddGrad(D) +// C=MatMul(A,B); D=Zeta(C,A); E=BiasAddGrad(D) TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Positive) { InitGraph( "node { name: 'A' op: 'Input'}" @@ -702,7 +702,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Positive) { " attr { key: 'transpose_a' value { b: false } }" " attr { key: 'transpose_b' value { b: false } }" " input: ['A', 'B']}" - "node { name: 'D' op: 'Sub'" + "node { name: 'D' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'A']}" "node { name: 'E' op: 'BiasAddGrad'" @@ -710,20 +710,20 @@ TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Positive) { " attr { key: 'data_format' value { s: 'NCHW' } }" " input: ['D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|" + "A(Input);B(Input);C(MatMul);D(Zeta);E(BiasAddGrad)|" "A->C;A->D:1;B->C:1;C->D;D->E"); } // No MatMul in the context for BiasAddGrad. No rewrite should happen. -// C=Add(A,B); D=Sub(C,A); E=BiasAddGrad(D) +// C=Polygamma(A,B); D=Zeta(C,A); E=BiasAddGrad(D) TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Negative_NoMatMul) { InitGraph( "node { name: 'A' op: 'Input'}" "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Add'" + "node { name: 'C' op: 'Polygamma'" " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B']}" - "node { name: 'D' op: 'Sub'" + "node { name: 'D' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'A']}" "node { name: 'E' op: 'BiasAddGrad'" @@ -731,7 +731,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Negative_NoMatMul) { " attr { key: 'data_format' value { s: 'NCHW' } }" " input: ['D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|" + "A(Input);B(Input);C(Polygamma);D(Zeta);E(BiasAddGrad)|" "A->C;A->D:1;B->C:1;C->D;D->E"); } @@ -752,10 +752,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B']}" - "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['B', 'C'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);" + "A(Input);B(Input);C(_MklConv2D);D(Zeta);DMT/_0(Const);" "DMT/_1(Const)|A->C;A:control->DMT/_0:control;" "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;" "DMT/_1->C:3"); @@ -781,14 +781,14 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'C']}" - "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;" + "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->C;A->D;" "A:control->DMT/_0:control;A:control->DMT/_1:control;" "A:control->DMT/_2:control;B->C:1;C->D:1;C->E;" - "C:1->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2"); + "C:2->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2"); } // Conv2D with INT32 which is not supported by Mkl @@ -803,10 +803,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B']}" - "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_HALF } }" + "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }" " input: ['B', 'C'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(HalfInput);B(HalfInput);C(Conv2D);D(Mul)|" + "A(HalfInput);B(HalfInput);C(Conv2D);D(Zeta)|" "A->C;B->C:1;B->D;C->D:1"); } @@ -822,11 +822,11 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B', 'C']}" - "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropFilter);" - "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Mul)|" + "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|" "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" "A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" "DMT/_1->D:4;DMT/_2->D:5"); @@ -844,11 +844,11 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradInput_Positive) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['B', 'A', 'C']}" - "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropInput);" - "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Mul)|" + "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|" "A->D:1;A->E;B->D;B:control->DMT/_0:control;" "B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;" "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); @@ -869,11 +869,11 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) { " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'N' value { i: 2 } }" " input: ['A', 'B:0', 'B:1']}" - "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;A:control->DMT/_0:control;" + "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); } @@ -908,16 +908,16 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) { " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'N' value { i: 2 } }" " input: ['G', 'E', 'F']}" - "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'H'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" - "F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;" + "F(_MklConv2D);G(Const);H(_MklConcat);I(Zeta)|A->E;A->I;" "A:control->DMT/_2:control;A:control->DMT/_3:control;" "B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;" "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;" - "DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;" + "DMT/_4->H:3;E->H:1;E:2->H:4;F->H:2;F:2->H:5;G->H;" "G:control->DMT/_4:control;H->I:1"); } @@ -935,7 +935,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B']}" - "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D']}" "node { name: 'G' op: 'Const' " " attr { key: 'dtype' value { type: DT_INT32 } }" @@ -946,14 +946,14 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) { " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'N' value { i: 2 } }" " input: ['G', 'E', 'F']}" - "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'H'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);" - "H(_MklConcat);I(Mul)|A->E;A->I;A:control->DMT/_0:control;" + "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);" + "H(_MklConcat);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;" "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;" - "DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:1->H:4;F->H:2;" + "DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:2->H:4;F->H:2;" "G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1"); } @@ -973,11 +973,11 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) { " attr { key: 'Tidx' value { type: DT_INT32 } }" " attr { key: 'N' value { i: 2 } }" " input: ['B:0', 'B:1', 'A']}" - "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Const);B(InputList);C(Input);D(_MklConcatV2);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;" + "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D:2;B->D;B:1->D:1;" "B:control->DMT/_0:control;B:control->DMT/_1:control;" "B:control->DMT/_2:control;C->E;D->E:1;DMT/_0->D:3;" "DMT/_1->D:4;DMT/_2->D:5"); @@ -1014,17 +1014,17 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) { " attr { key: 'Tidx' value { type: DT_INT32 } }" " attr { key: 'N' value { i: 2 } }" " input: ['E', 'F', 'G']}" - "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'H'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" - "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;" + "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Zeta)|A->E;A->I;" "A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;" "C:control->DMT/_0:control;C:control->DMT/_1:control;" "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;" - "DMT/_4->H:5;E->H;E:1->H:3;E:control->DMT/_4:control;F->H:1;" - "F:1->H:4;G->H:2;H->I:1"); + "DMT/_4->H:5;E->H;E:2->H:3;E:control->DMT/_4:control;F->H:1;" + "F:2->H:4;G->H:2;H->I:1"); } // ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it @@ -1041,7 +1041,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B']}" - "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D']}" "node { name: 'G' op: 'Const' " " attr { key: 'dtype' value { type: DT_INT32 } }" @@ -1053,14 +1053,14 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) { " attr { key: 'Tidx' value { type: DT_INT32 } }" " attr { key: 'N' value { i: 2 } }" " input: ['E', 'F', 'G']}" - "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'H'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);" - "H(_MklConcatV2);I(Mul)|A->E;A->I;A:control->DMT/_0:control;" + "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);" + "H(_MklConcatV2);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;" "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;" - "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;" + "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:2->H:3;" "E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;" "G->H:2;H->I:1"); } @@ -1071,10 +1071,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu_Positive) { "node { name: 'B' op: 'Relu'" " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A'] }" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(_MklRelu);C(Mul);DMT/_0(Const)|A->B;A->C;" + "A(Input);B(_MklRelu);C(Zeta);DMT/_0(Const)|A->B;A->C;" "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); } @@ -1085,10 +1085,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_Positive) { "node { name: 'C' op: 'ReluGrad'" " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }" - "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'C'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(_MklReluGrad);D(Mul);DMT/_0(Const);" + "A(Input);B(Input);C(_MklReluGrad);D(Zeta);DMT/_0(Const);" "DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;" "A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); } @@ -1102,10 +1102,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluReluGrad_Positive) { "node { name: 'C' op: 'ReluGrad'" " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }" - "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'C'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(_MklRelu);C(_MklReluGrad);D(Mul);DMT/_0(Const);" + "A(Input);B(_MklRelu);C(_MklReluGrad);D(Zeta);DMT/_0(Const);" "DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;" "A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;" "DMT/_1->C:2"); @@ -1121,10 +1121,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" " input: ['A'] }" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(_MklAvgPool);C(Mul);DMT/_0(Const)|A->B;A->C;" + "A(Input);B(_MklAvgPool);C(Zeta);DMT/_0(Const)|A->B;A->C;" "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); } @@ -1139,10 +1139,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolGrad_Positive) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" " input: ['A', 'B'] }" - "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['B', 'C'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Mul);DMT/_0(Const);" + "A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);" "DMT/_1(Const)|A->C;A:control->DMT/_0:control;" "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;" "DMT/_1->C:3"); @@ -1166,10 +1166,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolAvgPoolGrad_Positive) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" " input: ['I', 'B'] }" - "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'C'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Mul);DMT/_0(Const);" + "A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);" "DMT/_1(Const);I(Int32Input)|A->B;A->D;A:control->DMT/_0:control;" "B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;DMT/_1->C:2;I->C;" "I:control->DMT/_1:control"); @@ -1188,12 +1188,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGrad_Positive) { " attr { key: 'epsilon' value { f: 0.0001 } }" " attr { key: 'is_training' value { b: true } }" " input: ['A', 'B', 'C', 'D', 'E'] }" - "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'F'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);" - "F(_MklFusedBatchNormGrad);G(Mul)|A->F;A->G;" + "F(_MklFusedBatchNormGrad);G(Zeta)|A->F;A->G;" "A:control->DMT/_0:control;A:control->DMT/_1:control;" "A:control->DMT/_2:control;A:control->DMT/_3:control;" "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" @@ -1214,12 +1214,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_Positive) { " attr { key: 'epsilon' value { f: 0.0001 } }" " attr { key: 'is_training' value { b: true } }" " input: ['A', 'B', 'C', 'D', 'E'] }" - "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'F'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);" - "F(_MklFusedBatchNorm);G(Mul)|A->F;A->G;" + "F(_MklFusedBatchNorm);G(Zeta)|A->F;A->G;" "A:control->DMT/_0:control;A:control->DMT/_1:control;" "A:control->DMT/_2:control;A:control->DMT/_3:control;" "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" @@ -1268,12 +1268,12 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) { " attr { key: 'depth_radius' value { i: 2 } }" " input: ['E', 'F', 'B'] }" "node { name: 'H' op: 'Input'}" - "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['H', 'G'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);" - "I(Mul)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" + "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" "B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;B:control->DMT/_1:control;C->E:1;" "C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;DMT/_2->G:5;" "E->G;E:1->G:4;E:control->DMT/_2:control;F->G:1;G->I:1;H->I"); @@ -1301,11 +1301,11 @@ TEST_F(MklLayoutPassTest, LRN_Positive) { " attr { key: 'data_format' value { s: 'NCHW' } }" " attr { key: 'depth_radius' value { i: 2 } }" " input: ['C', 'D', 'B'] }" - "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'E'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);E(_MklLRNGrad);F(Mul)|" + "DMT/_2(Const);E(_MklLRNGrad);F(Zeta)|" "A->B;A:control->DMT/_0:control;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;" "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;" "D->E:1;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1"); @@ -1323,10 +1323,10 @@ TEST_F(MklLayoutPassTest, LRN_Negative1) { " attr { key: 'data_format' value { s: 'NCHW' } }" " attr { key: 'depth_radius' value { i: 2 } }" " input: ['A'] }" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(_MklLRN);C(Mul);DMT/_0(Const)|" + "A(Input);B(_MklLRN);C(Zeta);DMT/_0(Const)|" "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); } @@ -1344,11 +1344,11 @@ TEST_F(MklLayoutPassTest, LRN_Negative2) { " attr { key: 'data_format' value { s: 'NCHW' } }" " attr { key: 'depth_radius' value { i: 2 } }" " input: ['A', 'B', 'C'] }" - "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|" + "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|" "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" "A:control->DMT/_2:control;A:control->DMT/_3:control;" "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" @@ -1386,12 +1386,12 @@ TEST_F(MklLayoutPassTest, LRN_Negative3) { " attr { key: 'data_format' value { s: 'NCHW' } }" " attr { key: 'depth_radius' value { i: 2 } }" " input: ['C', 'B', 'D'] }" - "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['E', 'F'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);" - "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;" + "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Zeta)|A->B;" "A:control->DMT/_0:control;B->E:2;" "B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;" "C:control->DMT/_1:control;C:control->DMT/_2:control;" @@ -1421,11 +1421,11 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" " input: ['C', 'B', 'D'] }" - "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'E'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Mul)|" + "DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Zeta)|" "A->B;A:control->DMT/_0:control;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;" "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;" "D->E:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1"); @@ -1444,10 +1444,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" " input: ['A'] }" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(_MklMaxPool);C(Mul);DMT/_0(Const)|" + "A(Input);B(_MklMaxPool);C(Zeta);DMT/_0(Const)|" "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); } @@ -1466,11 +1466,11 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" " input: ['A', 'B', 'C'] }" - "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'D'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|" + "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|" "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" "A:control->DMT/_2:control;A:control->DMT/_3:control;" "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" @@ -1489,10 +1489,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative3) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " input: ['A'] }" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); + "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); } // Test MaxPool handling for batch-wise pooling (NCHW) @@ -1507,10 +1507,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative4) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }" " input: ['A'] }" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); + "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); } // Test MaxPool handling for depth-wise pooling (NHWC) @@ -1525,10 +1525,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative5) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " input: ['A'] }" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); + "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); } // Test MaxPool handling for depth-wise pooling (NCHW) @@ -1543,10 +1543,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative6) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 1, i:2, i:1, i:1} } }" " input: ['A'] }" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); + "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); } // Test MaxPool handling for batch-wise pooling (NHWC) @@ -1561,10 +1561,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative7) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " input: ['A'] }" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); + "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); } // Test MaxPool handling for batch-wise pooling (NHWC) @@ -1579,10 +1579,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative8) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }" " input: ['A'] }" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); + "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); } // Test MaxPool handling for depth-wise pooling (NHWC) @@ -1597,10 +1597,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative9) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " input: ['A'] }" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); + "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); } // Test MaxPool handling for depth-wise pooling (NHWC) @@ -1615,10 +1615,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative10) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:2} } }" " input: ['A'] }" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); + "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); } ///////////////////////////////////////////////////////////////////// @@ -1636,10 +1636,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B']}" - "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['B', 'C'] }", kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Conv2D);D(Mul)|A->C;B->C:1;B->D;C->D:1"); + "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1"); } TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) { @@ -1657,7 +1657,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B', 'C', 'M', 'N', 'O']}" - "node { name: 'E' op: 'Sub'" + "node { name: 'E' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['D', 'A']}" "node { name: 'F' op: 'BiasAddGrad'" @@ -1666,7 +1666,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) { " input: ['E'] }", kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" - "E(Sub);F(BiasAddGrad);M(_MklInput);N(_MklInput);" + "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);" "O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;E->F;" "M->D:3;N->D:4;O->D:5"); } @@ -1683,10 +1683,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) { " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'B', 'C']}" - "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'D'] }", kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Mul)|" + "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|" "A->D;A->E;B->D:1;C->D:2;D->E:1"); } @@ -1696,10 +1696,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) { "node { name: 'B' op: 'Relu'" " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A'] }" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }", kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Relu);C(Mul)|A->B;A->C;B->C:1"); + "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1"); } TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) { @@ -1709,10 +1709,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) { "node { name: 'C' op: 'ReluGrad'" " attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }" - "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'C'] }", kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(ReluGrad);D(Mul)|A->C;A->D;B->C:1;C->D:1"); + "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1"); } TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) { @@ -1725,10 +1725,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " input: ['A'] }" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }", kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1"); + "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); } TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) { @@ -1741,10 +1741,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) { " attr { key: 'padding' value { s: 'VALID' } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " input: ['A'] }" - "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'B'] }", kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(AvgPool);C(Mul)|A->B;A->C;B->C:1"); + "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1"); } // Concat Op test: Concat with no Mkl layer feeding it @@ -1762,10 +1762,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) { " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'N' value { i: 2 } }" " input: ['A', 'B:0', 'B:1']}" - "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D'] }", kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Const);B(InputList);C(Input);D(Concat);E(Mul)|A->D;" + "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;" "B->D:1;B:1->D:2;C->E;D->E:1"); } @@ -1784,10 +1784,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) { " attr { key: 'Tidx' value { type: DT_INT32 } }" " attr { key: 'N' value { i: 2 } }" " input: ['B:0', 'B:1', 'A']}" - "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['C', 'D'] }", kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Const);B(InputList);C(Input);D(ConcatV2);E(Mul)|" + "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); } @@ -1804,11 +1804,11 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) { " attr { key: 'epsilon' value { f: 0.0001 } }" " attr { key: 'is_training' value { b: true } }" " input: ['A', 'B', 'C', 'D', 'E'] }" - "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" " input: ['A', 'F'] }", kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(Input);E(Input);" - "F(FusedBatchNorm);G(Mul)|A->F;A->G;B->F:1;C->F:2;D->F:3;" + "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;" "E->F:4;F->G:1"); } @@ -1832,12 +1832,12 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) { " attr { key: 'data_format' value { s: 'NCHW' } }" " input: ['C', 'D'] }" "node { name: 'Y' op: 'Input'}" - "node { name: 'Z' op: 'Sub'" + "node { name: 'Z' op: 'Zeta'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['E', 'Y']}", kGPUDevice); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);" - "M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->C;" + "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;" "B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1"); } @@ -1853,7 +1853,7 @@ static void BM_MklLayoutRewritePass(int iters, int op_nodes) { random::SimplePhilox rnd(&philox); for (int op = 0; op < op_nodes; op++) { s += strings::Printf( - "node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { " + "node { name: 'op%04d' op: 'Zeta' attr { key: 'T' value { " "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }", op, rnd.Uniform(10), rnd.Uniform(10)); } diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc index 590b3d030fa212ec4f510ef35fb7a425f2aa2f9e..3f8b0e86d0c57c2896cc2ef4ac1eaf9d2ce66b60 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc @@ -64,6 +64,15 @@ namespace tensorflow { // in the Mkl format. Non-compliant ops accept inputs and outputs in the // TensorFlow format. // +// ADDENDUM: For element-wise ops, we may or may not need a conversion to +// take place before we hit the op. For this, we add a new op before each +// element-wise MKL op to deal with the inputs, called _MklInputConversion. +// This pass has been enhanced to add this capability. +// +// The _MklInputConversion op will check the inputs to the elementwise op and +// make sure that either both are in MKL format or both are in TF format, +// depending on their initial state and whether broadcast is needed or not. + class MklToTfConversionPass : public GraphOptimizationPass { public: MklToTfConversionPass() {} @@ -87,6 +96,16 @@ class MklToTfConversionPass : public GraphOptimizationPass { return mkl_op_registry::IsMklOp(op_name, T); } + // Is the input Op supported by Mkl-specific layout AND + // is it element-wise? + // + // @input op_name string of the op + // @input T Datatype to use for checking input op + // @return true if op is Mkl supported; false, otherwise. + inline bool IsMklElementWiseOp(const string& op_name, DataType T) const { + return mkl_op_registry::IsMklElementWiseOp(op_name, T); + } + // Insert layout conversion node on the edge pointed by 'e' from graph 'g'. // // Edge will be deleted once a call to this function is successful. @@ -96,6 +115,17 @@ class MklToTfConversionPass : public GraphOptimizationPass { // @return Success:OK() if insertion is successful, otherwise returns // appropriate error status code. Status InsertConversionNodeOnEdge(std::unique_ptr* g, Edge*); + + // For element-wise ops, we need to sanitize the inputs. For this, we add a + // new node at the input of the replacement element-wise node that checks + // the inputs and converts one/both of them as required. See the op code + // comments for details. + // + // Insert input conversion node as parent of 'n' from graph 'g'. + // + // @return Success:OK() if insertion is successful, otherwise returns + // appropriate error status code. + Status InsertInputConversionNode(std::unique_ptr* g, Node*); }; // We register MklToTf insertion for phase 2 in post-partition grouping @@ -171,6 +201,92 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge( return Status::OK(); } +Status MklToTfConversionPass::InsertInputConversionNode( + std::unique_ptr* g, Node* n) { + CHECK_NOTNULL(n); + + // Get the input nodes and edges + std::vector edges; + TF_CHECK_OK(n->input_edges(&edges)); + if (edges.size() != 4) { + return Status(error::Code::INVALID_ARGUMENT, + "MKL Binary Element-wise op should have exactly 2 data" + " inputs and 2 metadata inputs"); + } + + // Sanity check: ensure that both inputs are of the expected type, and the + // same type as input type + CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())), + BaseType(edges[1]->src()->output_type(edges[1]->src_output()))); + CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())), + BaseType(n->input_type(0))); + + // Check ordering of edges + for (uint i = 0; i < 4; i++) { + CHECK_EQ((edges[i]->dst_input() == i), true); + } + + // Build the conversion node and specify src as input. + Node* conversion_node = nullptr; + + TF_CHECK_OK( + NodeBuilder((*g)->NewName("MklInputConversion"), "_MklInputConversion") + .Input(edges[0]->src(), edges[0]->src_output()) + .Input(edges[1]->src(), edges[1]->src_output()) + .Input(edges[2]->src(), edges[2]->src_output()) + .Input(edges[3]->src(), edges[3]->src_output()) + .Device(n->def().device()) + .Attr("T", n->input_type(0)) + .Finalize(&**g, &conversion_node)); + + CHECK_NOTNULL(conversion_node); + + // Change the destination of any control edges to the InputConversion node + if (edges.size() != n->in_edges().size()) { + std::vector edges_to_remove; + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + CHECK_NOTNULL((*g)->AddControlEdge(e->src(), conversion_node)); + edges_to_remove.push_back(e); + } + } + for (const Edge* e : edges_to_remove) { + (*g)->RemoveEdge(e); + } + } + + string data_format; + if (GetNodeAttr(edges[0]->src()->def(), "data_format", &data_format) == + Status::OK()) { + conversion_node->AddAttr("data_format", data_format); + } + + // Get assigned device from destination node and apply it to conversion node. + // We want conversion node to be on the same device as the destination node. + conversion_node->set_assigned_device_name(n->assigned_device_name()); + + // Set the Mkl op label for this op. + conversion_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel); + + // Now that we have added edges from src->conversion_node, let's add edge from + // output of conversion_node to the element-wise node. + CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, n, edges[0]->dst_input())); + CHECK_NOTNULL((*g)->AddEdge(conversion_node, 1, n, edges[1]->dst_input())); + CHECK_NOTNULL((*g)->AddEdge(conversion_node, 2, n, edges[2]->dst_input())); + CHECK_NOTNULL((*g)->AddEdge(conversion_node, 3, n, edges[3]->dst_input())); + + VLOG(1) << "MklToTfConversionPass - InputConversion: Inserting input " + << "conversion node on: " << n->type_string() << " successful."; + + // Remove src->dst edge now. + (*g)->RemoveEdge(edges[0]); + (*g)->RemoveEdge(edges[1]); + (*g)->RemoveEdge(edges[2]); + (*g)->RemoveEdge(edges[3]); + + return Status::OK(); +} + bool MklToTfConversionPass::RunPass(std::unique_ptr* g) { bool result = false; @@ -239,6 +355,49 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr* g) { DumpGraph("After MklToTfConversionPass", &**g); + //--------------------------------------------------------------------------- + // Check all nodes and add an input-conversion-node if the node is an mkl + // element-wise node. + VLOG(1) << "Before running MklToTfConversionPass - InputConversion"; + + std::vector candidate_nodes; + std::vector order; + GetReversePostOrder(**g, &order); // This will give us topological sort. + + for (Node* n : order) { + // If node is not an op or it does not have a datatype, then skip. + DataType datatype; + if (!n->IsOp() || (GetNodeAttr(n->def(), "T", &datatype) != Status::OK())) { + continue; + } + if (IsMklElementWiseOp(n->type_string(), datatype)) { + // If the input node is an input-conversion op, skip + Node* input_node = nullptr; + TF_CHECK_OK(n->input_node(0, &input_node)); + DataType input_datatype; + if ((GetNodeAttr(n->def(), "T", &input_datatype) == Status::OK()) && + (input_node->type_string().compare("_MklInputConversion") == 0)) { + continue; + } + + VLOG(1) << "MklToTfConversionPass: InputConversion: Scheduled node " + << n->name() << " for inserting input conversion node"; + candidate_nodes.push_back(const_cast(n)); + } + } + + // Process all candidate edges and insert conversion nodes on them. + for (Node* n : candidate_nodes) { + // Even if we insert conversion node on a single node, we + // need to return true. + if (InsertInputConversionNode(g, n) == Status::OK()) { + VLOG(1) << "MklToTfConversionPass: Inserted conversion " + << "on node " << n->name(); + result = true; + } + } + DumpGraph("After MklToTfConversionPass - InputConversion", &**g); + // We need to return true even if we insert one conversion node // anywhere in the graph. return result; diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc index 90bef111648452f823a669cab3c063377ed7bdef..b01818f7461def79905df8e219c77922ac9f3e5f 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc @@ -173,13 +173,13 @@ TEST_F(MklToTfConversionPass, Positive) { EXPECT_EQ(DoRunMklToTfConversionPass(), "A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);" "Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;" - "C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3"); + "C:2->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3"); } } // MklConv2D followed by MklToTf op followed by Non-Mkl layer. // C=MklConv2D(A,M,B,N); D=MklToTf(C:0, C:1) F=Sub(D,E) (for interleaved) -// C=MklConv2D(A,B,M,N); D=MklToTf(C:0, C:1) F=Sub(D,E) (for contiguous) +// C=MklConv2D(A,B,M,N); D=MklToTf(C:0, C:2) F=Sub(D,E) (for contiguous) // MklToTf node should not be inserted again. TEST_F(MklToTfConversionPass, Negative_DoubleInsert) { if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { @@ -226,7 +226,7 @@ TEST_F(MklToTfConversionPass, Negative_DoubleInsert) { "node { name: 'D' op: '_MklToTf'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['C:0', 'C:1']}" + " input: ['C:0', 'C:2']}" "node { name: 'E' op: 'Input'}" "node { name: 'F' op: 'Sub'" " attr {key: 'T' value { type: DT_FLOAT } }" @@ -234,7 +234,7 @@ TEST_F(MklToTfConversionPass, Negative_DoubleInsert) { EXPECT_EQ(DoRunMklToTfConversionPass(), "A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);" "F(Sub);M(_MklInput);N(_MklInput)|" - "A->C;B->C:1;C->D;C:1->D:1;D->F;E->F:1;M->C:2;N->C:3"); + "A->C;B->C:1;C->D;C:2->D:1;D->F;E->F:1;M->C:2;N->C:3"); } } diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index 152d33fe051feae792feb24d1fcc95246246f768..c4ae5b79e4c02e2daeba1ef23a5a690d39cd7edb 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -91,6 +91,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler/inputs:utils", @@ -121,6 +122,7 @@ cc_test( "//tensorflow/cc:grad_testutil", "//tensorflow/cc:gradients", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 0ab6aff250b99c21341567145b26deb62556f232..1b1c88f2df41026f324081aa3425c2423e3370ff 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -396,6 +396,18 @@ Status GraphProperties::InferStatically() { } input_properties.push_back(properties); } + for (const auto& edge : node->in_edges()) { + if (!edge->src()->IsConstant()) { + continue; + } + const int input_id = edge->dst_input(); + if (input_id >= input_properties.size()) { + continue; + } + const NodeDef& node = edge->src()->def(); + const TensorProto& raw_val = node.attr().at("value").tensor(); + *input_properties[input_id].mutable_value() = raw_val; + } input_properties_[node->name()] = input_properties; // TODO(bsteiner): share this code with the input processing above. diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 954c5ead8fc1938f209252bedffbf480fcb8c2ac..461e58cf7362b2e380963de347bca638614fe07d 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -345,6 +345,15 @@ TEST_F(GraphPropertiesTest, MergeWithoutLoops) { EXPECT_EQ(DT_FLOAT, prop.dtype()); EXPECT_EQ(expected_outputs[i], PropToString(prop)); } + + // The "Less" node should be fed by 2 int32 scalar constant values. + const auto props = properties.GetInputProperties("Less"); + EXPECT_EQ(2, props.size()); + for (int i = 0; i < props.size(); ++i) { + EXPECT_EQ(DT_INT32, props[i].dtype()); + EXPECT_TRUE(props[i].has_value()); + EXPECT_EQ("int32: []", PropToString(props[i])); + } } TEST_F(GraphPropertiesTest, WhileLoop) { diff --git a/tensorflow/core/grappler/costs/op_performance_data.proto b/tensorflow/core/grappler/costs/op_performance_data.proto index 0d6b337d5a30f640a2295e108b5423b24bbc3e37..1a111b71dc5ee82650cd5c772dfce9abcb32931b 100644 --- a/tensorflow/core/grappler/costs/op_performance_data.proto +++ b/tensorflow/core/grappler/costs/op_performance_data.proto @@ -48,6 +48,16 @@ message OpInfo { DeviceProperties device = 4; } +message NormalDistribution { + double mu = 1; + double sigma = 2; +} + +message LogNormalDistribution { + double mu = 1; + double sigma = 2; +} + // Performance data for tensorflow operations message OpPerformance { // The op @@ -75,6 +85,12 @@ message OpPerformance { // Percentage of theoretical memory performance. double memory_efficiency = 8; + // Expected execution time, modeled using one of 2 possible distributions. + oneof execution_time { + NormalDistribution execution_time_normal = 10; + LogNormalDistribution execution_time_log_normal = 11; + }; + // Memory usage data for a tensorflow operation. message OpMemory { // The output information may have memory usage and output shapes. diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index b740e8a999ea1a866a54e50d6998a50dedf59893..33081061c01d5d715b3573de54e08ee95acc537c 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -37,6 +37,8 @@ limitations under the License. #include "tensorflow/core/grappler/inputs/utils.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/protobuf_internal.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/saver.pb.h" #include "tensorflow/core/public/session_options.h" @@ -163,6 +165,104 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( return nullptr; } + // TODO(yuefengz): consider handling saved_model_main_op and legacy_init_op. + // The reason why they are difficult to handle is because they may not intend + // to initialize all variables that are required to run fetch nodes. We may + // have to run restore op first. + + // Try to find initializers from variables and tables as init ops. + for (const string& var_collection : + {"variables", "local_variables", "model_variables", + "trainable_variables"}) { + if (meta_graph.collection_def().count(var_collection) == 0) { + continue; + } + const CollectionDef& vars = meta_graph.collection_def().at(var_collection); + for (const auto& raw_var : vars.bytes_list().value()) { + VariableDef var; + var.ParseFromString(raw_var); + if (!var.initializer_name().empty()) { + new_item->init_ops.push_back(NodeName(var.initializer_name())); + } + } + } + + if (meta_graph.collection_def().count("table_initializer") > 0) { + const CollectionDef& inits = + meta_graph.collection_def().at("table_initializer"); + if (inits.has_node_list()) { + for (const auto& node : inits.node_list().value()) { + new_item->init_ops.push_back(NodeName(node)); + // Tables are initialized from files, which can take a long time. Add + // 30 minutes to the initialization time for each table to avoid + // timing out. + // TODO(bsteiner): adjust the timeout based on the file size. + new_item->expected_init_time += 30 * 60; + } + } + } + + // We keep the mapping from asset node to asset files. This should have been + // used as feed but since asset node is usually a constant node, we will fill + // the values of these constant nodes with their actual asset file paths. + std::unordered_map asset_node_to_value; + + // Assets file may have changed their directory, we assemble their new paths + // if assets_directory_override is set. We also make sure we still can + // access these asset files. + if (!cfg.assets_directory_override.empty()) { + if (meta_graph.collection_def().count("saved_model_assets") > 0) { + const CollectionDef& collection = + meta_graph.collection_def().at("saved_model_assets"); + const auto& any_assets = collection.any_list().value(); + for (const auto& any_asset : any_assets) { + AssetFileDef asset_file_def; + if (!ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef") + .ok()) { + LOG(ERROR) << "Failed to parse AssetFile."; + continue; + } + string asset_filepath = io::JoinPath(cfg.assets_directory_override, + asset_file_def.filename()); + if (!FilesExist({asset_filepath}, nullptr)) { + LOG(ERROR) << "Can't access one or more of the asset files " + << asset_filepath << ", skipping this input"; + return nullptr; + } + asset_node_to_value[NodeName(asset_file_def.tensor_info().name())] = + asset_filepath; + } + } + } else if (meta_graph.collection_def().count("asset_filepaths") > 0) { + const CollectionDef& file_paths = + meta_graph.collection_def().at("asset_filepaths"); + std::vector paths; + for (const auto& raw_path : file_paths.bytes_list().value()) { + paths.push_back(raw_path); + } + if (!FilesExist(paths, nullptr)) { + LOG(ERROR) << "Can't access one or more of the asset files, skipping " + "this input"; + return nullptr; + } + } + + if (meta_graph.collection_def().count("queue_runners") > 0) { + const CollectionDef& vars = meta_graph.collection_def().at("queue_runners"); + for (const auto& raw : vars.bytes_list().value()) { + QueueRunnerDef queue_runner; + if (!queue_runner.ParseFromString(raw)) { + LOG(ERROR) << "Could not parse queue_runners, skipping this input"; + return nullptr; + } + if (queue_runner.cancel_op_name().empty()) { + LOG(ERROR) << "Queue without a cancel op, skipping this input"; + return nullptr; + } + new_item->queue_runners.push_back(queue_runner); + } + } + for (auto& node : *new_item->graph.mutable_node()) { if (IsPlaceholder(node)) { if (node.attr().count("dtype") == 0) { @@ -248,6 +348,24 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( // inferring shapes and is a no-op when dynamically inferring shapes as // the Placeholder shape will match the shape passed from new_item->feed. *(node.mutable_attr()->at("shape").mutable_shape()) = shape_proto; + } else if (IsConstant(node)) { + auto it = asset_node_to_value.find(node.name()); + if (it != asset_node_to_value.end()) { + auto iter = node.mutable_attr()->find("value"); + if (iter == node.attr().end()) { + LOG(ERROR) << "Value attribute expected in const op for asset files"; + return nullptr; + } + if (!iter->second.has_tensor() || + iter->second.tensor().string_val_size() != 1) { + LOG(INFO) << "Unexected AttrValue proto: " + << iter->second.DebugString(); + return nullptr; + } + LOG(INFO) << "Using asset file " << it->second << " for node " + << node.name(); + *(iter->second.mutable_tensor()->mutable_string_val(0)) = it->second; + } } // Erase the recorded result of any previous shape inference to start again @@ -268,71 +386,6 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( } } - for (const string& var_collection : - {"variables", "local_variables", "model_variables", - "trainable_variables"}) { - if (meta_graph.collection_def().count(var_collection) == 0) { - continue; - } - const CollectionDef& vars = meta_graph.collection_def().at(var_collection); - for (const auto& raw_var : vars.bytes_list().value()) { - VariableDef var; - var.ParseFromString(raw_var); - if (!var.initializer_name().empty()) { - new_item->init_ops.push_back(var.initializer_name()); - } - } - } - - if (meta_graph.collection_def().count("table_initializer") > 0) { - const CollectionDef& inits = - meta_graph.collection_def().at("table_initializer"); - if (inits.has_node_list()) { - for (const auto& node : inits.node_list().value()) { - new_item->init_ops.push_back(node); - // Tables are initialized from files, which can take a long time. Add 30 - // minutes to the initialization time for each table to avoid timing - // out. - // TODO(bsteiner): adjust the timeout based on the file size. - new_item->expected_init_time += 30 * 60; - } - } - } - - if (meta_graph.collection_def().count("queue_runners") > 0) { - const CollectionDef& vars = meta_graph.collection_def().at("queue_runners"); - for (const auto& raw : vars.bytes_list().value()) { - QueueRunnerDef queue_runner; - if (!queue_runner.ParseFromString(raw)) { - LOG(ERROR) << "Could parse queue_runners, skipping this input"; - return nullptr; - } - if (queue_runner.cancel_op_name().empty()) { - LOG(ERROR) << "Queue without a cancel op, skipping this input"; - return nullptr; - } - new_item->queue_runners.push_back(queue_runner); - } - } - - // Make sure we still can access the input files (aka "asset_filepaths") since - // these might have been moved or deleted, the cns cell might have been shut - // down, or we might be running as a user who does not have access to the - // files. - if (meta_graph.collection_def().count("asset_filepaths") > 0) { - const CollectionDef& file_paths = - meta_graph.collection_def().at("asset_filepaths"); - std::vector paths; - for (const auto& raw_path : file_paths.bytes_list().value()) { - paths.push_back(raw_path); - } - if (!FilesExist(paths, nullptr)) { - LOG(ERROR) - << "Can't access one or more of the asset files, skipping this input"; - return nullptr; - } - } - if (meta_graph.collection_def().count("savers") > 0) { const CollectionDef& savers = meta_graph.collection_def().at("savers"); for (const auto& raw : savers.bytes_list().value()) { diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h index d385a1916effefa353b612adbb148d93f1eeca95..4ce5055e7a1e037814298bdadad2610d6b72ae39 100644 --- a/tensorflow/core/grappler/grappler_item_builder.h +++ b/tensorflow/core/grappler/grappler_item_builder.h @@ -45,6 +45,8 @@ struct ItemConfig { bool apply_optimizations; // If true, does inlining. bool inline_functions; + // If non-empty, override the directory of asset paths. + string assets_directory_override; }; // Factory method for creating a GrapplerItem from a MetaGraphDef. diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc index 048870f9e51e8e3daa422e6f8ceab77bd42e83cd..4272179d3cbef35362dc3330b5d1b3076df9bdb1 100644 --- a/tensorflow/core/grappler/grappler_item_builder_test.cc +++ b/tensorflow/core/grappler/grappler_item_builder_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/grappler_item_builder.h" +#include "google/protobuf/any.pb.h" #include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/gradients/grad_testutil.h" #include "tensorflow/cc/ops/functional_ops.h" @@ -22,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" @@ -121,6 +123,136 @@ TEST_F(GrapplerItemBuilderTest, SymbolicGradientInlining) { CountOpsWithNames(with_inline, ops_of_inline)); } +TEST_F(GrapplerItemBuilderTest, AssetFilepathOverrideTest) { + MetaGraphDef meta_graph; + + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output var = + ops::Variable(s.WithOpName("var"), TensorShape(), DataType::DT_FLOAT); + Output filename_node = + ops::Const(s.WithOpName("filename"), string("model"), TensorShape()); + Output tensor_name = + ops::Const(s.WithOpName("tensorname"), string("var"), TensorShape()); + Output restore = ops::Restore(s.WithOpName("restore"), filename_node, + tensor_name, DataType::DT_FLOAT); + Output assign = ops::Assign(s.WithOpName("assign"), var, restore); + + TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def())); + + string temp_dir = testing::TmpDir(); + + Env *env = Env::Default(); + string filename = + io::JoinPath(temp_dir, "grappler_item_builder_test_filename"); + env->DeleteFile(filename).IgnoreError(); + std::unique_ptr file_to_write; + TF_CHECK_OK(env->NewWritableFile(filename, &file_to_write)); + TF_CHECK_OK(file_to_write->Close()); + TF_CHECK_OK(env->FileExists(filename)); + LOG(INFO) << filename; + + AssetFileDef asset_file_def; + *asset_file_def.mutable_tensor_info()->mutable_name() = "filename"; + *asset_file_def.mutable_filename() = "grappler_item_builder_test_filename"; + + (*meta_graph.mutable_collection_def())["saved_model_assets"] + .mutable_any_list() + ->add_value() + ->PackFrom(asset_file_def); + *((*meta_graph.mutable_collection_def())["train_op"] + .mutable_node_list() + ->add_value()) = "assign"; + + ItemConfig cfg; + cfg.assets_directory_override = temp_dir; + + std::unique_ptr item = + GrapplerItemFromMetaGraphDef("0", meta_graph, cfg); + ASSERT_TRUE(item != nullptr); + for (const NodeDef &node : item->graph.node()) { + if (node.name() == "filename") { + const auto iter = node.attr().find("value"); + ASSERT_TRUE(iter != node.attr().end()); + ASSERT_TRUE(iter->second.has_tensor()); + ASSERT_EQ(1, iter->second.tensor().string_val_size()); + + string tensor_string_val = iter->second.tensor().string_val(0); + EXPECT_EQ(tensor_string_val, filename); + } + } +} + +TEST_F(GrapplerItemBuilderTest, AssetFilepathOverrideTest_FileNotAccessible) { + MetaGraphDef meta_graph; + + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output var = + ops::Variable(s.WithOpName("var"), TensorShape(), DataType::DT_FLOAT); + Output filename_node1 = + ops::Const(s.WithOpName("filename1"), string("model1"), TensorShape()); + Output filename_node2 = + ops::Const(s.WithOpName("filename2"), string("model2"), TensorShape()); + Output tensor_name = + ops::Const(s.WithOpName("tensorname"), string("var"), TensorShape()); + Output restore1 = ops::Restore(s.WithOpName("restore1"), filename_node1, + tensor_name, DataType::DT_FLOAT); + Output restore2 = ops::Restore(s.WithOpName("restore2"), filename_node1, + tensor_name, DataType::DT_FLOAT); + Output assign1 = ops::Assign(s.WithOpName("assign1"), var, restore1); + Output assign2 = ops::Assign(s.WithOpName("assign2"), var, restore2); + + TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def())); + + string temp_dir = testing::TmpDir(); + + // Create the first AssetFileDef that has a valid file. + Env *env = Env::Default(); + string filename1 = + io::JoinPath(temp_dir, "grappler_item_builder_test_filename1"); + env->DeleteFile(filename1).IgnoreError(); + std::unique_ptr file_to_write; + TF_CHECK_OK(env->NewWritableFile(filename1, &file_to_write)); + TF_CHECK_OK(file_to_write->Close()); + TF_CHECK_OK(env->FileExists(filename1)); + + AssetFileDef asset_file_def1; + *asset_file_def1.mutable_tensor_info()->mutable_name() = "filename1"; + *asset_file_def1.mutable_filename() = "grappler_item_builder_test_filename1"; + + // Create the second AssetFileDef that has not a valid file. + string filename2 = + io::JoinPath(temp_dir, "grappler_item_builder_test_filename1"); + env->DeleteFile(filename2).IgnoreError(); + EXPECT_FALSE(env->FileExists(filename2).ok()); + + AssetFileDef asset_file_def2; + *asset_file_def2.mutable_tensor_info()->mutable_name() = "filename2"; + *asset_file_def2.mutable_filename() = "grappler_item_builder_test_filename2"; + + (*meta_graph.mutable_collection_def())["saved_model_assets"] + .mutable_any_list() + ->add_value() + ->PackFrom(asset_file_def1); + (*meta_graph.mutable_collection_def())["saved_model_assets"] + .mutable_any_list() + ->add_value() + ->PackFrom(asset_file_def2); + + *((*meta_graph.mutable_collection_def())["train_op"] + .mutable_node_list() + ->add_value()) = "assign1"; + *((*meta_graph.mutable_collection_def())["train_op"] + .mutable_node_list() + ->add_value()) = "assign2"; + + ItemConfig cfg; + cfg.assets_directory_override = temp_dir; + + std::unique_ptr item = + GrapplerItemFromMetaGraphDef("0", meta_graph, cfg); + ASSERT_TRUE(item == nullptr); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index a7515786a050e73b558981faeff842288c030566..659451e9913199164ef8a07562f6ed9f39b76063 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -275,6 +275,7 @@ cc_library( "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/utils:frame", ], ) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index d6684d9e890ca839a8d8dbcf0433c114d7dad837..d5f74017851bb2746f31b8f3e76cdb9843e0dd87 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -138,7 +138,10 @@ bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const { if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { return false; } - if (IsEnter(node) || IsPlaceholder(node)) { + if (IsEnter(node) || IsExit(node) || IsPlaceholder(node)) { + return false; + } + if (node.device().find("SPU") != string::npos) { return false; } const OpDef* op_def = nullptr; diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 016f78fcc56893e282c510f593a80fa842c3b685..443c0b72abc3699b400d42c47a6904c074e4089a 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -94,10 +94,6 @@ class DeviceSimple : public DeviceBase { std::unique_ptr eigen_device_; }; -string AsControlDependency(const NodeDef& node) { - return strings::StrCat("^", node.name()); -} - } // namespace ConstantFolding::ConstantFolding() { @@ -230,6 +226,7 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item, CHECK_LE(1, node.input_size()); string ctrl_dep = AddControlDependency(node.input(0)); node.set_input(0, ctrl_dep); + node_map_->AddOutput(NodeName(ctrl_dep), node.name()); } } } @@ -462,7 +459,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, return Status::OK(); } -Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output) { +Status ConstantFolding::FoldNode(NodeDef* node) { if (IsMerge(*node)) { // Merge nodes are special, in the sense that they execute as soon as one of // their input is ready. We can therefore fold a merge node iff it has at @@ -511,14 +508,15 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output) { "already present in the graph")); } - NodeDef* const_out = output->add_node(); + NodeDef* const_out = added_graph_.add_node(); *const_out = *input_node; const_out->set_name(const_out_name); const_out->set_device(node->device()); *const_out->add_input() = AsControlDependency(*node); node_map_->AddNode(const_out->name(), const_out); + node_map_->AddOutput(node->name(), const_out->name()); - NodeDef* const_index = output->add_node(); + NodeDef* const_index = added_graph_.add_node(); const_index->set_op("Const"); Tensor index(DT_INT32, TensorShape({})); index.flat()(0) = input_index; @@ -529,6 +527,7 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output) { const_index->set_device(node->device()); *const_index->add_input() = AsControlDependency(*node); node_map_->AddNode(const_index->name(), const_index); + node_map_->AddOutput(node->name(), const_index->name()); auto outputs = node_map_->GetOutputs(node->name()); for (auto& output : outputs) { @@ -538,8 +537,10 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output) { if (node_name == node->name()) { if (position == 0) { *output->mutable_input(i) = const_out->name(); + node_map_->AddOutput(const_out->name(), output->name()); } else if (position == 1) { *output->mutable_input(i) = const_index->name(); + node_map_->AddOutput(const_index->name(), output->name()); } else { // This is a control dependency (or an invalid edge since the // merge node has only 2 inputs): preserve them. @@ -565,13 +566,18 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output) { continue; } + // Forward control dependencies. for (const auto& input : node->input()) { - if (IsControlInput(input)) { + if (IsControlInput(input) && + std::find(const_node->input().begin(), const_node->input().end(), + input) == const_node->input().end()) { *const_node->add_input() = input; } else { NodeDef* input_node = node_map_->GetNode(input); for (const auto& fanin_of_input : input_node->input()) { - if (IsControlInput(fanin_of_input)) { + if (IsControlInput(fanin_of_input) && + std::find(const_node->input().begin(), const_node->input().end(), + fanin_of_input) == const_node->input().end()) { *const_node->add_input() = fanin_of_input; } } @@ -582,8 +588,15 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output) { // create new nodes otherwise. if (const_nodes.size() == 1) { node->set_op("Const"); + // Note we need to clear the inputs in NodeMap before we clear the inputs + // in the node, otherwise NodeMap would see empty inputs and effectively + // does nothing. + node_map_->RemoveInputs(node->name()); node->clear_input(); *node->mutable_input() = const_node->input(); + for (const auto& input : node->input()) { + node_map_->AddOutput(NodeName(input), node->name()); + } *node->mutable_attr() = const_node->attr(); break; } else { @@ -592,10 +605,13 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output) { return errors::AlreadyExists(strings::StrCat( const_node->name(), "already present in the graph")); } - NodeDef* added_node = output->add_node(); + NodeDef* added_node = added_graph_.add_node(); *added_node = *const_node; added_node->set_device(node->device()); node_map_->AddNode(added_node->name(), added_node); + for (const auto& input : added_node->input()) { + node_map_->AddOutput(NodeName(input), added_node->name()); + } // All the constant nodes encoding output values have the same control // dependencies (since these are the control dependencies of the node // we're trying to fold). Record one such constant node. @@ -614,11 +630,15 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output) { // Propagate control dependencies if possible. If not, we'll just // preserve the existing control dependencies. if (constant_output != nullptr) { + node_map_->UpdateInput(node_name, NodeName(output->input(i)), + constant_output->name()); *output->mutable_input(i) = AsControlDependency(*constant_output); } } else if (position < const_nodes.size() && !const_nodes[position].name().empty()) { // Replace alive outputs with the corresponding constant. + node_map_->UpdateInput(output->name(), NodeName(output->input(i)), + const_nodes[position].name()); *output->mutable_input(i) = const_nodes[position].name(); } else { // Leave this edge alone. @@ -629,6 +649,12 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output) { } } } + outputs = node_map_->GetOutputs(node->name()); + if (outputs.empty() && has_fetch_ && + nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) { + node_map_->RemoveInputs(node->name()); + node->clear_input(); + } } return Status::OK(); } @@ -648,12 +674,13 @@ Status ConstantFolding::FoldGraph(GraphDef* output) { if (processed_nodes.count(node->name())) { continue; } - Status s = FoldNode(node, output); + // We need to record a copy of output nodes before FoldNode() modifies it. + std::set outputs = node_map_->GetOutputs(node->name()); + Status s = FoldNode(node); processed_nodes.insert(node->name()); if (!s.ok()) { VLOG(1) << "Failed to fold node " << node->name() << ": " << s; } else { - auto outputs = node_map_->GetOutputs(node->name()); for (auto& output : outputs) { if (IsFoldable(*output)) { queue.push_back(output); @@ -662,11 +689,24 @@ Status ConstantFolding::FoldGraph(GraphDef* output) { } } - // Build the graph after constant folding. Note that we keep all processed - // nodes in the graph in case users need to fetch their values. + // Build the graph after constant folding. + for (const auto& node : added_graph_.node()) { + auto outputs = node_map_->GetOutputs(node.name()); + if (!outputs.empty()) { + auto added_node = output->add_node(); + *added_node = node; + } + } for (const auto& node : graph_.node()) { - auto added_node = output->add_node(); - *added_node = node; + // If no fetch nodes is provided, we conservatively + // keep all nodes in the original graph in case users need to fetch + // their values. + auto outputs = node_map_->GetOutputs(node.name()); + if (!outputs.empty() || !has_fetch_ || + nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { + auto added_node = output->add_node(); + *added_node = node; + } } return Status::OK(); } @@ -804,16 +844,12 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { graph_ = item.graph; node_map_.reset(new NodeMap(&graph_)); - for (const auto& node : item.fetch) { - nodes_to_preserve_.insert(NodeName(node)); - } - for (const auto& node : item.feed) { - nodes_to_preserve_.insert(NodeName(node.first)); - } + nodes_to_preserve_ = item.NodesToPreserve(); device_.reset(new DeviceSimple()); *output = GraphDef(); bool has_feed = !item.feed.empty(); + has_fetch_ = !item.fetch.empty(); GraphProperties properties(item); if (!has_feed) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 82ca5659101bbb53b1feb762718cf47aa9bf9988..0c1c40dfd343eff741f0fb336636c9671e76af05 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -60,7 +60,7 @@ class ConstantFolding : public GraphOptimizer { Status EvaluateOneFoldable(const NodeDef& node, std::vector* outputs); - Status FoldNode(NodeDef* node, GraphDef* output); + Status FoldNode(NodeDef* node); Status FoldGraph(GraphDef* output); @@ -73,7 +73,9 @@ class ConstantFolding : public GraphOptimizer { std::unique_ptr resource_mgr_; GraphDef graph_; std::unique_ptr node_map_; - std::set nodes_to_preserve_; + std::unordered_set nodes_to_preserve_; + GraphDef added_graph_; + bool has_fetch_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index a1d35da5acbe8aa5d9748216fa2263dff88ccb2b..0f7e7f1d494136d688cd4f3134561fcea7cea0f6 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -62,30 +62,27 @@ TEST_F(ConstantFoldingTest, SimpleFolding) { Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(4, output.node_size()); - - const NodeDef& node_a = output.node(0); - EXPECT_EQ("a", node_a.name()); + EXPECT_EQ(3, output.node_size()); - const NodeDef& node_b = output.node(1); + const NodeDef& node_b = output.node(0); EXPECT_EQ("b", node_b.name()); - const NodeDef& node_c = output.node(2); + const NodeDef& node_c = output.node(1); EXPECT_EQ("c", node_c.name()); EXPECT_EQ("Const", node_c.op()); EXPECT_EQ("/CPU:0", node_c.device()); - const NodeDef& node_d = output.node(3); + const NodeDef& node_d = output.node(2); EXPECT_EQ("d", node_d.name()); EXPECT_EQ("c", node_d.input(1)); EXPECT_EQ("", node_d.device()); - std::vector fetch = {"a", "b", "c", "d"}; + std::vector fetch = {"b", "c", "d"}; auto tensors_expected = EvaluateNodes(item.graph, fetch); auto tensors = EvaluateNodes(output, fetch); - EXPECT_EQ(4, tensors_expected.size()); - EXPECT_EQ(4, tensors.size()); - for (int i = 0; i < 4; i++) { + EXPECT_EQ(fetch.size(), tensors_expected.size()); + EXPECT_EQ(fetch.size(), tensors.size()); + for (int i = 0; i < fetch.size(); i++) { test::ExpectTensorEqual(tensors_expected[i], tensors[i]); } } @@ -98,10 +95,12 @@ TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) { auto b = ops::Unique(s.WithOpName("b"), {a}); Output c = ops::Identity(s.WithOpName("c"), {b.y}); Output d = ops::Identity(s.WithOpName("d"), {b.idx}); + Output e = ops::Identity(s.WithOpName("e"), {c}); + Output f = ops::Identity(s.WithOpName("f"), {d}); GrapplerItem item; - item.fetch.push_back("c"); - item.fetch.push_back("d"); + item.fetch.push_back("e"); + item.fetch.push_back("f"); TF_CHECK_OK(s.ToGraphDef(&item.graph)); ConstantFolding fold; @@ -109,36 +108,30 @@ TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) { Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(6, output.node_size()); - - const NodeDef& new_b_0 = output.node(0); - EXPECT_EQ("ConstantFolding/b-0", new_b_0.name()); - EXPECT_EQ("Const", new_b_0.op()); - - const NodeDef& new_b_1 = output.node(1); - EXPECT_EQ("ConstantFolding/b-1", new_b_1.name()); - EXPECT_EQ("Const", new_b_1.op()); - - const NodeDef& new_a = output.node(2); - EXPECT_EQ("a", new_a.name()); - - const NodeDef& new_b = output.node(3); - EXPECT_EQ("b", new_b.name()); + EXPECT_EQ(4, output.node_size()); - const NodeDef& new_c = output.node(4); + const NodeDef& new_c = output.node(0); EXPECT_EQ("c", new_c.name()); - EXPECT_EQ("ConstantFolding/b-0", new_c.input(0)); + EXPECT_EQ("Const", new_c.op()); - const NodeDef& new_d = output.node(5); + const NodeDef& new_d = output.node(1); EXPECT_EQ("d", new_d.name()); - EXPECT_EQ("ConstantFolding/b-1", new_d.input(0)); + EXPECT_EQ("Const", new_d.op()); + + const NodeDef& new_e = output.node(2); + EXPECT_EQ("e", new_e.name()); + EXPECT_EQ("c", new_e.input(0)); - std::vector fetch = {"a", "b", "c", "d"}; + const NodeDef& new_f = output.node(3); + EXPECT_EQ("f", new_f.name()); + EXPECT_EQ("d", new_f.input(0)); + + std::vector fetch = {"e", "f"}; auto tensors_expected = EvaluateNodes(item.graph, fetch); auto tensors = EvaluateNodes(output, fetch); - EXPECT_EQ(4, tensors_expected.size()); - EXPECT_EQ(4, tensors.size()); - for (int i = 0; i < 4; i++) { + EXPECT_EQ(fetch.size(), tensors_expected.size()); + EXPECT_EQ(fetch.size(), tensors.size()); + for (int i = 0; i < fetch.size(); i++) { test::ExpectTensorEqual(tensors_expected[i], tensors[i]); } } @@ -156,7 +149,7 @@ TEST_F(ConstantFoldingTest, ControlDependencies) { Output i3 = ops::Identity(scope.WithOpName("e"), {i2}); GrapplerItem item; - item.fetch.push_back("i3"); + item.fetch.push_back("e"); TF_CHECK_OK(scope.ToGraphDef(&item.graph)); ConstantFolding fold; @@ -164,8 +157,57 @@ TEST_F(ConstantFoldingTest, ControlDependencies) { Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + std::vector expected_nodes = {"dflt", "p1", "p2", "i2", "e"}; + EXPECT_EQ(output.node_size(), expected_nodes.size()); + int i = 0; int found = 0; for (const auto& node : output.node()) { + EXPECT_EQ(expected_nodes[i], output.node(i).name()); + i++; + if (node.name() == "i2") { + EXPECT_EQ("Const", node.op()); + ++found; + auto folded = EvaluateNodes(output, {"i2"}); + auto expected = EvaluateNodes(item.graph, {"i2"}); + EXPECT_EQ(1, expected.size()); + EXPECT_EQ(1, folded.size()); + test::ExpectTensorEqual(folded[0], expected[0]); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("^p1", node.input(0)); + EXPECT_EQ("^p2", node.input(1)); + } + } + EXPECT_EQ(1, found); +} + +TEST_F(ConstantFoldingTest, ControlDependenciesEmptyFetch) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1}); + Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1}); + Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1}); + Output c = + ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3}); + Output i1 = ops::Identity(scope.WithOpName("i1"), {c}); + Output i2 = + ops::Identity(scope.WithOpName("i2").WithControlDependencies(p2), {i1}); + Output i3 = ops::Identity(scope.WithOpName("e"), {i2}); + + GrapplerItem item; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding fold; + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + std::vector expected_nodes = {"dflt", "p1", "p2", "c", + "i1", "i2", "e"}; + EXPECT_EQ(output.node_size(), expected_nodes.size()); + int i = 0; + int found = 0; + for (const auto& node : output.node()) { + EXPECT_EQ(expected_nodes[i], output.node(i).name()); + i++; if (node.name() == "i1") { EXPECT_EQ("Const", node.op()); ++found; @@ -193,6 +235,43 @@ TEST_F(ConstantFoldingTest, ControlDependencies) { EXPECT_EQ(2, found); } +TEST_F(ConstantFoldingTest, ControlDependenciesDeduplicate) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + Output dflt = ops::Const(scope.WithOpName("dflt"), 3.14f, {1}); + Output p1 = ops::PlaceholderWithDefault(scope.WithOpName("p1"), dflt, {1}); + Output p2 = ops::PlaceholderWithDefault(scope.WithOpName("p2"), dflt, {1}); + Output c = + ops::Const(scope.WithOpName("c").WithControlDependencies(p1), 10, {3}); + Output i1 = ops::Identity(scope.WithOpName("i1") + .WithControlDependencies(p2) + .WithControlDependencies(p1), + {c}); + Output i2 = ops::Identity(scope.WithOpName("i2"), {i1}); + + GrapplerItem item; + item.fetch.push_back("i2"); + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding fold; + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + std::vector expected_nodes = {"dflt", "p1", "p2", "i1", "i2"}; + EXPECT_EQ(output.node_size(), expected_nodes.size()); + int i = 0; + for (const auto& node : output.node()) { + EXPECT_EQ(expected_nodes[i], output.node(i).name()); + i++; + if (node.name() == "i1") { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("^p1", node.input(0)); + EXPECT_EQ("^p2", node.input(1)); + } + } +} + TEST_F(ConstantFoldingTest, VariableNumberOfOutputs) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); // Add a DynamicPartition node to the graph @@ -290,6 +369,41 @@ TEST_F(ConstantFoldingTest, ShapeMaterialization) { Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + int found = 0; + for (const auto& node : output.node()) { + if (node.name() == "shape") { + ++found; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^v2", node.input(0)); + Tensor value; + CHECK(value.FromProto(node.attr().at("value").tensor())); + EXPECT_EQ(5, value.flat()(0)); + EXPECT_EQ(7, value.flat()(1)); + } + } + EXPECT_EQ(1, found); +} + +TEST_F(ConstantFoldingTest, ShapeMaterializationEmptyFetch) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT); + Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT); + Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT); + Output rank = ops::Rank(scope.WithOpName("rank"), v1); + Output shape = ops::Shape(scope.WithOpName("shape"), v2); + Output size = ops::Size(scope.WithOpName("size"), v3); + Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank); + Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape); + + GrapplerItem item; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding fold; + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + int found = 0; for (const auto& node : output.node()) { if (node.name() == "size") { @@ -322,7 +436,7 @@ TEST_F(ConstantFoldingTest, ShapeMaterialization) { EXPECT_EQ(3, found); } -TEST_F(ConstantFoldingTest, SwitchNodes) { +TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT); ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL); @@ -345,9 +459,6 @@ TEST_F(ConstantFoldingTest, SwitchNodes) { {statically_known.output, never_generated.output}); GrapplerItem item; - item.fetch.push_back("m"); - item.fetch.push_back("m2"); - TF_CHECK_OK(scope.ToGraphDef(&item.graph)); ConstantFolding fold; @@ -355,31 +466,111 @@ TEST_F(ConstantFoldingTest, SwitchNodes) { Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + std::set present_nodes = {"v_in", "v_ctrl", + "switch", "i", + "p1", "p2", + "m", "false", + "constant", "switch2", + "i2", "i3", + "m2", "ConstantFoldingCtrl/switch_0", + "rank", "size"}; + std::set not_present_nodes = {"ConstantFolding/switch2-0"}; + EXPECT_EQ(present_nodes.size(), output.node_size()); + int found = 0; for (const auto& node : output.node()) { + EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end()); + EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end()); + present_nodes.erase(node.name()); + not_present_nodes.erase(node.name()); if (node.name() == "rank") { + ++found; EXPECT_EQ("Const", node.op()); EXPECT_EQ(1, node.input_size()); EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0)); } if (node.name() == "size") { + ++found; EXPECT_EQ("Const", node.op()); EXPECT_EQ(1, node.input_size()); EXPECT_EQ("^i", node.input(0)); } - if (node.name() == "ConstantFolding/switch2-0") { + if (node.name() == "i2") { + ++found; EXPECT_EQ("Const", node.op()); EXPECT_EQ(0, node.input_size()); } - if (node.name() == "ConstantFolding/i2") { + if (node.name() == "i3") { + ++found; + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("switch2:1", node.input(0)); + } + } + EXPECT_EQ(4, found); +} + +TEST_F(ConstantFoldingTest, SwitchNodes) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT); + ops::Variable v_ctrl(scope.WithOpName("v_ctrl"), {}, DT_BOOL); + ops::Switch s1(scope.WithOpName("switch"), v_in, v_ctrl); + ops::Rank rank(scope.WithOpName("rank"), s1.output_false); + ops::Identity i(scope.WithOpName("i"), s1.output_true); + ops::Size size(scope.WithOpName("size"), i); + ops::Square p1(scope.WithOpName("p1"), rank); + ops::Square p2(scope.WithOpName("p2"), size); + ops::Merge m(scope.WithOpName("m"), {p1.y, p2.y}); + + Output predicate = + ops::Const(scope.WithOpName("false"), false, TensorShape({})); + Output constant = + ops::Const(scope.WithOpName("constant"), 1.0f, TensorShape({1})); + ops::Switch s2(scope.WithOpName("switch2"), constant, predicate); + ops::Identity statically_known(scope.WithOpName("i2"), s2.output_false); + ops::Identity never_generated(scope.WithOpName("i3"), s2.output_true); + ops::Merge m2(scope.WithOpName("m2"), + {statically_known.output, never_generated.output}); + + GrapplerItem item; + item.fetch.push_back("m"); + item.fetch.push_back("m2"); + + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding fold; + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + std::set present_nodes = {"v_in", "v_ctrl", + "switch", "i", + "p1", "p2", + "m", "false", + "constant", "switch2", + "i2", "i3", + "m2", "ConstantFoldingCtrl/switch_0"}; + std::set not_present_nodes = {"rank", "size", + "ConstantFolding/switch2-0"}; + EXPECT_EQ(present_nodes.size(), output.node_size()); + + int found = 0; + for (const auto& node : output.node()) { + EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end()); + EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end()); + present_nodes.erase(node.name()); + not_present_nodes.erase(node.name()); + if (node.name() == "i2") { + ++found; EXPECT_EQ("Const", node.op()); EXPECT_EQ(0, node.input_size()); } if (node.name() == "i3") { + ++found; EXPECT_EQ("Identity", node.op()); EXPECT_EQ(1, node.input_size()); EXPECT_EQ("switch2:1", node.input(0)); } } + EXPECT_EQ(2, found); } TEST_F(ConstantFoldingTest, MergeNodes) { @@ -411,7 +602,7 @@ TEST_F(ConstantFoldingTest, MergeNodes) { ops::Identity idx3(scope.WithOpName("idx3"), m3.value_index); GrapplerItem item; - item.fetch.push_back("out1, idx1, out2, idx2, out3, idx3"); + item.fetch = {"out1", "idx1", "out2", "idx2", "out3", "idx3"}; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); ConstantFolding fold; diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index f469f9a9acd5e557aed0b987e16d9145befd0368..a4b0a60e1ffde796638da3bb56ff667104b81341 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/frame.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -95,10 +96,84 @@ bool IsNodeNCHWToNHWC(const string& node_name) { return false; } -class NodeProcessor { +class GraphProcessor { public: - NodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : graph_(graph), node_(node), node_map_(node_map) {} + GraphProcessor(GraphDef* graph, NodeMap* node_map) + : graph_(graph), node_map_(node_map) {} + + protected: + NodeDef* AddNodePermConst(const string& name, const string& device, + const std::vector& permutation) { + NodeDef* node = graph_->add_node(); + node_map_->AddNode(name, node); + node->set_name(name); + node->set_op("Const"); + node->set_device(device); + AttrValue attr_data_type; + attr_data_type.set_type(DT_INT32); + node->mutable_attr()->insert({"dtype", attr_data_type}); + AttrValue attr_tensor; + Tensor tensor(DT_INT32, TensorShape({4})); + for (int i = 0; static_cast(i) < permutation.size(); i++) { + tensor.flat()(i) = permutation[i]; + } + tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); + node->mutable_attr()->insert({"value", attr_tensor}); + return node; + } + + NodeDef* AddNodeConstScalar(const string& name, const string& device, + DataType dtype, int value) { + NodeDef* node = graph_->add_node(); + node_map_->AddNode(name, node); + node->set_name(name); + node->set_op("Const"); + node->set_device(device); + AttrValue attr_data_type; + attr_data_type.set_type(dtype); + node->mutable_attr()->insert({"dtype", attr_data_type}); + AttrValue attr_tensor; + Tensor tensor(dtype, TensorShape({})); + tensor.scalar()() = value; + tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); + node->mutable_attr()->insert({"value", attr_tensor}); + return node; + } + + NodeDef* AddNodeReductionConst(const string& name, const string& device) { + NodeDef* node = graph_->add_node(); + node_map_->AddNode(name, node); + node->set_name(name); + node->set_op("Const"); + node->set_device(device); + AttrValue attr_data_type; + attr_data_type.set_type(DT_INT32); + node->mutable_attr()->insert({"dtype", attr_data_type}); + + AttrValue attr_tensor; + Tensor tensor(DT_INT32, TensorShape({3})); + std::vector axis = {0, 2, 3}; + for (int i = 0; static_cast(i) < axis.size(); i++) { + tensor.flat()(i) = axis[i]; + } + tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); + node->mutable_attr()->insert({"value", attr_tensor}); + return node; + } + + GraphDef* graph_; + NodeMap* node_map_; + + private: +}; + +class NodeProcessor : public GraphProcessor { + public: + NodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : GraphProcessor(graph, node_map), + node_(node), + is_in_frame_(is_in_frame) {} virtual ~NodeProcessor() {} virtual Status ConvertNode() { if (ShouldProcess()) { @@ -229,14 +304,14 @@ class NodeProcessor { } NodeDef* AddNodeTranspose(const string& node_name, const string& input_name, - DataType data_type, + const string& const_name, DataType data_type, const TensorShapeProto& input_shape, bool NHWCToNCHW) { NodeDef* node = graph_->add_node(); node_map_->AddNode(node_name, node); node->set_name(node_name); *node->add_input() = input_name; - *node->add_input() = NHWCToNCHW ? kPermNHWCToNCHW : kPermNCHWToNHWC; + *node->add_input() = const_name; node->set_op("Transpose"); node->set_device(node_->device()); AttrValue attr_data_type; @@ -276,8 +351,10 @@ class NodeProcessor { auto input_node = node_map_->GetNode(node_->input(pos)); TF_RETURN_IF_ERROR(HasAttribute(*node_, "T")); TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes")); + string const_name = GetOrAddNodePermNHWCToNCHW(pos); AddNodeTranspose( - node_name, node_->input(pos), node_->attr().at("T").type(), + node_name, node_->input(pos), const_name, + node_->attr().at("T").type(), input_node->attr().at("_output_shapes").list().shape(output_pos), true); node_map_->UpdateOutput(node_->input(pos), node_->name(), node_name); @@ -289,6 +366,7 @@ class NodeProcessor { virtual Status AddLayoutTransposeToOutputs() { auto outputs = node_map_->GetOutputs(node_->name()); + string const_name = GetOrAddNodePermNCHWToNHWC(); for (const auto& output : outputs) { string base_name = strings::StrCat(node_->name(), "-", output->name()); string node_name = @@ -315,9 +393,9 @@ class NodeProcessor { } TF_RETURN_IF_ERROR(HasAttribute(*node_, "T")); TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes")); - AddNodeTranspose(node_name, node_->name(), node_->attr().at("T").type(), - node_->attr().at("_output_shapes").list().shape(0), - false); + AddNodeTranspose( + node_name, node_->name(), const_name, node_->attr().at("T").type(), + node_->attr().at("_output_shapes").list().shape(0), false); *it = node_name; node_map_->UpdateOutput(node_->name(), output->name(), node_name); node_map_->AddOutput(node_name, output->name()); @@ -327,11 +405,56 @@ class NodeProcessor { virtual Status CustomizedProcessing() { return Status::OK(); } - GraphDef* graph_; + NodeDef* AddNodePermNHWCToNCHW(const string& suffix, + const string& depended_node, + const string& device) { + auto const_node = AddNodePermConst( + strings::StrCat(kPermNHWCToNCHW, "-", suffix), device, {0, 3, 1, 2}); + // This is to ensure the transpose node and the const node are in the + // same frame. + *const_node->add_input() = AsControlDependency(depended_node); + return const_node; + } + + NodeDef* AddNodePermNCHWToNHWC(const string& suffix, + const string& depended_node, + const string& device) { + auto const_node = AddNodePermConst( + strings::StrCat(kPermNCHWToNHWC, "-", suffix), device, {0, 2, 3, 1}); + // This is to ensure the transpose node and the const node are in the same + // frame. + *const_node->add_input() = AsControlDependency(depended_node); + return const_node; + } + NodeDef* node_; - NodeMap* node_map_; + bool is_in_frame_; private: + string GetOrAddNodePermNHWCToNCHW(int pos) { + string const_name; + if (is_in_frame_) { + auto const_node = AddNodePermNHWCToNCHW( + node_->input(pos), NodeName(node_->input(pos)), node_->device()); + const_name = const_node->name(); + } else { + const_name = kPermNHWCToNCHW; + } + return const_name; + } + + string GetOrAddNodePermNCHWToNHWC() { + string const_name; + if (is_in_frame_) { + auto const_node = + AddNodePermNCHWToNHWC(node_->name(), node_->name(), node_->device()); + const_name = const_node->name(); + } else { + const_name = kPermNCHWToNHWC; + } + return const_name; + } + void UpdateTuple(AttrValue_ListValue* list) { int64 h = list->i(1); int64 w = list->i(2); @@ -344,8 +467,9 @@ class NodeProcessor { class AvgPoolGradProcessor : public NodeProcessor { public: - AvgPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : NodeProcessor(graph, node, node_map) {} + AvgPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : NodeProcessor(graph, node, node_map, is_in_frame) {} protected: std::vector GetInputPos() const override { @@ -357,8 +481,9 @@ class AvgPoolGradProcessor : public NodeProcessor { class BiasAddGradProcessor : public NodeProcessor { public: - BiasAddGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : NodeProcessor(graph, node, node_map) {} + BiasAddGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : NodeProcessor(graph, node, node_map, is_in_frame) {} protected: bool ShouldProcess() const override { @@ -377,8 +502,8 @@ class BiasAddGradProcessor : public NodeProcessor { class Conv2DProcessor : public NodeProcessor { public: Conv2DProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, - bool no_gemm) - : NodeProcessor(graph, node, node_map), no_gemm_(no_gemm) {} + bool no_gemm, bool is_in_frame) + : NodeProcessor(graph, node, node_map, is_in_frame), no_gemm_(no_gemm) {} protected: bool ShouldProcess() const override { @@ -447,8 +572,9 @@ class Conv2DProcessor : public NodeProcessor { class Conv2DBackpropFilterProcessor : public Conv2DProcessor { public: Conv2DBackpropFilterProcessor(GraphDef* graph, NodeDef* node, - NodeMap* node_map, bool no_gemm) - : Conv2DProcessor(graph, node, node_map, no_gemm) {} + NodeMap* node_map, bool no_gemm, + bool is_in_frame) + : Conv2DProcessor(graph, node, node_map, no_gemm, is_in_frame) {} protected: bool IsGemmUsed() const override { @@ -472,8 +598,9 @@ class Conv2DBackpropFilterProcessor : public Conv2DProcessor { class Conv2DBackpropInputProcessor : public Conv2DProcessor { public: Conv2DBackpropInputProcessor(GraphDef* graph, NodeDef* node, - NodeMap* node_map, bool no_gemm) - : Conv2DProcessor(graph, node, node_map, no_gemm) {} + NodeMap* node_map, bool no_gemm, + bool is_in_frame) + : Conv2DProcessor(graph, node, node_map, no_gemm, is_in_frame) {} protected: bool IsGemmUsed() const override { @@ -492,8 +619,9 @@ class Conv2DBackpropInputProcessor : public Conv2DProcessor { class FusedBatchNormGradProcessor : public NodeProcessor { public: - FusedBatchNormGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : NodeProcessor(graph, node, node_map) {} + FusedBatchNormGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : NodeProcessor(graph, node, node_map, is_in_frame) {} protected: std::vector GetInputPos() const override { @@ -504,8 +632,9 @@ class FusedBatchNormGradProcessor : public NodeProcessor { class MaxPoolGradProcessor : public NodeProcessor { public: - MaxPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : NodeProcessor(graph, node, node_map) {} + MaxPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : NodeProcessor(graph, node, node_map, is_in_frame) {} protected: std::vector GetInputPos() const override { @@ -516,8 +645,9 @@ class MaxPoolGradProcessor : public NodeProcessor { class AgnosticNodeProcessor : public NodeProcessor { public: - AgnosticNodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : NodeProcessor(graph, node, node_map) {} + AgnosticNodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : NodeProcessor(graph, node, node_map, is_in_frame) {} protected: bool ShouldProcess() const override { @@ -548,8 +678,9 @@ class AgnosticNodeProcessor : public NodeProcessor { class AddNProcessor : public AgnosticNodeProcessor { public: - AddNProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) {} + AddNProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {} protected: std::vector GetInputPos() const override { @@ -564,8 +695,9 @@ class AddNProcessor : public AgnosticNodeProcessor { class BinaryOpProcessor : public AgnosticNodeProcessor { public: - BinaryOpProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) { + BinaryOpProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) { is_4d_with_vector_ = Is4DOperateWithVector(); } @@ -672,8 +804,9 @@ class BinaryOpProcessor : public AgnosticNodeProcessor { class ConcatProcessor : public AgnosticNodeProcessor { public: - ConcatProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) { + ConcatProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) { // For Concat, the concat axis is the first input; for ConcatV2, // the last input. axis_node_pos_ = @@ -698,8 +831,9 @@ class ConcatProcessor : public AgnosticNodeProcessor { } Status CustomizedProcessing() override { - node_map_->AddOutput(kConcatConst, node_->name()); - *node_->mutable_input(axis_node_pos_) = kConcatConst; + string concat_const_name = GetOrAddNodeConcatConst(); + node_map_->AddOutput(concat_const_name, node_->name()); + *node_->mutable_input(axis_node_pos_) = concat_const_name; return Status::OK(); } @@ -712,12 +846,38 @@ class ConcatProcessor : public AgnosticNodeProcessor { } int axis_node_pos_; + + private: + NodeDef* AddNodeConcatConst(const string& suffix, const string& depended_node, + const string& device) { + auto const_node = AddNodeConstScalar( + strings::StrCat(kConcatConst, "-", suffix), device, DT_INT32, 1); + // This is to ensure the concat node and the const node are + // in the same frame. + *const_node->add_input() = AsControlDependency(depended_node); + return const_node; + } + + string GetOrAddNodeConcatConst() { + string const_name; + if (is_in_frame_) { + int value_node_pos = (axis_node_pos_ == 0) ? 1 : 0; + auto const_node = AddNodeConcatConst( + node_->name(), NodeName(node_->input(value_node_pos)), + node_->device()); + const_name = const_node->name(); + } else { + const_name = kConcatConst; + } + return const_name; + } }; class ReluGradProcessor : public AgnosticNodeProcessor { public: - ReluGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) {} + ReluGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {} protected: std::vector GetInputPos() const override { @@ -728,8 +888,9 @@ class ReluGradProcessor : public AgnosticNodeProcessor { class SliceProcessor : public AgnosticNodeProcessor { public: - SliceProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) {} + SliceProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {} protected: Status CustomizedProcessing() override { @@ -749,14 +910,62 @@ class SliceProcessor : public AgnosticNodeProcessor { } private: + NodeDef* AddNodeGatherAxisConst(const string& suffix, + const string& depended_node, + const string& device) { + auto const_node = AddNodeConstScalar( + strings::StrCat(kGatherAxisConst, "-", suffix), device, DT_INT32, 0); + // This is to ensure the Slice node and the const node are + // in the same frame. + *const_node->add_input() = AsControlDependency(depended_node); + return const_node; + } + + string GetOrAddNodeGatherAxisConst() { + string const_name; + if (is_in_frame_) { + auto const_node = AddNodeGatherAxisConst( + node_->name(), NodeName(node_->input(0)), node_->device()); + const_name = const_node->name(); + } else { + const_name = kGatherAxisConst; + } + return const_name; + } + + string GetOrAddNodePermNHWCToNCHW() { + string const_name; + if (is_in_frame_) { + auto const_node = AddNodePermNHWCToNCHW( + node_->name(), NodeName(node_->input(0)), node_->device()); + const_name = const_node->name(); + } else { + const_name = kPermNHWCToNCHW; + } + return const_name; + } + + string GetOrAddNodePermNCHWToNHWC() { + string const_name; + if (is_in_frame_) { + auto const_node = AddNodePermNCHWToNHWC( + node_->name(), NodeName(node_->input(0)), node_->device()); + const_name = const_node->name(); + } else { + const_name = kPermNCHWToNHWC; + } + return const_name; + } + void AddNodePermVec(const string& node_name, const string& input_name, DataType data_type, bool NHWCToNCHW) { NodeDef* node = graph_->add_node(); node_map_->AddNode(node_name, node); node->set_name(node_name); *node->add_input() = input_name; - *node->add_input() = NHWCToNCHW ? kPermNHWCToNCHW : kPermNCHWToNHWC; - *node->add_input() = kGatherAxisConst; + *node->add_input() = NHWCToNCHW ? GetOrAddNodePermNHWCToNCHW() + : GetOrAddNodePermNCHWToNHWC(); + *node->add_input() = GetOrAddNodeGatherAxisConst(); node->set_op("GatherV2"); AttrValue attr_type_indices; @@ -782,8 +991,9 @@ class SliceProcessor : public AgnosticNodeProcessor { // before this optimization. class SliceProcessorConst : public AgnosticNodeProcessor { public: - SliceProcessorConst(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) {} + SliceProcessorConst(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {} protected: Status CustomizedProcessing() override { @@ -799,8 +1009,9 @@ class SliceProcessorConst : public AgnosticNodeProcessor { // example use case is in the gradient computation of Concat for InceptionV3. class SliceProcessorConcatOffset : public AgnosticNodeProcessor { public: - SliceProcessorConcatOffset(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) {} + SliceProcessorConcatOffset(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {} protected: Status CustomizedProcessing() override { @@ -849,8 +1060,9 @@ class SliceProcessorConcatOffset : public AgnosticNodeProcessor { class SqueezeProcessor : public AgnosticNodeProcessor { public: - SqueezeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) {} + SqueezeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {} protected: bool ShouldProcess() const override { @@ -898,8 +1110,9 @@ class SqueezeProcessor : public AgnosticNodeProcessor { class SumProcessor : public AgnosticNodeProcessor { public: - SumProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) - : AgnosticNodeProcessor(graph, node, node_map) {} + SumProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map, + bool is_in_frame) + : AgnosticNodeProcessor(graph, node, node_map, is_in_frame) {} protected: bool ShouldProcess() const override { @@ -913,7 +1126,7 @@ class SumProcessor : public AgnosticNodeProcessor { Status CustomizedProcessing() override { node_map_->AddOutput(kReductionConst, node_->name()); - *node_->mutable_input(1) = kReductionConst; + *node_->mutable_input(1) = GetOrAddNodeReductionConst(); return Status::OK(); } @@ -938,6 +1151,29 @@ class SumProcessor : public AgnosticNodeProcessor { } return false; } + + NodeDef* AddNodeReductionConst(const string& suffix, + const string& depended_node, + const string& device) { + auto const_node = GraphProcessor::AddNodeReductionConst( + strings::StrCat(kReductionConst, "-", suffix), device); + // This is to ensure the Sum node and the const node are in the + // same frame. + *const_node->add_input() = AsControlDependency(depended_node); + return const_node; + } + + string GetOrAddNodeReductionConst() { + string const_name; + if (is_in_frame_) { + auto const_node = AddNodeReductionConst( + node_->name(), NodeName(node_->input(0)), node_->device()); + const_name = const_node->name(); + } else { + const_name = kReductionConst; + } + return const_name; + } }; struct TuningConfig { @@ -951,13 +1187,12 @@ struct TuningConfig { bool no_gemm; }; -class DataLayoutOptimizer { +class DataLayoutOptimizer : GraphProcessor { public: explicit DataLayoutOptimizer(const string& default_device, GraphDef* graph, - TuningConfig config) - : default_device_(default_device), - graph_(graph), - node_map_(graph_), + NodeMap* node_map, TuningConfig config) + : GraphProcessor(graph, node_map), + default_device_(default_device), config_(config) {} Status Optimize() { @@ -970,105 +1205,65 @@ class DataLayoutOptimizer { } private: - NodeDef* AddNodePermConst(const string& name, - const std::vector& permutation) { - NodeDef* node = graph_->add_node(); - node_map_.AddNode(name, node); - node->set_name(name); - node->set_op("Const"); - node->set_device(default_device_); - AttrValue attr_data_type; - attr_data_type.set_type(DT_INT32); - node->mutable_attr()->insert({"dtype", attr_data_type}); - AttrValue attr_tensor; - Tensor tensor(DT_INT32, TensorShape({4})); - for (int i = 0; static_cast(i) < permutation.size(); i++) { - tensor.flat()(i) = permutation[i]; - } - tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); - node->mutable_attr()->insert({"value", attr_tensor}); - return node; + NodeDef* AddNodePermNHWCToNCHW() { + return AddNodePermConst(kPermNHWCToNCHW, default_device_, {0, 3, 1, 2}); } - NodeDef* AddConstScalar(const char* name, DataType dtype, int value) { - NodeDef* node = graph_->add_node(); - node_map_.AddNode(name, node); - node->set_name(name); - node->set_op("Const"); - node->set_device(default_device_); - AttrValue attr_data_type; - attr_data_type.set_type(dtype); - node->mutable_attr()->insert({"dtype", attr_data_type}); - AttrValue attr_tensor; - Tensor tensor(dtype, TensorShape({})); - tensor.scalar()() = value; - tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); - node->mutable_attr()->insert({"value", attr_tensor}); - return node; + NodeDef* AddNodePermNCHWToNHWC() { + return AddNodePermConst(kPermNCHWToNHWC, default_device_, {0, 2, 3, 1}); } NodeDef* AddNodeConcatConst() { - return AddConstScalar(kConcatConst, DT_INT32, 1); + return AddNodeConstScalar(kConcatConst, default_device_, DT_INT32, 1); } - NodeDef* AddGatherAxisConst() { - return AddConstScalar(kGatherAxisConst, DT_INT32, 0); + NodeDef* AddNodeGatherAxisConst() { + return AddNodeConstScalar(kGatherAxisConst, default_device_, DT_INT32, 0); } NodeDef* AddNodeReductionConst() { - NodeDef* node = graph_->add_node(); - node_map_.AddNode(kReductionConst, node); - node->set_name(kReductionConst); - node->set_op("Const"); - node->set_device(default_device_); - AttrValue attr_data_type; - attr_data_type.set_type(DT_INT32); - node->mutable_attr()->insert({"dtype", attr_data_type}); - - AttrValue attr_tensor; - Tensor tensor(DT_INT32, TensorShape({3})); - std::vector axis = {0, 2, 3}; - for (int i = 0; static_cast(i) < axis.size(); i++) { - tensor.flat()(i) = axis[i]; - } - tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); - node->mutable_attr()->insert({"value", attr_tensor}); - return node; + return GraphProcessor::AddNodeReductionConst(kReductionConst, + default_device_); } // Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic. Status Expand() { int node_size_original = graph_->node_size(); + std::unordered_map> frames; + IdentifyFrames(*graph_, &frames); + // This is the first pass where we expand the nodes which support NCHW. std::set ops_format_supported = GetOpsFormatSupported(); - for (int i = 0; i < graph_->node_size(); i++) { + for (int i = 0; i < node_size_original; i++) { if (ops_format_supported.find(graph_->node(i).op()) != ops_format_supported.end()) { auto node = graph_->mutable_node(i); + bool is_in_frame = !frames[node].empty(); std::unique_ptr node_processor; if (node->op().compare("AvgPoolGrad") == 0) { node_processor.reset( - new AvgPoolGradProcessor(graph_, node, &node_map_)); + new AvgPoolGradProcessor(graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("BiasAddGrad") == 0) { node_processor.reset( - new BiasAddGradProcessor(graph_, node, &node_map_)); + new BiasAddGradProcessor(graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("Conv2D") == 0) { - node_processor.reset( - new Conv2DProcessor(graph_, node, &node_map_, config_.no_gemm)); + node_processor.reset(new Conv2DProcessor( + graph_, node, node_map_, config_.no_gemm, is_in_frame)); } else if (node->op().compare("Conv2DBackpropFilter") == 0) { node_processor.reset(new Conv2DBackpropFilterProcessor( - graph_, node, &node_map_, config_.no_gemm)); + graph_, node, node_map_, config_.no_gemm, is_in_frame)); } else if (node->op().compare("Conv2DBackpropInput") == 0) { node_processor.reset(new Conv2DBackpropInputProcessor( - graph_, node, &node_map_, config_.no_gemm)); + graph_, node, node_map_, config_.no_gemm, is_in_frame)); } else if (node->op().compare("FusedBatchNormGrad") == 0) { - node_processor.reset( - new FusedBatchNormGradProcessor(graph_, node, &node_map_)); + node_processor.reset(new FusedBatchNormGradProcessor( + graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("MaxPoolGrad") == 0) { node_processor.reset( - new MaxPoolGradProcessor(graph_, node, &node_map_)); + new MaxPoolGradProcessor(graph_, node, node_map_, is_in_frame)); } else { - node_processor.reset(new NodeProcessor(graph_, node, &node_map_)); + node_processor.reset( + new NodeProcessor(graph_, node, node_map_, is_in_frame)); } TF_RETURN_IF_ERROR(node_processor->ConvertNode()); } @@ -1078,54 +1273,57 @@ class DataLayoutOptimizer { // only needs to be performed if at least one node in the previous pass is // expanded. if (graph_->node_size() > node_size_original) { - NodeDef* n = AddNodePermConst(kPermNHWCToNCHW, {0, 3, 1, 2}); - n = AddNodePermConst(kPermNCHWToNHWC, {0, 2, 3, 1}); + NodeDef* n = AddNodePermNHWCToNCHW(); + n = AddNodePermNCHWToNHWC(); n = AddNodeConcatConst(); - n = AddGatherAxisConst(); + n = AddNodeGatherAxisConst(); n = AddNodeReductionConst(); std::set ops_format_agnostic = GetOpsFormatAgnostic(); for (int i = 0; i < graph_->node_size(); i++) { if (ops_format_agnostic.find(graph_->node(i).op()) != ops_format_agnostic.end()) { auto node = graph_->mutable_node(i); + bool is_in_frame = !frames[node].empty(); std::unique_ptr node_processor; if (node->op().compare("AddN") == 0) { - node_processor.reset(new AddNProcessor(graph_, node, &node_map_)); + node_processor.reset( + new AddNProcessor(graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("Add") == 0 || node->op().compare("Mul") == 0 || node->op().compare("RealDiv") == 0 || node->op().compare("SquaredDifference") == 0 || node->op().compare("Sub") == 0) { node_processor.reset( - new BinaryOpProcessor(graph_, node, &node_map_)); + new BinaryOpProcessor(graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("Concat") == 0 || node->op().compare("ConcatV2") == 0) { - node_processor.reset(new ConcatProcessor(graph_, node, &node_map_)); + node_processor.reset( + new ConcatProcessor(graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("ReluGrad") == 0) { node_processor.reset( - new ReluGradProcessor(graph_, node, &node_map_)); + new ReluGradProcessor(graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("Slice") == 0) { - auto input1 = node_map_.GetNode(NodeName(node->input(1))); - auto input2 = node_map_.GetNode(NodeName(node->input(2))); + auto input1 = node_map_->GetNode(NodeName(node->input(1))); + auto input2 = node_map_->GetNode(NodeName(node->input(2))); if (input1->op() == "ConcatOffset") { - node_processor.reset( - new SliceProcessorConcatOffset(graph_, node, &node_map_)); + node_processor.reset(new SliceProcessorConcatOffset( + graph_, node, node_map_, is_in_frame)); } else if (input1->op() == "Const" && input2->op() == "Const") { - node_processor.reset( - new SliceProcessorConst(graph_, node, &node_map_)); + node_processor.reset(new SliceProcessorConst( + graph_, node, node_map_, is_in_frame)); } else { node_processor.reset( - new SliceProcessor(graph_, node, &node_map_)); + new SliceProcessor(graph_, node, node_map_, is_in_frame)); } - } else if (node->op().compare("Squeeze") == 0) { node_processor.reset( - new SqueezeProcessor(graph_, node, &node_map_)); + new SqueezeProcessor(graph_, node, node_map_, is_in_frame)); } else if (node->op().compare("Sum") == 0) { - node_processor.reset(new SumProcessor(graph_, node, &node_map_)); - } else { node_processor.reset( - new AgnosticNodeProcessor(graph_, node, &node_map_)); + new SumProcessor(graph_, node, node_map_, is_in_frame)); + } else { + node_processor.reset(new AgnosticNodeProcessor( + graph_, node, node_map_, is_in_frame)); } TF_RETURN_IF_ERROR(node_processor->ConvertNode()); } @@ -1145,12 +1343,12 @@ class DataLayoutOptimizer { if (IsNodeNCHWToNHWC(node->input(0))) { const string& trans_first = node->input(0); const string& trans_second = node->name(); - auto outputs = node_map_.GetOutputs(trans_second); + auto outputs = node_map_->GetOutputs(trans_second); CHECK(outputs.size() == 1) << "There is always only a single output for a Transpose node, " << "due to the way it is added by NodeProcessor."; NodeDef* output = *outputs.begin(); - string input = node_map_.GetNode(trans_first)->input(0); + string input = node_map_->GetNode(trans_first)->input(0); for (int i = 0; i < output->input_size(); i++) { if (output->input(i).compare(trans_second) == 0) { *output->mutable_input(i) = input; @@ -1173,8 +1371,6 @@ class DataLayoutOptimizer { } string default_device_; - GraphDef* graph_; - NodeMap node_map_; TuningConfig config_; }; @@ -1231,8 +1427,9 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, default_device = cluster->GetDevices().begin()->first; } } + std::unique_ptr node_map(new NodeMap(output)); std::unique_ptr layout_optimizer( - new DataLayoutOptimizer(default_device, output, config)); + new DataLayoutOptimizer(default_device, output, node_map.get(), config)); status = layout_optimizer->Optimize(); // This is based on an empirical observation that if the introduced Transpose // nodes is more than 30, not using GEMM implementation would result in better @@ -1240,8 +1437,9 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, if (status.ok() && GetNumTranspose(*output) > 30) { *output = new_item.graph; config.no_gemm = true; - layout_optimizer.reset( - new DataLayoutOptimizer(default_device, output, config)); + node_map.reset(new NodeMap(output)); + layout_optimizer.reset(new DataLayoutOptimizer(default_device, output, + node_map.get(), config)); status = layout_optimizer->Optimize(); } diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 13efc5590299293d326e6fbec8615686a0d1b53d..4ac985b41b46629200851475a50f6e2ac8ccaf2d 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -66,7 +66,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, optimizers.push_back( std::unique_ptr(new ConstantFolding())); } - if (cfg_.arithmetic_optimization() == RewriterConfig::ON) { + if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) { optimizers.push_back( std::unique_ptr(new ArithmeticOptimizer())); } @@ -138,7 +138,7 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, bool MetaOptimizerEnabled(const RewriterConfig& cfg) { return !cfg.disable_model_pruning() || cfg.optimize_tensor_layout() || cfg.constant_folding() != RewriterConfig::OFF || - cfg.arithmetic_optimization() == RewriterConfig::ON || + cfg.arithmetic_optimization() != RewriterConfig::OFF || cfg.auto_parallel().enable() || cfg.memory_optimization() > 1 || !cfg.optimizers().empty(); } diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index add50d8b14c9804c86553afd03fb2992bf09f043..c8830e9b3c041901977acfa3b6a2a1b109bba28d 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -27,7 +27,9 @@ namespace grappler { NodeMap::NodeMap(GraphDef* graph) : graph_(graph) { for (int i = 0; i < graph_->node_size(); i++) { auto node = graph_->mutable_node(i); - nodes_.insert(std::make_pair(node->name(), node)); + auto rslt = nodes_.insert(std::make_pair(node->name(), node)); + // Check that the graph doesn't contain multiple nodes with the same name. + CHECK(rslt.second); for (const auto& input : node->input()) { outputs_[NodeName(input)].insert(nodes_[node->name()]); } @@ -52,18 +54,80 @@ const std::set& NodeMap::GetOutputs(const string& node_name) const { } void NodeMap::AddNode(const string& name, NodeDef* node) { - nodes_.insert(std::make_pair(name, node)); + auto ret = nodes_.insert(std::make_pair(name, node)); + CHECK(ret.second) << "Pair (" << name << "," << node + << ") is not inserted because a same key already exists."; } -void NodeMap::AddOutput(const string& node, const string& output) { - outputs_[node].insert(nodes_[output]); +void NodeMap::AddOutput(const string& node_name, const string& output_name) { + auto output_node = nodes_[output_name]; + CHECK(output_node) << "Output node " << output_name + << " is missing in NodeMap."; + outputs_[node_name].insert(output_node); } -void NodeMap::UpdateOutput(const string& node, const string& old_output, - const string& new_output) { - std::set& outputs = outputs_[node]; - outputs.erase(nodes_[old_output]); - outputs.insert(nodes_[new_output]); +void NodeMap::RemoveOutput(const string& node_name, const string& output_name) { + outputs_[node_name].erase(nodes_[output_name]); +} + +void NodeMap::UpdateInput(const string& node_name, const string& old_input_name, + const string& new_input_name) { + RemoveOutput(old_input_name, node_name); + AddOutput(new_input_name, node_name); +} + +void NodeMap::RemoveInputs(const string& node_name) { + auto node = nodes_[node_name]; + for (const auto& input : node->input()) { + RemoveOutput(NodeName(input), node->name()); + } +} + +void NodeMap::RemoveOutputs(const string& node_name) { + outputs_.erase(node_name); +} + +void NodeMap::UpdateOutput(const string& node_name, + const string& old_output_name, + const string& new_output_name) { + std::set& outputs = outputs_[node_name]; + outputs.erase(nodes_[old_output_name]); + outputs.insert(nodes_[new_output_name]); +} + +OutputMap::OutputMap(GraphDef* graph) : graph_(graph) { + for (int i = 0; i < graph_->node_size(); i++) { + auto node = graph_->mutable_node(i); + auto rslt = nodes_.insert(std::make_pair(node->name(), node)); + // Check that the graph doesn't contain multiple nodes with the same name. + CHECK(rslt.second); + for (const auto& input : node->input()) { + string input_node = NodeName(input); + if (outputs_[input_node].count(node) == 0) { + outputs_[input_node].insert(std::make_pair(node, 1)); + } else { + outputs_[input_node][node]++; + } + } + } +} + +NodeDef* OutputMap::GetNode(const string& name) const { + string node_name = NodeName(name); + auto it = nodes_.find(node_name); + if (it == nodes_.end()) { + return nullptr; + } + return it->second; +} + +const std::unordered_map& OutputMap::GetOutputs( + const string& node_name) const { + auto it = outputs_.find(node_name); + if (it == outputs_.end()) { + return empty_map_; + } + return it->second; } bool IsSameInput(const string& name1, const string& name2) { @@ -148,5 +212,13 @@ bool ExecuteWithTimeout(std::function fn, const int64 timeout_in_ms, return notified; } +string AsControlDependency(const NodeDef& node) { + return strings::StrCat("^", node.name()); +} + +string AsControlDependency(const string& node) { + return strings::StrCat("^", node); +} + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index a49791bad898ede28f37458a0ddf73e699ab25da..03f49c0ca254b0bbb618263c61bed10c137656a4 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -36,9 +36,14 @@ class NodeMap { // This method doesn't record the outputs of the added node; the outputs need // to be explicitly added by the AddOutput method. void AddNode(const string& name, NodeDef* node); - void AddOutput(const string& node, const string& output); - void UpdateOutput(const string& node, const string& old_output, - const string& new_output); + void UpdateInput(const string& node_name, const string& old_input_name, + const string& new_input_name); + void AddOutput(const string& node_name, const string& output_name); + void RemoveInputs(const string& node_name); + void RemoveOutput(const string& node_name, const string& output_name); + void RemoveOutputs(const string& node_name); + void UpdateOutput(const string& node_name, const string& old_output_name, + const string& new_output_name); private: GraphDef* graph_; @@ -47,6 +52,22 @@ class NodeMap { std::unordered_map> outputs_; }; +// A utility class to lookup a node's outputs and the number of times it +// presents in each output. +class OutputMap { + public: + explicit OutputMap(GraphDef* graph); + NodeDef* GetNode(const string& name) const; + const std::unordered_map& GetOutputs( + const string& node_name) const; + + private: + GraphDef* graph_; + std::unordered_map empty_map_; + std::unordered_map nodes_; + std::unordered_map> outputs_; +}; + // True iff 'name' refers to a control inputs, i.e. a node name prefixed with // the ^ character. bool IsControlInput(const string& name); @@ -80,6 +101,14 @@ string AddPrefixToNodeName(const string& name, const string& prefix); bool ExecuteWithTimeout(std::function fn, int64 timeout_in_ms, thread::ThreadPool* thread_pool); +// Returns the node name prefixed with conventional symbol '^' +// for control dependency, given a NodeDef. +string AsControlDependency(const NodeDef& node); +// +// Returns the node name prefixed with conventional symbol '^' +// for control dependency, given a node name +string AsControlDependency(const string& node); + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/utils/frame.cc b/tensorflow/core/grappler/utils/frame.cc index ff7be6f7014367828c33902fa0e11a1f36c8a57b..7655d0bee5a7fcd78b3896147f8eed82ad9d5bcf 100644 --- a/tensorflow/core/grappler/utils/frame.cc +++ b/tensorflow/core/grappler/utils/frame.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils/frame.h" #include #include +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" @@ -32,9 +33,10 @@ int IdentifyFrames( if (node.input_size() == 0) { std::vector empty; ready_nodes.emplace_back(&node, empty); + (*frames)[&node] = empty; } } - int frame_id = 0; + std::map name_to_id; while (!ready_nodes.empty()) { auto ready_node = ready_nodes.front(); for (const auto& fanout : node_map.GetOutputs(ready_node.first->name())) { @@ -44,18 +46,35 @@ int IdentifyFrames( frame_ids.pop_back(); } if (IsEnter(*fanout)) { - frame_ids.push_back(frame_id); - frame_id++; + CHECK(fanout->attr().count("frame_name")) + << "Missing frame name for the Enter node " << fanout->name(); + string name = fanout->attr().at("frame_name").s(); + int id; + if (name_to_id.count(name)) { + id = name_to_id[name]; + } else { + id = name_to_id.size(); + name_to_id[name] = id; + } + frame_ids.push_back(id); } ready_nodes.emplace_back(fanout, frame_ids); + (*frames)[fanout] = frame_ids; } else { - CHECK(ready_node.second == (*frames)[fanout]); + auto frame_ids_fanout = (*frames)[fanout]; + auto frame_ids_node = ready_node.second; + if (IsEnter(*fanout)) { + frame_ids_fanout.pop_back(); + } + if (IsExit(*ready_node.first)) { + frame_ids_node.pop_back(); + } + CHECK(frame_ids_node == frame_ids_fanout); } } - (*frames)[ready_node.first] = ready_node.second; ready_nodes.pop_front(); } - return frame_id; + return name_to_id.size(); } } // namespace grappler diff --git a/tensorflow/core/grappler/utils/frame_test.cc b/tensorflow/core/grappler/utils/frame_test.cc index e9b224ff4aff6eb046038535d8d7618e683827ab..30673eed7a983bf641d3b86b83311f94367da287 100644 --- a/tensorflow/core/grappler/utils/frame_test.cc +++ b/tensorflow/core/grappler/utils/frame_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/utils/frame.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" @@ -26,15 +27,25 @@ class IdentifyFramesTest : public ::testing::Test { protected: static NodeDef CreateNode(const string& name, const std::vector& inputs) { - return CreateNode(name, "", inputs); + return CreateNode(name, "", "", inputs); } static NodeDef CreateNode(const string& name, const string& op, const std::vector& inputs) { + return CreateNode(name, op, "", inputs); + } + static NodeDef CreateNode(const string& name, const string& op, + const string& frame, + const std::vector& inputs) { NodeDef node; node.set_name(name); if (!op.empty()) { node.set_op(op); } + if (!frame.empty()) { + AttrValue frame_name; + frame_name.set_s(frame); + node.mutable_attr()->insert({"frame_name", frame_name}); + } for (const string& input : inputs) { node.add_input(input); } @@ -42,17 +53,19 @@ class IdentifyFramesTest : public ::testing::Test { } }; -TEST_F(IdentifyFramesTest, WithLoop) { +TEST_F(IdentifyFramesTest, NestedLoop) { GraphDef graph; // Create a two-level nested loop *graph.add_node() = CreateNode("0", {}); - *graph.add_node() = CreateNode("1", "Enter", {"0"}); + *graph.add_node() = + CreateNode("1", "Enter", "map/while/while_context1", {"0"}); *graph.add_node() = CreateNode("2", {"1"}); *graph.add_node() = CreateNode("3", "Merge", {"2", "14"}); *graph.add_node() = CreateNode("4", {"3"}); *graph.add_node() = CreateNode("5", "Switch", {"4"}); *graph.add_node() = CreateNode("6", {"5"}); - *graph.add_node() = CreateNode("7", "Enter", {"6"}); + *graph.add_node() = + CreateNode("7", "Enter", "map/while/while_context2", {"6"}); *graph.add_node() = CreateNode("8", {"7"}); *graph.add_node() = CreateNode("9", "Merge", {"8", "12"}); *graph.add_node() = CreateNode("10", {"9"}); @@ -73,9 +86,97 @@ TEST_F(IdentifyFramesTest, WithLoop) { {"12", {0, 1}}, {"13", {0, 1}}, {"14", {0}}, {"15", {0}}, {"16", {0}}, {"17", {}}}; EXPECT_EQ(num_frames, 2); + EXPECT_EQ(frames.size(), expected.size()); + std::cout << "Number of frame: " << num_frames << std::endl; + for (const auto& node : frames) { + std::cout << node.first->name() << ": "; + EXPECT_EQ(node.second.size(), expected[node.first->name()].size()); + for (int i = 0; i < node.second.size(); i++) { + EXPECT_EQ(expected[node.first->name()][i], node.second[i]); + std::cout << node.second[i] << " "; + } + std::cout << std::endl; + } +} + +TEST_F(IdentifyFramesTest, MultipleInputsToEnter) { + GraphDef graph; + *graph.add_node() = CreateNode("0", {}); + *graph.add_node() = CreateNode("1", {}); + *graph.add_node() = + CreateNode("2", "Enter", "map/while/while_context", {"0", "1"}); + *graph.add_node() = CreateNode("3", "Exit", {"2"}); + + std::unordered_map> frames; + int num_frames = IdentifyFrames(graph, &frames); + std::unordered_map> expected = { + {"0", {}}, {"1", {}}, {"2", {0}}, {"3", {0}}}; + EXPECT_EQ(num_frames, 1); + EXPECT_EQ(frames.size(), expected.size()); + std::cout << "Number of frame: " << num_frames << std::endl; + for (const auto& node : frames) { + std::cout << node.first->name() << ": "; + EXPECT_EQ(node.second.size(), expected[node.first->name()].size()); + for (int i = 0; i < node.second.size(); i++) { + EXPECT_EQ(expected[node.first->name()][i], node.second[i]); + std::cout << node.second[i] << " "; + } + std::cout << std::endl; + } +} + +TEST_F(IdentifyFramesTest, ExitOutput) { + GraphDef graph; + *graph.add_node() = CreateNode("0", {}); + *graph.add_node() = + CreateNode("1", "Enter", "map/while/while_context", {"0"}); + *graph.add_node() = CreateNode("2", "Exit", {"1"}); + *graph.add_node() = CreateNode("3", {}); + *graph.add_node() = CreateNode("4", {"2", "3"}); + + std::unordered_map> frames; + int num_frames = IdentifyFrames(graph, &frames); + std::unordered_map> expected = { + {"0", {}}, {"1", {0}}, {"2", {0}}, {"3", {}}, {"4", {}}}; + EXPECT_EQ(num_frames, 1); + EXPECT_EQ(frames.size(), expected.size()); + std::cout << "Number of frame: " << num_frames << std::endl; + for (const auto& node : frames) { + std::cout << node.first->name() << ": "; + EXPECT_EQ(node.second.size(), expected[node.first->name()].size()); + for (int i = 0; i < node.second.size(); i++) { + EXPECT_EQ(expected[node.first->name()][i], node.second[i]); + std::cout << node.second[i] << " "; + } + std::cout << std::endl; + } +} + +TEST_F(IdentifyFramesTest, MultipleEnterNodes) { + GraphDef graph; + *graph.add_node() = CreateNode("0", {}); + string frame = "map/while/while_context"; + *graph.add_node() = CreateNode("1", "Enter", frame, {"0"}); + *graph.add_node() = CreateNode("2", {"1"}); + *graph.add_node() = CreateNode("5", {}); + *graph.add_node() = CreateNode("4", "Enter", frame, {"5"}); + *graph.add_node() = CreateNode("3", {"4", "2"}); + *graph.add_node() = CreateNode("6", "Merge", {"3", "8"}); + *graph.add_node() = CreateNode("7", "Switch", {"6"}); + *graph.add_node() = CreateNode("8", "NextIteration", {"7"}); + *graph.add_node() = CreateNode("9", "Exit", {"7"}); + + std::unordered_map> frames; + int num_frames = IdentifyFrames(graph, &frames); + std::unordered_map> expected = { + {"0", {}}, {"1", {0}}, {"2", {0}}, {"3", {0}}, {"4", {0}}, + {"5", {}}, {"6", {0}}, {"7", {0}}, {"8", {0}}, {"9", {0}}}; + EXPECT_EQ(num_frames, 1); + EXPECT_EQ(frames.size(), expected.size()); std::cout << "Number of frame: " << num_frames << std::endl; for (const auto& node : frames) { std::cout << node.first->name() << ": "; + EXPECT_EQ(node.second.size(), expected[node.first->name()].size()); for (int i = 0; i < node.second.size(); i++) { EXPECT_EQ(expected[node.first->name()][i], node.second[i]); std::cout << node.second[i] << " "; diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc index a5a7f34db0fe0548de5e5bb13f372359f36f6f6f..77d4702d21e75b1689875eb17fbd2cda41aa1ba8 100644 --- a/tensorflow/core/grappler/utils/topological_sort.cc +++ b/tensorflow/core/grappler/utils/topological_sort.cc @@ -26,7 +26,7 @@ namespace grappler { // Kahn's algorithm is implemented. // For details, see https://en.wikipedia.org/wiki/Topological_sorting void TopologicalSort(GraphDef* graph) { - NodeMap node_map(graph); + OutputMap output_map(graph); std::vector ready_nodes; ready_nodes.reserve(graph->node_size()); int front = 0; @@ -41,7 +41,7 @@ void TopologicalSort(GraphDef* graph) { if (IsMerge(*node)) { ready_inputs[node] = 0; for (const auto& input : node->input()) { - if (IsNextIteration(*node_map.GetNode(input))) { + if (IsNextIteration(*output_map.GetNode(input))) { ready_inputs[node]++; } } @@ -52,8 +52,9 @@ void TopologicalSort(GraphDef* graph) { while (front != back) { auto ready_node = ready_nodes[front]; - for (const auto& fanout : node_map.GetOutputs(ready_node->name())) { - ready_inputs[fanout]++; + for (const auto& fanout_pair : output_map.GetOutputs(ready_node->name())) { + auto fanout = fanout_pair.first; + ready_inputs[fanout] += fanout_pair.second; if (ready_inputs[fanout] == fanout->input_size()) { ready_nodes.push_back(fanout); back++; @@ -70,6 +71,8 @@ void TopologicalSort(GraphDef* graph) { new_node->Swap(ready_nodes[i]); } graph->mutable_node()->Swap(new_graph.mutable_node()); + } else { + LOG(ERROR) << "The graph couldn't be sorted in topological order."; } } diff --git a/tensorflow/core/grappler/utils/topological_sort_test.cc b/tensorflow/core/grappler/utils/topological_sort_test.cc index 55f66b273496c53d9450626ee0c896e725415a48..dc99cb1052ce9db3035401a2cd75e838281fb748 100644 --- a/tensorflow/core/grappler/utils/topological_sort_test.cc +++ b/tensorflow/core/grappler/utils/topological_sort_test.cc @@ -89,6 +89,18 @@ TEST_F(TopologicalSortTest, WithIllegalLoop) { } } +TEST_F(TopologicalSortTest, DuplicatedInputs) { + GraphDef graph; + *graph.add_node() = CreateNode("2", {"1", "1"}); + *graph.add_node() = CreateNode("1", {}); + + TopologicalSort(&graph); + std::vector order = {"1", "2"}; + for (int i = 0; i < order.size(); i++) { + EXPECT_EQ(graph.node(i).name(), order[i]); + } +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 99b1acbc469a14d20e18cedbd68e8419de10b59f..9ebd47458f2a6376c3e5a22632272464baf70648 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -20,6 +20,7 @@ package_group( packages = [ "//learning/brain/contrib/...", "//learning/brain/research/sparse_matrix/...", + "//learning/faster_training/...", "//tensorflow/...", ], ) @@ -259,19 +260,13 @@ cc_library( cc_library( name = "conv_ops_gpu_hdrs", hdrs = ["conv_ops_gpu.h"], - deps = [ - ":eigen_helpers", - "//third_party/eigen3", - ], + deps = ["//third_party/eigen3"], ) cc_library( name = "gpu_util_hdrs", hdrs = ["gpu_utils.h"], - deps = [ - ":eigen_helpers", - "//third_party/eigen3", - ], + deps = ["//third_party/eigen3"], ) tf_cc_test( @@ -1085,6 +1080,12 @@ tf_kernel_library( ], ) +# Unlike gather_functor library, this does not include the CUDA code and deps. +cc_library( + name = "gather_functor_hdr", + hdrs = ["gather_functor.h"], +) + tf_kernel_library( name = "dense_update_functor", srcs = ["dense_update_functor.cc"], @@ -1801,6 +1802,7 @@ cc_library( ":draw_bounding_box_op", ":encode_jpeg_op", ":encode_png_op", + ":extract_jpeg_shape_op", ":non_max_suppression_op", ":random_crop_op", ":resize_area_op", @@ -1831,16 +1833,22 @@ tf_kernel_library( deps = IMAGE_DEPS, ) +cc_library( + name = "adjust_hsv_gpu_lib", + hdrs = ["adjust_hsv_gpu.cu.h"], + deps = ["//tensorflow/core:framework"], +) + tf_kernel_library( name = "adjust_hue_op", prefix = "adjust_hue_op", - deps = IMAGE_DEPS, + deps = IMAGE_DEPS + [":adjust_hsv_gpu_lib"], ) tf_kernel_library( name = "adjust_saturation_op", prefix = "adjust_saturation_op", - deps = IMAGE_DEPS, + deps = IMAGE_DEPS + [":adjust_hsv_gpu_lib"], ) tf_kernel_library( @@ -1891,6 +1899,12 @@ tf_kernel_library( deps = IMAGE_DEPS, ) +tf_kernel_library( + name = "extract_jpeg_shape_op", + prefix = "extract_jpeg_shape_op", + deps = IMAGE_DEPS, +) + tf_kernel_library( name = "non_max_suppression_op", prefix = "non_max_suppression_op", @@ -2568,14 +2582,17 @@ tf_kernel_library( tf_kernel_library( name = "reduction_ops", + srcs = ["reduction_ops_gpu_kernels.h"], prefix = "reduction_ops", - deps = MATH_DEPS, + deps = MATH_DEPS + if_cuda(["@cub_archive//:cub"]), ) tf_kernel_library( name = "segment_reduction_ops", prefix = "segment_reduction_ops", - deps = MATH_DEPS, + deps = MATH_DEPS + if_cuda([ + ":cuda_solvers", + ]), ) tf_kernel_library( @@ -3051,14 +3068,16 @@ tf_kernel_library( tf_kernel_library( name = "l2loss_op", prefix = "l2loss_op", + #srcs = ["reduction_ops_gpu_kernels.h"], deps = [ + ":reduction_ops", + "//third_party/eigen3", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:nn_grad", "//tensorflow/core:nn_ops_op_lib", - "//third_party/eigen3", - ], + ] + if_cuda(["@cub_archive//:cub"]), ) tf_cuda_cc_test( @@ -3327,6 +3346,19 @@ tf_kernel_library( deps = PARSING_DEPS, ) +tf_cc_test( + name = "parse_tensor_test", + srcs = ["parse_tensor_test.cc"], + deps = [ + ":ops_testutil", + ":parse_tensor_op", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_kernel_library( name = "string_to_number_op", prefix = "string_to_number_op", @@ -4639,6 +4671,7 @@ filegroup( "decode_image_op.*", "encode_png_op.*", "encode_jpeg_op.*", + "extract_jpeg_shape_op.*", "decode_jpeg_op.*", "decode_gif_op.*", "identity_reader_op.*", @@ -4648,7 +4681,10 @@ filegroup( "whole_file_read_ops.*", "sample_distorted_bounding_box_op.*", "ctc_loss_op.*", + "summary_interface.*", + "summary_kernels.*", "spectrogram_convert_test_data.cc", + "sql_dataset_ops.cc", # Excluded due to experimental status: "debug_ops.*", "scatter_nd_op*", @@ -5465,6 +5501,22 @@ tf_mkl_kernel_library( ], ) +tf_mkl_kernel_library( + name = "mkl_input_conversion_op", + hdrs = ["mkl_tfconv_op.h"], + prefix = "mkl_input_conversion", + deps = [ + ":bounds_check", + ":ops_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:nn_ops_op_lib", + "//third_party/mkl:intel_binary_blob", + ], +) + tf_mkl_kernel_library( name = "mkl_pooling_ops", srcs = [ @@ -5547,6 +5599,20 @@ tf_mkl_kernel_library( ], ) +tf_mkl_kernel_library( + name = "mkl_cwise_ops_common", + hdrs = [ + "cwise_ops.h", + "cwise_ops_common.h", + "cwise_ops_gradients.h", + ], + prefix = "mkl_cwise_ops_common", + deps = NN_DEPS + [ + "cwise_op", + "//third_party/mkl:intel_binary_blob", + ], +) + cc_library( name = "dataset", srcs = ["dataset.cc"], @@ -5555,6 +5621,21 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/util/tensor_bundle", + ], +) + +cc_library( + name = "dataset_utils", + srcs = ["dataset_utils.cc"], + hdrs = ["dataset_utils.h"], + deps = [ + ":captured_function", + ":dataset", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/util/tensor_bundle", ], ) @@ -5684,6 +5765,7 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -5698,6 +5780,22 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_kernel_library( + name = "sloppy_interleave_dataset_op", + srcs = ["sloppy_interleave_dataset_op.cc"], + deps = [ + ":captured_function", + ":dataset", + ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -5863,6 +5961,28 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "sql_dataset_ops", + srcs = [ + "sql/driver_manager.cc", + "sql/sqlite_query_connection.cc", + "sql_dataset_ops.cc", + ], + hdrs = [ + "sql/driver_manager.h", + "sql/query_connection.h", + "sql/sqlite_query_connection.h", + ], + deps = [ + ":dataset", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "@sqlite_archive//:sqlite", + ], +) + tf_kernel_library( name = "iterator_ops", srcs = ["iterator_ops.cc"], @@ -5912,7 +6032,9 @@ tf_kernel_library( ":repeat_dataset_op", ":shuffle_dataset_op", ":skip_dataset_op", + ":sloppy_interleave_dataset_op", ":sparse_tensor_slice_dataset_op", + ":sql_dataset_ops", ":take_dataset_op", ":tensor_dataset_op", ":tensor_slice_dataset_op", @@ -5920,6 +6042,43 @@ tf_kernel_library( ], ) +cc_library( + name = "summary_interface", + srcs = ["summary_interface.cc"], + hdrs = ["summary_interface.h"], + deps = [ + "//tensorflow/compiler/xla:util", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:proto_text", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "summary_interface_test", + srcs = ["summary_interface_test.cc"], + deps = [ + ":summary_interface", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_kernel_library( + name = "summary_kernels", + srcs = ["summary_kernels.cc"], + deps = [ + ":summary_interface", + "//tensorflow/core:framework", + "//tensorflow/core:summary_ops_op_lib", + ], +) + # ----------------------------------------------------------------------------- # Google-internal targets. These must be at the end for syncrepo. diff --git a/tensorflow/core/kernels/adjust_hsv_gpu.cu.h b/tensorflow/core/kernels/adjust_hsv_gpu.cu.h new file mode 100644 index 0000000000000000000000000000000000000000..c160ce2c3349fbd08a1d512e35a424dc00919628 --- /dev/null +++ b/tensorflow/core/kernels/adjust_hsv_gpu.cu.h @@ -0,0 +1,146 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { +namespace internal { + +typedef struct RgbTuple { + float r; + float g; + float b; +} RgbTuple; + +typedef struct HsvTuple { + float h; + float s; + float v; +} HsvTuple; + +inline __device__ HsvTuple rgb2hsv_cuda(const float r, const float g, + const float b) { + HsvTuple tuple; + const float M = fmaxf(r, fmaxf(g, b)); + const float m = fminf(r, fminf(g, b)); + const float chroma = M - m; + float h = 0.0f, s = 0.0f; + // hue + if (chroma > 0.0f) { + if (M == r) { + const float num = (g - b) / chroma; + const float sign = copysignf(1.0f, num); + h = ((sign < 0.0f) * 6.0f + sign * fmodf(sign * num, 6.0f)) / 6.0f; + } else if (M == g) { + h = ((b - r) / chroma + 2.0f) / 6.0f; + } else { + h = ((r - g) / chroma + 4.0f) / 6.0f; + } + } else { + h = 0.0f; + } + // saturation + if (M > 0.0) { + s = chroma / M; + } else { + s = 0.0f; + } + tuple.h = h; + tuple.s = s; + tuple.v = M; + return tuple; +} + +inline __device__ RgbTuple hsv2rgb_cuda(const float h, const float s, + const float v) { + RgbTuple tuple; + const float new_h = h * 6.0f; + const float chroma = v * s; + const float x = chroma * (1.0f - fabsf(fmodf(new_h, 2.0f) - 1.0f)); + const float new_m = v - chroma; + const bool between_0_and_1 = new_h >= 0.0f && new_h < 1.0f; + const bool between_1_and_2 = new_h >= 1.0f && new_h < 2.0f; + const bool between_2_and_3 = new_h >= 2.0f && new_h < 3.0f; + const bool between_3_and_4 = new_h >= 3.0f && new_h < 4.0f; + const bool between_4_and_5 = new_h >= 4.0f && new_h < 5.0f; + const bool between_5_and_6 = new_h >= 5.0f && new_h < 6.0f; + tuple.r = chroma * (between_0_and_1 || between_5_and_6) + + x * (between_1_and_2 || between_4_and_5) + new_m; + tuple.g = chroma * (between_1_and_2 || between_2_and_3) + + x * (between_0_and_1 || between_3_and_4) + new_m; + tuple.b = chroma * (between_3_and_4 || between_4_and_5) + + x * (between_2_and_3 || between_5_and_6) + new_m; + return tuple; +} + +template +__global__ void adjust_hsv_nhwc(const int64 number_elements, + const float* const __restrict__ input, + float* const output, + const float* const hue_delta, + const float* const saturation_scale, + const float* const value_scale) { + // multiply by 3 since we're dealing with contiguous RGB bytes for each pixel + // (NHWC) + const int64 idx = (blockDim.x * blockIdx.x + threadIdx.x) * 3; + // bounds check + if (idx > number_elements - 1) { + return; + } + if (!AdjustHue && !AdjustSaturation && !AdjustV) { + output[idx] = input[idx]; + output[idx + 1] = input[idx + 1]; + output[idx + 2] = input[idx + 2]; + return; + } + const HsvTuple hsv = rgb2hsv_cuda(input[idx], input[idx + 1], input[idx + 2]); + float new_h = hsv.h; + float new_s = hsv.s; + float new_v = hsv.v; + // hue adjustment + if (AdjustHue) { + const float delta = *hue_delta; + new_h = fmodf(hsv.h + delta, 1.0f); + if (new_h < 0.0f) { + new_h = fmodf(1.0f + new_h, 1.0f); + } + } + // saturation adjustment + if (AdjustSaturation && saturation_scale != nullptr) { + const float scale = *saturation_scale; + new_s = fminf(1.0f, fmaxf(0.0f, hsv.s * scale)); + } + // value adjustment + if (AdjustV && value_scale != nullptr) { + const float scale = *value_scale; + new_v = hsv.v * scale; + } + const RgbTuple rgb = hsv2rgb_cuda(new_h, new_s, new_v); + output[idx] = rgb.r; + output[idx + 1] = rgb.g; + output[idx + 2] = rgb.b; +} + +} // namespace internal +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_ADJUST_HSV_GPU_CU_H_ diff --git a/tensorflow/core/kernels/adjust_hue_op_gpu.cu.cc b/tensorflow/core/kernels/adjust_hue_op_gpu.cu.cc index 865583c1c309b1779d165d7c6c296c85690dcefa..a4fe5f755cafb6f30a28e87ea7febf0535c68a70 100644 --- a/tensorflow/core/kernels/adjust_hue_op_gpu.cu.cc +++ b/tensorflow/core/kernels/adjust_hue_op_gpu.cu.cc @@ -16,104 +16,11 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/adjust_hsv_gpu.cu.h" #include "tensorflow/core/kernels/adjust_hue_op.h" #include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { -namespace internal { - -namespace { -typedef struct RgbTuple { - float r; - float g; - float b; -} RgbTuple; - -typedef struct HsvTuple { - float h; - float s; - float v; -} HsvTuple; -} // namespace - -__device__ HsvTuple rgb2hsv_cuda(const float r, const float g, const float b) { - HsvTuple tuple; - const float M = fmaxf(r, fmaxf(g, b)); - const float m = fminf(r, fminf(g, b)); - const float chroma = M - m; - float h = 0.0f, s = 0.0f; - // hue - if (chroma > 0.0f) { - if (M == r) { - const float num = (g - b) / chroma; - const float sign = copysignf(1.0f, num); - h = ((sign < 0.0f) * 6.0f + sign * fmodf(sign * num, 6.0f)) / 6.0f; - } else if (M == g) { - h = ((b - r) / chroma + 2.0f) / 6.0f; - } else { - h = ((r - g) / chroma + 4.0f) / 6.0f; - } - } else { - h = 0.0f; - } - // saturation - if (M > 0.0) { - s = chroma / M; - } else { - s = 0.0f; - } - tuple.h = h; - tuple.s = s; - tuple.v = M; - return tuple; -} - -__device__ RgbTuple hsv2rgb_cuda(const float h, const float s, const float v) { - RgbTuple tuple; - const float new_h = h * 6.0f; - const float chroma = v * s; - const float x = chroma * (1.0f - fabsf(fmodf(new_h, 2.0f) - 1.0f)); - const float new_m = v - chroma; - const bool between_0_and_1 = new_h >= 0.0f && new_h < 1.0f; - const bool between_1_and_2 = new_h >= 1.0f && new_h < 2.0f; - const bool between_2_and_3 = new_h >= 2.0f && new_h < 3.0f; - const bool between_3_and_4 = new_h >= 3.0f && new_h < 4.0f; - const bool between_4_and_5 = new_h >= 4.0f && new_h < 5.0f; - const bool between_5_and_6 = new_h >= 5.0f && new_h < 6.0f; - tuple.r = chroma * (between_0_and_1 || between_5_and_6) + - x * (between_1_and_2 || between_4_and_5) + new_m; - tuple.g = chroma * (between_1_and_2 || between_2_and_3) + - x * (between_0_and_1 || between_3_and_4) + new_m; - tuple.b = chroma * (between_3_and_4 || between_4_and_5) + - x * (between_2_and_3 || between_5_and_6) + new_m; - return tuple; -} - -__global__ void adjust_hue_nhwc(const int64 number_elements, - const float* const __restrict__ input, - float* const output, - const float* const hue_delta) { - // multiply by 3 since we're dealing with contiguous RGB bytes for each pixel - // (NHWC) - const int64 idx = (blockDim.x * blockIdx.x + threadIdx.x) * 3; - // bounds check - if (idx > number_elements - 1) { - return; - } - const float delta = hue_delta[0]; - const HsvTuple hsv = rgb2hsv_cuda(input[idx], input[idx + 1], input[idx + 2]); - // hue adjustment - float new_h = fmodf(hsv.h + delta, 1.0f); - if (new_h < 0.0f) { - new_h = fmodf(1.0f + new_h, 1.0f); - } - const RgbTuple rgb = hsv2rgb_cuda(new_h, hsv.s, hsv.v); - output[idx] = rgb.r; - output[idx + 1] = rgb.g; - output[idx + 2] = rgb.b; -} -} // namespace internal namespace functor { @@ -126,8 +33,9 @@ void AdjustHueGPU::operator()(GPUDevice* device, const int64 number_of_elements, const int threads_per_block = config.thread_per_block; const int block_count = (number_of_elements + threads_per_block - 1) / threads_per_block; - internal::adjust_hue_nhwc<<>>( - number_of_elements, input, output, delta); + internal::adjust_hsv_nhwc + <<>>( + number_of_elements, input, output, delta, nullptr, nullptr); } } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/adjust_saturation_op.cc b/tensorflow/core/kernels/adjust_saturation_op.cc index 34a65815148d247acf5e6b7df680988eedf2c7f9..4643d4e6efda2157458a557819873c8cb7546e1a 100644 --- a/tensorflow/core/kernels/adjust_saturation_op.cc +++ b/tensorflow/core/kernels/adjust_saturation_op.cc @@ -12,6 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif + +#include "tensorflow/core/kernels/adjust_saturation_op.h" #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" @@ -206,4 +213,35 @@ class AdjustSaturationOp : public AdjustSaturationOpBase { REGISTER_KERNEL_BUILDER(Name("AdjustSaturation").Device(DEVICE_CPU), AdjustSaturationOp); +#if GOOGLE_CUDA +template <> +class AdjustSaturationOp : public AdjustSaturationOpBase { + public: + explicit AdjustSaturationOp(OpKernelConstruction* context) + : AdjustSaturationOpBase(context) {} + + void DoCompute(OpKernelContext* context, + const ComputeOptions& options) override { + const Tensor* input = options.input; + const Tensor* scale = options.scale; + Tensor* output = options.output; + const int64 number_of_elements = input->NumElements(); + GPUDevice device = context->eigen_gpu_device(); + const auto stream = device.stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + if (number_of_elements > 0) { + const float* input_data = input->flat().data(); + const float* scale_data = scale->flat().data(); + float* const output_data = output->flat().data(); + functor::AdjustSaturationGPU()(&device, number_of_elements, input_data, + scale_data, output_data); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("AdjustSaturation").Device(DEVICE_GPU), + AdjustSaturationOp); + +#endif + } // namespace tensorflow diff --git a/tensorflow/core/kernels/adjust_saturation_op.h b/tensorflow/core/kernels/adjust_saturation_op.h new file mode 100644 index 0000000000000000000000000000000000000000..05c45c07c31fccab224d1d53d9028b2524648ecb --- /dev/null +++ b/tensorflow/core/kernels/adjust_saturation_op.h @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef _TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H +#define _TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +struct AdjustSaturationGPU { + void operator()(GPUDevice* device, const int64 number_of_elements, + const float* const input, const float* const scale, + float* const output); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // _TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H diff --git a/tensorflow/core/kernels/adjust_saturation_op_gpu.cu.cc b/tensorflow/core/kernels/adjust_saturation_op_gpu.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..37cfb26a47b01ca15cdb6287243a16490bb34bfb --- /dev/null +++ b/tensorflow/core/kernels/adjust_saturation_op_gpu.cu.cc @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/adjust_hsv_gpu.cu.h" +#include "tensorflow/core/kernels/adjust_saturation_op.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { + +namespace functor { + +void AdjustSaturationGPU::operator()(GPUDevice* device, + const int64 number_of_elements, + const float* const input, + const float* const scale, + float* const output) { + const auto stream = device->stream(); + const CudaLaunchConfig config = + GetCudaLaunchConfig(number_of_elements, *device); + const int threads_per_block = config.thread_per_block; + const int block_count = + (number_of_elements + threads_per_block - 1) / threads_per_block; + internal::adjust_hsv_nhwc + <<>>( + number_of_elements, input, output, nullptr, scale, nullptr); +} +} // namespace functor +} // namespace tensorflow +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc index e07ca5e0c4cfde810766f7ef2034ee2c27ed432f..ddc2d457b0e3132a3331e4ef72fa47e1ccc0e14f 100644 --- a/tensorflow/core/kernels/bias_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc @@ -142,9 +142,9 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop, int group_size) { // Initialize the shared memory. typedef typename AccumulatorType::type AccT; - __shared__ AccT s_data[32]; - int32 s_data_size = sizeof(s_data) / sizeof(T); - for (int32 index = threadIdx.x; index < s_data_size; index += blockDim.x) { + const int32 kSDataSize = 32; + __shared__ AccT s_data[kSDataSize]; + for (int32 index = threadIdx.x; index < kSDataSize; index += blockDim.x) { s_data[index] = AccT(0); } __syncthreads(); diff --git a/tensorflow/core/kernels/captured_function.cc b/tensorflow/core/kernels/captured_function.cc index 8e6e2db07e681ed0afd540b7e5e5eac4d16b5ccb..b9de1e303bf98bdfa65e3f3698a998146dd741ea 100644 --- a/tensorflow/core/kernels/captured_function.cc +++ b/tensorflow/core/kernels/captured_function.cc @@ -59,8 +59,24 @@ Status CapturedFunction::Create( } else if (!s.ok()) { \ return s; \ } \ - TF_RETURN_IF_ERROR(device->resource_manager()->Create( \ - input_handle.container(), input_handle.name(), resource)); \ + ResourceType* already_created_resource; \ + /* Look up the resource in the this function's resource manager, in case \ + * it has already been created. */ \ + s = device->resource_manager()->Lookup(input_handle.container(), \ + input_handle.name(), \ + &already_created_resource); \ + if (s.ok()) { \ + CHECK_EQ(resource, already_created_resource); \ + resource->Unref(); \ + already_created_resource->Unref(); \ + } else { \ + if (errors::IsNotFound(s)) { \ + TF_RETURN_IF_ERROR(device->resource_manager()->Create( \ + input_handle.container(), input_handle.name(), resource)); \ + } else { \ + return s; \ + } \ + } \ continue; \ } @@ -105,8 +121,7 @@ Status CapturedFunction::Create( Status CapturedFunction::Run(FunctionLibraryRuntime::Options f_opts, gtl::ArraySlice args, - std::vector* rets, const string& prefix) { - port::Tracing::TraceMe activity(prefix, "::Run"); + std::vector* rets) { Notification n; Status s; auto done_callback = [&n, &s](Status func_status) { @@ -128,17 +143,15 @@ Status CapturedFunction::Run(FunctionLibraryRuntime::Options f_opts, void CapturedFunction::RunAsync(FunctionLibraryRuntime::Options f_opts, gtl::ArraySlice args, - std::vector* rets, const string& prefix, + std::vector* rets, FunctionLibraryRuntime::DoneCallback done) { - auto activity = new port::Tracing::TraceMe(prefix, "::RunAsync"); auto c_mgr = new CancellationManager; f_opts.cancellation_manager = c_mgr; FunctionLibraryRuntime::DoneCallback wrapped_done = std::bind( - [activity, c_mgr](FunctionLibraryRuntime::DoneCallback done, - // Begin unbound arguments. - Status s) { + [c_mgr](FunctionLibraryRuntime::DoneCallback done, + // Begin unbound arguments. + Status s) { delete c_mgr; - delete activity; done(s); }, std::move(done), std::placeholders::_1); diff --git a/tensorflow/core/kernels/captured_function.h b/tensorflow/core/kernels/captured_function.h index b5a18fd90ec77a5d1b73f027b6ce08f4ec3763f0..f0aca4d23920d4fe42c21a53e41aa92d4196e080 100644 --- a/tensorflow/core/kernels/captured_function.h +++ b/tensorflow/core/kernels/captured_function.h @@ -61,12 +61,10 @@ class CapturedFunction { std::unique_ptr* out_function); Status Run(FunctionLibraryRuntime::Options f_opts, - gtl::ArraySlice args, std::vector* rets, - const string& prefix); + gtl::ArraySlice args, std::vector* rets); void RunAsync(FunctionLibraryRuntime::Options f_opts, gtl::ArraySlice args, std::vector* rets, - const string& prefix, FunctionLibraryRuntime::DoneCallback done); const Device* device() const { return device_; } diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index f0052619f0c6353d4b7e9126800b4d927a96713c..cdc11452827e4b0a34c386a99bd0b316c4acc51f 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -17,6 +17,10 @@ limitations under the License. #define EIGEN_USE_THREADS +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif + #include "tensorflow/core/kernels/constant_op.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -26,13 +30,14 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/platform/macros.h" #ifdef TENSORFLOW_USE_SYCL #include "tensorflow/core/common_runtime/sycl/sycl_util.h" -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL namespace tensorflow { @@ -40,9 +45,8 @@ ConstantOp::ConstantOp(OpKernelConstruction* ctx) : OpKernel(ctx), tensor_(ctx->output_type(0)) { const TensorProto* proto = nullptr; OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto)); - OP_REQUIRES_OK(ctx, - ctx->device()->MakeTensorFromProto( - *proto, AllocatorAttributes(), &tensor_)); + OP_REQUIRES_OK(ctx, ctx->device()->MakeTensorFromProto( + *proto, AllocatorAttributes(), &tensor_)); OP_REQUIRES( ctx, ctx->output_type(0) == tensor_.dtype(), errors::InvalidArgument("Type mismatch between value (", @@ -85,9 +89,9 @@ REGISTER_KERNEL(GPU, bool); #endif #ifdef TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNEL(D, TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("Const").Device(DEVICE_##D).TypeConstraint("dtype"), \ +#define REGISTER_SYCL_KERNEL(D, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Const").Device(DEVICE_##D).TypeConstraint("dtype"), \ ConstantOp); REGISTER_SYCL_KERNEL(SYCL, float); REGISTER_SYCL_KERNEL(SYCL, double); @@ -194,18 +198,18 @@ struct FillFunctor { void operator()(const SYCLDevice& d, typename TTypes::Flat out, typename TTypes::ConstScalar in) { #if !defined(EIGEN_HAS_INDEX_LIST) - Eigen::array rank1{1}; + Eigen::array rank1{1}; #else - Eigen::IndexList> rank1; + Eigen::IndexList > rank1; #endif - const int size = out.dimension(0); - Eigen::array broadcast_dims{size}; + const int size = out.dimension(0); + Eigen::array broadcast_dims{size}; - To32Bit(out).device(d) = in.reshape(rank1).broadcast(broadcast_dims); + To32Bit(out).device(d) = in.reshape(rank1).broadcast(broadcast_dims); } }; -} -#endif // TENSORFLOW_USE_SYCL +} // namespace functor +#endif // TENSORFLOW_USE_SYCL #define REGISTER_KERNEL(D, TYPE) \ REGISTER_KERNEL_BUILDER(Name("Fill") \ @@ -219,6 +223,7 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL); // TODO(b/28917570): Add a test for this. Currently python 3 is not happy about // the conversion from uint8 to quint8. REGISTER_KERNEL(CPU, quint8); +REGISTER_KERNEL(CPU, quint16); #undef REGISTER_CPU_KERNEL #ifdef TENSORFLOW_USE_SYCL @@ -272,11 +277,23 @@ class ZerosLikeOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor& input = ctx->input(0); - Tensor* out = nullptr; - OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( - {0}, 0, input.shape(), &out)); - functor::SetZeroFunctor f; - f(ctx->eigen_device(), out->flat()); + const Device& d = ctx->eigen_device(); + if (std::is_same::value) { + OP_REQUIRES(ctx, input.dims() == 0, + errors::InvalidArgument( + "ZerosLike of non-unary Variant not supported.")); + const Variant& v = input.scalar()(); + Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({})); + Variant* out_v = &(out.scalar()()); + OP_REQUIRES_OK(ctx, CreateZerosLikeVariant(ctx, v, out_v)); + ctx->set_output(0, out); + } else { + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( + {0}, 0, input.shape(), &out)); + functor::SetZeroFunctor f; + f(d, out->flat()); + } } }; @@ -287,6 +304,7 @@ class ZerosLikeOp : public OpKernel { #define REGISTER_CPU(type) REGISTER_KERNEL(type, CPU) TF_CALL_POD_STRING_TYPES(REGISTER_CPU); +REGISTER_CPU(Variant); #undef REGISTER_CPU #ifdef TENSORFLOW_USE_SYCL @@ -314,6 +332,14 @@ REGISTER_KERNEL_BUILDER(Name("ZerosLike") .TypeConstraint("T") .HostMemory("y"), ZerosLikeOp); +// TODO(ebrevdo): Once rendezvous has been properly set up for +// Variants, we'll no longer need a HostMemory attribute for this case. +REGISTER_KERNEL_BUILDER(Name("ZerosLike") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("x") + .HostMemory("y"), + ZerosLikeOp); #endif // GOOGLE_CUDA #undef REGISTER_KERNEL diff --git a/tensorflow/core/kernels/constant_op_gpu.cu.cc b/tensorflow/core/kernels/constant_op_gpu.cu.cc index 56bd918f3c3377bad8eab2d520a6ea586ce98149..d1a1e34ec365da444a8465b34dd67f8865d29f5e 100644 --- a/tensorflow/core/kernels/constant_op_gpu.cu.cc +++ b/tensorflow/core/kernels/constant_op_gpu.cu.cc @@ -95,6 +95,7 @@ DEFINE_SETZERO_GPU(float); DEFINE_SETZERO_GPU(double); DEFINE_SETZERO_GPU(complex64); DEFINE_SETZERO_GPU(complex128); +DEFINE_SETZERO_GPU(int32); DEFINE_SETZERO_GPU(int64); #undef DEFINE_SETZERO_GPU @@ -113,6 +114,7 @@ DEFINE_SETONE_GPU(float); DEFINE_SETONE_GPU(double); DEFINE_SETONE_GPU(complex64); DEFINE_SETONE_GPU(complex128); +DEFINE_SETONE_GPU(int32); DEFINE_SETONE_GPU(int64); #undef DEFINE_SETONE_GPU diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h index 4bb0b7f3b41b5b16bbf112634461a4737f16bc71..8de8f1b2650431f3b96a6edd48a6645f9d9ffcf0 100644 --- a/tensorflow/core/kernels/conv_2d.h +++ b/tensorflow/core/kernels/conv_2d.h @@ -225,13 +225,13 @@ struct PadInput { const std::array& padding_right, typename TTypes::Tensor out, TensorFormat format) { - Eigen::array, NDIMS> padding; - padding[GetTensorDimIndex(format, 'N')] = std::make_pair(0, 0); + Eigen::array, NDIMS> padding; + padding[GetTensorDimIndex(format, 'N')] = {0, 0}; for (int i = 0; i < NDIMS - 2; ++i) { - padding[GetTensorDimIndex(format, '0' + i)] = - std::make_pair(padding_left[i], padding_right[i]); + padding[GetTensorDimIndex(format, '0' + i)] = { + padding_left[i], padding_right[i]}; } - padding[GetTensorDimIndex(format, 'C')] = std::make_pair(0, 0); + padding[GetTensorDimIndex(format, 'C')] = {0, 0}; out.device(d) = in.pad(padding); } }; diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 65514937f4ee543b7fa849e76fe8a39cf0b8cdb1..8eb705b2e5f2b833c80c2450522529abfc35de48 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -91,6 +91,20 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +template +struct LaunchConv2DBackpropInputOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& input, + int row_stride, int col_stride, const Padding& padding, + Tensor* filter_backprop, TensorFormat data_format) { + const CPUDevice& d = ctx->eigen_device(); + functor::SpatialConvolutionBackwardInput()( + d, filter_backprop->tensor(), input.tensor(), + out_backprop.tensor(), filter_backprop->dim_size(0), + filter_backprop->dim_size(1), row_stride, col_stride); + } +}; + #ifdef TENSORFLOW_USE_LIBXSMM template struct LaunchXsmmBackwardFilter { @@ -237,11 +251,9 @@ class Conv2DFastBackpropFilterOp : public OpKernel { } #endif - functor::SpatialConvolutionBackwardKernel()( - context->eigen_device(), filter_backprop->tensor(), - input.tensor(), out_backprop.tensor(), - dims.spatial_dims[0].filter_size, dims.spatial_dims[1].filter_size, - dims.spatial_dims[0].stride, dims.spatial_dims[1].stride); + LaunchConv2DBackpropInputOp()( + context, false, false, out_backprop, input, dims.spatial_dims[0].stride, + dims.spatial_dims[1].stride, padding_, filter_backprop, data_format_); } private: @@ -495,15 +507,10 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); use_cudnn_ &= CanUseCudnn(); cudnn_use_autotune_ = CudnnUseAutotune(); - cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization(); OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); } void Compute(OpKernelContext* context) override { - using perftools::gputools::dnn::AlgorithmConfig; - using perftools::gputools::dnn::AlgorithmType; - using perftools::gputools::dnn::ProfileResult; - using perftools::gputools::dnn::kDefaultAlgorithm; const Tensor& input = context->input(0); const Tensor& filter_sizes = context->input(1); const Tensor& out_backprop = context->input(2); @@ -512,352 +519,373 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { errors::InvalidArgument( "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ", filter_sizes.dims())); - const TensorShape& input_shape = input.shape(); TensorShape filter_shape; OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( filter_sizes.vec(), &filter_shape)); - ConvBackpropDimensions dims; - OP_REQUIRES_OK(context, - ConvBackpropComputeDimensions( - "Conv2DSlowBackpropFilter", /*num_spatial_dims=*/2, - input.shape(), filter_shape, out_backprop.shape(), - strides_, padding_, data_format_, &dims)); - Tensor* filter_backprop = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, filter_shape, &filter_backprop)); - const int padding_rows = - (padding_ == VALID) - ? 0 - : std::max(0, (dims.spatial_dims[0].output_size - 1) * - dims.spatial_dims[0].stride + - dims.spatial_dims[0].filter_size - - dims.spatial_dims[0].input_size); - const int padding_cols = - (padding_ == VALID) - ? 0 - : std::max(0, (dims.spatial_dims[1].output_size - 1) * - dims.spatial_dims[1].stride + - dims.spatial_dims[1].filter_size - - dims.spatial_dims[1].input_size); - - // TODO(zhengxq): cuDNN only supports equal padding on both sides, so only - // calling it when that is true. Remove this check when (if?) cuDNN starts - // supporting different padding. - bool rows_odd = (padding_rows % 2 != 0); - bool cols_odd = (padding_cols % 2 != 0); - - auto* stream = context->op_device_context()->stream(); - OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - - if (!use_cudnn_) { - context->SetStatus(errors::Unimplemented( - "Conv2DBackprop for GPU is not currently supported " - "without cudnn")); - return; - } + // For now we take the stride from the second and third dimensions only (we + // do not support striding on the batch or depth dimension). + const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); + const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); - if (!cudnn_disable_conv_1x1_optimization_ && - dims.spatial_dims[0].filter_size == 1 && - dims.spatial_dims[1].filter_size == 1 && - dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 && - data_format_ == FORMAT_NHWC) { - const uint64 m = dims.in_depth; - const uint64 k = dims.batch_size * dims.spatial_dims[0].input_size * - dims.spatial_dims[1].input_size; - const uint64 n = dims.out_depth; - - // The shape of output backprop is - // [batch, out_rows, out_cols, out_depth] - // From cublas's perspective, it is: n x k - auto a_ptr = AsDeviceMemory(out_backprop.template flat().data(), - out_backprop.template flat().size()); - - // The shape of input is - // [batch, in_rows, in_cols, in_depth], - // From cublas's perspective, it is: m x k - auto b_ptr = AsDeviceMemory(input.template flat().data(), - input.template flat().size()); - - // the shape of the filter backprop from the conv_2d should be - // [1, 1, in_depth, out_depth] - // From cublas's perspective, it is: n x m - auto c_ptr = AsDeviceMemory(filter_backprop->template flat().data(), - filter_backprop->template flat().size()); - - bool blas_launch_status = - stream - ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose, - perftools::gputools::blas::Transpose::kTranspose, - n, m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, - ", n=", n, ", k=", k)); - } - return; - } else if (dims.spatial_dims[0].filter_size == - dims.spatial_dims[0].input_size && - dims.spatial_dims[1].filter_size == - dims.spatial_dims[1].input_size && - padding_ == VALID && data_format_ == FORMAT_NHWC) { - // The input data and filter have the same height/width, so call cublas - // directly. - const uint64 m = dims.spatial_dims[0].input_size * - dims.spatial_dims[1].input_size * dims.in_depth; - const uint64 k = dims.batch_size; - const uint64 n = dims.out_depth; - - auto a_ptr = AsDeviceMemory(input.template flat().data(), - input.template flat().size()); - auto b_ptr = AsDeviceMemory(out_backprop.template flat().data(), - out_backprop.template flat().size()); - auto c_ptr = AsDeviceMemory(filter_backprop->template flat().data(), - filter_backprop->template flat().size()); - - bool blas_launch_status = - stream - ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose, - perftools::gputools::blas::Transpose::kTranspose, - n, m, k, 1.0f, b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, - ", n=", n, ", k=", k)); - } - return; - } + launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input, + stride_rows, stride_cols, padding_, filter_backprop, + data_format_); + } - Tensor compatible_input; - if (rows_odd || cols_odd) { - // If a padding dimension is odd, we have one more element on the right - // side or the bottom side. This is unsupported in cudnn. Therefore, - // we pad that extra element and make it compatible. - OP_REQUIRES_OK( - context, - context->allocate_temp( - DataTypeToEnum::value, - ShapeFromFormat(data_format_, dims.batch_size, - dims.spatial_dims[0].input_size + rows_odd, - dims.spatial_dims[1].input_size + cols_odd, - dims.in_depth), - &compatible_input)); - - functor::PadInput()( - context->template eigen_device(), - To32Bit(input.tensor()), {{0, 0}}, {{rows_odd, cols_odd}}, - To32Bit(compatible_input.tensor()), data_format_); - } else { - compatible_input = input; + private: + std::vector strides_; + Padding padding_; + bool use_cudnn_; + TensorFormat data_format_; + LaunchConv2DBackpropFilterOp launcher_; + bool cudnn_use_autotune_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropFilterOp); +}; + +template +void LaunchConv2DBackpropFilterOp::operator()( + OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& input, int row_stride, + int col_stride, const Padding& padding, Tensor* filter_backprop, + TensorFormat data_format) { + using perftools::gputools::dnn::AlgorithmConfig; + using perftools::gputools::dnn::AlgorithmType; + using perftools::gputools::dnn::ProfileResult; + + std::vector strides(4, 1); + strides[GetTensorDimIndex(data_format, 'H')] = row_stride; + strides[GetTensorDimIndex(data_format, 'W')] = col_stride; + TensorShape filter_shape = filter_backprop->shape(); + + ConvBackpropDimensions dims; + OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions( + "Conv2DSlowBackpropFilter", /*num_spatial_dims=*/2, + input.shape(), filter_shape, out_backprop.shape(), + strides, padding, data_format, &dims)); + + const int padding_rows = + (padding == VALID) + ? 0 + : std::max(0, (dims.spatial_dims[0].output_size - 1) * + dims.spatial_dims[0].stride + + dims.spatial_dims[0].filter_size - + dims.spatial_dims[0].input_size); + const int padding_cols = + (padding == VALID) + ? 0 + : std::max(0, (dims.spatial_dims[1].output_size - 1) * + dims.spatial_dims[1].stride + + dims.spatial_dims[1].filter_size - + dims.spatial_dims[1].input_size); + + // TODO(zhengxq): cuDNN only supports equal padding on both sides, so only + // calling it when that is true. Remove this check when (if?) cuDNN starts + // supporting different padding. + bool rows_odd = (padding_rows % 2 != 0); + bool cols_odd = (padding_cols % 2 != 0); + + auto* stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); + + if (!use_cudnn) { + ctx->SetStatus(errors::Unimplemented( + "Conv2DBackprop for GPU is not currently supported " + "without cudnn")); + return; + } + + bool cudnn_disable_conv_1x1_optimization_ = CudnnDisableConv1x1Optimization(); + if (!cudnn_disable_conv_1x1_optimization_ && + dims.spatial_dims[0].filter_size == 1 && + dims.spatial_dims[1].filter_size == 1 && + dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 && + data_format == FORMAT_NHWC) { + const uint64 m = dims.in_depth; + const uint64 k = dims.batch_size * dims.spatial_dims[0].input_size * + dims.spatial_dims[1].input_size; + const uint64 n = dims.out_depth; + + // The shape of output backprop is + // [batch, out_rows, out_cols, out_depth] + // From cublas's perspective, it is: n x k + auto a_ptr = AsDeviceMemory(out_backprop.template flat().data(), + out_backprop.template flat().size()); + + // The shape of input is + // [batch, in_rows, in_cols, in_depth], + // From cublas's perspective, it is: m x k + auto b_ptr = AsDeviceMemory(input.template flat().data(), + input.template flat().size()); + + // the shape of the filter backprop from the conv_2d should be + // [1, 1, in_depth, out_depth] + // From cublas's perspective, it is: n x m + auto c_ptr = AsDeviceMemory(filter_backprop->template flat().data(), + filter_backprop->template flat().size()); + + bool blas_launch_status = + stream + ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose, + perftools::gputools::blas::Transpose::kTranspose, n, + m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); } + return; + } else if (dims.spatial_dims[0].filter_size == + dims.spatial_dims[0].input_size && + dims.spatial_dims[1].filter_size == + dims.spatial_dims[1].input_size && + padding == VALID && data_format == FORMAT_NHWC) { + // The input data and filter have the same height/width, so call cublas + // directly. + const uint64 m = dims.spatial_dims[0].input_size * + dims.spatial_dims[1].input_size * dims.in_depth; + const uint64 k = dims.batch_size; + const uint64 n = dims.out_depth; + + auto a_ptr = AsDeviceMemory(input.template flat().data(), + input.template flat().size()); + auto b_ptr = AsDeviceMemory(out_backprop.template flat().data(), + out_backprop.template flat().size()); + auto c_ptr = AsDeviceMemory(filter_backprop->template flat().data(), + filter_backprop->template flat().size()); + + bool blas_launch_status = + stream + ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose, + perftools::gputools::blas::Transpose::kTranspose, n, + m, k, 1.0f, b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); + } + return; + } - CHECK(padding_rows >= 0 && padding_cols >= 0) - << "Negative row or col paddings: (" << padding_rows << ", " - << padding_cols << ")"; - perftools::gputools::dnn::BatchDescriptor input_desc; - input_desc.set_count(dims.batch_size) - .set_height(GetTensorDim(compatible_input, data_format_, 'H')) - .set_width(GetTensorDim(compatible_input, data_format_, 'W')) - .set_feature_map_count(dims.in_depth) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::BatchDescriptor output_desc; - output_desc.set_count(dims.batch_size) - .set_height(dims.spatial_dims[0].output_size) - .set_width(dims.spatial_dims[1].output_size) - .set_feature_map_count(dims.out_depth) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::FilterDescriptor filter_desc; - filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size) - .set_input_filter_width(dims.spatial_dims[1].filter_size) - .set_input_feature_map_count(dims.in_depth) - .set_output_feature_map_count(dims.out_depth); - perftools::gputools::dnn::ConvolutionDescriptor conv_desc; - conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride) - .set_horizontal_filter_stride(dims.spatial_dims[1].stride) - .set_zero_padding_height(padding_rows / 2) - .set_zero_padding_width(padding_cols / 2); - - // NOTE(zhengxq): - // cuDNN only supports the following layouts : - // Input : B x D x R x C - // Filter : OD x ID x R x C - // Whereas, we have - // Input : B x R x C x D - // Filter : R x C x ID x OD - // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C) - // The first TransformDepth performs - // (B x R x C x D) => (B x D x R x C). - // Since the tensor returned from cuDNN is B x D x R x C also, - // the second TransformDepth performs - // (B x D x R x C) => (B x R x C x D). - - Tensor pre_transformed_filter_backprop; - OP_REQUIRES_OK(context, context->allocate_temp( - DataTypeToEnum::value, - TensorShape({dims.out_depth, dims.in_depth, - dims.spatial_dims[0].filter_size, - dims.spatial_dims[1].filter_size}), - &pre_transformed_filter_backprop)); - - Tensor transformed_out_backprop; - if (data_format_ == FORMAT_NHWC) { - TensorShape nchw_shape = ShapeFromFormat( - FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size, - dims.spatial_dims[1].output_size, dims.out_depth); - if (dims.out_depth > 1) { - OP_REQUIRES_OK(context, context->allocate_temp( - DataTypeToEnum::value, nchw_shape, - &transformed_out_backprop)); - functor::NHWCToNCHW()( - context->eigen_device(), out_backprop.tensor(), - transformed_out_backprop.tensor()); - } else { - // If depth <= 1, just reshape. - CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); - } + Tensor compatible_input; + if (rows_odd || cols_odd) { + // If a padding dimension is odd, we have one more element on the right + // side or the bottom side. This is unsupported in cudnn. Therefore, + // we pad that extra element and make it compatible. + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DataTypeToEnum::value, + ShapeFromFormat(data_format, dims.batch_size, + dims.spatial_dims[0].input_size + rows_odd, + dims.spatial_dims[1].input_size + cols_odd, + dims.in_depth), + &compatible_input)); + + functor::PadInput()( + ctx->template eigen_device(), To32Bit(input.tensor()), + {{0, 0}}, {{rows_odd, cols_odd}}, + To32Bit(compatible_input.tensor()), data_format); + } else { + compatible_input = input; + } + + CHECK(padding_rows >= 0 && padding_cols >= 0) + << "Negative row or col paddings: (" << padding_rows << ", " + << padding_cols << ")"; + perftools::gputools::dnn::BatchDescriptor input_desc; + input_desc.set_count(dims.batch_size) + .set_height(GetTensorDim(compatible_input, data_format, 'H')) + .set_width(GetTensorDim(compatible_input, data_format, 'W')) + .set_feature_map_count(dims.in_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::BatchDescriptor output_desc; + output_desc.set_count(dims.batch_size) + .set_height(dims.spatial_dims[0].output_size) + .set_width(dims.spatial_dims[1].output_size) + .set_feature_map_count(dims.out_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::FilterDescriptor filter_desc; + filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size) + .set_input_filter_width(dims.spatial_dims[1].filter_size) + .set_input_feature_map_count(dims.in_depth) + .set_output_feature_map_count(dims.out_depth); + perftools::gputools::dnn::ConvolutionDescriptor conv_desc; + conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride) + .set_horizontal_filter_stride(dims.spatial_dims[1].stride) + .set_zero_padding_height(padding_rows / 2) + .set_zero_padding_width(padding_cols / 2); + + // NOTE(zhengxq): + // cuDNN only supports the following layouts : + // Input : B x D x R x C + // Filter : OD x ID x R x C + // Whereas, we have + // Input : B x R x C x D + // Filter : R x C x ID x OD + // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C) + // The first TransformDepth performs + // (B x R x C x D) => (B x D x R x C). + // Since the tensor returned from cuDNN is B x D x R x C also, + // the second TransformDepth performs + // (B x D x R x C) => (B x R x C x D). + + Tensor pre_transformed_filter_backprop; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum::value, + TensorShape({dims.out_depth, dims.in_depth, + dims.spatial_dims[0].filter_size, + dims.spatial_dims[1].filter_size}), + &pre_transformed_filter_backprop)); + + Tensor transformed_out_backprop; + if (data_format == FORMAT_NHWC) { + TensorShape nchw_shape = ShapeFromFormat( + FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size, + dims.spatial_dims[1].output_size, dims.out_depth); + if (dims.out_depth > 1) { + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum::value, nchw_shape, + &transformed_out_backprop)); + functor::NHWCToNCHW()( + ctx->eigen_device(), out_backprop.tensor(), + transformed_out_backprop.tensor()); } else { - transformed_out_backprop = out_backprop; + // If depth <= 1, just reshape. + CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); } + } else { + transformed_out_backprop = out_backprop; + } - Tensor transformed_input; - if (data_format_ == FORMAT_NHWC) { - TensorShape nchw_shape = ShapeFromFormat( - FORMAT_NCHW, GetTensorDim(compatible_input, data_format_, 'N'), - GetTensorDim(compatible_input, data_format_, 'H'), - GetTensorDim(compatible_input, data_format_, 'W'), - GetTensorDim(compatible_input, data_format_, 'C')); - if (nchw_shape.dim_size(1) > 1) { - OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum::value, - nchw_shape, &transformed_input)); - functor::NHWCToNCHW()( - context->eigen_device(), - const_cast(compatible_input).tensor(), - transformed_input.tensor()); - } else { - // If depth <= 1, just reshape. - CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape)); - } + Tensor transformed_input; + if (data_format == FORMAT_NHWC) { + TensorShape nchw_shape = ShapeFromFormat( + FORMAT_NCHW, GetTensorDim(compatible_input, data_format, 'N'), + GetTensorDim(compatible_input, data_format, 'H'), + GetTensorDim(compatible_input, data_format, 'W'), + GetTensorDim(compatible_input, data_format, 'C')); + if (nchw_shape.dim_size(1) > 1) { + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, + nchw_shape, &transformed_input)); + functor::NHWCToNCHW()( + ctx->eigen_device(), + const_cast(compatible_input).tensor(), + transformed_input.tensor()); } else { - transformed_input = compatible_input; + // If depth <= 1, just reshape. + CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape)); } + } else { + transformed_input = compatible_input; + } - auto out_backprop_ptr = - AsDeviceMemory(transformed_out_backprop.template flat().data(), - transformed_out_backprop.template flat().size()); - auto filter_backprop_ptr = AsDeviceMemory( - pre_transformed_filter_backprop.template flat().data(), - pre_transformed_filter_backprop.template flat().size()); - auto input_ptr = - AsDeviceMemory(transformed_input.template flat().data(), - transformed_input.template flat().size()); - - static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit( - "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default - ); - int device_id = stream->parent()->device_ordinal(); - DataType dtype = input.dtype(); - ConvParameters conv_parameters = { - dims.batch_size, // batch - dims.in_depth, // in_depths - {{input_desc.height(), // in_rows - input_desc.width()}}, // in_cols - dims.out_depth, // out_depths - {{dims.spatial_dims[0].filter_size, // filter_rows - dims.spatial_dims[1].filter_size}}, // filter_cols - {{dims.spatial_dims[0].stride, // stride_rows - dims.spatial_dims[1].stride}}, // stride_cols - {{padding_rows, // padding_rows - padding_cols}}, // padding_cols - dtype, // tensor datatype - device_id, // device_id - }; - AlgorithmConfig algorithm_config; - if (cudnn_use_autotune_ && !AutoTuneConvBwdFilter::GetInstance()->Find( - conv_parameters, &algorithm_config)) { - std::vector algorithms; - CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms( - conv_parameters.ShouldIncludeWinogradNonfusedAlgo(), &algorithms)); - ProfileResult best_result; - ProfileResult best_result_no_scratch; - for (auto profile_algorithm : algorithms) { - // TODO(zhengxq): profile each algorithm multiple times to better - // accuracy. - CudnnScratchAllocator scratch_allocator( - ConvolveBackwardFilterScratchSize, context); - ProfileResult profile_result; - bool cudnn_launch_status = - stream - ->ThenConvolveBackwardFilterWithAlgorithm( - input_desc, input_ptr, output_desc, out_backprop_ptr, - conv_desc, filter_desc, &filter_backprop_ptr, - &scratch_allocator, AlgorithmConfig(profile_algorithm), - &profile_result) - .ok(); - if (cudnn_launch_status) { - if (profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalByteSize() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_no_scratch.elapsed_time_in_ms()) { - best_result_no_scratch = profile_result; - } + auto out_backprop_ptr = + AsDeviceMemory(transformed_out_backprop.template flat().data(), + transformed_out_backprop.template flat().size()); + auto filter_backprop_ptr = + AsDeviceMemory(pre_transformed_filter_backprop.template flat().data(), + pre_transformed_filter_backprop.template flat().size()); + auto input_ptr = AsDeviceMemory(transformed_input.template flat().data(), + transformed_input.template flat().size()); + + static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit( + "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default + ); + int device_id = stream->parent()->device_ordinal(); + DataType dtype = input.dtype(); + ConvParameters conv_parameters = { + dims.batch_size, // batch + dims.in_depth, // in_depths + {{input_desc.height(), // in_rows + input_desc.width()}}, // in_cols + dims.out_depth, // out_depths + {{dims.spatial_dims[0].filter_size, // filter_rows + dims.spatial_dims[1].filter_size}}, // filter_cols + {{dims.spatial_dims[0].stride, // stride_rows + dims.spatial_dims[1].stride}}, // stride_cols + {{padding_rows, // padding_rows + padding_cols}}, // padding_cols + dtype, // tensor datatype + device_id, // device_id + }; + AlgorithmConfig algorithm_config; + if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find( + conv_parameters, &algorithm_config)) { + std::vector algorithms; + CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms( + conv_parameters.ShouldIncludeWinogradNonfusedAlgo(), &algorithms)); + ProfileResult best_result; + ProfileResult best_result_no_scratch; + for (auto profile_algorithm : algorithms) { + // TODO(zhengxq): profile each algorithm multiple times to better + // accuracy. + CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, + ctx); + ProfileResult profile_result; + bool cudnn_launch_status = + stream + ->ThenConvolveBackwardFilterWithAlgorithm( + input_desc, input_ptr, output_desc, out_backprop_ptr, + conv_desc, filter_desc, &filter_backprop_ptr, + &scratch_allocator, AlgorithmConfig(profile_algorithm), + &profile_result) + .ok(); + if (cudnn_launch_status) { + if (profile_result.is_valid()) { + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + if (scratch_allocator.TotalByteSize() == 0 && + profile_result.elapsed_time_in_ms() < + best_result_no_scratch.elapsed_time_in_ms()) { + best_result_no_scratch = profile_result; } } } - OP_REQUIRES(context, - best_result.is_valid() || best_result_no_scratch.is_valid(), - errors::NotFound("No algorithm worked!")); - if (best_result.is_valid()) { - algorithm_config.set_algorithm(best_result.algorithm()); - } - if (best_result_no_scratch.is_valid()) { - algorithm_config.set_algorithm_no_scratch( - best_result_no_scratch.algorithm()); - } - AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters, - algorithm_config); } - CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, - context); - bool cudnn_launch_status = - stream - ->ThenConvolveBackwardFilterWithAlgorithm( - input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc, - filter_desc, &filter_backprop_ptr, &scratch_allocator, - algorithm_config, nullptr) - .ok(); - - if (!cudnn_launch_status) { - context->SetStatus(errors::Internal( - "cuDNN Backward Filter function launch failure : input shape(", - input_shape.DebugString(), ") filter shape(", - filter_shape.DebugString(), ")")); - return; + OP_REQUIRES(ctx, + best_result.is_valid() || best_result_no_scratch.is_valid(), + errors::NotFound("No algorithm worked!")); + if (best_result.is_valid()) { + algorithm_config.set_algorithm(best_result.algorithm()); } - - auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; - functor::ReverseTransformFilter()( - context->eigen_device(), - toConstTensor(pre_transformed_filter_backprop).template tensor(), - filter_backprop->tensor()); + if (best_result_no_scratch.is_valid()) { + algorithm_config.set_algorithm_no_scratch( + best_result_no_scratch.algorithm()); + } + AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters, + algorithm_config); + } + CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, + ctx); + bool cudnn_launch_status = + stream + ->ThenConvolveBackwardFilterWithAlgorithm( + input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc, + filter_desc, &filter_backprop_ptr, &scratch_allocator, + algorithm_config, nullptr) + .ok(); + + if (!cudnn_launch_status) { + ctx->SetStatus(errors::Internal( + "cuDNN Backward Filter function launch failure : input shape(", + input.shape().DebugString(), ") filter shape(", + filter_shape.DebugString(), ")")); + return; } - private: - std::vector strides_; - Padding padding_; - bool use_cudnn_; - TensorFormat data_format_; - bool cudnn_use_autotune_; - bool cudnn_disable_conv_1x1_optimization_; - - TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropFilterOp); -}; + auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; + functor::ReverseTransformFilter()( + ctx->eigen_device(), + toConstTensor(pre_transformed_filter_backprop).template tensor(), + filter_backprop->tensor()); +} // Forward declarations of the functor specializations for GPU. namespace functor { diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index a5a9549a2f98a8c7a8053778cb337084a45510cb..ce561aa99c218a57b641e8647bbb1d57e0226fe0 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -97,29 +97,17 @@ typedef Eigen::GpuDevice GPUDevice; // for CPU for now since nvcc times out when trying to compile them. // TODO(yangke): enable them for GPUs when we have a faster compiler. -template -struct LaunchBackwardInputConvolution { - bool operator()(OpKernelContext* context, const Device&, - typename TTypes::Tensor, - typename TTypes::ConstTensor, - typename TTypes::ConstTensor, int, int, int, int, - TensorFormat) const { - return false; - } -}; - -template <> -struct LaunchBackwardInputConvolution { - bool operator()(OpKernelContext* context, const CPUDevice& d, - typename TTypes::Tensor input_backward, - typename TTypes::ConstTensor kernel, - typename TTypes::ConstTensor output_backward, - int input_rows, int input_cols, int row_stride, - int col_stride, TensorFormat data_format) const { - functor::SpatialConvolutionBackwardInput()( - d, input_backward, kernel, output_backward, input_rows, input_cols, - row_stride, col_stride); - return true; +template +struct LaunchConv2DBackpropInputOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& filter, + int row_stride, int col_stride, const Padding& padding, + Tensor* in_backprop, TensorFormat data_format) { + const CPUDevice& d = ctx->eigen_device(); + functor::SpatialConvolutionBackwardInput()( + d, in_backprop->tensor(), filter.tensor(), + out_backprop.tensor(), in_backprop->dim_size(1), + in_backprop->dim_size(2), row_stride, col_stride); } }; @@ -268,11 +256,10 @@ class Conv2DFastBackpropInputOp : public OpKernel { } #endif - LaunchBackwardInputConvolution()( - context, context->eigen_device(), in_backprop->tensor(), - filter.tensor(), out_backprop.tensor(), - dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size, - dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, data_format_); + LaunchConv2DBackpropInputOp()( + context, false, false, out_backprop, filter, + dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, padding_, + in_backprop, data_format_); } private: @@ -600,10 +587,6 @@ class Conv2DSlowBackpropInputOp : public OpKernel { } void Compute(OpKernelContext* context) override { - using perftools::gputools::dnn::AlgorithmConfig; - using perftools::gputools::dnn::AlgorithmType; - using perftools::gputools::dnn::ProfileResult; - using perftools::gputools::dnn::kDefaultAlgorithm; const Tensor& input_sizes = context->input(0); const Tensor& filter = context->input(1); const Tensor& out_backprop = context->input(2); @@ -615,351 +598,372 @@ class Conv2DSlowBackpropInputOp : public OpKernel { TensorShape input_shape; OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( input_sizes.vec(), &input_shape)); - const TensorShape& filter_shape = filter.shape(); - - ConvBackpropDimensions dims; - OP_REQUIRES_OK( - context, ConvBackpropComputeDimensions( - "Conv2DSlowBackpropInput", /*num_spatial_dims=*/2, - input_shape, filter_shape, out_backprop.shape(), strides_, - padding_, data_format_, &dims)); Tensor* in_backprop = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); - const int padding_rows = - (padding_ == VALID) - ? 0 - : std::max(0, (dims.spatial_dims[0].output_size - 1) * - dims.spatial_dims[0].stride + - dims.spatial_dims[0].filter_size - - dims.spatial_dims[0].input_size); - const int padding_cols = - (padding_ == VALID) - ? 0 - : std::max(0, (dims.spatial_dims[1].output_size - 1) * - dims.spatial_dims[1].stride + - dims.spatial_dims[1].filter_size - - dims.spatial_dims[1].input_size); - - // TODO(keveman): cuDNN only supports equal padding on both sides, so only - // calling it when that is true. Remove this check when (if?) cuDNN starts - // supporting different padding. - bool rows_odd = (padding_rows % 2 != 0); - bool cols_odd = (padding_cols % 2 != 0); - - auto* stream = context->op_device_context()->stream(); - OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - - if (!use_cudnn_) { - context->SetStatus(errors::Unimplemented( - "Conv2DBackpropInput for GPU is not currently supported " - "without cudnn")); - return; - } + // For now we take the stride from the second and third dimensions only (we + // do not support striding on the batch or depth dimension). + const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); + const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); - if (dims.spatial_dims[0].filter_size == 1 && - dims.spatial_dims[1].filter_size == 1 && - dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 && - data_format_ == FORMAT_NHWC) { - // 1x1 filter, so call cublas directly. - const uint64 m = dims.batch_size * dims.spatial_dims[0].input_size * - dims.spatial_dims[1].input_size; - const uint64 k = dims.out_depth; - const uint64 n = dims.in_depth; - - auto a_ptr = AsDeviceMemory(out_backprop.template flat().data(), - out_backprop.template flat().size()); - auto b_ptr = AsDeviceMemory(filter.template flat().data(), - filter.template flat().size()); - auto c_ptr = AsDeviceMemory(in_backprop->template flat().data(), - in_backprop->template flat().size()); - - auto transpose = perftools::gputools::blas::Transpose::kTranspose; - auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; - - bool blas_launch_status = - stream - ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k, - a_ptr, k, 0.0f, &c_ptr, n) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, - ", n=", n, ", k=", k)); - } - return; - } else if (dims.spatial_dims[0].filter_size == - dims.spatial_dims[0].input_size && - dims.spatial_dims[1].filter_size == - dims.spatial_dims[1].input_size && - padding_ == VALID && data_format_ == FORMAT_NHWC) { - // The input data and filter have the same height/width, so call cublas - // directly. - const uint64 m = dims.batch_size; - const uint64 k = dims.out_depth; - const uint64 n = dims.spatial_dims[0].input_size * - dims.spatial_dims[1].input_size * dims.in_depth; - - auto a_ptr = AsDeviceMemory(out_backprop.template flat().data(), - out_backprop.template flat().size()); - auto b_ptr = AsDeviceMemory(filter.template flat().data(), - filter.template flat().size()); - auto c_ptr = AsDeviceMemory(in_backprop->template flat().data(), - in_backprop->template flat().size()); - - auto transpose = perftools::gputools::blas::Transpose::kTranspose; - auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; - - bool blas_launch_status = - stream - ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k, - a_ptr, k, 0.0f, &c_ptr, n) - .ok(); - if (!blas_launch_status) { - context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, - ", n=", n, ", k=", k)); - } - return; - } + launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter, + stride_rows, stride_cols, padding_, in_backprop, data_format_); + } - TensorShape compatible_input_shape; - if (rows_odd || cols_odd) { - // If a padding dimension is odd, we have one more element on the right - // side or the bottom side. This is unsupported in cudnn. Therefore, - // we pad that extra element and make it compatible. - compatible_input_shape = ShapeFromFormat( - data_format_, dims.batch_size, - dims.spatial_dims[0].input_size + rows_odd, - dims.spatial_dims[1].input_size + cols_odd, dims.in_depth); - } else { - compatible_input_shape = input_shape; + private: + std::vector strides_; + Padding padding_; + bool use_cudnn_; + TensorFormat data_format_; + LaunchConv2DBackpropInputOp launcher_; + bool cudnn_use_autotune_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropInputOp); +}; + +template +void LaunchConv2DBackpropInputOp::operator()( + OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& filter, int row_stride, + int col_stride, const Padding& padding, Tensor* in_backprop, + TensorFormat data_format) { + using perftools::gputools::dnn::AlgorithmConfig; + using perftools::gputools::dnn::AlgorithmType; + using perftools::gputools::dnn::ProfileResult; + + std::vector strides(4, 1); + strides[GetTensorDimIndex(data_format, 'H')] = row_stride; + strides[GetTensorDimIndex(data_format, 'W')] = col_stride; + TensorShape input_shape = in_backprop->shape(); + + const TensorShape& filter_shape = filter.shape(); + ConvBackpropDimensions dims; + OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions( + "Conv2DSlowBackpropInput", /*num_spatial_dims=*/2, + input_shape, filter_shape, out_backprop.shape(), + strides, padding, data_format, &dims)); + + const int padding_rows = + (padding == VALID) + ? 0 + : std::max(0, (dims.spatial_dims[0].output_size - 1) * + dims.spatial_dims[0].stride + + dims.spatial_dims[0].filter_size - + dims.spatial_dims[0].input_size); + const int padding_cols = + (padding == VALID) + ? 0 + : std::max(0, (dims.spatial_dims[1].output_size - 1) * + dims.spatial_dims[1].stride + + dims.spatial_dims[1].filter_size - + dims.spatial_dims[1].input_size); + + // TODO(keveman): cuDNN only supports equal padding on both sides, so only + // calling it when that is true. Remove this check when (if?) cuDNN starts + // supporting different padding. + bool rows_odd = (padding_rows % 2 != 0); + bool cols_odd = (padding_cols % 2 != 0); + + auto* stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); + + if (!use_cudnn) { + ctx->SetStatus(errors::Unimplemented( + "Conv2DBackpropInput for GPU is not currently supported " + "without cudnn")); + return; + } + + if (dims.spatial_dims[0].filter_size == 1 && + dims.spatial_dims[1].filter_size == 1 && + dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 && + data_format == FORMAT_NHWC) { + // 1x1 filter, so call cublas directly. + const uint64 m = dims.batch_size * dims.spatial_dims[0].input_size * + dims.spatial_dims[1].input_size; + const uint64 k = dims.out_depth; + const uint64 n = dims.in_depth; + + auto a_ptr = AsDeviceMemory(out_backprop.template flat().data(), + out_backprop.template flat().size()); + auto b_ptr = AsDeviceMemory(filter.template flat().data(), + filter.template flat().size()); + auto c_ptr = AsDeviceMemory(in_backprop->template flat().data(), + in_backprop->template flat().size()); + + auto transpose = perftools::gputools::blas::Transpose::kTranspose; + auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + + bool blas_launch_status = + stream + ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k, + a_ptr, k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); } + return; + } else if (dims.spatial_dims[0].filter_size == + dims.spatial_dims[0].input_size && + dims.spatial_dims[1].filter_size == + dims.spatial_dims[1].input_size && + padding == VALID && data_format == FORMAT_NHWC) { + // The input data and filter have the same height/width, so call cublas + // directly. + const uint64 m = dims.batch_size; + const uint64 k = dims.out_depth; + const uint64 n = dims.spatial_dims[0].input_size * + dims.spatial_dims[1].input_size * dims.in_depth; + + auto a_ptr = AsDeviceMemory(out_backprop.template flat().data(), + out_backprop.template flat().size()); + auto b_ptr = AsDeviceMemory(filter.template flat().data(), + filter.template flat().size()); + auto c_ptr = AsDeviceMemory(in_backprop->template flat().data(), + in_backprop->template flat().size()); + + auto transpose = perftools::gputools::blas::Transpose::kTranspose; + auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + + bool blas_launch_status = + stream + ->ThenBlasGemm(transpose, no_transpose, n, m, k, 1.0f, b_ptr, k, + a_ptr, k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); + } + return; + } - CHECK(padding_rows >= 0 && padding_cols >= 0) - << "Negative row or col paddings: (" << padding_rows << ", " - << padding_cols << ")"; - perftools::gputools::dnn::BatchDescriptor input_desc; - input_desc.set_count(dims.batch_size) - .set_height(GetTensorDim(compatible_input_shape, data_format_, 'H')) - .set_width(GetTensorDim(compatible_input_shape, data_format_, 'W')) - .set_feature_map_count(dims.in_depth) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::BatchDescriptor output_desc; - output_desc.set_count(dims.batch_size) - .set_height(dims.spatial_dims[0].output_size) - .set_width(dims.spatial_dims[1].output_size) - .set_feature_map_count(dims.out_depth) - .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); - perftools::gputools::dnn::FilterDescriptor filter_desc; - filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size) - .set_input_filter_width(dims.spatial_dims[1].filter_size) - .set_input_feature_map_count(dims.in_depth) - .set_output_feature_map_count(dims.out_depth); - perftools::gputools::dnn::ConvolutionDescriptor conv_desc; - conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride) - .set_horizontal_filter_stride(dims.spatial_dims[1].stride) - .set_zero_padding_height(padding_rows / 2) - .set_zero_padding_width(padding_cols / 2); - - // NOTE(keveman): - // cuDNN only supports the following layouts : - // Input : B x D x R x C - // Filter : OD x ID x R x C - // Whereas, we have - // Input : B x R x C x D - // Filter : R x C x ID x OD - // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C) - // The first TransformDepth performs - // (B x R x C x D) => (B x D x R x C). - // Since the tensor returned from cuDNN is B x D x R x C also, - // the second TransformDepth performs - // (B x D x R x C) => (B x R x C x D). - Tensor transformed_filter; - OP_REQUIRES_OK(context, context->allocate_temp( - DataTypeToEnum::value, - TensorShape({dims.out_depth, dims.in_depth, - dims.spatial_dims[0].filter_size, - dims.spatial_dims[1].filter_size}), - &transformed_filter)); - - functor::TransformFilter()( - context->eigen_device(), To32Bit(filter.tensor()), - To32Bit(transformed_filter.tensor())); - - Tensor transformed_out_backprop; - if (data_format_ == FORMAT_NHWC) { - TensorShape nchw_shape = ShapeFromFormat( - FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size, - dims.spatial_dims[1].output_size, dims.out_depth); - if (dims.out_depth > 1) { - OP_REQUIRES_OK(context, context->allocate_temp( - DataTypeToEnum::value, nchw_shape, - &transformed_out_backprop)); - functor::NHWCToNCHW()( - context->eigen_device(), out_backprop.tensor(), - transformed_out_backprop.tensor()); - } else { - // If depth <= 1, then just reshape. - CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); - } + TensorShape compatible_input_shape; + if (rows_odd || cols_odd) { + // If a padding dimension is odd, we have one more element on the right + // side or the bottom side. This is unsupported in cudnn. Therefore, + // we pad that extra element and make it compatible. + compatible_input_shape = ShapeFromFormat( + data_format, dims.batch_size, + dims.spatial_dims[0].input_size + rows_odd, + dims.spatial_dims[1].input_size + cols_odd, dims.in_depth); + } else { + compatible_input_shape = input_shape; + } + + CHECK(padding_rows >= 0 && padding_cols >= 0) + << "Negative row or col paddings: (" << padding_rows << ", " + << padding_cols << ")"; + perftools::gputools::dnn::BatchDescriptor input_desc; + input_desc.set_count(dims.batch_size) + .set_height(GetTensorDim(compatible_input_shape, data_format, 'H')) + .set_width(GetTensorDim(compatible_input_shape, data_format, 'W')) + .set_feature_map_count(dims.in_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::BatchDescriptor output_desc; + output_desc.set_count(dims.batch_size) + .set_height(dims.spatial_dims[0].output_size) + .set_width(dims.spatial_dims[1].output_size) + .set_feature_map_count(dims.out_depth) + .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); + perftools::gputools::dnn::FilterDescriptor filter_desc; + filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size) + .set_input_filter_width(dims.spatial_dims[1].filter_size) + .set_input_feature_map_count(dims.in_depth) + .set_output_feature_map_count(dims.out_depth); + perftools::gputools::dnn::ConvolutionDescriptor conv_desc; + conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride) + .set_horizontal_filter_stride(dims.spatial_dims[1].stride) + .set_zero_padding_height(padding_rows / 2) + .set_zero_padding_width(padding_cols / 2); + + // NOTE(keveman): + // cuDNN only supports the following layouts : + // Input : B x D x R x C + // Filter : OD x ID x R x C + // Whereas, we have + // Input : B x R x C x D + // Filter : R x C x ID x OD + // TransformFilter performs (R x C x ID x OD) => (OD x ID x R x C) + // The first TransformDepth performs + // (B x R x C x D) => (B x D x R x C). + // Since the tensor returned from cuDNN is B x D x R x C also, + // the second TransformDepth performs + // (B x D x R x C) => (B x R x C x D). + Tensor transformed_filter; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum::value, + TensorShape({dims.out_depth, dims.in_depth, + dims.spatial_dims[0].filter_size, + dims.spatial_dims[1].filter_size}), + &transformed_filter)); + + functor::TransformFilter()( + ctx->eigen_device(), To32Bit(filter.tensor()), + To32Bit(transformed_filter.tensor())); + + Tensor transformed_out_backprop; + if (data_format == FORMAT_NHWC) { + TensorShape nchw_shape = ShapeFromFormat( + FORMAT_NCHW, dims.batch_size, dims.spatial_dims[0].output_size, + dims.spatial_dims[1].output_size, dims.out_depth); + if (dims.out_depth > 1) { + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum::value, nchw_shape, + &transformed_out_backprop)); + functor::NHWCToNCHW()( + ctx->eigen_device(), out_backprop.tensor(), + transformed_out_backprop.tensor()); } else { - transformed_out_backprop = out_backprop; + // If depth <= 1, then just reshape. + CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); } + } else { + transformed_out_backprop = out_backprop; + } - Tensor pre_transformed_in_backprop; - OP_REQUIRES_OK( - context, - context->allocate_temp( - DataTypeToEnum::value, - ShapeFromFormat( - FORMAT_NCHW, - GetTensorDim(compatible_input_shape, data_format_, 'N'), - GetTensorDim(compatible_input_shape, data_format_, 'H'), - GetTensorDim(compatible_input_shape, data_format_, 'W'), - GetTensorDim(compatible_input_shape, data_format_, 'C')), - &pre_transformed_in_backprop)); - - auto out_backprop_ptr = - AsDeviceMemory(transformed_out_backprop.template flat().data(), - transformed_out_backprop.template flat().size()); - auto filter_ptr = - AsDeviceMemory(transformed_filter.template flat().data(), - transformed_filter.template flat().size()); - auto in_backprop_ptr = - AsDeviceMemory(pre_transformed_in_backprop.template flat().data(), - pre_transformed_in_backprop.template flat().size()); - - static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit( - "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default - ); - CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, - context); - int device_id = stream->parent()->device_ordinal(); - DataType dtype = out_backprop.dtype(); - ConvParameters conv_parameters = { - dims.batch_size, // batch - dims.in_depth, // in_depths - {{input_desc.height(), // in_rows - input_desc.width()}}, // in_cols - dims.out_depth, // out_depths - {{dims.spatial_dims[0].filter_size, // filter_rows - dims.spatial_dims[1].filter_size}}, // filter_cols - {{dims.spatial_dims[0].stride, // stride_rows - dims.spatial_dims[1].stride}}, // stride_cols - {{padding_rows, // padding_rows - padding_cols}}, // padding_cols - dtype, // tensor data type - device_id, // device_id - }; - AlgorithmConfig algorithm_config; - if (cudnn_use_autotune_ && !AutoTuneConvBwdData::GetInstance()->Find( - conv_parameters, &algorithm_config)) { - std::vector algorithms; - CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms( - conv_parameters.ShouldIncludeWinogradNonfusedAlgo(), &algorithms)); - ProfileResult best_result; - ProfileResult best_result_no_scratch; - for (auto profile_algorithm : algorithms) { - // TODO(zhengxq): profile each algorithm multiple times to better - // accuracy. - CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, - context); - ProfileResult profile_result; - bool cudnn_launch_status = - stream - ->ThenConvolveBackwardDataWithAlgorithm( - filter_desc, filter_ptr, output_desc, out_backprop_ptr, - conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, - AlgorithmConfig(profile_algorithm), &profile_result) - .ok(); - if (cudnn_launch_status) { - if (profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalByteSize() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_no_scratch.elapsed_time_in_ms()) { - best_result_no_scratch = profile_result; - } + Tensor pre_transformed_in_backprop; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DataTypeToEnum::value, + ShapeFromFormat( + FORMAT_NCHW, + GetTensorDim(compatible_input_shape, data_format, 'N'), + GetTensorDim(compatible_input_shape, data_format, 'H'), + GetTensorDim(compatible_input_shape, data_format, 'W'), + GetTensorDim(compatible_input_shape, data_format, 'C')), + &pre_transformed_in_backprop)); + + auto out_backprop_ptr = + AsDeviceMemory(transformed_out_backprop.template flat().data(), + transformed_out_backprop.template flat().size()); + auto filter_ptr = + AsDeviceMemory(transformed_filter.template flat().data(), + transformed_filter.template flat().size()); + auto in_backprop_ptr = + AsDeviceMemory(pre_transformed_in_backprop.template flat().data(), + pre_transformed_in_backprop.template flat().size()); + + static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit( + "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default + ); + CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx); + int device_id = stream->parent()->device_ordinal(); + DataType dtype = out_backprop.dtype(); + ConvParameters conv_parameters = { + dims.batch_size, // batch + dims.in_depth, // in_depths + {{input_desc.height(), // in_rows + input_desc.width()}}, // in_cols + dims.out_depth, // out_depths + {{dims.spatial_dims[0].filter_size, // filter_rows + dims.spatial_dims[1].filter_size}}, // filter_cols + {{dims.spatial_dims[0].stride, // stride_rows + dims.spatial_dims[1].stride}}, // stride_cols + {{padding_rows, // padding_rows + padding_cols}}, // padding_cols + dtype, // tensor data type + device_id, // device_id + }; + AlgorithmConfig algorithm_config; + if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find( + conv_parameters, &algorithm_config)) { + std::vector algorithms; + CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms( + conv_parameters.ShouldIncludeWinogradNonfusedAlgo(), &algorithms)); + ProfileResult best_result; + ProfileResult best_result_no_scratch; + for (auto profile_algorithm : algorithms) { + // TODO(zhengxq): profile each algorithm multiple times to better + // accuracy. + CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, + ctx); + ProfileResult profile_result; + bool cudnn_launch_status = + stream + ->ThenConvolveBackwardDataWithAlgorithm( + filter_desc, filter_ptr, output_desc, out_backprop_ptr, + conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, + AlgorithmConfig(profile_algorithm), &profile_result) + .ok(); + if (cudnn_launch_status) { + if (profile_result.is_valid()) { + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + if (scratch_allocator.TotalByteSize() == 0 && + profile_result.elapsed_time_in_ms() < + best_result_no_scratch.elapsed_time_in_ms()) { + best_result_no_scratch = profile_result; } } } - OP_REQUIRES(context, - best_result.is_valid() || best_result_no_scratch.is_valid(), - errors::NotFound("No algorithm worked!")); - if (best_result.is_valid()) { - algorithm_config.set_algorithm(best_result.algorithm()); - } - if (best_result_no_scratch.is_valid()) { - algorithm_config.set_algorithm_no_scratch( - best_result_no_scratch.algorithm()); - } - AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters, - algorithm_config); - } - bool cudnn_launch_status = - stream - ->ThenConvolveBackwardDataWithAlgorithm( - filter_desc, filter_ptr, output_desc, out_backprop_ptr, - conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, - algorithm_config, nullptr) - .ok(); - - if (!cudnn_launch_status) { - context->SetStatus(errors::Internal( - "cuDNN Backward Data function launch failure : input shape(", - input_shape.DebugString(), ") filter shape(", - filter_shape.DebugString(), ")")); - return; } - - if (rows_odd || cols_odd) { - Tensor in_backprop_remove_padding; - OP_REQUIRES_OK( - context, - context->allocate_temp( - DataTypeToEnum::value, - ShapeFromFormat(FORMAT_NCHW, - GetTensorDim(input_shape, data_format_, 'N'), - GetTensorDim(input_shape, data_format_, 'H'), - GetTensorDim(input_shape, data_format_, 'W'), - GetTensorDim(input_shape, data_format_, 'C')), - &in_backprop_remove_padding)); - - // Remove the padding for odd rows or cols. - functor::PadInput()( - context->template eigen_device(), - To32Bit(const_cast(pre_transformed_in_backprop) - .tensor()), - {{0, 0}}, {{-rows_odd, -cols_odd}}, - To32Bit(in_backprop_remove_padding.tensor()), FORMAT_NCHW); - - pre_transformed_in_backprop = in_backprop_remove_padding; + OP_REQUIRES(ctx, + best_result.is_valid() || best_result_no_scratch.is_valid(), + errors::NotFound("No algorithm worked!")); + if (best_result.is_valid()) { + algorithm_config.set_algorithm(best_result.algorithm()); } - - if (data_format_ == FORMAT_NHWC) { - auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; - functor::NCHWToNHWC()( - context->eigen_device(), - toConstTensor(pre_transformed_in_backprop).template tensor(), - in_backprop->tensor()); - } else { - *in_backprop = pre_transformed_in_backprop; + if (best_result_no_scratch.is_valid()) { + algorithm_config.set_algorithm_no_scratch( + best_result_no_scratch.algorithm()); } + AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters, + algorithm_config); + } + bool cudnn_launch_status = + stream + ->ThenConvolveBackwardDataWithAlgorithm( + filter_desc, filter_ptr, output_desc, out_backprop_ptr, conv_desc, + input_desc, &in_backprop_ptr, &scratch_allocator, + algorithm_config, nullptr) + .ok(); + + if (!cudnn_launch_status) { + ctx->SetStatus(errors::Internal( + "cuDNN Backward Data function launch failure : input shape(", + input_shape.DebugString(), ") filter shape(", + filter_shape.DebugString(), ")")); + return; } - private: - std::vector strides_; - Padding padding_; - bool use_cudnn_; - TensorFormat data_format_; - bool cudnn_use_autotune_; + if (rows_odd || cols_odd) { + Tensor in_backprop_remove_padding; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DataTypeToEnum::value, + ShapeFromFormat(FORMAT_NCHW, + GetTensorDim(input_shape, data_format, 'N'), + GetTensorDim(input_shape, data_format, 'H'), + GetTensorDim(input_shape, data_format, 'W'), + GetTensorDim(input_shape, data_format, 'C')), + &in_backprop_remove_padding)); + + // Remove the padding for odd rows or cols. + functor::PadInput()( + ctx->template eigen_device(), + To32Bit(const_cast(pre_transformed_in_backprop) + .tensor()), + {{0, 0}}, {{-rows_odd, -cols_odd}}, + To32Bit(in_backprop_remove_padding.tensor()), FORMAT_NCHW); + + pre_transformed_in_backprop = in_backprop_remove_padding; + } - TF_DISALLOW_COPY_AND_ASSIGN(Conv2DSlowBackpropInputOp); -}; + if (data_format == FORMAT_NHWC) { + auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; + functor::NCHWToNHWC()( + ctx->eigen_device(), + toConstTensor(pre_transformed_in_backprop).template tensor(), + in_backprop->tensor()); + } else { + *in_backprop = pre_transformed_in_backprop; + } +} // Forward declarations of the functor specializations for GPU. namespace functor { diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h index 3ea9510afbac1876502e69b29904ef3a1c225f28..2926bb3a86751f9351bcb21e770e73766487d761 100644 --- a/tensorflow/core/kernels/conv_grad_ops.h +++ b/tensorflow/core/kernels/conv_grad_ops.h @@ -168,6 +168,43 @@ limitations under the License. namespace tensorflow { +// Forward declaration. +class OpKernelContext; + +template +struct LaunchConv2DBackpropInputOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& filter, + int row_stride, int col_stride, const Padding& padding, + Tensor* in_backprop, TensorFormat data_format); +}; + +template +struct LaunchConv2DBackpropFilterOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& input, + int row_stride, int col_stride, const Padding& padding, + Tensor* filter_backprop, TensorFormat data_format); +}; + +#ifdef GOOGLE_CUDA +template +struct LaunchConv2DBackpropInputOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_stride, + int col_stride, const Padding& padding, Tensor* output, + TensorFormat data_format); +}; + +template +struct LaunchConv2DBackpropFilterOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& input, + int row_stride, int col_stride, const Padding& padding, + Tensor* filter_backprop, TensorFormat data_format); +}; +#endif // GOOGLE_CUDA + // Information about a single spatial dimension for a convolution // backpropagation. struct ConvBackpropSpatialDimension { diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 2c77a389527fa356ece3f72d9f9d4687e088c382..bbb9e36fc9dce1b5964839dfb858b5156c0e662b 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -58,10 +58,10 @@ typedef Eigen::GpuDevice GPUDevice; namespace { template struct LaunchGeneric { - static void launch(OpKernelContext* ctx, const Tensor& input, - const Tensor& filter, int row_stride, int col_stride, - const Eigen::PaddingType& padding, Tensor* output, - TensorFormat data_format) { + void operator()(OpKernelContext* ctx, const Tensor& input, + const Tensor& filter, int row_stride, int col_stride, + const Padding& padding, Tensor* output, + TensorFormat data_format) { CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only " "supports NHWC tensor format for now."; if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 && @@ -86,8 +86,7 @@ struct LaunchGeneric { filter.shaped({filter.dim_size(2), filter.dim_size(3)}), dim_pair); } else if (filter.dim_size(0) == input.dim_size(1) && - filter.dim_size(1) == input.dim_size(2) && - padding == Eigen::PADDING_VALID) { + filter.dim_size(1) == input.dim_size(2) && padding == VALID) { // If the input data and filter have the same height/width, // the 2D convolution is reduced to matrix multiplication. const int k = // Length of reduction dimension. @@ -104,28 +103,26 @@ struct LaunchGeneric { functor::SpatialConvolution()( ctx->eigen_device(), output->tensor(), input.tensor(), filter.tensor(), row_stride, col_stride, - padding); + BrainPadding2EigenPadding(padding)); } } }; } // namespace template -class LaunchConv2DOp { - public: - void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int row_stride, - int col_stride, const Eigen::PaddingType& padding, Tensor* output, - TensorFormat data_format) { +struct LaunchConv2DOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_stride, + int col_stride, const Padding& padding, Tensor* output, + TensorFormat data_format) { if (data_format != FORMAT_NHWC) { ctx->SetStatus( errors::Unimplemented("Generic conv implementation only supports " "NHWC tensor format for now.")); return; } - LaunchGeneric::launch(ctx, input, filter, row_stride, - col_stride, padding, output, - data_format); + LaunchGeneric()(ctx, input, filter, row_stride, col_stride, + padding, output, data_format); } }; @@ -387,9 +384,8 @@ class Conv2DOp : public BinaryOp { return; } - launcher_.launch(context, use_cudnn_, cudnn_use_autotune_, input, filter, - stride_rows, stride_cols, - BrainPadding2EigenPadding(padding_), output, data_format_); + launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter, + stride_rows, stride_cols, padding_, output, data_format_); } private: @@ -445,10 +441,10 @@ typedef AutoTuneSingleton -void LaunchConv2DOp::launch( +void LaunchConv2DOp::operator()( OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, const Tensor& input_param, const Tensor& filter, int row_stride, - int col_stride, const Eigen::PaddingType& padding, Tensor* output, + int col_stride, const Padding& padding, Tensor* output, TensorFormat data_format) { using perftools::gputools::dnn::AlgorithmConfig; using perftools::gputools::dnn::AlgorithmType; @@ -492,8 +488,8 @@ void LaunchConv2DOp::launch( } return; } else if (filter.dim_size(0) == input.dim_size(1) && - filter.dim_size(1) == input.dim_size(2) && - padding == Eigen::PADDING_VALID && data_format == FORMAT_NHWC) { + filter.dim_size(1) == input.dim_size(2) && padding == VALID && + data_format == FORMAT_NHWC) { // The input data and filter have the same height/width, so call cublas // directly. const uint64 m = input.dim_size(0); @@ -533,7 +529,7 @@ void LaunchConv2DOp::launch( const int64 out_depths = GetTensorDim(*output, data_format, 'C'); const int64 patch_rows = filter.dim_size(0); const int64 patch_cols = filter.dim_size(1); - if (padding == Eigen::PADDING_SAME) { + if (padding == SAME) { // Total padding on rows and cols is // Pr = (R' - 1) * S + Kr - R // Pc = (C' - 1) * S + Kc - C diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h index 60091fc27fdc73a1814d505d4bf55851413b099f..e29271dff278afbc1ff2c947c161824615640b66 100644 --- a/tensorflow/core/kernels/conv_ops.h +++ b/tensorflow/core/kernels/conv_ops.h @@ -32,14 +32,23 @@ namespace tensorflow { class OpKernelContext; template -class LaunchConv2DOp { - public: - void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int row_stride, - int col_stride, const Eigen::PaddingType& padding, Tensor* output, - TensorFormat data_format); +struct LaunchConv2DOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_stride, + int col_stride, const Padding& padding, Tensor* output, + TensorFormat data_format); }; +#ifdef GOOGLE_CUDA +template +struct LaunchConv2DOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_stride, + int col_stride, const Padding& padding, Tensor* output, + TensorFormat data_format); +}; +#endif // GOOGLE_CUDA + // Used to keep track of persistent memory buffers used within the op. // It uses malloc and free to avoid the time cost of initializing the memory. template @@ -55,17 +64,6 @@ struct Im2ColBufferResource : public ResourceBase { string DebugString() { return "Im2ColBufferResource"; } }; -#ifdef GOOGLE_CUDA -template -class LaunchConv2DOp { - public: - void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, - const Tensor& input, const Tensor& filter, int row_stride, - int col_stride, const Eigen::PaddingType& padding, Tensor* output, - TensorFormat data_format); -}; -#endif // GOOGLE_CUDA - } // namespace tensorflow #endif // TENSORFLOW_KERNELS_CONV_OPS_H diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index 168cf37bc77ef34d29c810e020a7baeecf92f7cc..c852dc9991c2e879c8fa6a64b2bd8b5141606409 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -92,11 +92,11 @@ class ConvParameters { ConvParameters(int64 batch, int64 in_depths, const SpatialArray& in, int64 out_depths, const SpatialArray& filter, const SpatialArray& stride, const SpatialArray& padding, - const DataType& dtype, int device_id) + DataType dtype, int device_id) : batch_(batch), in_depths_(in_depths), - in_(in), out_depths_(out_depths), + in_(in), filter_(filter), stride_(stride), padding_(padding), @@ -130,7 +130,8 @@ class ConvParameters { "(", str_util::Join(filter_, ", "), "), ", "(", str_util::Join(stride_, ", "), "), ", "(", str_util::Join(padding_, ", "), "), ", - dtype_, ", ", device_id_); + dtype_, ", ", + device_id_); // clang-format on } @@ -150,26 +151,28 @@ class ConvParameters { } } - private: - typedef std::tuple - ParameterDataType; + protected: + using ParameterDataType = + std::tuple; ParameterDataType get_data_as_tuple() const { return std::make_tuple(batch_, in_depths_, in_, out_depths_, filter_, stride_, padding_, dtype_, device_id_); } + uint64 hash_code_; + + private: int64 batch_; int64 in_depths_; - SpatialArray in_; int64 out_depths_; + SpatialArray in_; SpatialArray filter_; SpatialArray stride_; SpatialArray padding_; DataType dtype_; int device_id_; - uint64 hash_code_; }; typedef Eigen::GpuDevice GPUDevice; diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc index 2307c2de0e63b06b64dfd711498d5106fdc663b3..3d4670c9bae413d187366264026522e9c8dbbd55 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc @@ -556,6 +556,7 @@ template struct functor::NCHWToNHWC; template struct functor::NCHWToNHWC; template struct functor::NCHWToNHWC; +template struct functor::PadInput; template struct functor::PadInput; template struct functor::PadInput; diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc index 56181a686cac27bf36e239726508a889c7970d68..45cc2fbbb8be8401c54f8024920f309178365747 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.cc +++ b/tensorflow/core/kernels/crop_and_resize_op.cc @@ -19,59 +19,98 @@ limitations under the License. #include "tensorflow/core/kernels/crop_and_resize_op.h" +#include +#include + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/work_sharder.h" #if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/platform/cuda.h" #include "tensorflow/core/platform/stream_executor.h" + +using ::perftools::gputools::cuda::ScopedActivateExecutorContext; #endif // GOOGLE_CUDA namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +using Callback = std::function; + +namespace { -static inline void ParseAndCheckBoxSizes(OpKernelContext* context, - const Tensor& boxes, - const Tensor& box_ind, - int* num_boxes) { - if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) { +static inline Status ParseAndCheckBoxSizes(const Tensor& boxes, + const Tensor& box_index, + int* num_boxes) { + if (boxes.NumElements() == 0 && box_index.NumElements() == 0) { *num_boxes = 0; - return; + return Status::OK(); } // The shape of 'boxes' is [num_boxes, 4]. - OP_REQUIRES(context, boxes.dims() == 2, - errors::InvalidArgument("boxes must be 2-D", - boxes.shape().DebugString())); + if (boxes.dims() != 2) { + return errors::InvalidArgument("boxes must be 2-D", + boxes.shape().DebugString()); + } *num_boxes = boxes.dim_size(0); - OP_REQUIRES(context, boxes.dim_size(1) == 4, - errors::InvalidArgument("boxes must have 4 columns")); - - // The shape of 'box_ind' is [num_boxes]. - OP_REQUIRES(context, box_ind.dims() == 1, - errors::InvalidArgument("box_ind must be 1-D", - box_ind.shape().DebugString())); - OP_REQUIRES(context, box_ind.dim_size(0) == *num_boxes, - errors::InvalidArgument("box_ind has incompatible shape")); + if (boxes.dim_size(1) != 4) { + return errors::InvalidArgument("boxes must have 4 columns"); + } + // The shape of 'box_index' is [num_boxes]. + if (box_index.dims() != 1) { + return errors::InvalidArgument("box_index must be 1-D", + box_index.shape().DebugString()); + } + if (box_index.dim_size(0) != *num_boxes) { + return errors::InvalidArgument("box_index has incompatible shape"); + } + return Status::OK(); } -// Verifies that all values in box_ind are in [0, batch). +// Conditionally calls the compute callback if all values in box_index are in +// [0, batch_size) then calls done. template -inline void CheckValidBoxInd( - OpKernelContext* context, - typename TTypes::ConstTensor box_ind_data, int batch); +inline void RunIfBoxIndexIsValid( + OpKernelContext* context, typename TTypes::ConstTensor box_index, + int batch_size, const Callback& compute, const Callback& done); + +// Specialization of CheckValidBoxIndex for a CPUDevice. +template <> +inline void RunIfBoxIndexIsValid( + OpKernelContext* context, typename TTypes::ConstTensor box_index, + int batch_size, const Callback& compute, const Callback& done) { + const int num_boxes = box_index.dimension(0); + for (int b = 0; b < num_boxes; ++b) { + OP_REQUIRES_ASYNC( + context, FastBoundsCheck(box_index(b), batch_size), + errors::OutOfRange("box_index has values outside [0, batch_size)"), + done); + } + if (compute) { + compute(); + } + if (done) { + done(); + } +} + +} // namespace template -class CropAndResizeOp : public OpKernel { +class CropAndResizeOp : public AsyncOpKernel { public: - explicit CropAndResizeOp(OpKernelConstruction* context) : OpKernel(context) { + explicit CropAndResizeOp(OpKernelConstruction* context) + : AsyncOpKernel(context) { string method; OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES(context, method == "bilinear", @@ -80,69 +119,77 @@ class CropAndResizeOp : public OpKernel { &extrapolation_value_)); } - void Compute(OpKernelContext* context) override { - // The shape of 'image' is [batch, image_height, image_width, channels]. + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + // The shape of 'image' is [batch_size, image_height, image_width, + // channels]. const Tensor& image = context->input(0); - OP_REQUIRES(context, image.dims() == 4, - errors::InvalidArgument("input image must be 4-D", - image.shape().DebugString())); - - const int batch = image.dim_size(0); - const int image_height = image.dim_size(1); - const int image_width = image.dim_size(2); - const int depth = image.dim_size(3); - OP_REQUIRES(context, image_height > 0 && image_width > 0, - errors::InvalidArgument("image dimensions must be positive")); - // The shape of 'boxes' is [num_boxes, 4]. const Tensor& boxes = context->input(1); - - // The shape of 'box_ind' is [num_boxes]. - const Tensor& box_ind = context->input(2); - - int num_boxes = 0; - ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); - + // The shape of 'box_index' is [num_boxes]. + const Tensor& box_index = context->input(2); // The shape of 'crop_size' is [2]. const Tensor& crop_size = context->input(3); - OP_REQUIRES(context, crop_size.dims() == 1, - errors::InvalidArgument("crop_size must be 1-D", - crop_size.shape().DebugString())); - OP_REQUIRES(context, crop_size.dim_size(0) == 2, - errors::InvalidArgument("crop_size must have two elements", - crop_size.shape().DebugString())); - + // Validate inputs dimensions. + OP_REQUIRES_ASYNC(context, image.dims() == 4, + errors::InvalidArgument("input image must be 4-D", + image.shape().DebugString()), + done); + const int batch_size = image.dim_size(0); + const int image_height = image.dim_size(1); + const int image_width = image.dim_size(2); + const int depth = image.dim_size(3); + OP_REQUIRES_ASYNC( + context, image_height > 0 && image_width > 0, + errors::InvalidArgument("image dimensions must be positive"), done); + int num_boxes = 0; + OP_REQUIRES_OK_ASYNC( + context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done); + + OP_REQUIRES_ASYNC(context, crop_size.dims() == 1, + errors::InvalidArgument("crop_size must be 1-D", + crop_size.shape().DebugString()), + done); + OP_REQUIRES_ASYNC( + context, crop_size.dim_size(0) == 2, + errors::InvalidArgument("crop_size must have two elements", + crop_size.shape().DebugString()), + done); + + // Copy and validate crop sizes. auto crop_size_vec = crop_size.vec(); const int crop_height = internal::SubtleMustCopy(crop_size_vec(0)); const int crop_width = internal::SubtleMustCopy(crop_size_vec(1)); - OP_REQUIRES(context, crop_height > 0 && crop_width > 0, - errors::InvalidArgument("crop dimensions must be positive")); + OP_REQUIRES_ASYNC( + context, crop_height > 0 && crop_width > 0, + errors::InvalidArgument("crop dimensions must be positive"), done); // Allocate output tensor. Tensor* output = nullptr; - OP_REQUIRES_OK( + OP_REQUIRES_OK_ASYNC( context, context->allocate_output( 0, TensorShape({num_boxes, crop_height, crop_width, depth}), - &output)); - - typename TTypes::ConstTensor image_data = image.tensor(); - typename TTypes::ConstTensor boxes_data = - boxes.tensor(); - typename TTypes::ConstTensor box_ind_data = - box_ind.tensor(); - typename TTypes::Tensor crops_data = output->tensor(); - - CheckValidBoxInd(context, box_ind_data, batch); - - bool status = functor::CropAndResize()( - context, image_data, boxes_data, box_ind_data, extrapolation_value_, - crops_data); - if (!status) { - context->SetStatus( - errors::Internal("Failed launch CropAndResizeKernel.")); - } + &output), + done); + + auto compute_callback = [this, context, output]() { + const Tensor& image = context->input(0); + const Tensor& boxes = context->input(1); + const Tensor& box_index = context->input(2); + const bool status = functor::CropAndResize()( + context, image.tensor(), boxes.tensor(), + box_index.tensor(), extrapolation_value_, + output->tensor()); + if (!status) { + context->SetStatus( + errors::Internal("Failed launch CropAndResizeKernel.")); + } + }; + + RunIfBoxIndexIsValid(context, box_index.tensor(), + batch_size, std::move(compute_callback), + std::move(done)); } private: @@ -156,10 +203,10 @@ struct CropAndResize { bool operator()(const OpKernelContext* context, typename TTypes::ConstTensor image, typename TTypes::ConstTensor boxes, - typename TTypes::ConstTensor box_ind, + typename TTypes::ConstTensor box_index, float extrapolation_value, typename TTypes::Tensor crops) { - const int batch = image.dimension(0); + const int batch_size = image.dimension(0); const int image_height = image.dimension(1); const int image_width = image.dimension(2); @@ -176,8 +223,8 @@ struct CropAndResize { const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); - const int32 b_in = box_ind(b); - if (b_in < 0 || b_in >= batch) { + const int32 b_in = box_index(b); + if (!FastBoundsCheck(b_in, batch_size)) { continue; } @@ -255,89 +302,94 @@ struct CropAndResize { return true; } }; + } // namespace functor template -class CropAndResizeGradImageOp : public OpKernel { +class CropAndResizeGradImageOp : public AsyncOpKernel { public: explicit CropAndResizeGradImageOp(OpKernelConstruction* context) - : OpKernel(context) { + : AsyncOpKernel(context) { string method; OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES(context, method == "bilinear", errors::InvalidArgument("method must be 'bilinear'", method)); } - void Compute(OpKernelContext* context) override { + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. const Tensor& grads = context->input(0); - - OP_REQUIRES(context, grads.dims() == 4, - errors::InvalidArgument("grads image must be 4-D", - grads.shape().DebugString())); - const int crop_height = grads.dim_size(1); - const int crop_width = grads.dim_size(2); - OP_REQUIRES(context, crop_height > 0 && crop_width > 0, - errors::InvalidArgument("grads dimensions must be positive")); - // The shape of 'boxes' is [num_boxes, 4]. const Tensor& boxes = context->input(1); - - // The shape of 'box_ind' is [num_boxes]. - const Tensor& box_ind = context->input(2); - - int num_boxes = 0; - ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); - - OP_REQUIRES( - context, grads.dim_size(0) == num_boxes, - errors::InvalidArgument("boxes and grads have incompatible shape")); - + // The shape of 'box_index' is [num_boxes]. + const Tensor& box_index = context->input(2); // The shape of 'image_size' is [4]. const Tensor& image_size = context->input(3); - OP_REQUIRES(context, image_size.dims() == 1, - errors::InvalidArgument("image_size must be 1-D", - image_size.shape().DebugString())); - OP_REQUIRES(context, image_size.dim_size(0) == 4, - errors::InvalidArgument("image_size must have 4 elements", - image_size.shape().DebugString())); + // Validate input shapes. + OP_REQUIRES_ASYNC(context, grads.dims() == 4, + errors::InvalidArgument("grads image must be 4-D", + grads.shape().DebugString()), + done); + const int crop_height = grads.dim_size(1); + const int crop_width = grads.dim_size(2); + OP_REQUIRES_ASYNC( + context, crop_height > 0 && crop_width > 0, + errors::InvalidArgument("grads dimensions must be positive"), done); + int num_boxes = 0; + OP_REQUIRES_OK_ASYNC( + context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done); + OP_REQUIRES_ASYNC( + context, grads.dim_size(0) == num_boxes, + errors::InvalidArgument("boxes and grads have incompatible shape"), + done); + + OP_REQUIRES_ASYNC(context, image_size.dims() == 1, + errors::InvalidArgument("image_size must be 1-D", + image_size.shape().DebugString()), + done); + OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4, + errors::InvalidArgument("image_size must have 4 elements", + image_size.shape().DebugString()), + done); auto image_size_vec = image_size.vec(); - const int batch = internal::SubtleMustCopy(image_size_vec(0)); + const int batch_size = internal::SubtleMustCopy(image_size_vec(0)); const int image_height = internal::SubtleMustCopy(image_size_vec(1)); const int image_width = internal::SubtleMustCopy(image_size_vec(2)); const int depth = internal::SubtleMustCopy(image_size_vec(3)); - - OP_REQUIRES(context, image_height > 0 && image_width > 0, - errors::InvalidArgument("image dimensions must be positive")); - OP_REQUIRES( + OP_REQUIRES_ASYNC( + context, image_height > 0 && image_width > 0, + errors::InvalidArgument("image dimensions must be positive"), done); + OP_REQUIRES_ASYNC( context, grads.dim_size(3) == depth, - errors::InvalidArgument("image_size and grads are incompatible")); + errors::InvalidArgument("image_size and grads are incompatible"), done); // Allocate output tensor. Tensor* output = nullptr; - OP_REQUIRES_OK( - context, context->allocate_output( - 0, TensorShape({batch, image_height, image_width, depth}), - &output)); - - typename TTypes::ConstTensor grads_data = - grads.tensor(); - typename TTypes::ConstTensor boxes_data = - boxes.tensor(); - typename TTypes::ConstTensor box_ind_data = - box_ind.tensor(); - typename TTypes::Tensor output_data = output->tensor(); - - CheckValidBoxInd(context, box_ind_data, batch); - - bool status = functor::CropAndResizeBackpropImage()( - context->eigen_device(), grads_data, boxes_data, box_ind_data, - output_data); - if (!status) { - context->SetStatus( - errors::Internal("Failed launch CropAndResizeBackpropImageKernel.")); - } + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_output( + 0, TensorShape({batch_size, image_height, image_width, depth}), + &output), + done); + + auto compute_callback = [context, output]() { + const Tensor& grads = context->input(0); + const Tensor& boxes = context->input(1); + const Tensor& box_index = context->input(2); + const bool status = functor::CropAndResizeBackpropImage()( + context->eigen_device(), grads.tensor(), + boxes.tensor(), box_index.tensor(), + output->tensor()); + if (!status) { + context->SetStatus(errors::Internal( + "Failed launch CropAndResizeBackpropImage kernel.")); + } + }; + + RunIfBoxIndexIsValid(context, box_index.tensor(), + batch_size, std::move(compute_callback), + std::move(done)); } }; @@ -348,9 +400,9 @@ struct CropAndResizeBackpropImage { bool operator()(const CPUDevice& d, typename TTypes::ConstTensor grads, typename TTypes::ConstTensor boxes, - typename TTypes::ConstTensor box_ind, + typename TTypes::ConstTensor box_index, typename TTypes::Tensor grads_image) { - const int batch = grads_image.dimension(0); + const int batch_size = grads_image.dimension(0); const int image_height = grads_image.dimension(1); const int image_width = grads_image.dimension(2); @@ -367,8 +419,8 @@ struct CropAndResizeBackpropImage { const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); - const int32 b_in = box_ind(b); - if (b_in < 0 || b_in >= batch) { + const int32 b_in = box_index(b); + if (!FastBoundsCheck(b_in, batch_size)) { continue; } @@ -419,83 +471,90 @@ struct CropAndResizeBackpropImage { return true; } }; + } // namespace functor template -class CropAndResizeGradBoxesOp : public OpKernel { +class CropAndResizeGradBoxesOp : public AsyncOpKernel { public: explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context) - : OpKernel(context) { + : AsyncOpKernel(context) { string method; OP_REQUIRES_OK(context, context->GetAttr("method", &method)); OP_REQUIRES(context, method == "bilinear", errors::InvalidArgument("method must be 'bilinear'", method)); } - void Compute(OpKernelContext* context) override { + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { // The shape of 'grads' is [num_boxes, crop_height, crop_width, depth]. const Tensor& grads = context->input(0); + // The shape of 'boxes' is [num_boxes, 4]. + const Tensor& boxes = context->input(2); + // The shape of 'box_index' is [num_boxes]. + const Tensor& box_index = context->input(3); + // The shape of 'image' is [batch_size, image_height, image_width, depth]. + const Tensor& image = context->input(1); - OP_REQUIRES(context, grads.dims() == 4, - errors::InvalidArgument("grads image must be 4-D", - grads.shape().DebugString())); - + // Validate input shapes. + OP_REQUIRES_ASYNC(context, grads.dims() == 4, + errors::InvalidArgument("grads image must be 4-D", + grads.shape().DebugString()), + done); const int crop_height = grads.dim_size(1); const int crop_width = grads.dim_size(2); const int depth = grads.dim_size(3); - OP_REQUIRES(context, crop_height > 0 && crop_width > 0, - errors::InvalidArgument("grads dimensions must be positive")); - - // The shape of 'image' is [batch, image_height, image_width, depth]. - const Tensor& image = context->input(1); - OP_REQUIRES(context, image.dims() == 4, - errors::InvalidArgument("input image must be 4-D", - image.shape().DebugString())); - - const int batch = image.dim_size(0); + OP_REQUIRES_ASYNC( + context, crop_height > 0 && crop_width > 0, + errors::InvalidArgument("grads dimensions must be positive"), done); + + OP_REQUIRES_ASYNC(context, image.dims() == 4, + errors::InvalidArgument("input image must be 4-D", + image.shape().DebugString()), + done); + const int batch_size = image.dim_size(0); const int image_height = image.dim_size(1); const int image_width = image.dim_size(2); - OP_REQUIRES(context, image_height > 0 && image_width > 0, - errors::InvalidArgument("image dimensions must be positive")); - OP_REQUIRES(context, image.dim_size(3) == depth, - errors::InvalidArgument("image, grads depth differ")); - - // The shape of 'boxes' is [num_boxes, 4]. - const Tensor& boxes = context->input(2); - - // The shape of 'box_ind' is [num_boxes]. - const Tensor& box_ind = context->input(3); + OP_REQUIRES_ASYNC( + context, image_height > 0 && image_width > 0, + errors::InvalidArgument("image dimensions must be positive"), done); + OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth, + errors::InvalidArgument("image, grads depth differ"), + done); int num_boxes = 0; - ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes); + OP_REQUIRES_OK_ASYNC( + context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done); - OP_REQUIRES( + OP_REQUIRES_ASYNC( context, grads.dim_size(0) == num_boxes, - errors::InvalidArgument("boxes and grads have incompatible shape")); + errors::InvalidArgument("boxes and grads have incompatible shape"), + done); // Allocate output tensor. Tensor* output = nullptr; - OP_REQUIRES_OK(context, context->allocate_output( - 0, TensorShape({num_boxes, 4}), &output)); - - typename TTypes::ConstTensor grads_data = - grads.tensor(); - typename TTypes::ConstTensor image_data = image.tensor(); - typename TTypes::ConstTensor boxes_data = - boxes.tensor(); - typename TTypes::ConstTensor box_ind_data = - box_ind.tensor(); - typename TTypes::Tensor output_data = output->tensor(); - - CheckValidBoxInd(context, box_ind_data, batch); - - bool status = functor::CropAndResizeBackpropBoxes()( - context->eigen_device(), grads_data, image_data, boxes_data, - box_ind_data, output_data); - if (!status) { - context->SetStatus( - errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel.")); - } + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_output(0, TensorShape({num_boxes, 4}), &output), + done); + + auto compute_callback = [context, output]() { + const Tensor& grads = context->input(0); + const Tensor& image = context->input(1); + const Tensor& boxes = context->input(2); + const Tensor& box_index = context->input(3); + const bool status = functor::CropAndResizeBackpropBoxes()( + context->eigen_device(), grads.tensor(), + image.tensor(), boxes.tensor(), + box_index.tensor(), output->tensor()); + if (!status) { + context->SetStatus(errors::Internal( + "Failed launch CropAndResizeBackpropBoxes kernel.")); + } + }; + + RunIfBoxIndexIsValid(context, box_index.tensor(), + batch_size, std::move(compute_callback), + std::move(done)); } }; @@ -507,9 +566,9 @@ struct CropAndResizeBackpropBoxes { typename TTypes::ConstTensor grads, typename TTypes::ConstTensor image, typename TTypes::ConstTensor boxes, - typename TTypes::ConstTensor box_ind, + typename TTypes::ConstTensor box_index, typename TTypes::Tensor grads_boxes) { - const int batch = image.dimension(0); + const int batch_size = image.dimension(0); const int image_height = image.dimension(1); const int image_width = image.dimension(2); @@ -526,8 +585,8 @@ struct CropAndResizeBackpropBoxes { const float y2 = boxes(b, 2); const float x2 = boxes(b, 3); - const int32 b_in = box_ind(b); - if (b_in < 0 || b_in >= batch) { + const int32 b_in = box_index(b); + if (!FastBoundsCheck(b_in, batch_size)) { continue; } @@ -609,30 +668,19 @@ struct CropAndResizeBackpropBoxes { return true; } }; -} // namespace functor -// Specialization of CheckValidBoxInd for a CPUDevice. -template <> -inline void CheckValidBoxInd( - OpKernelContext* context, typename TTypes::ConstTensor box_ind, - int batch) { - const int num_boxes = box_ind.dimension(0); - for (int b = 0; b < num_boxes; ++b) { - OP_REQUIRES(context, box_ind(b) >= 0 && box_ind(b) < batch, - errors::OutOfRange("box_ind has values outside [0, batch)")); - } -} +} // namespace functor -#define REGISTER_KERNEL(T) \ - REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .HostMemory("crop_size"), \ - CropAndResizeOp); \ - \ - REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .HostMemory("crop_size"), \ + CropAndResizeOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ CropAndResizeGradBoxesOp); TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); @@ -654,50 +702,93 @@ TF_CALL_double(REGISTER_KERNEL); #if GOOGLE_CUDA -// Forward declaration of the CheckValidBoxIndHelper specialization for GPU. +// Forward declaration of the CheckValidBoxIndexHelper specialization for GPU. namespace functor { template <> -void CheckValidBoxIndHelper::operator()( - const GPUDevice& d, typename TTypes::ConstTensor box_ind, - int batch, typename TTypes::Tensor isvalid); -extern template struct CheckValidBoxIndHelper; +void CheckValidBoxIndexHelper::operator()( + const GPUDevice& d, typename TTypes::ConstTensor box_index, + int batch_size, typename TTypes::Tensor isvalid); +extern template struct CheckValidBoxIndexHelper; } // namespace functor -// Specialization of CheckValidBoxInd for a GPUDevice. +namespace { + +// Specialization of CheckValidBoxIndex for a GPUDevice. template <> -inline void CheckValidBoxInd( - OpKernelContext* context, typename TTypes::ConstTensor box_ind, - int batch) { - const int num_boxes = box_ind.dimension(0); +inline void RunIfBoxIndexIsValid( + OpKernelContext* context, typename TTypes::ConstTensor box_index, + int batch_size, const Callback& compute, const Callback& done) { + const int num_boxes = box_index.dimension(0); if (num_boxes == 0) { + compute(); + done(); return; } - Tensor isvalid_tensor; - OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum::value, - TensorShape({}), &isvalid_tensor)); - typename TTypes::Tensor isvalid = isvalid_tensor.tensor(); + Tensor isvalid_dev_tensor; + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_temp(DataTypeToEnum::value, TensorShape({}), + &isvalid_dev_tensor), + done); + typename TTypes::Tensor isvalid_dev = + isvalid_dev_tensor.tensor(); - functor::CheckValidBoxIndHelper()( - context->eigen_device(), box_ind, batch, isvalid); + // Run the actual box check on the device. + functor::CheckValidBoxIndexHelper()( + context->eigen_device(), box_index, batch_size, isvalid_dev); + // Copy the result back to the host. auto* stream = context->op_device_context()->stream(); - OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - - bool isvalid_host = false; - perftools::gputools::DeviceMemoryBase isvalid_gpu(isvalid.data(), - sizeof(bool)); - stream->ThenMemcpy(&isvalid_host, isvalid_gpu, sizeof(bool)); - stream->BlockHostUntilDone(); - - OP_REQUIRES(context, stream->ok(), - errors::Internal("cudaMemcpy from device to host failed")); - - OP_REQUIRES(context, isvalid_host, - errors::OutOfRange("box_ind has values outside [0, batch)")); + OP_REQUIRES_ASYNC(context, stream, + errors::Internal("No GPU stream available."), done); + Tensor isvalid_host_tensor; + // Use pinned host memory on the host to avoid unnecessary + // synchronization. + AllocatorAttributes alloc_attr; + alloc_attr.set_on_host(true); + alloc_attr.set_gpu_compatible(true); + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_temp(DataTypeToEnum::value, TensorShape({}), + &isvalid_host_tensor, alloc_attr), + done); + perftools::gputools::DeviceMemoryBase wrapped(isvalid_dev.data(), + sizeof(bool)); + const bool status = + stream + ->ThenMemcpy( + isvalid_host_tensor.scalar().data() /* destination */, + wrapped /* source */, sizeof(bool)) + .ok(); + OP_REQUIRES_ASYNC( + context, status, + errors::Internal("Failed to launch copy of isvalid from device to host."), + done); + + // We capture both temporary tensors to prevent them from being deallocated + // when ComputeAsync returns and before the closure runs. + TensorReference isvalid_dev_ref(isvalid_dev_tensor); + auto wrapped_callback = [context, isvalid_host_tensor, isvalid_dev_ref, + compute, done]() { + auto stream = context->op_device_context()->stream(); + ScopedActivateExecutorContext scoped_activation{stream->parent()}; + const bool isvalid = isvalid_host_tensor.scalar()(); + isvalid_dev_ref.Unref(); + OP_REQUIRES_ASYNC( + context, isvalid, + errors::OutOfRange("box_index has values outside [0, batch_size)"), + done); + compute(); + done(); + }; + + context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( + stream, wrapped_callback); } +} // namespace + #define REGISTER_KERNEL(T) \ REGISTER_KERNEL_BUILDER(Name("CropAndResize") \ .Device(DEVICE_GPU) \ diff --git a/tensorflow/core/kernels/crop_and_resize_op.h b/tensorflow/core/kernels/crop_and_resize_op.h index 84d7a5e03b8d34bccf921087a2b412aa5f269e72..b6b1dbd7b0c3d44729f1a8d88f1d562062f61410 100644 --- a/tensorflow/core/kernels/crop_and_resize_op.h +++ b/tensorflow/core/kernels/crop_and_resize_op.h @@ -55,12 +55,12 @@ struct CropAndResizeBackpropBoxes { }; template -struct CheckValidBoxIndHelper { - // Checks if all values in box_ind are in [0, batch). +struct CheckValidBoxIndexHelper { + // Checks if all values in box_index are in [0, batch). void operator()(const Device& d, - typename TTypes::ConstTensor box_ind, int batch, + typename TTypes::ConstTensor box_index, int batch, typename TTypes::Tensor isvalid) { - isvalid.device(d) = ((box_ind >= 0) && (box_ind < batch)).all(); + isvalid.device(d) = ((box_index >= 0) && (box_index < batch)).all(); } }; diff --git a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc index 1726e4a816895cf5b4ceaf0357ce3ea30e764f34..d12787d5244d12f27dd03a32fecb6f2713af2c29 100644 --- a/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc +++ b/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc @@ -442,7 +442,7 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); #undef DEFINE_GPU_SPECS -template struct CheckValidBoxIndHelper; +template struct CheckValidBoxIndexHelper; } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/crop_and_resize_op_test.cc b/tensorflow/core/kernels/crop_and_resize_op_test.cc index 1bf28d4d0032a1277ef3832ebefbe3c2b02f14c7..22c659b587be4b981948fba19e091b6ecbf34481 100644 --- a/tensorflow/core/kernels/crop_and_resize_op_test.cc +++ b/tensorflow/core/kernels/crop_and_resize_op_test.cc @@ -251,7 +251,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) { Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE( - StringPiece(s.ToString()).contains("box_ind has incompatible shape")) + StringPiece(s.ToString()).contains("box_index has incompatible shape")) << s; } @@ -264,7 +264,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) { Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE(StringPiece(s.ToString()) - .contains("box_ind has values outside [0, batch)")) + .contains("box_index has values outside [0, batch_size)")) << s; } diff --git a/tensorflow/core/kernels/ctc_loss_op.cc b/tensorflow/core/kernels/ctc_loss_op.cc index a1f60019141dc7b55046c0f9f84089b277cb69fd..fb03adb7a5336919c85c4685f4cc7e7a8180892d 100644 --- a/tensorflow/core/kernels/ctc_loss_op.cc +++ b/tensorflow/core/kernels/ctc_loss_op.cc @@ -91,7 +91,14 @@ class CTCLossOp : public OpKernel { OP_REQUIRES(ctx, batch_size != 0, errors::InvalidArgument("batch_size must not be 0")); - TensorShape labels_shape({batch_size, max_time}); + // Figure out the maximum label length to use as sparse tensor dimension. + auto labels_indices_t = labels_indices->matrix(); + int64 max_label_len = 0; + for (int i = 0; i < labels_indices->dim_size(0); i++) { + max_label_len = std::max(max_label_len, labels_indices_t(i, 1) + 1); + } + + TensorShape labels_shape({batch_size, max_label_len}); std::vector order{0, 1}; sparse::SparseTensor labels_sp(*labels_indices, *labels_values, labels_shape, order); diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc index 3a8ccfe6b74b1c8658f33fa61723162aba05b530..5c6b5eec829427b91e2d53b795083ee3f83fd401 100644 --- a/tensorflow/core/kernels/cuda_solvers.cc +++ b/tensorflow/core/kernels/cuda_solvers.cc @@ -30,10 +30,13 @@ #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/cuda.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" +using ::perftools::gputools::cuda::ScopedActivateExecutorContext; + namespace tensorflow { namespace { @@ -148,7 +151,12 @@ Status CudaSolver::CopyLapackInfoToHostAsync( // This callback checks that all batch items in all calls were processed // successfully and passes status to the info_checker_callback accordingly. auto wrapped_info_checker_callback = - [info_checker_callback](std::vector host_lapack_infos) { + [](OpKernelContext* context, + std::function&)> + info_checker_callback, + std::vector host_lapack_infos) { + auto stream = context->op_device_context()->stream(); + ScopedActivateExecutorContext scoped_activation{stream->parent()}; Status status; for (const auto& host_lapack_info : host_lapack_infos) { for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) { @@ -166,8 +174,10 @@ Status CudaSolver::CopyLapackInfoToHostAsync( } info_checker_callback(status, host_lapack_infos); }; + auto cb = - std::bind(wrapped_info_checker_callback, std::move(host_lapack_infos)); + std::bind(wrapped_info_checker_callback, context_, + std::move(info_checker_callback), std::move(host_lapack_infos)); auto stream = context_->op_device_context()->stream(); context_->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( stream, std::move(cb)); diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h index ac6119d8a2168c2b43ecd1dc246b57566a485672..0fd6450f98248fa6d5af0b880e18956b78ae70fc 100644 --- a/tensorflow/core/kernels/cuda_solvers.h +++ b/tensorflow/core/kernels/cuda_solvers.h @@ -313,6 +313,9 @@ class ScratchSpace { int64 size() const { return scratch_tensor_.NumElements(); } const string& debug_info() const { return debug_info_; } + Tensor& tensor() { return scratch_tensor_; } + const Tensor& tensor() const { return scratch_tensor_; } + // Returns true if this ScratchSpace is in host memory. bool on_host() const { return on_host_; } diff --git a/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc b/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc index b9e42b4d00d4bd76eee5a1e3d22b024efc06ba0b..af6c094d7ac6435cc0836230271786aea951e158 100644 --- a/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc +++ b/tensorflow/core/kernels/cuda_solvers_gpu.cu.cc @@ -51,55 +51,57 @@ namespace { // Hacks around missing support for complex arithmetic in nvcc. template -__host__ __device__ inline Scalar Multiply(Scalar x, Scalar y) { +__device__ inline Scalar Multiply(Scalar x, Scalar y) { return x * y; } template <> -__host__ __device__ inline cuComplex Multiply(cuComplex x, cuComplex y) { +__device__ inline cuComplex Multiply(cuComplex x, cuComplex y) { return cuCmulf(x, y); } template <> -__host__ __device__ inline cuDoubleComplex Multiply(cuDoubleComplex x, - cuDoubleComplex y) { +__device__ inline cuDoubleComplex Multiply(cuDoubleComplex x, + cuDoubleComplex y) { return cuCmul(x, y); } template -__host__ __device__ inline Scalar Negate(Scalar x) { +__device__ inline Scalar Negate(Scalar x) { return -x; } template <> -__host__ __device__ inline cuComplex Negate(cuComplex x) { +__device__ inline cuComplex Negate(cuComplex x) { return make_cuComplex(-cuCrealf(x), -cuCimagf(x)); } template <> -__host__ __device__ inline cuDoubleComplex Negate(cuDoubleComplex x) { +__device__ inline cuDoubleComplex Negate(cuDoubleComplex x) { return make_cuDoubleComplex(-cuCreal(x), -cuCimag(x)); } template -__host__ __device__ inline bool IsFinite(Scalar x) { - return isfinite(x); +__device__ inline bool IsFinite(Scalar x) { + return Eigen::numext::isfinite(x); } template <> -__host__ __device__ inline bool IsFinite(cuComplex x) { - return isfinite(cuCrealf(x)) && isfinite(cuCimagf(x)); +__device__ inline bool IsFinite(cuComplex x) { + return Eigen::numext::isfinite(cuCrealf(x)) && + Eigen::numext::isfinite(cuCimagf(x)); } template <> -__host__ __device__ inline bool IsFinite(cuDoubleComplex x) { - return isfinite(cuCreal(x)) && isfinite(cuCimag(x)); +__device__ inline bool IsFinite(cuDoubleComplex x) { + return Eigen::numext::isfinite(cuCreal(x)) && + Eigen::numext::isfinite(cuCimag(x)); } template struct Const { template - __host__ __device__ static inline Scalar make_const(const RealScalar x) { + __device__ static inline Scalar make_const(const RealScalar x) { return Scalar(x); } }; @@ -107,7 +109,7 @@ struct Const { template <> struct Const { template - __host__ __device__ static inline cuComplex make_const(const RealScalar x) { + __device__ static inline cuComplex make_const(const RealScalar x) { return make_cuComplex(x, 0.0f); } }; @@ -115,8 +117,7 @@ struct Const { template <> struct Const { template - __host__ __device__ static inline cuDoubleComplex make_const( - const RealScalar x) { + __device__ static inline cuDoubleComplex make_const(const RealScalar x) { return make_cuDoubleComplex(x, 0.0f); } }; diff --git a/tensorflow/core/kernels/cwise_op_sub.cc b/tensorflow/core/kernels/cwise_op_sub.cc index eb173c7040d435879abdb4d5cbae5f20a720199f..6adaecba04bfcf1b42a760d712eece493131ade2 100644 --- a/tensorflow/core/kernels/cwise_op_sub.cc +++ b/tensorflow/core/kernels/cwise_op_sub.cc @@ -18,7 +18,10 @@ limitations under the License. namespace tensorflow { REGISTER7(BinaryOp, CPU, "Sub", functor::sub, float, Eigen::half, double, int32, int64, complex64, complex128); -#if defined(__ANDROID_TYPES_SLIM__) +#if !defined(__ANDROID_TYPES_SLIM__) +// Sub op for int8, uint8, int16, uint16 +REGISTER4(BinaryOp, CPU, "Sub", functor::sub, int8, uint8, int16, uint16); +#else // We only register the first type when we have multi-argument calls in the // case where we're trying to reduce executable size, but it turns out that the // int32 version of this op is needed, so explicitly include it. diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc index 192a4f732ef884952e24c0aa3648041955233b43..693c6467ac592e3357e5b06a620a64b3829bc938 100644 --- a/tensorflow/core/kernels/cwise_ops_common.cc +++ b/tensorflow/core/kernels/cwise_ops_common.cc @@ -20,7 +20,9 @@ namespace tensorflow { BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in) : OpKernel(ctx) { +#ifndef INTEL_MKL OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out})); +#endif } void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) { diff --git a/tensorflow/core/kernels/dataset.cc b/tensorflow/core/kernels/dataset.cc index f99684b1ca32848caf2f255811be65dc61e72490..2bfbdc1cd932a45461d30e16bfc880e868bfb3ac 100644 --- a/tensorflow/core/kernels/dataset.cc +++ b/tensorflow/core/kernels/dataset.cc @@ -52,4 +52,6 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx, MakeDataset(ctx, input, another_input, output); } +const char IteratorBase::kIteratorExhausted[] = "ITERATOR_EXHAUSTED"; + } // namespace tensorflow diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h index 9bfc5c1e969d3d520a405ba41510ccf18b66d43d..aa97d340415f8e39170b226f501b8ee5e3ca8212 100644 --- a/tensorflow/core/kernels/dataset.h +++ b/tensorflow/core/kernels/dataset.h @@ -19,7 +19,11 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/util/tensor_bundle/naming.h" +#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" // Polymorphic datasets should support all primitive TensorFlow // types. Use this macro to expand `m(T)` once for each primitive type @@ -86,6 +90,10 @@ class IteratorContext { // range of outputs is typically represented by an `DatasetBase`, // defined below. class IteratorBase { + protected: + class IteratorBundleReader; + class IteratorBundleWriter; + public: virtual ~IteratorBase() {} @@ -115,6 +123,118 @@ class IteratorBase { // (and possibly partially defined) shapes of each tuple component // in the outputs of this iterator. virtual const std::vector& output_shapes() const = 0; + + // Saves the state of this iterator. + virtual Status SaveState(OpKernelContext* ctx, StringPiece path) { + BundleWriter bundle_writer(ctx->env(), path); + IteratorBundleWriter writer(&bundle_writer); + if (is_exhausted_) { + LOG(INFO) << "Iterator exhausted. Nothing to save."; + TF_RETURN_IF_ERROR( + writer.WriteScalar(kIteratorExhausted, kIteratorExhausted)); + } else { + TF_RETURN_IF_ERROR(SaveStateInternal(ctx, &writer)); + } + TF_RETURN_IF_ERROR(bundle_writer.Finish()); + return Status::OK(); + } + + // Restores the state of this iterator. + virtual Status RestoreState(OpKernelContext* ctx, StringPiece& path) { + if (!(ctx->env()->FileExists(MetaFilename(path)).ok())) { + return errors::NotFound( + "Failed to restore Iterator state. No file found at ", + MetaFilename(path)); + } + BundleReader bundle_reader(ctx->env(), path); + if (bundle_reader.Contains(kIteratorExhausted)) { + LOG(INFO) << "Iterator exhausted. Nothing to restore."; + is_exhausted_ = true; + return Status::OK(); + } else { + IteratorBundleReader reader(&bundle_reader); + return RestoreStateInternal(ctx, &reader); + } + } + + protected: + class IteratorBundleReader { + public: + IteratorBundleReader(BundleReader* bundle_reader) + : bundle_reader_(bundle_reader) {} + + // Reads a scalar value. + template + Status ReadScalar(T* val, const string& key) { + Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({})); + TF_RETURN_IF_ERROR(Lookup(StringPiece(key), &val_t)); + *val = val_t.scalar()(); + return Status::OK(); + } + + // Restores the state of a parent iterator recursively. + Status RestoreParentState(OpKernelContext* ctx, + const std::unique_ptr& parent) { + return parent->RestoreStateInternal(ctx, this); + } + + private: + Status Lookup(StringPiece key, Tensor* val) { + return bundle_reader_->Lookup(key, val); + } + + BundleReader* bundle_reader_; + }; + + class IteratorBundleWriter { + public: + IteratorBundleWriter(BundleWriter* bundle_writer) + : bundle_writer_(bundle_writer) {} + + // Writes a scalar value. + template + Status WriteScalar(const T val, const string& key) { + Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({})); + val_t.scalar()() = val; + TF_RETURN_IF_ERROR(Add(StringPiece(key), val_t)); + return Status::OK(); + } + + // Saves the state of a parent iterator recursively. + Status SaveParentState(OpKernelContext* ctx, + const std::unique_ptr& parent) { + return parent->SaveStateInternal(ctx, this); + } + + private: + Status Add(StringPiece key, const Tensor& val) { + return bundle_writer_->Add(key, val); + } + + BundleWriter* bundle_writer_; + }; + + // Saves the state of this iterator. + // Note: Contents written to `writer` may not get flushed to disk + // until the call to `SaveState` in the leaf iterator is finished. + // Must be overridden by sub-classes. + virtual Status SaveStateInternal(OpKernelContext* ctx, + IteratorBundleWriter* writer) { + return errors::Unimplemented("SaveState not implemented."); + } + + // Restores the state of this iterator. + // + // Must be overridden by sub-classes. + virtual Status RestoreStateInternal(OpKernelContext* ctx, + IteratorBundleReader* reader) { + return errors::Unimplemented("RestoreState not implemented"); + } + + bool is_exhausted_ = false; // Whether the iterator has been exhausted. + + private: + static const char kIteratorExhausted[]; }; // Represents a (potentially infinite) range of outputs, where each @@ -182,6 +302,10 @@ class DatasetIterator : public IteratorBase { Status GetNext(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence) final { port::Tracing::TraceMe activity(params_.prefix); + if (is_exhausted_) { + *end_of_sequence = true; + return Status::OK(); + } return GetNextInternal(ctx, out_tensors, end_of_sequence); } @@ -190,6 +314,11 @@ class DatasetIterator : public IteratorBase { std::vector* out_tensors, bool* end_of_sequence) = 0; + protected: + string full_name(const string& name) { + return strings::StrCat(prefix(), ":", name); + } + private: Params params_; }; diff --git a/tensorflow/core/kernels/dataset_utils.cc b/tensorflow/core/kernels/dataset_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..f320b3b09c6fe58883114cd09d16a4571613ac66 --- /dev/null +++ b/tensorflow/core/kernels/dataset_utils.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/dataset_utils.h" + +namespace tensorflow { + +namespace dataset { + +Status MakeIteratorFromInputElement( + IteratorContext* ctx, const std::vector& input_element, + int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, + std::unique_ptr* out_iterator) { + FunctionLibraryRuntime::Options opts; + opts.runner = ctx->runner(); + // Choose a step ID that is guaranteed not to clash with any + // Session-generated step ID. DirectSession only generates + // non-negative step IDs (contiguous, starting from 0), and + // MasterSession generates 56-bit random step IDs whose MSB + // is always 0, so a negative random step ID should suffice. + opts.step_id = CapturedFunction::generate_step_id(); + ScopedStepContainer step_container( + opts.step_id, [captured_func, ctx](const string& name) { + captured_func->resource_manager()->Cleanup(name).IgnoreError(); + }); + opts.step_container = &step_container; + std::vector return_values; + TF_RETURN_IF_ERROR(captured_func->Run(opts, input_element, &return_values)); + + if (!(return_values.size() == 1 && return_values[0].dtype() == DT_RESOURCE && + TensorShapeUtils::IsScalar(return_values[0].shape()))) { + return errors::InvalidArgument( + "Function must return a single scalar of dtype DT_RESOURCE."); + } + + // Retrieve the dataset that was created in `f`. + DatasetBase* returned_dataset; + const ResourceHandle& dataset_resource = + return_values[0].scalar()(); + + // NOTE(mrry): We cannot use the core `LookupResource()` or + // `DeleteResource()` functions, because we have an + // `IteratorContext*` and not an `OpKernelContext*`, so we + // replicate the necessary functionality here. + auto type_index = MakeTypeIndex(); + if (type_index.hash_code() != dataset_resource.hash_code()) { + return errors::InvalidArgument("Function must return a Dataset resource."); + } + TF_RETURN_IF_ERROR(captured_func->resource_manager()->Lookup( + dataset_resource.container(), dataset_resource.name(), + &returned_dataset)); + core::ScopedUnref unref_dataset(returned_dataset); + + // Create an iterator for the dataset that was returned by + // `f`. This transfers ownership of the dataset to the + // iterator, so we can delete it from the resource manager. + *out_iterator = returned_dataset->MakeIterator( + strings::StrCat(prefix, "[", thread_index, "]")); + TF_RETURN_IF_ERROR(captured_func->resource_manager()->Delete( + dataset_resource.container(), dataset_resource.name())); + return Status::OK(); +} + +} // namespace dataset + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/dataset_utils.h b/tensorflow/core/kernels/dataset_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..eea2b8802b813808f752659a469c3818a52162d2 --- /dev/null +++ b/tensorflow/core/kernels/dataset_utils.h @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_UTILS_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/captured_function.h" +#include "tensorflow/core/kernels/dataset.h" + +namespace tensorflow { + +namespace dataset { + +Status MakeIteratorFromInputElement( + IteratorContext* ctx, const std::vector& input_element, + int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, + std::unique_ptr* out_iterator); + +} // namespace dataset + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DATASET_UTILS_H_ diff --git a/tensorflow/core/kernels/debug_ops_test.cc b/tensorflow/core/kernels/debug_ops_test.cc index 89bcbc9c373210a416201b55f4623bee1d1c94d3..37c94865942ff7df15f0a660130310c60ebc3d8e 100644 --- a/tensorflow/core/kernels/debug_ops_test.cc +++ b/tensorflow/core/kernels/debug_ops_test.cc @@ -573,7 +573,8 @@ TEST_F(DebugNumericSummaryOpTest, UInt8Success) { TEST_F(DebugNumericSummaryOpTest, BoolSuccess) { TF_ASSERT_OK(Init(DT_BOOL)); - AddInputFromArray(TensorShape({2, 3}), {0, 0, 1, 1, 1, 0}); + AddInputFromArray(TensorShape({2, 3}), + {false, false, true, true, true, false}); TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_DOUBLE, TensorShape({16})); diff --git a/tensorflow/core/kernels/depthwise_conv_grad_op.cc b/tensorflow/core/kernels/depthwise_conv_grad_op.cc index 00d7f5640829c38eb1b73d337835436eba2387b6..9804d7d38e1a811ea30136697a11d085e3533552 100644 --- a/tensorflow/core/kernels/depthwise_conv_grad_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_grad_op.cc @@ -361,19 +361,15 @@ static void ComputeBackpropInput(const DepthwiseArgs& args, } } -// Kernels to compute the input backprop for depthwise convolution. -template -struct LaunchDepthwiseConvBackpropInputOp; - // Computes the depthwise conv2d backprop input of 'out_backprop' by // 'depthwise_filter' and stores the result in 'in_backprop'. template struct LaunchDepthwiseConvBackpropInputOp { typedef typename Eigen::internal::packet_traits::type Packet; - static void launch(OpKernelContext* ctx, const DepthwiseArgs& args, - const T* out_backprop, const T* depthwise_filter, - T* in_backprop, TensorFormat data_format) { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* depthwise_filter, + T* in_backprop, TensorFormat data_format) { OP_REQUIRES( ctx, data_format == FORMAT_NHWC, errors::Unimplemented( @@ -514,27 +510,8 @@ static void DepthwiseConvBackpropInputReference(const DepthwiseArgs& args, #if GOOGLE_CUDA -template -struct DepthwiseConv2dBackpropInputGPULaunch { - static void Run(const GPUDevice& d, const DepthwiseArgs args, - const T* out_backprop, const T* filter, T* in_backprop, - TensorFormat data_format); -}; - -template -struct LaunchDepthwiseConvBackpropInputOp { - static void launch(OpKernelContext* ctx, const DepthwiseArgs args, - const T* out_backprop, const T* filter, T* in_backprop, - TensorFormat data_format) { - const GPUDevice& d = ctx->eigen_device(); - DepthwiseConv2dBackpropInputGPULaunch().Run( - d, args, out_backprop, filter, in_backprop, data_format); - auto stream = ctx->op_device_context()->stream(); - OP_REQUIRES(ctx, stream->ok(), errors::Internal("Launch of gpu kernel for " - "DepthwiseConv2dBackpropInp" - "utGPULaunch failed")); - } -}; +extern template struct LaunchDepthwiseConvBackpropInputOp; +extern template struct LaunchDepthwiseConvBackpropInputOp; #endif // GOOGLE_CUDA @@ -598,7 +575,7 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel { if (input_shape.num_elements() == 0) { return; } - LaunchDepthwiseConvBackpropInputOp::launch( + LaunchDepthwiseConvBackpropInputOp()( context, args, out_backprop_ptr, filter_ptr, in_backprop_ptr, data_format_); } @@ -744,9 +721,9 @@ template struct LaunchDepthwiseConvBackpropFilterOp { typedef typename Eigen::internal::packet_traits::type Packet; - static void launch(OpKernelContext* ctx, const DepthwiseArgs& args, - const T* out_backprop, const T* input, T* filter_backprop, - TensorFormat data_format) { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* input, T* filter_backprop, + TensorFormat data_format) { OP_REQUIRES( ctx, data_format == FORMAT_NHWC, errors::Unimplemented( @@ -907,35 +884,8 @@ static void DepthwiseConvBackpropFilterReference(const DepthwiseArgs& args, #if GOOGLE_CUDA -template -struct DepthwiseConv2dBackpropFilterGPULaunch { - static void Run(const GPUDevice& d, const DepthwiseArgs args, - const T* out_backprop, const T* input, T* filter_backprop, - TensorFormat data_format); -}; - -template -struct LaunchDepthwiseConvBackpropFilterOp { - static void launch(OpKernelContext* ctx, const DepthwiseArgs args, - const T* out_backprop, const T* input, T* filter_backprop, - TensorFormat data_format) { - const GPUDevice& d = ctx->eigen_device(); - auto stream = ctx->op_device_context()->stream(); - - // Initialize the results to 0. - int num_filter_backprop = - args.filter_rows * args.filter_cols * args.out_depth; - perftools::gputools::DeviceMemoryBase filter_bp_ptr(filter_backprop, - num_filter_backprop); - stream->ThenMemset32(&filter_bp_ptr, 0, num_filter_backprop * sizeof(T)); - - DepthwiseConv2dBackpropFilterGPULaunch().Run( - d, args, out_backprop, input, filter_backprop, data_format); - OP_REQUIRES(ctx, stream->ok(), errors::Internal("Launch of gpu kernel for " - "DepthwiseConv2dBackpropFil" - "terGPULaunch failed")); - } -}; +extern template struct LaunchDepthwiseConvBackpropFilterOp; +extern template struct LaunchDepthwiseConvBackpropFilterOp; #endif // GOOGLE_CUDA @@ -1001,7 +951,7 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel { if (filter_shape.num_elements() == 0) { return; } - LaunchDepthwiseConvBackpropFilterOp::launch( + LaunchDepthwiseConvBackpropFilterOp()( context, args, out_backprop_ptr, input_ptr, filter_backprop_ptr, data_format_); } diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index ccd33c08612b272afae8a8206f346ac5f9fd06b1..bbeeaf789544a45ced75148064be0b39c7457053 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -54,9 +54,6 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -template -struct LaunchDepthwiseConvOp; - // Computes the vectorized product of 'input_buffer' and 'filter' and stores // result in 'output' at location specified by 'out_r' and 'out_c'. // @@ -156,9 +153,9 @@ template struct LaunchDepthwiseConvOp { typedef typename Eigen::internal::packet_traits::type Packet; - static void launch(OpKernelContext* ctx, const DepthwiseArgs& args, - const T* input, const T* depthwise_filter, T* output, - TensorFormat data_format) { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* input, const T* depthwise_filter, T* output, + TensorFormat data_format) { OP_REQUIRES( ctx, data_format == FORMAT_NHWC, errors::Unimplemented( @@ -248,27 +245,9 @@ extern template class LaunchConv2DOp; #if GOOGLE_CUDA -template -struct DepthwiseConv2dGPULaunch { - static void Run(const GPUDevice& d, const DepthwiseArgs args, const T* input, - const T* filter, T* output, TensorFormat data_format); -}; - -template -struct LaunchDepthwiseConvOp { - static void launch(OpKernelContext* ctx, const DepthwiseArgs args, - const T* input, const T* filter, T* output, - TensorFormat data_format) { - const GPUDevice& d = ctx->eigen_device(); - DepthwiseConv2dGPULaunch().Run(d, args, input, filter, output, - data_format); - auto stream = ctx->op_device_context()->stream(); - OP_REQUIRES( - ctx, stream->ok(), - errors::Internal( - "Launch of gpu kernel for DepthwiseConv2dGPULaunch failed")); - } -}; +// Extern template instantiated in depthwise_conv_op_gpu.cc. +extern template struct LaunchDepthwiseConvOp; +extern template struct LaunchDepthwiseConvOp; // Extern template instantiated in conv_ops.cc. extern template class LaunchConv2DOp; @@ -368,10 +347,11 @@ class DepthwiseConv2dNativeOp : public BinaryOp { TensorShape out_shape = ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth); OP_REQUIRES( - context, out_shape.num_elements() <= 2147483647, - errors::InvalidArgument("total number of outputs should be within the " - "range of int which is used in the GPU kernel", - in_depth, " vs ", filter.dim_size(2))); + context, + (!std::is_same::value || + FastBoundsCheck(out_shape.num_elements(), + std::numeric_limits::max())), + errors::InvalidArgument("Output elements too large for GPU kernel")); Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); @@ -392,9 +372,8 @@ class DepthwiseConv2dNativeOp : public BinaryOp { // If in_depth==1, this operation is just a standard convolution, so // invoke that op. if (std::is_same::value && in_depth == 1) { - launcher_.launch(context, use_cudnn_, cudnn_use_autotune_, input, filter, - stride_, stride_, BrainPadding2EigenPadding(padding_), - output, data_format_); + launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter, + stride_, stride_, padding_, output, data_format_); return; } @@ -416,8 +395,8 @@ class DepthwiseConv2dNativeOp : public BinaryOp { auto input_ptr = input.template flat().data(); auto filter_ptr = filter.template flat().data(); auto output_ptr = output->template flat().data(); - LaunchDepthwiseConvOp::launch( - context, args, input_ptr, filter_ptr, output_ptr, data_format_); + LaunchDepthwiseConvOp()(context, args, input_ptr, filter_ptr, + output_ptr, data_format_); } private: diff --git a/tensorflow/core/kernels/depthwise_conv_op.h b/tensorflow/core/kernels/depthwise_conv_op.h index 1960b02bbea4834c8e0e70e155ebc5f92538547d..aa5b5c76f6ac13d7d1dbc5bfb62710cde538621a 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.h +++ b/tensorflow/core/kernels/depthwise_conv_op.h @@ -56,6 +56,53 @@ struct DepthwiseArgs { out_depth(0) {} }; +// Forward declaration. +class OpKernelContext; + +template +struct LaunchDepthwiseConvOp { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* input, const T* filter, T* output, + TensorFormat data_format); +}; + +template +struct LaunchDepthwiseConvBackpropInputOp { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* filter, T* in_backprop, + TensorFormat data_format); +}; + +template +struct LaunchDepthwiseConvBackpropFilterOp { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* input, T* filter_backprop, + TensorFormat data_format); +}; + +#if GOOGLE_CUDA +template +struct LaunchDepthwiseConvOp { + void operator()(OpKernelContext* ctx, const DepthwiseArgs args, + const T* input, const T* filter, T* output, + TensorFormat data_format); +}; + +template +struct LaunchDepthwiseConvBackpropInputOp { + void operator()(class OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* filter, T* in_backprop, + TensorFormat data_format); +}; + +template +struct LaunchDepthwiseConvBackpropFilterOp { + void operator()(class OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* input, T* filter_backprop, + TensorFormat data_format); +}; +#endif + } // namespace tensorflow namespace tensorflow { diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index f63a99a73088a3c5c339cf36e960b4771694013e..fcfcd188d2d41dacb766a71c2148092754528172 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -17,6 +17,7 @@ limitations under the License. #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/kernels/depthwise_conv_op.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/cuda_kernel_helper.h" @@ -689,21 +690,27 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args, // A simple launch pad to launch the Cuda kernel for depthwise convolution. template -struct DepthwiseConv2dGPULaunch { - static void Run(const GpuDevice& d, const DepthwiseArgs args, const T* input, - const T* filter, T* output, TensorFormat data_format) { - if (args.filter_rows == 3 && args.filter_cols == 3) { - LaunchDepthwiseConv2dGPU(d, args, input, filter, output, +void LaunchDepthwiseConvOp::operator()(OpKernelContext* ctx, + const DepthwiseArgs args, + const T* input, + const T* filter, T* output, + TensorFormat data_format) { + const GPUDevice& d = ctx->eigen_device(); + if (args.filter_rows == 3 && args.filter_cols == 3) { + LaunchDepthwiseConv2dGPU(d, args, input, filter, output, + data_format); + } else { + LaunchDepthwiseConv2dGPU(d, args, input, filter, output, data_format); - } else { - LaunchDepthwiseConv2dGPU(d, args, input, filter, output, - data_format); - } } -}; + auto stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream->ok(), + errors::Internal( + "Launch of gpu kernel for DepthwiseConv2dGPULaunch failed")); +} -template struct DepthwiseConv2dGPULaunch; -template struct DepthwiseConv2dGPULaunch; +template struct LaunchDepthwiseConvOp; +template struct LaunchDepthwiseConvOp; // A Cuda kernel to compute the depthwise convolution backprop w.r.t. input. template -struct DepthwiseConv2dBackpropInputGPULaunch { - static void Run(const GpuDevice& d, const DepthwiseArgs args, - const T* out_backprop, const T* filter, T* in_backprop, - TensorFormat data_format) { - if (args.filter_rows == 3 && args.filter_cols == 3) { - LaunchDepthwiseConv2dBackpropInputGPU( - d, args, out_backprop, filter, in_backprop, data_format); - } else { - LaunchDepthwiseConv2dBackpropInputGPU( - d, args, out_backprop, filter, in_backprop, data_format); - } +void LaunchDepthwiseConvBackpropInputOp::operator()( + OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, + const T* filter, T* in_backprop, TensorFormat data_format) { + const GPUDevice& d = ctx->eigen_device(); + if (args.filter_rows == 3 && args.filter_cols == 3) { + LaunchDepthwiseConv2dBackpropInputGPU( + d, args, out_backprop, filter, in_backprop, data_format); + } else { + LaunchDepthwiseConv2dBackpropInputGPU( + d, args, out_backprop, filter, in_backprop, data_format); } -}; + auto stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream->ok(), + errors::Internal("Launch of gpu kernel for " + "DepthwiseConv2dBackpropInp" + "utGPULaunch failed")); +} -template struct DepthwiseConv2dBackpropInputGPULaunch; -template struct DepthwiseConv2dBackpropInputGPULaunch; +template struct LaunchDepthwiseConvBackpropInputOp; +template struct LaunchDepthwiseConvBackpropInputOp; // A Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. template -struct DepthwiseConv2dBackpropFilterGPULaunch { - static void Run(const GpuDevice& d, const DepthwiseArgs args, - const T* out_backprop, const T* input, T* filter_backprop, - TensorFormat data_format) { - if (args.filter_rows == 3 && args.filter_cols == 3) { - LaunchDepthwiseConv2dBackpropFilterGPU( - d, args, out_backprop, input, filter_backprop, data_format); - } else { - LaunchDepthwiseConv2dBackpropFilterGPU( - d, args, out_backprop, input, filter_backprop, data_format); - } +void LaunchDepthwiseConvBackpropFilterOp::operator()( + OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, + const T* input, T* filter_backprop, TensorFormat data_format) { + const GPUDevice& d = ctx->eigen_device(); + auto stream = ctx->op_device_context()->stream(); + + // Initialize the results to 0. + int num_filter_backprop = + args.filter_rows * args.filter_cols * args.out_depth; + perftools::gputools::DeviceMemoryBase filter_bp_ptr(filter_backprop, + num_filter_backprop); + stream->ThenMemset32(&filter_bp_ptr, 0, num_filter_backprop * sizeof(T)); + + if (args.filter_rows == 3 && args.filter_cols == 3) { + LaunchDepthwiseConv2dBackpropFilterGPU( + d, args, out_backprop, input, filter_backprop, data_format); + } else { + LaunchDepthwiseConv2dBackpropFilterGPU( + d, args, out_backprop, input, filter_backprop, data_format); } -}; + OP_REQUIRES(ctx, stream->ok(), + errors::Internal("Launch of gpu kernel for " + "DepthwiseConv2dBackpropFil" + "terGPULaunch failed")); +} -template struct DepthwiseConv2dBackpropFilterGPULaunch; -template struct DepthwiseConv2dBackpropFilterGPULaunch; +template struct LaunchDepthwiseConvBackpropFilterOp; +template struct LaunchDepthwiseConvBackpropFilterOp; } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/extract_jpeg_shape_op.cc b/tensorflow/core/kernels/extract_jpeg_shape_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..60d798af56737c6abb322a971b31ae596ea96ec6 --- /dev/null +++ b/tensorflow/core/kernels/extract_jpeg_shape_op.cc @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/image_ops.cc + +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/jpeg/jpeg_mem.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// Extract the shape of a JPEG image. +template +class ExtractJpegShapeOp : public OpKernel { + public: + explicit ExtractJpegShapeOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Get input content. + const Tensor& contents = context->input(0); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents.shape()), + errors::InvalidArgument("contents must be scalar, got shape ", + contents.shape().DebugString())); + const StringPiece input = contents.scalar()(); + OP_REQUIRES(context, input.size() <= std::numeric_limits::max(), + errors::InvalidArgument("JPEG contents are too large for int: ", + input.size())); + + // Call GetImageInfo to get image shape. + int width, height, components; + OP_REQUIRES( + context, + jpeg::GetImageInfo(input.data(), input.size(), &width, &height, + &components), + errors::InvalidArgument("Invalid JPEG data, size ", input.size())); + // Allocate tensor and set shape size. + Tensor* image_shape = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({3}), &image_shape)); + auto image_shape_data = image_shape->tensor(); + image_shape_data(0) = height; + image_shape_data(1) = width; + image_shape_data(2) = components; + } +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("ExtractJpegShape") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("output_type"), \ + ExtractJpegShapeOp) + +TF_CALL_int32(REGISTER_KERNELS); +TF_CALL_int64(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/filter_dataset_op.cc b/tensorflow/core/kernels/filter_dataset_op.cc index 05d6f6cae1597e25a4302330119f01bdcaa160da..1e1a3ad75dfd243dde3092c9b7e56738c954b1e3 100644 --- a/tensorflow/core/kernels/filter_dataset_op.cc +++ b/tensorflow/core/kernels/filter_dataset_op.cc @@ -120,8 +120,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { Notification n; Status ret; std::vector result; - ret = dataset()->captured_func_->Run(opts, *out_tensors, &result, - prefix()); + ret = dataset()->captured_func_->Run(opts, *out_tensors, &result); if (!ret.ok()) { return ret; diff --git a/tensorflow/core/kernels/flat_map_dataset_op.cc b/tensorflow/core/kernels/flat_map_dataset_op.cc index a60a7a5ff68a1a6d150db324df76653d9341f187..a87e54bf3102328d5b2c70486b3dde66db75cfb0 100644 --- a/tensorflow/core/kernels/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/flat_map_dataset_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/kernels/captured_function.h" +#include "tensorflow/core/kernels/dataset_utils.h" namespace tensorflow { @@ -125,58 +126,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - FunctionLibraryRuntime::Options opts; - opts.runner = ctx->runner(); - opts.step_id = CapturedFunction::generate_step_id(); - ScopedStepContainer step_container( - opts.step_id, [this, ctx](const string& name) { - dataset() - ->captured_func_->resource_manager() - ->Cleanup(name) - .IgnoreError(); - }); - opts.step_container = &step_container; - std::vector return_values; - TF_RETURN_IF_ERROR(dataset()->captured_func_->Run( - opts, args, &return_values, prefix())); - - if (!(return_values.size() == 1 && - return_values[0].dtype() == DT_RESOURCE && - TensorShapeUtils::IsScalar(return_values[0].shape()))) { - return errors::InvalidArgument( - "`f` must return a single scalar of dtype DT_RESOURCE."); - } - - // Retrieve the dataset that was created in `f`. - DatasetBase* returned_dataset; - const ResourceHandle& dataset_resource = - return_values[0].scalar()(); - - // NOTE(mrry): We cannot use the core `LookupResource()` or - // `DeleteResource()` functions, because we have an - // `IteratorContext*` and not an `OpKernelContext*`, so we - // replicate the necessary functionality here. - auto type_index = MakeTypeIndex(); - if (type_index.hash_code() != dataset_resource.hash_code()) { - return errors::InvalidArgument( - "`f` must return a Dataset resource."); - } - TF_RETURN_IF_ERROR( - dataset()->captured_func_->resource_manager()->Lookup( - dataset_resource.container(), dataset_resource.name(), - &returned_dataset)); - core::ScopedUnref unref_dataset(returned_dataset); - - // Create an iterator for the dataset that was returned by - // `f`. This transfers ownership of the dataset to the - // iterator, so we can delete it from the resource manager. - current_element_iterator_ = returned_dataset->MakeIterator( - strings::StrCat(prefix(), "[", element_index_++, "]")); - TF_RETURN_IF_ERROR( - dataset() - ->captured_func_->resource_manager() - ->Delete(dataset_resource.container(), - dataset_resource.name())); + TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( + ctx, args, element_index_++, dataset()->captured_func_.get(), + prefix(), ¤t_element_iterator_)); } while (true); } diff --git a/tensorflow/core/kernels/gemm_functors.h b/tensorflow/core/kernels/gemm_functors.h index 7c224bcab65495e59153cbd2900671f79aabac84..4b30c1f17fc8d6bb537316be1760ffae319cbf21 100644 --- a/tensorflow/core/kernels/gemm_functors.h +++ b/tensorflow/core/kernels/gemm_functors.h @@ -33,11 +33,20 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" +// Apple provides an optimized BLAS library that is better than Eigen for their +// devices, so use that if possible. #if defined(__APPLE__) && defined(USE_GEMM_FOR_CONV) #include -#define USE_ACCELERATE_GEMM +#define USE_CBLAS_GEMM #endif // __APPLE__ +// Older Raspberry Pi systems don't have NEON SIMD acceleration, so Eigen falls +// back to scalar code, but OpenBLAS has much faster support so prefer that. +#if defined(RASPBERRY_PI) && defined(USE_GEMM_FOR_CONV) && defined(USE_OPENBLAS) +#include +#define USE_CBLAS_GEMM +#endif + // A readable but slow implementation of matrix multiplication, useful for // debugging and understanding the algorithm. Use instead of FastGemmFunctor in // the Im2ColConvFunctor template definition inside the op registration to @@ -94,9 +103,8 @@ class FastGemmFunctor { } }; -// If we have Apple's Accelerate framework, use their implementation of GEMM to -// get a performance boost for float. -#if defined(USE_ACCELERATE_GEMM) +// If we have a fast CBLAS library, use its implementation through a wrapper. +#if defined(USE_CBLAS_GEMM) template <> class FastGemmFunctor { public: @@ -107,4 +115,4 @@ class FastGemmFunctor { lda, b, ldb, 0.0f, c, ldc); } }; -#endif // USE_ACCELERATE_GEMM +#endif // USE_CBLAS_GEMM diff --git a/tensorflow/core/kernels/group_by_window_dataset_op.cc b/tensorflow/core/kernels/group_by_window_dataset_op.cc index 89e07228996b448f2923dd81bd7c2343891c602e..a4f9608b1fa8bac8eab8b1bce4494c8846d72278 100644 --- a/tensorflow/core/kernels/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/group_by_window_dataset_op.cc @@ -36,20 +36,14 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { graph_def_version_(ctx->graph_def_version()) { OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size_func", &window_size_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - int64 window_size = 0; - OP_REQUIRES_OK( - ctx, ParseScalarArgument(ctx, "window_size", &window_size)); - OP_REQUIRES( - ctx, window_size > 0, - errors::InvalidArgument("Window size must be greater than zero.")); - - // Get captured inputs for the key and reduce functions. + // Get captured inputs for the key, reduce, and window_size functions. OpInputList key_func_other_argument_inputs; OP_REQUIRES_OK(ctx, ctx->input_list("key_func_other_arguments", &key_func_other_argument_inputs)); @@ -67,6 +61,16 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { for (const Tensor& t : reduce_func_other_argument_inputs) { reduce_func_other_arguments.push_back(t); } + OpInputList window_size_func_other_argument_inputs; + OP_REQUIRES_OK(ctx, + ctx->input_list("window_size_func_other_arguments", + &window_size_func_other_argument_inputs)); + std::vector window_size_func_other_arguments; + window_size_func_other_arguments.reserve( + window_size_func_other_argument_inputs.size()); + for (const Tensor& t : window_size_func_other_argument_inputs) { + window_size_func_other_arguments.push_back(t); + } // TODO(mrry): Refactor CapturedFunction to share the runtime // state between multiple functions? std::unique_ptr captured_key_func; @@ -79,24 +83,30 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { ctx, CapturedFunction::Create(ctx, reduce_func_, graph_def_version_, std::move(reduce_func_other_arguments), &captured_reduce_func)); - - *output = new Dataset(input, window_size, std::move(captured_key_func), - std::move(captured_reduce_func), output_types_, - output_shapes_); + std::unique_ptr captured_window_size_func; + OP_REQUIRES_OK(ctx, CapturedFunction::Create( + ctx, window_size_func_, graph_def_version_, + std::move(window_size_func_other_arguments), + &captured_window_size_func)); + + *output = new Dataset( + input, std::move(captured_key_func), std::move(captured_reduce_func), + std::move(captured_window_size_func), output_types_, output_shapes_); } private: class Dataset : public DatasetBase { public: - Dataset(const DatasetBase* input, int64 window_size, + Dataset(const DatasetBase* input, std::unique_ptr captured_key_func, std::unique_ptr captured_reduce_func, + std::unique_ptr captured_window_size_func, const DataTypeVector& output_types, const std::vector& output_shapes) : input_(input), - window_size_(window_size), captured_key_func_(std::move(captured_key_func)), captured_reduce_func_(std::move(captured_reduce_func)), + captured_window_size_func_(std::move(captured_window_size_func)), output_types_(output_types), output_shapes_(output_shapes) { input_->Ref(); @@ -171,7 +181,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { // group. std::vector key_func_output; TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Run( - opts, next_input_element, &key_func_output, prefix())); + opts, next_input_element, &key_func_output)); if (key_func_output.size() != 1 || key_func_output[0].dtype() != DT_INT64 || @@ -182,10 +192,44 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { } const int64 key = key_func_output[0].scalar()(); + if (window_sizes_.find(key) == window_sizes_.end()) { + // Run window_size function + FunctionLibraryRuntime::Options opts2; + opts2.step_id = CapturedFunction::generate_step_id(); + opts2.runner = ctx->runner(); + ScopedStepContainer step_container2( + opts2.step_id, [this, ctx](const string& name) { + dataset() + ->captured_window_size_func_->resource_manager() + ->Cleanup(name) + .IgnoreError(); + }); + opts2.step_container = &step_container2; + + // Run the window size function on the key to identify its + // window size. + std::vector window_size_func_output; + TF_RETURN_IF_ERROR(dataset()->captured_window_size_func_->Run( + opts2, key_func_output, &window_size_func_output)); + + if (window_size_func_output.size() != 1 || + window_size_func_output[0].dtype() != DT_INT64 || + window_size_func_output[0].NumElements() != 1) { + // TODO(mrry): Support non-int64 window sizes. + return errors::InvalidArgument( + "`window_size_func` must return a scalar int64."); + } + const int64 window_size = + window_size_func_output[0].scalar()(); + window_sizes_[key] = window_size; + } + + const int64 window_size = window_sizes_[key]; + std::vector>& group = groups_[key]; group.push_back(std::move(next_input_element)); - if (group.size() == dataset()->window_size_) { + if (group.size() == window_size) { TF_RETURN_IF_ERROR(StartFlushingGroup(ctx, key)); break; } @@ -259,8 +303,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { {std::move(key_arg), std::move(group_dataset_arg)}); std::vector return_values; - TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Run( - opts, args, &return_values, prefix())); + TF_RETURN_IF_ERROR( + dataset()->captured_reduce_func_->Run(opts, args, &return_values)); if (!(return_values.size() == 1 && return_values[0].dtype() == DT_RESOURCE && @@ -297,6 +341,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { bool end_of_input_ GUARDED_BY(mu_) = false; std::map>> groups_ GUARDED_BY(mu_); std::unique_ptr current_group_iterator_ GUARDED_BY(mu_); + std::map window_sizes_ GUARDED_BY(mu_); }; // A resource name for the temporary window dataset that is @@ -304,9 +349,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { static constexpr const char* kWindowResourceName = "__window_dataset"; const DatasetBase* const input_; - const int64 window_size_; const std::unique_ptr captured_key_func_; const std::unique_ptr captured_reduce_func_; + const std::unique_ptr captured_window_size_func_; const DataTypeVector output_types_; const std::vector output_shapes_; }; @@ -316,6 +361,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { std::vector output_shapes_; const NameAttrList* key_func_; const NameAttrList* reduce_func_; + const NameAttrList* window_size_func_; }; REGISTER_KERNEL_BUILDER(Name("GroupByWindowDataset").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/interleave_dataset_op.cc b/tensorflow/core/kernels/interleave_dataset_op.cc index 90907e49bdf7ba6143d95292cc2889be21bb00c0..7b148b74c9800993f281b860b170c1772452563b 100644 --- a/tensorflow/core/kernels/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/interleave_dataset_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/kernels/captured_function.h" +#include "tensorflow/core/kernels/dataset_utils.h" namespace tensorflow { @@ -168,8 +169,9 @@ class InterleaveDatasetOp : public OpKernel { TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, &args, &end_of_input_)); if (!end_of_input_) { - TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( - ctx, args, ¤t_elements_[cycle_index_])); + TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( + ctx, args, cycle_index_, dataset()->captured_func_.get(), + prefix(), ¤t_elements_[cycle_index_])); ++num_open_; } } else { @@ -182,62 +184,6 @@ class InterleaveDatasetOp : public OpKernel { } private: - Status MakeIteratorFromInputElement( - IteratorContext* ctx, const std::vector& input_element, - std::unique_ptr* out_iterator) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - FunctionLibraryRuntime::Options opts; - opts.runner = ctx->runner(); - opts.step_id = CapturedFunction::generate_step_id(); - ScopedStepContainer step_container( - opts.step_id, [this, ctx](const string& name) { - dataset() - ->captured_func_->resource_manager() - ->Cleanup(name) - .IgnoreError(); - }); - opts.step_container = &step_container; - std::vector return_values; - TF_RETURN_IF_ERROR(dataset()->captured_func_->Run( - opts, input_element, &return_values, prefix())); - - if (!(return_values.size() == 1 && - return_values[0].dtype() == DT_RESOURCE && - TensorShapeUtils::IsScalar(return_values[0].shape()))) { - return errors::InvalidArgument( - "`f` must return a single scalar of dtype DT_RESOURCE."); - } - - // Retrieve the dataset that was created in `f`. - DatasetBase* returned_dataset; - const ResourceHandle& dataset_resource = - return_values[0].scalar()(); - - // NOTE(mrry): We cannot use the core `LookupResource()` or - // `DeleteResource()` functions, because we have an - // `IteratorContext*` and not an `OpKernelContext*`, so we - // replicate the necessary functionality here. - auto type_index = MakeTypeIndex(); - if (type_index.hash_code() != dataset_resource.hash_code()) { - return errors::InvalidArgument("`f` must return a Dataset resource."); - } - TF_RETURN_IF_ERROR( - dataset()->captured_func_->resource_manager()->Lookup( - dataset_resource.container(), dataset_resource.name(), - &returned_dataset)); - core::ScopedUnref unref_dataset(returned_dataset); - - // Create an iterator for the dataset that was returned by - // `f`. This transfers ownership of the dataset to the - // iterator, so we can delete it from the resource manager. - *out_iterator = returned_dataset->MakeIterator( - strings::StrCat(prefix(), "[", cycle_index_, "]")); - TF_RETURN_IF_ERROR( - dataset()->captured_func_->resource_manager()->Delete( - dataset_resource.container(), dataset_resource.name())); - return Status::OK(); - } - mutex mu_; const std::unique_ptr input_impl_ GUARDED_BY(mu_); std::vector> current_elements_ diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc index c0e4f91991383ff6ac366912901b3a4c0b32bfe4..7f0e11872abd9698c7d820ffbd44187909165b93 100644 --- a/tensorflow/core/kernels/iterator_ops.cc +++ b/tensorflow/core/kernels/iterator_ops.cc @@ -89,6 +89,31 @@ class IteratorResource : public ResourceBase { } } + Status SaveState(OpKernelContext* ctx, StringPiece path) { + std::shared_ptr captured_iterator(iterator_); + if (captured_iterator) { + return captured_iterator->SaveState(ctx, path); + } else { + return errors::FailedPrecondition( + "SaveState() failed because the iterator has not been initialized. " + "Ensure that you have run the initializer operation for this " + "iterator before getting the next element."); + } + } + + Status RestoreState(OpKernelContext* ctx, StringPiece path) { + std::shared_ptr captured_iterator(iterator_); + if (captured_iterator) { + return captured_iterator->RestoreState(ctx, path); + } else { + return errors::FailedPrecondition( + "RestoreState() failed because the iterator has not been " + "initialized. " + "Ensure that you have run the initializer operation for this " + "iterator before getting the next element."); + } + } + // Transfers ownership of iterator to this. This method is thread-safe. Status set_iterator(std::unique_ptr iterator) { if (iterator) { @@ -161,6 +186,32 @@ class MakeIteratorOp : public OpKernel { } }; +class SaveIteratorOp : public OpKernel { + public: + explicit SaveIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + IteratorResource* iterator_resource; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); + const string& path = ctx->input(1).scalar()(); + OP_REQUIRES_OK(ctx, iterator_resource->SaveState(ctx, path)); + } +}; + +class RestoreIteratorOp : public OpKernel { + public: + explicit RestoreIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + IteratorResource* iterator_resource; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); + const string& path = ctx->input(1).scalar()(); + OP_REQUIRES_OK(ctx, iterator_resource->RestoreState(ctx, path)); + } +}; + class OneShotIteratorOp : public AsyncOpKernel { public: explicit OneShotIteratorOp(OpKernelConstruction* ctx) @@ -504,6 +555,10 @@ class IteratorFromStringHandleOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp); REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU), MakeIteratorOp); +REGISTER_KERNEL_BUILDER(Name("SaveIterator").Device(DEVICE_CPU), + SaveIteratorOp); +REGISTER_KERNEL_BUILDER(Name("RestoreIterator").Device(DEVICE_CPU), + RestoreIteratorOp); REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU), OneShotIteratorOp); REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/l2loss_op.cc b/tensorflow/core/kernels/l2loss_op.cc index 9875cd027d5a6d1bd6e22c74e58ad47d8a15c3a3..f8ed9351579ff8cbeeb5f45030e8ff278fa75101 100644 --- a/tensorflow/core/kernels/l2loss_op.cc +++ b/tensorflow/core/kernels/l2loss_op.cc @@ -27,10 +27,9 @@ limitations under the License. namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -typedef Eigen::GpuDevice GPUDevice; -template -class L2LossOp : public OpKernel { +template +class L2LossOp : public OpKernel { public: explicit L2LossOp(OpKernelConstruction* context) : OpKernel(context) {} @@ -42,8 +41,9 @@ class L2LossOp : public OpKernel { Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output)); - functor::L2Loss()(context->eigen_device(), - input.flat(), output->scalar()); + const CPUDevice& d = context->eigen_device(); + output->scalar().device(d) = + (input.flat().square() * static_cast(0.5)).sum(); } }; @@ -57,33 +57,4 @@ REGISTER_KERNEL(double); REGISTER_KERNEL(Eigen::half); #undef REGISTER_KERNEL -#if GOOGLE_CUDA -// Forward declarations of the functor specializations for GPU. -namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void L2Loss::operator()(const GPUDevice& d, \ - typename TTypes::ConstTensor input, \ - typename TTypes::Scalar output); \ - extern template struct L2Loss; - -DECLARE_GPU_SPEC(float); -DECLARE_GPU_SPEC(double); -DECLARE_GPU_SPEC(Eigen::half); -#undef DECLARE_GPU_SPEC -} // namespace functor - -// Registration of the GPU implementations. -#define REGISTER_GPU_KERNEL(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("L2Loss").Device(DEVICE_GPU).TypeConstraint("T"), \ - L2LossOp); - -REGISTER_GPU_KERNEL(float); -REGISTER_GPU_KERNEL(double); -REGISTER_GPU_KERNEL(Eigen::half); -#undef REGISTER_GPU_KERNEL - -#endif // GOOGLE_CUDA - } // namespace tensorflow diff --git a/tensorflow/core/kernels/l2loss_op.h b/tensorflow/core/kernels/l2loss_op.h index f7204cefdd418ba4ab33b33d001611ec0f91ac70..4953aa237cd75e4e352a49fbc839f7a937fdbf78 100644 --- a/tensorflow/core/kernels/l2loss_op.h +++ b/tensorflow/core/kernels/l2loss_op.h @@ -15,25 +15,19 @@ limitations under the License. #ifndef TENSORFLOW_KERNELS_L2LOSS_OP_H_ #define TENSORFLOW_KERNELS_L2LOSS_OP_H_ -// Functor definition for L2LossOp, must be compilable by nvcc. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" namespace tensorflow { -namespace functor { -// Functor used by L2LossOp to do the computations. template -struct L2Loss { - void operator()(const Device& d, typename TTypes::ConstTensor input, - typename TTypes::Scalar output) { - // We flatten the input tensor and reduce on dimension 0, producing - // a single number which is Mul(Sum(x^2), 0.5). - output.device(d) = (input.square() * static_cast(0.5)).sum(); - } +struct L2LossOp : public OpKernel { + explicit L2LossOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) {} }; -} // namespace functor } // namespace tensorflow #endif // TENSORFLOW_KERNELS_L2LOSS_OP_H_ diff --git a/tensorflow/core/kernels/l2loss_op_gpu.cu.cc b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc index 420df37086555dfc9d7f89170e0a04210283a075..73b6472254cf9e8526d6d26cbf24cc1e398d3208 100644 --- a/tensorflow/core/kernels/l2loss_op_gpu.cu.cc +++ b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc @@ -21,12 +21,55 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/reduction_ops_common.h" +#include "tensorflow/core/kernels/reduction_ops_gpu_kernels.h" + namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; -template struct functor::L2Loss; -template struct functor::L2Loss; -template struct functor::L2Loss; + +// TODO(eriche): can add specialization for half2 +template +struct squareHalf { + __host__ __device__ T operator()(const T& x) const { + return static_cast(0.5) * x * x; + } +}; + +template +class L2LossOp : public OpKernel { + public: + explicit L2LossOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // The input tensor can be of any number of dimensions, even though it's + // 2D in most typical applications. + const Tensor& input = context->input(0); + // The output is a single number. + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), &output)); + typedef cub::TransformInputIterator, T*> inputIterType; + inputIterType input_itr((T*)input.flat().data(), squareHalf()); + typedef const Eigen::array::Tensor::Index, 1>& ReductionAxes; + + Constants constants; + functor::ReduceImpl( + context, (T*)output->flat().data(), input_itr, 1, + input.flat().size(), 1, 1, 0, constants.kZero, cub::Sum(), T(0)); + } +}; + +// Registration of the GPU implementations. +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("L2Loss").Device(DEVICE_GPU).TypeConstraint("T"), \ + L2LossOp); + +REGISTER_GPU_KERNEL(float); +REGISTER_GPU_KERNEL(double); +REGISTER_GPU_KERNEL(Eigen::half); +#undef REGISTER_GPU_KERNEL } // namespace tensorflow diff --git a/tensorflow/core/kernels/map_dataset_op.cc b/tensorflow/core/kernels/map_dataset_op.cc index 68bcb6a1b679931b12ab58f2bbe83f47821d42ba..10f4c2b82c88bc669d57cfa9106deadc2c9fd674 100644 --- a/tensorflow/core/kernels/map_dataset_op.cc +++ b/tensorflow/core/kernels/map_dataset_op.cc @@ -122,8 +122,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { opts.runner = ctx->runner(); // TODO(mrry): Avoid blocking a threadpool thread. We will need to // stack-rip the iterators and use async kernels. - Status s = - dataset()->captured_func_->Run(opts, args, out_tensors, prefix()); + Status s = dataset()->captured_func_->Run(opts, args, out_tensors); if (errors::IsOutOfRange(s)) { // `f` may deliberately raise `errors::OutOfRange` to indicate // that we should terminate the iteration early. diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index ef7338e0e0d634e70e2830048623f2d67d8e272f..00884d09814c283c93fd5a12f544db4084af54e9 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -97,8 +97,12 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { errors::InvalidArgument( "Conv2DCustomBackpropInput: size must be 4-dim")); - MklSizesToTFSizes(context, data_format, mkl_context.filter_shape, - &filter_shape); + const int64* filter_sizes = + (const int64*)mkl_context.filter_shape.GetSizes(); + const int64 filter_dims = mkl_context.filter_shape.GetDimension(); + + OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( + filter_sizes, filter_dims, &filter_shape)); } else { filter_shape = filter.shape(); } diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 203e6946314e5379d7d40e4efd2bb01adca05fe2..7099aa13071fd1ecee22ae14bc828db4aebf2299 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -265,6 +265,28 @@ class MklConv2DOp : public OpKernel { sizeof(T)); AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape, mkl_output_mkl_shape); + // Filter output to be used in the backprop_input + TensorShape mkl_filter_output_tf_shape; + MklShape mkl_filter_output_mkl_shape; + mkl_filter_output_mkl_shape.SetMklTensor(true); + mkl_filter_output_mkl_shape.SetMklLayout(mkl_context.prim_fwd, + dnnResourceFilter); + + size_t filter_sizes[4] = {filter.dim_size(0), filter.dim_size(1), + filter.dim_size(2), filter.dim_size(3)}; + mkl_filter_output_mkl_shape.SetTfLayout(filter.dims(), filter_sizes, + mkl_context.filter_strides); + + mkl_filter_output_mkl_shape.SetTfDimOrder(mkl_context.filter_dims, + data_format_); + mkl_filter_output_tf_shape.AddDim( + dnnLayoutGetMemorySize_F32(static_cast( + mkl_filter_output_mkl_shape.GetMklLayout())) / + sizeof(T)); + AllocateOutputSetMklShape(context, 1, &mkl_context.output_filter, + mkl_filter_output_tf_shape, + mkl_filter_output_mkl_shape); + mkl_context.conv_res[dnnResourceDst] = static_cast(output->flat().data()); @@ -303,6 +325,7 @@ class MklConv2DOp : public OpKernel { dnnPrimitive_t prim_fwd; void* conv_res[dnnResourceNumber]; dnnLayout_t lt_filter, lt_bias, lt_input; + Tensor* output_filter = nullptr; // Create MKL dnnLayout_t objects for tensors coming into the layer void MklCreateInputLayouts(OpKernelContext* context) { @@ -383,8 +406,13 @@ class MklConv2DOp : public OpKernel { CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_filter, lt_filter, mkl_lt_internal_filter), E_SUCCESS); - AllocTmpBuffer(context, mkl_tmp_filter_buf_tensor, - mkl_lt_internal_filter, &mkl_buf_convert_filter); +<<<<<<< HEAD + mkl_buf_convert_filter = const_cast(static_cast( + output_filter->flat().data())); +======= + mkl_buf_convert_filter = const_cast( + static_cast(output_filter->flat().data())); +>>>>>>> e722358e7e96dd2aa20d7e2c56336e76845daa6a CHECK_EQ( dnnConversionExecute_F32(mkl_prim_convert_filter, mkl_buf_filter, mkl_buf_convert_filter), diff --git a/tensorflow/core/kernels/mkl_cwise_ops_common.cc b/tensorflow/core/kernels/mkl_cwise_ops_common.cc new file mode 100644 index 0000000000000000000000000000000000000000..7fc633c2542f4b3af34d9719c1d9f74519fb583c --- /dev/null +++ b/tensorflow/core/kernels/mkl_cwise_ops_common.cc @@ -0,0 +1,88 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0(the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef INTEL_MKL + +// See docs in ../ops/math_ops.cc. + +#define EIGEN_USE_THREADS +#include +#include + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +#include "tensorflow/core/util/mkl_util.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +class MklBinaryOp : public BinaryOp { + public: + explicit MklBinaryOp(OpKernelConstruction* context) + : BinaryOp(context) {} + + void Compute(OpKernelContext* context) override { + auto in0 = context->input(0); + auto in1 = context->input(1); + VLOG(1) << "Shapes (start mklbinaryop compute): " + << in0.shape().DebugString() << " _and_ " + << in1.shape().DebugString(); + + // Call the TensorFlow BinaryOp Compute method + BinaryOp::Compute(context); + + auto out = context->mutable_output(0); + VLOG(1) << "Shapes (output): " << out->shape().DebugString(); + + // Pass input shape through to ouput shape + ForwardMklMetaDataInToOut(context, 0, 0); + + out = context->mutable_output(0); + VLOG(1) << "Shapes (output): " << out->shape().DebugString(); + } +}; + +//---------- Registration macros for various element-wise ops ----------- +// We will need to redefine "REGISTER" to include the mkl_op_registry flag +#pragma push_macro("REGISTER") +#undef REGISTER +#define REGISTER(OP, D, N, F, T) \ + REGISTER_KERNEL_BUILDER(Name(N) \ + .Device(DEVICE_##D) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + OP>); + +REGISTER5(MklBinaryOp, CPU, "_MklAdd", functor::add, float, Eigen::half, double, + int32, int64); +REGISTER7(MklBinaryOp, CPU, "_MklSub", functor::sub, float, Eigen::half, double, + int32, int64, complex64, complex128); +REGISTER5(MklBinaryOp, CPU, "_MklMul", functor::mul, float, Eigen::half, double, + uint8, int32); +REGISTER5(MklBinaryOp, CPU, "_MklMaximum", functor::maximum, float, Eigen::half, + double, int32, int64); +REGISTER5(MklBinaryOp, CPU, "_MklSquaredDifference", + functor::squared_difference, float, Eigen::half, double, int32, + int64); + +#undef REGISTER +#pragma pop_macro("REGISTER") +//----------------------------------------------------------------------- + +} // end namespace tensorflow + +#endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_identity_op.cc b/tensorflow/core/kernels/mkl_identity_op.cc index ca20294a2683059488d9d2b3c7fe9f232b093dfb..f31e7afd46873a02c10277283862a7e5e2384803 100644 --- a/tensorflow/core/kernels/mkl_identity_op.cc +++ b/tensorflow/core/kernels/mkl_identity_op.cc @@ -41,9 +41,9 @@ class MklIdentityOp : public OpKernel { bool input_in_mkl_format = mkl_shape_input.IsMklTensor(); if (input_in_mkl_format) { - ForwarMklTensorInToOut(context, 0, 0); + ForwardMklTensorInToOut(context, 0, 0); } else { - FowardTfTensorInToOut(context, 0, 0); + ForwardTfTensorInToOut(context, 0, 0); } } diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b58e44e39800c8c047d5557ab3c84113bb78d3ca --- /dev/null +++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc @@ -0,0 +1,259 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef INTEL_MKL + +#include +#include +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/tensor_format.h" + +#include "tensorflow/core/kernels/mkl_tfconv_op.h" +#include "tensorflow/core/util/mkl_util.h" + +namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; + +/////////////////////////////////////////////////////////// +// Op kernel +// Checks and ensures that the 2 inputs are compatible for mkl binary ops. +// Here's the basic logic: +// +// if both inputs are in TF format: +// pass the inputs through to the output +// else if both inputs are in mkl format: +// if both have the same shape: +// pass the inputs through to the output +// else: +// convert both to TF +// else if one is TF and one is MKL: +// if broadcast is needed: +// convert the MKL format input to TF format +// else: +// convert the TF format input to MKL format +/////////////////////////////////////////////////////////// + +template +class MklInputConversionOp : public OpKernel { + public: + explicit MklInputConversionOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); + OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type)); + has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F); + } + + private: + void Compute(OpKernelContext* context) override { + // Check if input tensors are in MKL format. + const Tensor& input_tensor_0 = MklGetInput(context, 0); + MklShape input_shape_0; + GetMklShape(context, 0, &input_shape_0); + + const Tensor& input_tensor_1 = MklGetInput(context, 1); + MklShape input_shape_1; + GetMklShape(context, 1, &input_shape_1); + + bool tf_shapes_are_same = MklCompareShapes(&context->input(0).shape(), + &context->input(1).shape()); + + VLOG(1) << "MklInputConversionOp: Input shapes are " + << (tf_shapes_are_same ? "*same*" : "*different*") << ": " + << context->input(0).shape().DebugString() << " and " + << context->input(1).shape().DebugString(); + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // if both inputs are in TF format, just copy input tensors to output. + if (!input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { + VLOG(1) << "MklInputConversionOp: No conversion needed, " + << "copying TF inputs to output"; + + ForwardTfTensorInToOut(context, 0, 0); + ForwardTfTensorInToOut(context, 1, 1); + return; + } + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // If both inputs are in MKL format + if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { + // If both have the same shape, pass them through + if (tf_shapes_are_same) { + VLOG(1) << "MklInputConversionOp: No conversion needed, " + << "copying MKL inputs with identical shapes to output"; + + ForwardMklTensorInToOut(context, 0, 0); + ForwardMklTensorInToOut(context, 1, 1); + return; + } + + // Sanity check + bool mkl_shapes_are_same = + MklCompareShapes(&input_shape_0, &input_shape_1); + if (mkl_shapes_are_same) { + CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are " + "different but MKL shapes are same"; + } + + // Both have different shapes, so broadcast will be necessary. + // Convert to TF and pass both tensors through (we can't do broadcast + // with MKL tensors) + VLOG(1) << "MklInputConversionOp: Broadcast needed, " + << "converted MKL inputs to TF format"; + + MklToTfOp::ConvertMklToTf(this, context, data_format_str, + op_data_type, has_avx512f_, 0); + MklToTfOp::ConvertMklToTf(this, context, data_format_str, + op_data_type, has_avx512f_, 1); + SetDummyMklShapeOutput(context, 0); + SetDummyMklShapeOutput(context, 1); + return; + } + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // One input is MKL and one is TF. If no broadcast is needed, convert + // the TF tensor to MKL, otherwise convert the MKL tensor to TF format + VLOG(1) << "MklInputConversionOp: Inputs in different formats (MKL/TF)"; + + const Tensor* mkl_tensor; + const MklShape* mkl_shape; + const Tensor* tf_tensor; + MklShape* tf_mkl_shape; + uint mkl_tensor_index; + uint tf_tensor_index; + if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { + mkl_tensor = &input_tensor_0; + mkl_shape = &input_shape_0; + mkl_tensor_index = 0; + tf_tensor = &input_tensor_1; + tf_mkl_shape = &input_shape_1; + tf_tensor_index = 1; + } else if (!input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { + mkl_tensor = &input_tensor_1; + mkl_shape = &input_shape_1; + mkl_tensor_index = 1; + tf_tensor = &input_tensor_0; + tf_mkl_shape = &input_shape_0; + tf_tensor_index = 0; + } else { + CHECK(false) << "MklInputConversionOp: Unexpected combination of input " + "shapes for MKL " + << "element-wise op"; + } + + // Broadcast is needed if the shapes are not the same + bool broadcast_needed; + + size_t in0_size = 1; + for (size_t i = 0; i < mkl_shape->GetDimension(); ++i) + in0_size *= mkl_shape->tf_dim_size(i); + + size_t in1_size = 1; + for (size_t i = 0; i < tf_tensor->shape().dims(); ++i) + in1_size *= tf_tensor->shape().dim_size(i); + + broadcast_needed = (in0_size != in1_size); + + if (!broadcast_needed) { + // Both shapes are same, convert the TF input to MKL + VLOG(1) << "MklInputConversionOp: No broadcast needed."; + VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index + << " to MKL format"; + + // Create MklShape + Tensor* tensor_out; + MklShape mkl_output_mkl_shape; + mkl_output_mkl_shape.SetMklTensor(true); + mkl_output_mkl_shape.SetTfLayout(mkl_shape->GetDimension(), + mkl_shape->GetSizes(), + mkl_shape->GetStrides()); + mkl_output_mkl_shape.SetTfDimOrder(mkl_shape->GetDimension()); + + // ** Temporarily borrow the layout from the MKL input ** + mkl_output_mkl_shape.SetMklLayout(mkl_shape->GetCurLayout()); + + // Create output tensor + AllocateOutputSetMklShape(context, tf_tensor_index, &tensor_out, + mkl_tensor->shape(), mkl_output_mkl_shape); + + // Since the shapes are the same, use information from the other tensor + tf_mkl_shape->SetTfLayout(mkl_shape->GetDimension(), + mkl_shape->GetSizes(), mkl_shape->GetStrides()); + // Convert the data format + tf_mkl_shape->GetConvertedFlatData( + mkl_shape->GetCurLayout(), + const_cast(tf_tensor->flat().data()), + const_cast(tensor_out->flat().data())); + + // ** Release the borrowed layout to avoid double deletion + // in the destructor call ** + mkl_output_mkl_shape.SetMklLayout(nullptr); + + // -- The tensor in MKL format passes through -- + ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index); + } else { + // Broadcast is needed, so convert the MKL input to TF + VLOG(1) << "MklInputConversionOp: Broadcast needed."; + VLOG(1) << "MklInputConversionOp: Converting input " << mkl_tensor_index + << " to TF format"; + MklToTfOp::ConvertMklToTf(this, context, data_format_str, + op_data_type, has_avx512f_, + mkl_tensor_index); + SetDummyMklShapeOutput(context, mkl_tensor_index); + + // The tensor in TF format passes through + ForwardTfTensorInToOut(context, tf_tensor_index, tf_tensor_index); + } + + VLOG(1) << "MklInputConversionOp: Shapes (output): " + << context->mutable_output(0)->shape().DebugString() << " and " + << context->mutable_output(1)->shape().DebugString(); + + VLOG(1) << "MklInputConversion completed successfully."; + } + + private: + /// Data format of the operation + string data_format_str; + + /// Data type of the operation + DataType op_data_type; + + /// CPUIDInfo + bool has_avx512f_ = false; +}; + +/////////////////////////////////////////////////////////// +// Register kernel +/////////////////////////////////////////////////////////// + +#define REGISTER_CPU(T) \ + REGISTER_KERNEL_BUILDER(Name("_MklInputConversion") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklInputConversionOp); + +TF_CALL_NUMBER_TYPES(REGISTER_CPU); +#undef REGISTER_CPU +} // namespace tensorflow +#endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc index b3763f17bc1393ba42ace07f21db36568eaae6cb..5e985824750befb702f8fa7a59d699f853f40267 100644 --- a/tensorflow/core/kernels/mkl_reshape_op.cc +++ b/tensorflow/core/kernels/mkl_reshape_op.cc @@ -43,30 +43,26 @@ class MklReshapeOp : public OpKernel { OP_REQUIRES(context, IsLegacyVector(sizes.shape()), errors::InvalidArgument("sizes input must be 1-D, not shape ", sizes.shape().DebugString())); - const int64 num_dims = sizes.NumElements(); // Compute the output shape. Determine product of specified // dimensions, and find the index of the unspecified one. TensorShape shape; int64 product = 1; int unknown_index = -1; - auto vec_size = sizes.flat(); - for (int d = 0; d < num_dims; ++d) { - const int32 size = vec_size(d); - if (size == -1) { - OP_REQUIRES( - context, unknown_index == -1, - errors::InvalidArgument("only one input size may be -1, not both ", - unknown_index, " and ", d)); - unknown_index = d; - shape.AddDim(1); - } else { - OP_REQUIRES(context, size >= 0, - errors::InvalidArgument( - "size ", d, " must be non-negative, not ", size)); - shape.AddDim(size); - product *= size; - } + switch (sizes.dtype()) { + case DT_INT32: + OP_REQUIRES_OK(context, ValidateSizes(sizes, &product, + &unknown_index, &shape)); + break; + case DT_INT64: + OP_REQUIRES_OK(context, ValidateSizes(sizes, &product, + &unknown_index, &shape)); + break; + default: + context->CtxFailure(errors::InvalidArgument( + "desired shape must be a DT_INT32 or DT_INT64 vector, not a ", + DataTypeString(sizes.dtype()))); + return; } if (unknown_index != -1) { OP_REQUIRES( @@ -132,6 +128,35 @@ class MklReshapeOp : public OpKernel { CopyTfTensorInToOutWithShape(context, 0, 0, shape); } } + + private: + template + Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index, + TensorShape* shape) { + *product = 1; + *unknown_index = -1; + const int64 num_dims = sizes.NumElements(); + auto Svec = sizes.flat(); + for (int d = 0; d < num_dims; ++d) { + const Tshape size = Svec(d); + if (size == -1) { + if (*unknown_index != -1) { + return errors::InvalidArgument( + "Only one input size may be -1, not both ", *unknown_index, + " and ", d); + } + *unknown_index = d; + shape->AddDim(1); + } else if (size < 0) { + return errors::InvalidArgument("Size ", d, + " must be non-negative, not ", size); + } else { + shape->AddDim(size); + (*product) *= size; + } + } + return Status::OK(); + } }; #define REGISTER_MKL_CPU(T) \ @@ -141,6 +166,13 @@ class MklReshapeOp : public OpKernel { .TypeConstraint("T") \ .TypeConstraint("Tshape") \ .Label(mkl_op_registry::kMklOpLabel), \ + MklReshapeOp); \ + REGISTER_KERNEL_BUILDER(Name("_MklReshape") \ + .Device(DEVICE_CPU) \ + .HostMemory("shape") \ + .TypeConstraint("T") \ + .TypeConstraint("Tshape") \ + .Label(mkl_op_registry::kMklOpLabel), \ MklReshapeOp); TF_CALL_float(REGISTER_MKL_CPU); #undef REGISTER_MKL_CPU diff --git a/tensorflow/core/kernels/mkl_tfconv_op.cc b/tensorflow/core/kernels/mkl_tfconv_op.h similarity index 80% rename from tensorflow/core/kernels/mkl_tfconv_op.cc rename to tensorflow/core/kernels/mkl_tfconv_op.h index b48c735d12465ab20549501e40f638315513a5e7..a240ee44fb014555b467ff2a920604dcc425972d 100644 --- a/tensorflow/core/kernels/mkl_tfconv_op.cc +++ b/tensorflow/core/kernels/mkl_tfconv_op.h @@ -15,6 +15,9 @@ limitations under the License. #ifdef INTEL_MKL +#ifndef TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_ +#define TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_ + #include #include #include "tensorflow/core/framework/numeric_op.h" @@ -28,9 +31,9 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/tensor_format.h" -#include "tensorflow/core/util/mkl_util.h" #include "mkl_dnn.h" #include "mkl_dnn_types.h" +#include "tensorflow/core/util/mkl_util.h" namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -49,14 +52,22 @@ class MklToTfOp : public OpKernel { } void Compute(OpKernelContext* context) override { + ConvertMklToTf(this, context, data_format_str, op_data_type, has_avx512f_, + 0); + VLOG(1) << "MKLToTFConversion complete successfully."; + } + + static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context, + string data_format_str, DataType op_data_type, + bool has_avx512f, uint input_number) { // Check that input tensor is in MKL format. - const Tensor& input_tensor = MklGetInput(context, 0); + const Tensor& input_tensor = MklGetInput(context, input_number); MklShape input_shape; - GetMklShape(context, 0, &input_shape); + GetMklShape(context, input_number, &input_shape); // if input is already in Tf format, then just copy input tensor to output. if (!input_shape.IsMklTensor()) { - context->set_output(0, input_tensor); + context->set_output(input_number, input_tensor); VLOG(1) << "MKLToTFConversion: No conversion needed, " << "copying input to output"; return; @@ -64,8 +75,8 @@ class MklToTfOp : public OpKernel { // Check that input data type is same as operator data type and that it is // same as output data type. - DataType input_data_type = input_type(0); - DataType output_data_type = output_type(0); + DataType input_data_type = op_kernel->input_type(input_number); + DataType output_data_type = op_kernel->output_type(input_number); CHECK_EQ(op_data_type, input_data_type); CHECK_EQ(op_data_type, output_data_type); @@ -81,7 +92,7 @@ class MklToTfOp : public OpKernel { // Allocate output tensor. Tensor* output_tensor = NULL; OP_REQUIRES_OK(context, - context->allocate_output(0, output_shape, &output_tensor)); + context->allocate_output(input_number, output_shape, &output_tensor)); dnnLayout_t output_layout = static_cast(input_shape.GetTfLayout()); @@ -118,7 +129,8 @@ class MklToTfOp : public OpKernel { .Label(mkl_op_registry::kMklOpLabel), \ MklToTfOp); -TF_CALL_float(REGISTER_CPU); +TF_CALL_NUMBER_TYPES(REGISTER_CPU); #undef REGISTER_CPU } // namespace tensorflow -#endif /* INTEL_MKL */ +#endif // TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_ +#endif // INTEL_MKL diff --git a/tensorflow/core/kernels/neon/BUILD b/tensorflow/core/kernels/neon/BUILD index e74310228a5887b9a7d27481c8e649d7cb36e534..536b2bdc03c5dc91e8e3e25dd9fbba82cd29fd5b 100644 --- a/tensorflow/core/kernels/neon/BUILD +++ b/tensorflow/core/kernels/neon/BUILD @@ -37,6 +37,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:ops_util", "@gemmlowp//:gemmlowp", ], diff --git a/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc b/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc index 818b44aab3908648b770b34c9bb3c86eca13e7bf..17f2af550f248a6924bb3d1e7546eca84d4c1e51 100644 --- a/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc +++ b/tensorflow/core/kernels/neon/neon_depthwise_conv_op.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/neon/depthwiseconv_float.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" @@ -95,10 +96,10 @@ class NeonDepthwiseConv2dNativeOp : public BinaryOp { padding_, &out_cols, &pad_cols)); TensorShape out_shape({batch, out_rows, out_cols, out_depth}); OP_REQUIRES( - context, out_shape.num_elements() <= 2147483647, - errors::InvalidArgument("total number of outputs should be within the " - "range of int which is used in the GPU kernel", - in_depth, " vs ", filter.dim_size(2))); + context, + FastBoundsCheck(out_shape.num_elements(), + std::numeric_limits::max()), + errors::InvalidArgument("Output elements too large for NEON kernel")); // Output tensor is of the following dimensions: // [ in_batch, out_rows, out_cols, out_depth ] diff --git a/tensorflow/core/kernels/ops_util.cc b/tensorflow/core/kernels/ops_util.cc index 130939263bef47eb5635e205413272e6a5e995f6..efacd05dd39cceb33397d647bbbc4c71228f1029 100644 --- a/tensorflow/core/kernels/ops_util.cc +++ b/tensorflow/core/kernels/ops_util.cc @@ -37,11 +37,6 @@ Eigen::PaddingType BrainPadding2EigenPadding(Padding padding) { Status GetBroadcastSize(const int index, const int in_size, const int ksize, const int stride, const int pad_size, int* bindex, int* bsize) { - // Cannot have strides larger than the patch size. - if (stride > ksize) { - return errors::InvalidArgument( - "stride must be less than or equal to kernel size"); - } // Cannot have index beyond the input size. if (index * stride > in_size) { return errors::InvalidArgument( diff --git a/tensorflow/core/kernels/ops_util_test.cc b/tensorflow/core/kernels/ops_util_test.cc index 42ffef6735bce6de3336c0c743a616837abc8d86..9d53882deef89230bd39d8318f11d84269406f20 100644 --- a/tensorflow/core/kernels/ops_util_test.cc +++ b/tensorflow/core/kernels/ops_util_test.cc @@ -173,12 +173,6 @@ TEST_F(OpsUtilTest, Get2dOutputSizeVerbose) { VerifyGet2dOutputVerboseSizeValues(pad_struct2, error::OK); } -// Test stride > ksize fails with INVALID_ARGUMENT. -TEST_F(OpsUtilTest, GetBroadcastTest3_1_2_0) { - bcast_struct bcast = {{0, 3, 1, 2, 0}, {0, 3}}; - VerifyBoundaries(bcast, error::INVALID_ARGUMENT); -} - // Test index * stride > in_size fails with INVALID_ARGUMENT. TEST_F(OpsUtilTest, GetBroadcastTestBadIndex) { bcast_struct bcast = {{2, 3, 1, 2, 0}, {0, 3}}; @@ -281,6 +275,38 @@ TEST_F(OpsUtilTest, GetBroadcastTest3_3_3_2) { } } +// in_size = 3, ksize = 1, stride = 2, pad_size = 0 +TEST_F(OpsUtilTest, GetBroadcastTest3_1_2_0) { + bcast_struct bcast[] = { + {{0, 3, 1, 2, 0}, {0, 1}}, + {{1, 3, 1, 2, 0}, {2, 1}}, + }; + for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) { + VerifyBcastValues(bcast[i]); + } +} + +// in_size = 3, ksize = 2, stride = 3, pad_size = 0 +TEST_F(OpsUtilTest, GetBroadcastTest3_2_3_0) { + bcast_struct bcast[] = { + {{0, 3, 2, 3, 0}, {0, 2}}, + }; + for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) { + VerifyBcastValues(bcast[i]); + } +} + +// in_size = 3, ksize = 2, stride = 3, pad_size = 1 +TEST_F(OpsUtilTest, GetBroadcastTest3_2_3_1) { + bcast_struct bcast[] = { + {{0, 3, 2, 3, 1}, {0, 1}}, + {{1, 3, 2, 3, 1}, {2, 1}}, + }; + for (size_t i = 0; i < sizeof(bcast) / sizeof(bcast[0]); ++i) { + VerifyBcastValues(bcast[i]); + } +} + TEST_F(OpsUtilTest, SanitizeThreadSuffix) { EXPECT_EQ("_aBc123_-___", SanitizeThreadSuffix("/aBc123_- /")); } diff --git a/tensorflow/core/kernels/pad_op.cc b/tensorflow/core/kernels/pad_op.cc index 6e8b09d05003ef64c7acb896e84e8887a503d75b..6196c5ed93ee3c0ff4001ad2b1d3bb7ac2776022 100644 --- a/tensorflow/core/kernels/pad_op.cc +++ b/tensorflow/core/kernels/pad_op.cc @@ -146,9 +146,9 @@ class PadOp : public OpKernel { Tensor* output) { CHECK_EQ(Dims, paddings.dimension(0)); CHECK_EQ(2, paddings.dimension(1)); - Eigen::array, Dims> paddings_array; + Eigen::array, Dims> paddings_array; for (int i = 0; i < Dims; ++i) { - paddings_array[i] = std::make_pair(paddings(i, 0), paddings(i, 1)); + paddings_array[i] = {paddings(i, 0), paddings(i, 1)}; } functor::Pad functor; functor(context->eigen_device(), output->tensor(), input, @@ -180,7 +180,7 @@ namespace functor { void Pad::operator()( \ const GPUDevice& d, typename TTypes::Tensor output, \ typename TTypes::ConstTensor input, \ - Eigen::array, Dims> paddings, T pad_value); \ + Eigen::array, Dims> paddings, T pad_value); \ extern template struct Pad; #define DECLARE_GPU_SPECS(T) \ diff --git a/tensorflow/core/kernels/pad_op.h b/tensorflow/core/kernels/pad_op.h index 6a973833e2d31961309e5bd1a6e4c15363862aff..95a7c9a3ae58b66fd7711a31aa90365aef5a4a46 100644 --- a/tensorflow/core/kernels/pad_op.h +++ b/tensorflow/core/kernels/pad_op.h @@ -31,7 +31,7 @@ struct Pad { // See pad_op.cc for details. void operator()(const Device& d, typename TTypes::Tensor output, typename TTypes::ConstTensor input, - Eigen::array, Dims> paddings, + Eigen::array, Dims> paddings, T pad_value) { if (Eigen::internal::is_same::value && (output.size() <= std::numeric_limits::max())) { @@ -47,7 +47,7 @@ struct Pad { // In the scalar case we simply copy the input. void operator()(const Device& d, typename TTypes::Tensor output, typename TTypes::ConstTensor input, - Eigen::array, 0>, T) { + Eigen::array, 0>, T) { output.device(d) = input; } }; diff --git a/tensorflow/core/kernels/parallel_map_dataset_op.cc b/tensorflow/core/kernels/parallel_map_dataset_op.cc index 00093b695655273693c45c9d1308346b9dc9b3eb..3f503581fb66a5c6bcf7fe1e31fda1cc57c3759a 100644 --- a/tensorflow/core/kernels/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/parallel_map_dataset_op.cc @@ -205,7 +205,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { opts.step_container = step_container; opts.runner = ctx->runner(); dataset()->captured_func_->RunAsync( - opts, input_element, &result->return_values, prefix(), + opts, input_element, &result->return_values, [result, step_container, result_index](Status ret_status) { delete step_container; result->status.Update(ret_status); diff --git a/tensorflow/core/kernels/parse_tensor_op.cc b/tensorflow/core/kernels/parse_tensor_op.cc index 79199ff5c3fcd73a485aeb180060317ce3f281c1..ab91a6ef677a95a498df1c3de85c8ea07d6451e8 100644 --- a/tensorflow/core/kernels/parse_tensor_op.cc +++ b/tensorflow/core/kernels/parse_tensor_op.cc @@ -16,11 +16,13 @@ limitations under the License. // See docs in ../ops/parsing_ops.cc. #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/framework/register_types.h" namespace tensorflow { @@ -65,4 +67,31 @@ class ParseTensorOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("ParseTensor").Device(DEVICE_CPU), ParseTensorOp); +template +class SerializeTensorOp : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* context) override { + const Tensor& tensor = context->input(0); + TensorProto proto; + if (tensor.dtype() == DT_STRING) { + tensor.AsProtoField(&proto); + } else { + tensor.AsProtoTensorContent(&proto); + } + Tensor* proto_string = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), &proto_string)); + CHECK(proto.SerializeToString(&proto_string->scalar()())); + } +}; + +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("SerializeTensor").Device(DEVICE_CPU).TypeConstraint("T"), \ + SerializeTensorOp); +TF_CALL_ALL_TYPES(REGISTER) +#undef REGISTER + } // namespace tensorflow diff --git a/tensorflow/core/kernels/parse_tensor_test.cc b/tensorflow/core/kernels/parse_tensor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4a5fc07935c02b3932d20f67c45122150c7559fb --- /dev/null +++ b/tensorflow/core/kernels/parse_tensor_test.cc @@ -0,0 +1,198 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" + +namespace tensorflow { +namespace { + +class SerializeTensorOpTest : public OpsTestBase { + protected: + template + void MakeOp(const TensorShape& input_shape, std::function functor) { + TF_ASSERT_OK(NodeDefBuilder("myop", "SerializeTensor") + .Input(FakeInput(DataTypeToEnum::value)) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + AddInput(input_shape, functor); + } + void ParseSerializedWithNodeDef(const NodeDef& parse_node_def, + Tensor* serialized, Tensor* parse_output) { + std::unique_ptr device( + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); + gtl::InlinedVector inputs; + inputs.push_back({nullptr, serialized}); + Status status; + std::unique_ptr op(CreateOpKernel(DEVICE_CPU, device.get(), + cpu_allocator(), parse_node_def, + TF_GRAPH_DEF_VERSION, &status)); + TF_EXPECT_OK(status); + OpKernelContext::Params params; + params.device = device.get(); + params.inputs = &inputs; + params.frame_iter = FrameAndIter(0, 0); + params.op_kernel = op.get(); + std::vector attrs; + test::SetOutputAttrs(¶ms, &attrs); + OpKernelContext ctx(¶ms); + op->Compute(&ctx); + TF_EXPECT_OK(status); + *parse_output = *ctx.mutable_output(0); + } + template + void ParseSerializedOutput(Tensor* serialized, Tensor* parse_output) { + NodeDef parse; + TF_ASSERT_OK(NodeDefBuilder("parse", "ParseTensor") + .Input(FakeInput(DT_STRING)) + .Attr("out_type", DataTypeToEnum::value) + .Finalize(&parse)); + ParseSerializedWithNodeDef(parse, serialized, parse_output); + } +}; + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_half) { + MakeOp(TensorShape({10}), [](int x) -> Eigen::half { + return static_cast(x / 10.); + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_float) { + MakeOp(TensorShape({1, 10}), + [](int x) -> float { return static_cast(x / 10.); }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_double) { + MakeOp(TensorShape({5, 5}), + [](int x) -> double { return static_cast(x / 10.); }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int64) { + MakeOp(TensorShape({2, 3, 4}), + [](int x) -> int64 { return static_cast(x - 10); }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int32) { + MakeOp(TensorShape({4, 2}), + [](int x) -> int32 { return static_cast(x + 7); }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int16) { + MakeOp(TensorShape({8}), + [](int x) -> int16 { return static_cast(x + 18); }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_int8) { + MakeOp(TensorShape({2}), + [](int x) -> int8 { return static_cast(x + 8); }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint16) { + MakeOp(TensorShape({1, 3}), + [](int x) -> uint16 { return static_cast(x + 2); }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_uint8) { + MakeOp(TensorShape({2, 1, 1}), + [](int x) -> uint8 { return static_cast(x + 1); }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex64) { + MakeOp(TensorShape({}), [](int x) -> complex64 { + return complex64{static_cast(x / 8.), static_cast(x / 2.)}; + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_complex128) { + MakeOp(TensorShape({3}), [](int x) -> complex128 { + return complex128{x / 3., x / 2.}; + }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_bool) { + MakeOp(TensorShape({1}), + [](int x) -> bool { return static_cast(x % 2); }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +TEST_F(SerializeTensorOpTest, SerializeTensorOpTest_string) { + MakeOp(TensorShape({10}), + [](int x) -> string { return std::to_string(x / 10.); }); + TF_ASSERT_OK(RunOpKernel()); + Tensor parse_output; + ParseSerializedOutput(GetOutput(0), &parse_output); + test::ExpectTensorEqual(parse_output, GetInput(0)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/queue_op.h b/tensorflow/core/kernels/queue_op.h index 99d2d19bfda12241ee58a1ba301c618cdeb26352..005ea6e26dac51c023b3aa06315cc5415b385fe6 100644 --- a/tensorflow/core/kernels/queue_op.h +++ b/tensorflow/core/kernels/queue_op.h @@ -56,13 +56,6 @@ class TypedQueueOp : public QueueOp { public: using QueueOp::QueueOp; - void Compute(OpKernelContext* context) override { - QueueOp::Compute(context); - if (queue_ && context->track_allocations()) { - context->record_host_persistent_memory_allocation(queue_->MemoryUsed()); - } - } - protected: template Status CreateTypedQueue(TypedQueue* queue, QueueInterface** ret) { diff --git a/tensorflow/core/kernels/range_dataset_op.cc b/tensorflow/core/kernels/range_dataset_op.cc index a32a02f57d44b11661b359e622f59e1162f0cd82..9976c558387b1e79b4ca8b0cd0fedad535a4e5e6 100644 --- a/tensorflow/core/kernels/range_dataset_op.cc +++ b/tensorflow/core/kernels/range_dataset_op.cc @@ -86,6 +86,7 @@ class RangeDatasetOp : public DatasetOpKernel { if ((dataset()->step_ > 0 && next_ >= dataset()->stop_) || (dataset()->step_ < 0 && next_ <= dataset()->stop_)) { *end_of_sequence = true; + is_exhausted_ = true; return Status::OK(); } Tensor value_tensor(cpu_allocator(), DT_INT64, {}); @@ -97,9 +98,26 @@ class RangeDatasetOp : public DatasetOpKernel { return Status::OK(); } + protected: + Status SaveStateInternal(OpKernelContext* ctx, + IteratorBundleWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR( + writer->WriteScalar(next_, full_name("next"))); + return Status::OK(); + } + + Status RestoreStateInternal(OpKernelContext* ctx, + IteratorBundleReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR( + reader->ReadScalar(&next_, full_name("next"))); + return Status::OK(); + } + private: mutex mu_; - int64 next_; + int64 next_ GUARDED_BY(mu_); }; const int64 start_; diff --git a/tensorflow/core/kernels/reader_dataset_ops.cc b/tensorflow/core/kernels/reader_dataset_ops.cc index 407f69cde73051c836c89dbc4319d6f15b39a414..73fc09abc8fcfd90b60f0a1981c259a33ae07f21 100644 --- a/tensorflow/core/kernels/reader_dataset_ops.cc +++ b/tensorflow/core/kernels/reader_dataset_ops.cc @@ -315,6 +315,7 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel { // Iteration ends when there are no more files to process. if (current_file_index_ == dataset()->filenames_.size()) { *end_of_sequence = true; + is_exhausted_ = true; return Status::OK(); } @@ -332,6 +333,51 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel { } while (true); } + protected: + Status SaveStateInternal(OpKernelContext* ctx, + IteratorBundleWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar( + current_file_index_, full_name("current_file_index"))); + + // `input_buffer_` is empty if + // 1. GetNext has not been called even once. + // 2. All files have been read and iterator has been exhausted. + int64 current_pos = input_buffer_ ? input_buffer_->Tell() : -1; + TF_RETURN_IF_ERROR( + writer->WriteScalar(current_pos, full_name("current_pos"))); + return Status::OK(); + } + + Status RestoreStateInternal(OpKernelContext* ctx, + IteratorBundleReader* reader) override { + mutex_lock l(mu_); + int64 current_file_index; + TF_RETURN_IF_ERROR(reader->ReadScalar( + ¤t_file_index, full_name("current_file_index"))); + current_file_index_ = size_t(current_file_index); + int64 current_pos; + TF_RETURN_IF_ERROR( + reader->ReadScalar(¤t_pos, full_name("current_pos"))); + + // Seek to current_pos. + input_buffer_.reset(); + file_.reset(); + if (current_pos >= 0) { // There was an active input_buffer_. + uint64 file_size; + TF_RETURN_IF_ERROR(ctx->env()->GetFileSize( + dataset()->filenames_[current_file_index_], &file_size)); + file_pos_limit_ = file_size - dataset()->footer_bytes_; + TF_RETURN_IF_ERROR(ctx->env()->NewRandomAccessFile( + dataset()->filenames_[current_file_index_], &file_)); + input_buffer_.reset( + new io::InputBuffer(file_.get(), dataset()->buffer_size_)); + TF_RETURN_IF_ERROR(input_buffer_->Seek(current_pos)); + } + + return Status::OK(); + } + private: mutex mu_; size_t current_file_index_ GUARDED_BY(mu_) = 0; diff --git a/tensorflow/core/kernels/reduction_ops.h b/tensorflow/core/kernels/reduction_ops.h index 5db9e6032e0ad678653c2d602d3f29de0505ac60..e43d2828f3093a39d2fdbe26c3557627839b6c36 100644 --- a/tensorflow/core/kernels/reduction_ops.h +++ b/tensorflow/core/kernels/reduction_ops.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" namespace tensorflow { @@ -67,7 +68,7 @@ void FillIdentityEigenImpl(const Device& d, OUT_T out, const Reducer& reducer) { template struct ReduceFunctor { template - static void Reduce(const Device& d, OUT_T out, IN_T in, + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, const ReductionAxes& reduction_axes, const Reducer& reducer); diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h index 553f8895232277cd1fa570ccf05e898515acd6c0..71af9d88dc1d34db392cd1e29714bdcad645abd9 100644 --- a/tensorflow/core/kernels/reduction_ops_common.h +++ b/tensorflow/core/kernels/reduction_ops_common.h @@ -190,24 +190,24 @@ class ReductionOp : public OpKernel { Functor::FillIdentity(d, tmp_out.flat(), reducer); } else if ((helper.ndims() == 1) && helper.reduce_first_axis()) { // Reduce to a scalar. - Functor::Reduce(d, helper.out(&tmp_out), helper.in(data), + Functor::Reduce(ctx, helper.out(&tmp_out), helper.in(data), constants.kZero, reducer); } else if ((helper.ndims() == 2) && helper.reduce_first_axis()) { // Can be viewed as a reduction of a matrix along 1st dimension. - Functor::Reduce(d, helper.out(&tmp_out), helper.in(data), + Functor::Reduce(ctx, helper.out(&tmp_out), helper.in(data), constants.kZero, reducer); } else if ((helper.ndims() == 2) && !helper.reduce_first_axis()) { // Can be viewed as a reduction of a matrix along 2nd dimension. - Functor::Reduce(d, helper.out(&tmp_out), helper.in(data), + Functor::Reduce(ctx, helper.out(&tmp_out), helper.in(data), constants.kOne, reducer); } else if ((helper.ndims() == 3) && helper.reduce_first_axis()) { // Can be viewed as a reduction of a 3D tensor along 1st and 3rd // dimensions. - Functor::Reduce(d, helper.out(&tmp_out), helper.in(data), + Functor::Reduce(ctx, helper.out(&tmp_out), helper.in(data), constants.kZeroTwo, reducer); } else if ((helper.ndims() == 3) && !helper.reduce_first_axis()) { // Can be viewed as a reduction of a 3D tensor along 2nd dimension. - Functor::Reduce(d, helper.out(&tmp_out), helper.in(data), + Functor::Reduce(ctx, helper.out(&tmp_out), helper.in(data), constants.kOne, reducer); } else { // If we don't hit one of the cases above, transpose the data so that @@ -223,7 +223,7 @@ class ReductionOp : public OpKernel { const int64 unreduced = tmp_out.NumElements(); const int64 reduced = shuffled.NumElements() / unreduced; const Tensor& const_shuffled = shuffled; - Functor::Reduce(d, tmp_out.flat(), + Functor::Reduce(ctx, tmp_out.flat(), const_shuffled.shaped({unreduced, reduced}), constants.kOne, reducer); } @@ -258,9 +258,10 @@ namespace functor { template struct ReduceFunctorBase { template - static void Reduce(const Device& d, OUT_T out, IN_T in, + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, const ReductionAxes& reduction_axes, const Reducer& reducer) { + const Device& d = ctx->eigen_device(); ReduceEigenImpl(d, out, in, reduction_axes, reducer); } diff --git a/tensorflow/core/kernels/reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/reduction_ops_gpu.cu.cc index ec4490db83fc38c23208991b15ee93ca072c7d50..8fd9165eb9f32b1449033e8ae598cba58e5f0d2f 100644 --- a/tensorflow/core/kernels/reduction_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/reduction_ops_gpu.cu.cc @@ -17,8 +17,7 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/framework/numeric_types.h" -#include "tensorflow/core/kernels/reduction_ops.h" +#include "tensorflow/core/kernels/reduction_ops_gpu_kernels.h" namespace tensorflow { namespace functor { @@ -33,15 +32,27 @@ typedef TTypes::Tensor::Index Index; template struct ReduceFunctor { template - static void Reduce(const GPUDevice& d, OUT_T out, IN_T in, + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, const ReductionAxes& reduction_axes, - const Reducer& reducer) { - ReduceEigenImpl(d, To32Bit(out), To32Bit(in), reduction_axes, reducer); + const Reducer& reducer); +}; + +template +struct ReduceFunctor> { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::SumReducer& reducer) { + ReduceImpl( + ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, + cub::Sum(), T(0)); } template static void FillIdentity(const GPUDevice& d, OUT_T out, - const Reducer& reducer) { + const Eigen::internal::SumReducer& reducer) { FillIdentityEigenImpl(d, To32Bit(out), reducer); } }; @@ -49,19 +60,30 @@ struct ReduceFunctor { template struct ReduceFunctor> { template - static void Reduce(const GPUDevice& d, OUT_T out, IN_T in, + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, const ReductionAxes& reduction_axes, const Eigen::internal::MeanReducer& reducer) { - typedef typename IN_T::Index Index; - // Eigen sum reductions are much faster on GPU than mean reductions: - // Simply trigger them by computing the sum of the weighted inputs. - Index num_coeffs_to_reduce = 1; - for (int i = 0; i < Eigen::internal::array_size::value; - ++i) { - num_coeffs_to_reduce *= in.dimension(reduction_axes[i]); - } - T scale = T(1.0 / num_coeffs_to_reduce); - out.device(d) = (in * scale).sum(reduction_axes); + int divisor = 1; + if (out.rank() == 0) + divisor = in.size(); + else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0) + divisor = in.dimension(0); + else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1) + divisor = in.dimension(1); + else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 && + reduction_axes[1] == 2) + divisor = in.dimension(0) * in.dimension(2); + else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1) + divisor = in.dimension(1); + + DividesBy div_op(static_cast(divisor)); + TransformOutputIterator> itr((T*)out.data(), div_op); + ReduceImpl>, T*, + ReductionAxes>(ctx, itr, (T*)in.data(), in.rank(), + in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), + reduction_axes, cub::Sum(), T(0)); } template @@ -71,15 +93,159 @@ struct ReduceFunctor> { } }; +template <> +struct ReduceFunctor> { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::MeanReducer& reducer) { + float divisor = 1.f; + if (out.rank() == 0) + divisor = in.size(); + else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0) + divisor = in.dimension(0); + else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1) + divisor = in.dimension(1); + else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 && + reduction_axes[1] == 2) + divisor = in.dimension(0) * in.dimension(2); + else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1) + divisor = in.dimension(1); + DividesBy div_op(divisor); + + typedef cub::TransformInputIterator + inputIterType; + inputIterType input_itr((Eigen::half*)in.data(), HalfToFloat()); + + typedef TransformOutputIterator> + outputIterType; + outputIterType itr((Eigen::half*)out.data(), div_op); + + ReduceImpl( + ctx, itr, input_itr, in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, + cub::Sum(), 0.f); + } + + template + static void FillIdentity( + const GPUDevice& d, OUT_T out, + const Eigen::internal::MeanReducer& reducer) { + FillIdentityEigenImpl(d, To32Bit(out), reducer); + } +}; + +template +struct ReduceFunctor> { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::MaxReducer& reducer) { + ReduceImpl( + ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, + cub::Max(), std::numeric_limits::lowest()); + } + + template + static void FillIdentity(const GPUDevice& d, OUT_T out, + const Eigen::internal::MaxReducer& reducer) { + FillIdentityEigenImpl(d, To32Bit(out), reducer); + } +}; + +template +struct ReduceFunctor> { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::MinReducer& reducer) { + ReduceImpl( + ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, + cub::Min(), std::numeric_limits::max()); + } + + template + static void FillIdentity(const GPUDevice& d, OUT_T out, + const Eigen::internal::MinReducer& reducer) { + FillIdentityEigenImpl(d, To32Bit(out), reducer); + } +}; + +template +struct ReduceFunctor> { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::ProdReducer& reducer) { + ReduceImpl, T*, T*, ReductionAxes>( + ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, + Prod(), T(1)); + } + + template + static void FillIdentity(const GPUDevice& d, OUT_T out, + const Eigen::internal::ProdReducer& reducer) { + FillIdentityEigenImpl(d, To32Bit(out), reducer); + } +}; + +template <> +struct ReduceFunctor { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::AndReducer& reducer) { + ReduceImpl( + ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, And(), + true); + } + + template + static void FillIdentity(const GPUDevice& d, OUT_T out, + const Eigen::internal::AndReducer& reducer) { + FillIdentityEigenImpl(d, To32Bit(out), reducer); + } +}; + +template <> +struct ReduceFunctor { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::OrReducer& reducer) { + ReduceImpl( + ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, Or(), + false); + } + + template + static void FillIdentity(const GPUDevice& d, OUT_T out, + const Eigen::internal::OrReducer& reducer) { + FillIdentityEigenImpl(d, To32Bit(out), reducer); + } +}; + // T: the data type // REDUCER: the reducer functor // NUM_AXES: the number of axes to reduce // IN_DIMS: the number of dimensions of the input tensor -#define DEFINE(T, REDUCER, IN_DIMS, NUM_AXES) \ - template void ReduceFunctor::Reduce( \ - const GPUDevice& d, TTypes::Tensor out, \ - TTypes::ConstTensor in, \ - const Eigen::array& reduction_axes, \ +#define DEFINE(T, REDUCER, IN_DIMS, NUM_AXES) \ + template void ReduceFunctor::Reduce( \ + OpKernelContext* ctx, TTypes::Tensor out, \ + TTypes::ConstTensor in, \ + const Eigen::array& reduction_axes, \ const REDUCER& reducer); #define DEFINE_IDENTITY(T, REDUCER) \ diff --git a/tensorflow/core/kernels/reduction_ops_gpu_kernels.h b/tensorflow/core/kernels/reduction_ops_gpu_kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..ce471c672c7d235213d576d3f35b414cb8283415 --- /dev/null +++ b/tensorflow/core/kernels/reduction_ops_gpu_kernels.h @@ -0,0 +1,713 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "external/cub_archive/cub/device/device_reduce.cuh" +#include "external/cub_archive/cub/device/device_segmented_reduce.cuh" +#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh" +#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh" +#include "external/cub_archive/cub/warp/warp_reduce.cuh" +#include "cuda/include/cuComplex.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/reduction_ops.h" +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" +#include "tensorflow/core/util/permutation_input_iterator.h" +#include "tensorflow/core/util/transform_output_iterator.h" + +#include + +namespace tensorflow { +namespace functor { + +typedef Eigen::GpuDevice GPUDevice; + +template +struct Prod { + __host__ __device__ T operator()(const T& a, const T& b) const { + return a * b; + } +}; + +// needed to work around a compiler bug in nvcc - it doesn't seem to like +// the overloaded multiply op for std::complex +template <> +struct Prod> { + __host__ __device__ std::complex operator()( + const std::complex& a, const std::complex& b) const { + auto result = cuCmulf(make_cuComplex(a.real(), a.imag()), + make_cuComplex(b.real(), b.imag())); + return std::complex(result.x, result.y); + } +}; + +template <> +struct Prod> { + __host__ __device__ std::complex operator()( + const std::complex& a, const std::complex& b) const { + auto result = cuCmul(make_cuDoubleComplex(a.real(), a.imag()), + make_cuDoubleComplex(b.real(), b.imag())); + return std::complex(result.x, result.y); + } +}; + +template +struct DividesBy { + T divisor; + + __host__ __device__ explicit DividesBy(T divisor) : divisor(divisor) {} + + __host__ __device__ outT operator()(const T& x) const { return x / divisor; } +}; + +// needed to work around a compiler bug in nvcc - it doesn't seem to like +// the overloaded ops for std::complex +template <> +struct DividesBy> { + cuFloatComplex divisor; + + __host__ __device__ explicit DividesBy(std::complex divisor) + : divisor(make_cuComplex(divisor.real(), divisor.imag())) {} + + // implements + __host__ __device__ std::complex operator()( + const std::complex& x) const { + auto result = cuCdivf(make_cuComplex(x.real(), x.imag()), divisor); + return std::complex(result.x, result.y); + } +}; + +template <> +struct DividesBy> { + cuDoubleComplex divisor; + + __host__ __device__ explicit DividesBy(std::complex divisor) + : divisor(make_cuDoubleComplex(divisor.real(), divisor.imag())) {} + + // implements + __host__ __device__ std::complex operator()( + const std::complex& x) const { + auto result = cuCdiv(make_cuDoubleComplex(x.real(), x.imag()), divisor); + return std::complex(result.x, result.y); + } +}; + +template <> +struct DividesBy { + float divisor; + + __host__ __device__ explicit DividesBy(float divisor) : divisor(divisor) {} + + __host__ __device__ Eigen::half operator()(const float& x) const { + return Eigen::half(x / divisor); + } +}; + +struct HalfToFloat { + __host__ __device__ float operator()(const Eigen::half& x) const { + return Eigen::half_impl::half_to_float(x); + } +}; + +struct FloatToHalf { + __host__ __device__ Eigen::half operator()(const float& x) const { + return Eigen::half_impl::float_to_half_rtne(x); + } +}; + +struct And { + __host__ __device__ bool operator()(const bool& a, const bool& b) const { + return a && b; + } +}; + +struct Or { + __host__ __device__ bool operator()(const bool& a, const bool& b) const { + return a || b; + } +}; + +// each block does a grid strided loop and reduces its values locally +// the case of one block is used for low latency small reductions to scalars +template +__global__ void BlockReduceKernel( + T in, outT out, int num_elems, Op op, + typename std::iterator_traits::value_type initVal) { + const int bid = blockIdx.x; + const int tid = threadIdx.x; + + const int gid = bid * blockDim.x + tid; + const int stride = blockDim.x * gridDim.x; + + typedef typename std::iterator_traits::value_type value_type; + + value_type sum = initVal; + if (gid < num_elems) { + sum = in[gid]; + for (int pos = gid + stride; pos < num_elems; pos += stride) { + sum = op(sum, in[pos]); + } + } + + typedef cub::BlockReduce BlockReduce; + + __shared__ typename BlockReduce::TempStorage temp_storage; + + // only include input values in the reduction + // + // elements: ----------------- + // grid: |====|====|====|====|====| + const int num_elements_to_reduce = + max(min(num_elems - bid * blockDim.x, num_threads), 0); + + sum = BlockReduce(temp_storage) + .template Reduce(sum, op, num_elements_to_reduce); + + if (tid == 0) out[bid] = sum; +} + +// maps a warp to each row +template +__global__ void RowReduceKernel( + T in, outT out, int num_rows, int num_cols, Op op, + typename std::iterator_traits::value_type initVal) { + typedef typename std::iterator_traits::value_type value_type; + const int row = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + const int lane = threadIdx.x % 32; + + if (num_cols == 1) { + int gid = threadIdx.x + blockIdx.x * blockDim.x; + if (gid < num_rows) out[gid] = in[gid]; + return; + } + + value_type sum = initVal; + int col = lane; + + if (row < num_rows && col < num_cols) { + sum = in[row * num_cols + col]; + col += 32; + for (; col < num_cols; col += 32) { + sum = op(sum, in[row * num_cols + col]); + } + } + + typedef cub::WarpReduce WarpReduce; + + __shared__ typename WarpReduce::TempStorage temp_storage; + + sum = WarpReduce(temp_storage).template Reduce(sum, op, min(num_cols, 32)); + + if (row < num_rows && lane == 0) out[row] = sum; +} + +// Works only if there are <= 16 columns +// each warps sums over multiple rows at once +template +__global__ void ColumnReduceMax16ColumnsKernel( + T in, outT out, int num_rows, int num_cols, Op op, + typename std::iterator_traits::value_type initVal) { + typedef typename std::iterator_traits::value_type value_type; + int rows_per_warp = 32 / num_cols; + + const int lane = threadIdx.x % 32; + const int lane_row = lane / num_cols; + + const int start_row_warp = + rows_per_warp * (blockIdx.y * blockDim.y + threadIdx.y); + const int start_row_lane = start_row_warp + lane_row; + int row = start_row_lane; + int col = lane % num_cols; + + value_type sum = initVal; + if (row * num_cols + col < num_rows * num_cols) + sum = in[row * num_cols + col]; + + __shared__ value_type partial_sums[32][33]; + + row += rows_per_warp * gridDim.y * blockDim.y; + for (; row < num_rows; row += rows_per_warp * gridDim.y * blockDim.y) { + int global_pos = row * num_cols + col; + if (global_pos < (num_rows * num_cols)) + sum = op(sum, in[row * num_cols + col]); + } + + const int rows_in_this_warp = min(rows_per_warp, num_rows - start_row_warp); + // not the most efficient way to do this sum + for (int i = 1; i < rows_in_this_warp; ++i) { + value_type tmp = + cub::ShuffleIndex(sum, threadIdx.x + i * num_cols, 32, 0xffffffff); + if (lane < num_cols) sum = op(sum, tmp); + } + + if (lane < num_cols) partial_sums[lane][threadIdx.y] = sum; + + __syncthreads(); + + if (threadIdx.y == 0 && threadIdx.x < num_cols) { + value_type s = partial_sums[threadIdx.x][0]; + + if (blockDim.y > 1) { + for (int row = 1; row < blockDim.y; ++row) { + s = op(s, partial_sums[threadIdx.x][row]); + } + } + + out[col * gridDim.y + blockIdx.y] = s; + } +} + +// Maps each block to a column range 32 wide +template +__global__ void ColumnReduceKernel( + T in, outT out, int num_rows, int num_cols, Op op, + typename std::iterator_traits::value_type initVal) { + typedef typename std::iterator_traits::value_type value_type; + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * 32 + threadIdx.x; + + value_type sum = initVal; + if (row < num_rows && col < num_cols) + sum = in[row * num_cols + col]; + + __shared__ value_type partial_sums[32][33]; + + row += gridDim.y * blockDim.y; + + if (col < num_cols) { + for (; row < num_rows; row += gridDim.y * blockDim.y) { + sum = op(sum, in[row * num_cols + col]); + } + } + + partial_sums[threadIdx.x][threadIdx.y] = sum; + + __syncthreads(); + + if (threadIdx.y == 0 && col < num_cols) { + value_type s = partial_sums[threadIdx.x][0]; + + // only include input values in the reduction + // elem block_rows + // - = + // - = + // # # block boundary + // - = + // - = + // # # block boundary + // - = + // = + const int numRowsThisBlock = + min(blockDim.y, num_rows - blockIdx.y * blockDim.y); + + for (int row = 1; row < numRowsThisBlock; ++row) { + s = op(s, partial_sums[threadIdx.x][row]); + } + + out[col * gridDim.y + blockIdx.y] = s; + } +} + +// does multiple warp size segmented reductions in parallel +// segments cannot cross warp boundaries (mainly used for reducing the segments +// that come from the Max16Columns column reduction kernel) +template +__global__ void CleanupSegments( + T partial_sums, outT out, int num_rows, int num_cols, int segment_size, + Op op, typename std::iterator_traits::value_type initVal) { + typedef typename std::iterator_traits::value_type value_type; + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + + value_type val = initVal; + if (tid < segment_size * num_cols) + val = partial_sums[tid]; + + typedef cub::WarpReduce WarpReduce; + + __shared__ typename WarpReduce::TempStorage temp_storage; + + const bool head_flag = (threadIdx.x % segment_size) == 0; + value_type sum = + WarpReduce(temp_storage).HeadSegmentedReduce(val, head_flag, op); + + if (head_flag && tid < segment_size * num_cols) { + out[tid / segment_size] = sum; + } +} + +// assigns one thread to a column +template +__global__ void ColumnReduceSimpleKernel(T in, outT out, int num_planes, + int num_rows, int num_cols, Op op) { + typedef typename std::iterator_traits::value_type value_type; + const int gid = threadIdx.x + blockIdx.x * blockDim.x; + const int elems_per_plane = num_rows * num_cols; + + const int plane = gid / num_cols; + const int col = gid % num_cols; + + if (plane >= num_planes) return; + + if (num_rows == 1) { + out[plane * elems_per_plane + col] = in[plane * elems_per_plane + col]; + return; + } + + value_type sum = op(in[plane * elems_per_plane + col], + in[plane * elems_per_plane + num_cols + col]); + for (int row = 2; row < num_rows; ++row) { + sum = op(sum, in[plane * elems_per_plane + row * num_cols + col]); + } + + out[plane * num_cols + col] = sum; +} + +struct RowOffset { + __host__ __device__ explicit RowOffset(const int& cols) : cols_(cols) {} + + __host__ __device__ int operator()(const int& x) const { return cols_ * x; } + + int cols_; +}; + +struct GatherOp { + __host__ __device__ GatherOp(const int& extent_x, const int& extent_y, + const int& extent_z, bool kOne) + : extent_x_(extent_x), + extent_y_(extent_y), + extent_z_(extent_z), + kOne_(kOne) { + if (kOne_) + group_size_ = extent_y_; + else + group_size_ = extent_x_ * extent_z_; + } + + __host__ __device__ int operator()(const int& ind) const { + const int group = kOne_ ? ind / group_size_ : ind % group_size_; + const int offset = kOne_ ? ind % group_size_ : ind / group_size_; + + const int x = group / extent_z_; + const int z = group % extent_z_; + + return x * extent_y_ * extent_z_ + z + offset * extent_z_; + } + + int extent_x_; + int extent_y_; + int extent_z_; + bool kOne_; + int group_size_; +}; + +template +void LaunchScalarReduction(OpKernelContext* ctx, OUT_T out, IN_T in, + int in_size, Op op, T init, + const cudaStream_t& cu_stream) { + // handle situations where low latency is important better than CUB + if (in_size <= 4096) { + const int num_blocks = 1; + const int num_threads = 256; + BlockReduceKernel + <<>>(in, out, in_size, op, init); + return; + } else if (in_size <= 1 << 19) { + const int num_threads = 256; + const int num_blocks = min(32, Eigen::divup(in_size, num_threads)); + // it seems like tailoring this to the GPU + // would be more effective, but all attempts + // at making this a multiple of the number of + // multiprocessors have lead to lower perf + // in general + // TODO(eriche) investigate this more + + Tensor temp_storage; + OP_REQUIRES_OK( + ctx, + ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(num_blocks * sizeof(T))}), + &temp_storage)); + + BlockReduceKernel + <<>>( + in, (T*)temp_storage.flat().data(), in_size, op, init); + + // take care that we only reduce blocks that had some valid elements in them + // TODO(eriche): CUB currently has a bug in HeadSegmentedReduce that + // requires it to be used with a full warp. Can reduce 32 -> num_blocks + // when this is fixed. + CleanupSegments<<<1, 32, 0, cu_stream>>>( + (T*)temp_storage.flat().data(), out, 1, 1, num_blocks, op, + init); + return; + } + std::size_t temp_storage_bytes = 0; + + Tensor temp_storage; + // written as a loop because it reduces clutter + // first pass allocates memory, second launches kernel(s) + for (int i = 0; i < 2; ++i) { + auto success = cub::DeviceReduce::Reduce( + i == 0 ? nullptr : temp_storage.flat().data(), + temp_storage_bytes, in, out, in_size, op, init, cu_stream); + + OP_REQUIRES( + ctx, success == 0, + errors::Internal("CUB reduce error", cudaGetErrorString(success))); + + if (i == 0) + OP_REQUIRES_OK( + ctx, + ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + } +} + +template +void LaunchRowReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int num_rows, + int num_cols, Op op, T init, + const cudaStream_t& cu_stream) { + if (num_cols < 1024) { + const int threads_per_block = 128; + const int warps_per_block = threads_per_block / 32; + int num_blocks = (num_rows + warps_per_block - 1) / warps_per_block; + + RowReduceKernel<<>>( + in, out, num_rows, num_cols, op, init); + return; + } + + // setup segment offsets with counting and transform iterator + RowOffset row_offset_op(num_cols); + cub::CountingInputIterator counting_iter(0); + cub::TransformInputIterator> + transform_iter(counting_iter, row_offset_op); + + std::size_t temp_storage_bytes = 0; + Tensor temp_storage; + for (int i = 0; i < 2; ++i) { + auto success = cub::DeviceSegmentedReduce::Reduce( + i == 0 ? nullptr : temp_storage.flat().data(), + temp_storage_bytes, in, out, num_rows, transform_iter, + transform_iter + 1, op, init, cu_stream); + + OP_REQUIRES(ctx, success == 0, + errors::Internal("CUB segmented reduce error", + cudaGetErrorString(success))); + + if (i == 0) + OP_REQUIRES_OK( + ctx, + ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + } +} + +template +void LaunchColumnReduction_LTE16Cols(OpKernelContext* ctx, OUT_T out, IN_T in, + int extent_x, int extent_y, Op op, T init, + const cudaStream_t& cu_stream) { + int rows_per_warp = 32 / extent_y; + dim3 block_dim(32, min(Eigen::divup(extent_x, rows_per_warp), 32), 1); + dim3 grid_dim(1, + Eigen::divup(static_cast(extent_x), + rows_per_warp * block_dim.y), + 1); + + grid_dim.y = min((int)grid_dim.y, 32); + + if (grid_dim.y > 2 && grid_dim.y < 32) { + int log2 = Log2Floor(grid_dim.y); + grid_dim.y = 1 << log2; + } + + if (grid_dim.y == 1) { + ColumnReduceMax16ColumnsKernel<<>>( + in, out, extent_x, extent_y, op, init); + } else { + Tensor temp_storage; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DT_INT8, + TensorShape({static_cast( + sizeof(T) * extent_y * grid_dim.y)}), + &temp_storage)); + ColumnReduceMax16ColumnsKernel<<>>( + in, (T*)temp_storage.flat().data(), extent_x, extent_y, op, + init); + + dim3 new_grid_dim((grid_dim.y * extent_y + 31) / 32, 1, 1); + dim3 num_threads(128, 1, 1); + CleanupSegments<<>>( + (T*)temp_storage.flat().data(), out, extent_x, extent_y, + grid_dim.y, op, init); + } +} + +template +void LaunchColumnReduction_LTE4096Cols(OpKernelContext* ctx, OUT_T out, IN_T in, + int extent_x, int extent_y, Op op, + T init, const cudaStream_t& cu_stream) { + dim3 block_dim(32, min(extent_x, 32), 1); + dim3 grid_dim((extent_y + 31) / 32, 1, 1); + + if (grid_dim.x < 16) grid_dim.y = min((extent_x + 31) / 32, 32); + + if (grid_dim.y > 2 && grid_dim.y < 32) { + int log2 = Log2Floor(grid_dim.y); + grid_dim.y = 1 << log2; + } + + if (grid_dim.y == 1) { + ColumnReduceKernel<<>>( + in, out, extent_x, extent_y, op, init); + } else { + Tensor temp_storage; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DT_INT8, + TensorShape({static_cast( + sizeof(T) * extent_y * grid_dim.y)}), + &temp_storage)); + + ColumnReduceKernel<<>>( + in, (T*)temp_storage.flat().data(), extent_x, extent_y, op, + init); + + dim3 new_grid_dim((grid_dim.y * extent_y + 31) / 32, 1, 1); + dim3 num_threads(128, 1, 1); + CleanupSegments<<>>( + (T*)temp_storage.flat().data(), out, extent_x, extent_y, + grid_dim.y, op, init); + } +} + +template +void LaunchColumnReduction(OpKernelContext* ctx, OUT_T out, IN_T in, + int extent_x, int extent_y, Op op, T init, + const cudaStream_t& cu_stream) { + if (extent_y <= 16) { + LaunchColumnReduction_LTE16Cols(ctx, out, in, extent_x, extent_y, op, init, + cu_stream); + } else if (extent_y <= 4096) { + LaunchColumnReduction_LTE4096Cols(ctx, out, in, extent_x, extent_y, op, + init, cu_stream); + } else { + int threads_per_block = 128; + int num_blocks = Eigen::divup(extent_y, threads_per_block); + + ColumnReduceSimpleKernel<<>>( + in, out, 1, extent_x, extent_y, op); + } +} + +template +void Launch3DYReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x, + int extent_y, int extent_z, Op op, T init, + const cudaStream_t& cu_stream) { + int threads_per_block = 128; + int num_blocks = + (extent_x * extent_z + threads_per_block - 1) / threads_per_block; + + // TODO(eriche): this won't be very good in the case of small x + // small z and large y. + ColumnReduceSimpleKernel<<>>( + in, out, extent_x, extent_y, extent_z, op); +} + +template +void Launch3DXZReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x, + int extent_y, int extent_z, Op op, T init, + const cudaStream_t& cu_stream) { + // setup segment offsets with counting and transform iterator + RowOffset row_offset_op(extent_x * extent_z); + cub::CountingInputIterator counting_iter(0); + cub::TransformInputIterator> + transform_iter(counting_iter, row_offset_op); + + GatherOp gather_op(extent_x, extent_y, extent_z, false); + typedef cub::TransformInputIterator> + gatherIterType; + gatherIterType gather_iter(counting_iter, gather_op); + + PermutationInputIterator permute_iter(in, + gather_iter); + + std::size_t temp_storage_bytes = 0; + Tensor temp_storage; + + for (int i = 0; i < 2; ++i) { + auto success = cub::DeviceSegmentedReduce::Reduce( + i == 0 ? nullptr : temp_storage.flat().data(), + temp_storage_bytes, permute_iter, out, extent_y, transform_iter, + transform_iter + 1, op, init, cu_stream); + + OP_REQUIRES(ctx, success == 0, + errors::Internal("CUB segmented reduce error", + cudaGetErrorString(success))); + + if (i == 0) + OP_REQUIRES_OK( + ctx, + ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + } +} + +template +void ReduceImpl(OpKernelContext* ctx, OUT_T out, IN_T in, int in_rank, + int in_dim0, int in_dim1, int in_dim2, int out_rank, + const ReductionAxes& reduction_axes, Op op, T init) { + const cudaStream_t& cu_stream = GetCudaStream(ctx); + if (out_rank == 0) { + const int in_size = in_dim0 * in_dim1 * in_dim2; + LaunchScalarReduction(ctx, out, in, in_size, op, init, cu_stream); + } else if (in_rank == 2 && out_rank == 1 && + reduction_axes[0] == 1) { // row reduction + LaunchRowReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream); + } else if (in_rank == 2 && out_rank == 1 && + reduction_axes[0] == 0) { // column reduction + LaunchColumnReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream); + } else if (in_rank == 3 && out_rank == 2 && reduction_axes[0] == 1) { + Launch3DYReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init, + cu_stream); + } else if (in_rank == 3 && out_rank == 1 && reduction_axes[0] == 0 && + reduction_axes[1] == 2) { + Launch3DXZReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init, + cu_stream); + } else { + std::stringstream ss; + ss << "Invalid reduction requested: in_rank, out_rank, axes " << in_rank + << " " << out_rank; + if (out_rank == 1) ss << " " << reduction_axes[0]; + if (out_rank == 2) ss << " " << reduction_axes[1]; + LOG(FATAL) << ss.str(); + } +} + +} // namespace functor +} // namespace tensorflow + +#endif diff --git a/tensorflow/core/kernels/reduction_ops_test.cc b/tensorflow/core/kernels/reduction_ops_test.cc index 9cdebdd4f2308012cda3e6e281aac4dcfd4c060e..9bbe993a2f93e522688738abaf41a518e95ef871 100644 --- a/tensorflow/core/kernels/reduction_ops_test.cc +++ b/tensorflow/core/kernels/reduction_ops_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" @@ -22,14 +23,59 @@ namespace tensorflow { // Creates a Graph which "reduce"s a 3D float tensor of "num" elements // into a scalar. -static Graph* ToScalar(const string& reduce, int num) { - Graph* g = new Graph(OpRegistry::Global()); - Tensor data(DT_FLOAT, TensorShape({64, 64, num / (64 * 64)})); - data.flat().setRandom(); - Tensor axes(DT_INT32, TensorShape({3})); +template +static Graph* ToScalar(const string& reduce, int num_x, int num_y) { + auto* g = new Graph(OpRegistry::Global()); + Tensor data(DataTypeToEnum::value, TensorShape({num_x, num_y})); + data.flat().setRandom(); + Tensor axes(DT_INT32, TensorShape({2})); axes.flat()(0) = 0; axes.flat()(1) = 1; - axes.flat()(2) = 2; + test::graph::Reduce(g, reduce, test::graph::Constant(g, data), + test::graph::Constant(g, axes)); + return g; +} + +static Graph* ColReduce(const string& reduce, int num_x, int num_y) { + auto* g = new Graph(OpRegistry::Global()); + Tensor data(DT_FLOAT, TensorShape({num_x, num_y})); + data.flat().setRandom(); + Tensor axes(DT_INT32, TensorShape({1})); + axes.flat()(0) = 0; + test::graph::Reduce(g, reduce, test::graph::Constant(g, data), + test::graph::Constant(g, axes)); + return g; +} + +static Graph* RowReduce(const string& reduce, int num_x, int num_y) { + auto* g = new Graph(OpRegistry::Global()); + Tensor data(DT_FLOAT, TensorShape({num_x, num_y})); + data.flat().setRandom(); + Tensor axes(DT_INT32, TensorShape({1})); + axes.flat()(0) = 1; + test::graph::Reduce(g, reduce, test::graph::Constant(g, data), + test::graph::Constant(g, axes)); + return g; +} + +static Graph* ThreeDYReduce(const string& reduce, int num_y, int num_z) { + auto* g = new Graph(OpRegistry::Global()); + Tensor data(DT_FLOAT, TensorShape({4, num_y, num_z})); + data.flat().setRandom(); + Tensor axes(DT_INT32, TensorShape({1})); + axes.flat()(0) = 1; + test::graph::Reduce(g, reduce, test::graph::Constant(g, data), + test::graph::Constant(g, axes)); + return g; +} + +static Graph* ThreeDXZReduce(const string& reduce, int num_y, int num_z) { + auto* g = new Graph(OpRegistry::Global()); + Tensor data(DT_FLOAT, TensorShape({4, num_y, num_z})); + data.flat().setRandom(); + Tensor axes(DT_INT32, TensorShape({2})); + axes.flat()(0) = 0; + axes.flat()(1) = 2; test::graph::Reduce(g, reduce, test::graph::Constant(g, data), test::graph::Constant(g, axes)); return g; @@ -37,51 +83,100 @@ static Graph* ToScalar(const string& reduce, int num) { // Creates a bench which reduces a 3D tensor with total "num" floats // into a scalar on a "device". Runs the bench for "iters" times. +template static void ReduceToScalar(int iters, const string& device, - const string& reduce, int num) { - testing::ItemsProcessed(static_cast(iters) * num); - testing::BytesProcessed(static_cast(iters) * num * sizeof(float)); - test::Benchmark(device, ToScalar(reduce, num)).Run(iters); + const string& reduce, int num_x, int num_y) { + testing::ItemsProcessed(static_cast(iters) * num_x * num_y); + testing::BytesProcessed(static_cast(iters) * num_x * num_y * + sizeof(T)); + test::Benchmark(device, ToScalar(reduce, num_x, num_y)).Run(iters); +} + +static void DoRowReduce(int iters, const string& device, const string& reduce, + int num_x, int num_y) { + testing::ItemsProcessed(static_cast(iters) * num_x * num_y); + testing::BytesProcessed(static_cast(iters) * num_x * num_y * + sizeof(float)); + test::Benchmark(device, RowReduce(reduce, num_x, num_y)).Run(iters); +} + +static void DoColReduce(int iters, const string& device, const string& reduce, + int num_x, int num_y) { + testing::ItemsProcessed(static_cast(iters) * num_x * num_y); + testing::BytesProcessed(static_cast(iters) * num_x * num_y * + sizeof(float)); + test::Benchmark(device, ColReduce(reduce, num_x, num_y)).Run(iters); +} + +static void Do3DYReduce(int iters, const string& device, const string& reduce, + int num_x, int num_y) { + testing::ItemsProcessed(static_cast(iters) * num_x * num_y); + testing::BytesProcessed(static_cast(iters) * num_x * num_y * + sizeof(float)); + test::Benchmark(device, ThreeDYReduce(reduce, num_x, num_y)).Run(iters); +} + +static void Do3DXZReduce(int iters, const string& device, const string& reduce, + int num_x, int num_y) { + testing::ItemsProcessed(static_cast(iters) * num_x * num_y); + testing::BytesProcessed(static_cast(iters) * num_x * num_y * + sizeof(float)); + test::Benchmark(device, ThreeDXZReduce(reduce, num_x, num_y)).Run(iters); +} + +static void BM_Sum2DToScalarGPU(int iters, int num_x, int num_y) { + ReduceToScalar(iters, "gpu", "Sum", num_x, num_y); +} +BENCHMARK(BM_Sum2DToScalarGPU)->RangePair(1, 8192, 1, 8192); + +static void BM_Sum2DToScalarGPUComplex(int iters, int num_x, int num_y) { + ReduceToScalar>(iters, "gpu", "Sum", num_x, num_y); +} +BENCHMARK(BM_Sum2DToScalarGPUComplex)->RangePair(1, 8192, 1, 8192); + +static void BM_Sum2DToScalarGPUHalf(int iters, int num_x, int num_y) { + ReduceToScalar(iters, "gpu", "Sum", num_x, num_y); } +BENCHMARK(BM_Sum2DToScalarGPUHalf)->RangePair(1, 8192, 1, 8192); -static void BM_Sum3DToScalarCPU(int iters, int num) { - ReduceToScalar(iters, "cpu", "Sum", num); +static void BM_Sum2DRowReduceGPU(int iters, int num_x, int num_y) { + DoRowReduce(iters, "gpu", "Sum", num_x, num_y); } -BENCHMARK(BM_Sum3DToScalarCPU)->Range(1 << 13, 1 << 20); +BENCHMARK(BM_Sum2DRowReduceGPU)->RangePair(1, 8192, 1, 8192); -static void BM_Max3DToScalarCPU(int iters, int num) { - ReduceToScalar(iters, "cpu", "Max", num); +static void BM_Sum2DColumnReduceGPU(int iters, int num_x, int num_y) { + DoColReduce(iters, "gpu", "Sum", num_x, num_y); } -BENCHMARK(BM_Max3DToScalarCPU)->Range(1 << 13, 1 << 20); +BENCHMARK(BM_Sum2DColumnReduceGPU)->RangePair(1, 8192, 1, 8192); -static void BM_Prod3DToScalarCPU(int iters, int num) { - ReduceToScalar(iters, "cpu", "Prod", num); +static void BM_Sum3DYReduceGPU(int iters, int num_x, int num_y) { + Do3DYReduce(iters, "gpu", "Sum", num_x, num_y); } -BENCHMARK(BM_Prod3DToScalarCPU)->Range(1 << 13, 1 << 20); +BENCHMARK(BM_Sum3DYReduceGPU)->RangePair(64, 4096, 64, 4096); -static void BM_Mean3DToScalarCPU(int iters, int num) { - ReduceToScalar(iters, "cpu", "Mean", num); +static void BM_Sum3DXZReduceGPU(int iters, int num_x, int num_y) { + Do3DXZReduce(iters, "gpu", "Sum", num_x, num_y); } -BENCHMARK(BM_Mean3DToScalarCPU)->Range(1 << 13, 1 << 20); +BENCHMARK(BM_Sum3DXZReduceGPU)->RangePair(64, 4096, 64, 4096); -static void BM_Sum3DToScalarGPU(int iters, int num) { - ReduceToScalar(iters, "gpu", "Sum", num); +static void BM_Mean2DToScalarGPU(int iters, int num_x, int num_y) { + ReduceToScalar(iters, "gpu", "Mean", num_x, num_y); } -BENCHMARK(BM_Sum3DToScalarGPU)->Range(1 << 13, 1 << 20); +BENCHMARK(BM_Mean2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192); -static void BM_Max3DToScalarGPU(int iters, int num) { - ReduceToScalar(iters, "gpu", "Max", num); +static void BM_Max2DToScalarGPU(int iters, int num_x, int num_y) { + ReduceToScalar(iters, "gpu", "Max", num_x, num_y); } -BENCHMARK(BM_Max3DToScalarGPU)->Range(1 << 13, 1 << 20); +BENCHMARK(BM_Max2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192); -static void BM_Prod3DToScalarGPU(int iters, int num) { - ReduceToScalar(iters, "gpu", "Prod", num); +static void BM_Min2DToScalarGPU(int iters, int num_x, int num_y) { + ReduceToScalar(iters, "gpu", "Min", num_x, num_y); } -BENCHMARK(BM_Prod3DToScalarGPU)->Range(1 << 13, 1 << 20); +BENCHMARK(BM_Min2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192); -static void BM_Mean3DToScalarGPU(int iters, int num) { - ReduceToScalar(iters, "gpu", "Mean", num); +static void BM_Bool2DToScalarGPU(int iters, int num_x, int num_y) { + ReduceToScalar(iters, "gpu", "All", num_x, num_y); } -BENCHMARK(BM_Mean3DToScalarGPU)->Range(1 << 13, 1 << 20); +BENCHMARK(BM_Bool2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192); } // end namespace tensorflow diff --git a/tensorflow/core/kernels/repeat_dataset_op.cc b/tensorflow/core/kernels/repeat_dataset_op.cc index ef17fc8d7d70a0403d64e711625dc67ab8a6d0e8..6ed69ecf2ebb1b8b16ae9cd41d4172ed8ce12e05 100644 --- a/tensorflow/core/kernels/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/repeat_dataset_op.cc @@ -107,10 +107,28 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { input_impl_ = dataset()->input_->MakeIterator(prefix()); } *end_of_sequence = true; + is_exhausted_ = true; input_impl_.reset(); return Status::OK(); } + protected: + Status SaveStateInternal(OpKernelContext* ctx, + IteratorBundleWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(i_, full_name("i"))); + TF_RETURN_IF_ERROR(writer->SaveParentState(ctx, input_impl_)); + return Status::OK(); + } + + Status RestoreStateInternal(OpKernelContext* ctx, + IteratorBundleReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(reader->ReadScalar(&i_, full_name("i"))); + TF_RETURN_IF_ERROR(reader->RestoreParentState(ctx, input_impl_)); + return Status::OK(); + } + private: mutex mu_; int64 i_ GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 2100b2d49150a512a94262856a482cd3cd87e3d4..98f3718c128a6130a2bc79c192bf974adfb9e311 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -48,12 +48,13 @@ class ReadVariableOp : public OpKernel { void Compute(OpKernelContext* ctx) override { Var* variable = nullptr; ResourceHandle handle = HandleFromInput(ctx, 0); - OP_REQUIRES( - ctx, LookupResource(ctx, handle, &variable).ok(), - errors::NotFound("Attempted to read a nonexistent variable. " - "This usually means that the variable was not " - "initialized. Container: ", - handle.container(), ", name: ", handle.name())); + const auto status = LookupResource(ctx, handle, &variable); + OP_REQUIRES(ctx, status.ok(), + errors::NotFound( + "Error while reading resource variable ", handle.name(), + " from Container: ", handle.container(), + ". This could mean that the variable was not initialized. ", + status.ToString())); core::ScopedUnref s(variable); // TODO(apassos): It's possible to do copy-on-write here instead of always @@ -279,11 +280,11 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); #undef REGISTER_KERNELS #if GOOGLE_CUDA -#define REGISTER_GPU_KERNELS(type) \ - REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("dtype") \ - .HostMemory("resource"), \ +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("dtype") \ + .HostMemory("resource"), \ AssignVariableOp); TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); @@ -430,7 +431,16 @@ class ResourceGatherOp : public OpKernel { TF_CALL_ALL_TYPES(REGISTER_GATHER_CPU); TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU); +// Registers GPU kernels. +#if GOOGLE_CUDA +#define REGISTER_GATHER_GPU(type) REGISTER_GATHER_ALL_INDICES(GPU, type) + +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_GATHER_GPU); + +#endif // GOOGLE_CUDA + #undef REGISTER_GATHER_CPU +#undef REGISTER_GATHER_GPU #undef REGISTER_GATHER_ALL_INDICES #undef REGISTER_GATHER_FULL @@ -450,11 +460,12 @@ class ResourceScatterUpdateOp : public OpKernel { // Check that we have enough index space const int64 N_big = indices.NumElements(); - OP_REQUIRES(c, N_big <= std::numeric_limits::max(), - errors::InvalidArgument( - "indices has too many elements for ", - DataTypeString(DataTypeToEnum::v()), " indexing: ", - N_big, " > ", std::numeric_limits::max())); + OP_REQUIRES( + c, N_big <= std::numeric_limits::max(), + errors::InvalidArgument("indices has too many elements for ", + DataTypeString(DataTypeToEnum::v()), + " indexing: ", N_big, " > ", + std::numeric_limits::max())); const Index N = static_cast(indices.NumElements()); OP_REQUIRES( c, params->dim_size(0) <= std::numeric_limits::max(), diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index 80d490174064a366212ffe5a48681a2c48f5f42e..6b06cf650a849d3ff606b62b00f437ac9accb013 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -216,9 +216,12 @@ void RestoreTensor(OpKernelContext* context, if (output_shape.num_elements() == 0) return; -#define READER_COPY(T) \ - case DataTypeToEnum::value: \ - reader->CopySliceData(tensor_name, slice_to_load, t->flat().data()); \ +#define READER_COPY(T) \ + case DataTypeToEnum::value: \ + OP_REQUIRES(context, \ + reader->CopySliceData(tensor_name, slice_to_load, \ + t->flat().data()), \ + errors::InvalidArgument("Error copying slice data")); \ break; switch (type) { diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc index 9cdbe89457cbe25a65aff6aa655776b43cbd8b4a..5624d5cd1b1882ea29a52233ac597f92d48e46b1 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.cc +++ b/tensorflow/core/kernels/segment_reduction_ops.cc @@ -16,6 +16,9 @@ limitations under the License. // See docs in ../ops/math_ops.cc. #define EIGEN_USE_THREADS +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA #include "tensorflow/core/kernels/segment_reduction_ops.h" #include @@ -32,6 +35,14 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/util.h" +#if GOOGLE_CUDA +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/kernels/cuda_solvers.h" +#include "tensorflow/core/platform/cuda.h" + +using ::perftools::gputools::cuda::ScopedActivateExecutorContext; +#endif // GOOGLE_CUDA + namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -183,6 +194,106 @@ class SegmentReductionOp : public OpKernel { } }; +#ifdef GOOGLE_CUDA +// SegmentSumGPUOp is a segment sum operator implemented for GPU only. +// TODO: This implementation of SegmentSumGPUOp is sometimes slower than +// its unsorted counterpart (mostly when problem size is small). +// This is due to the following two main reasons and a cost-effective way +// to resolve these problems is desirable. +// 1. Sorted segment sum requires a memory transfer from device to host in +// order to know the size of the output dimension whereas unsorted segment +// sum receives the size of the output dimension as an input parameter. +// 2. Sorted segment sum is essentially a tiled version of unsorted segment +// sum and therefore such optimization comes at an inherent cost. However +// such cost may not be justified when the problem size is small. When to +// use the tiled version or the untiled version depends on many factors +// including data alignments, ratio of calculation to memory traffic and +// obviously, the problem sizes. +template +class SegmentSumGPUOp : public AsyncOpKernel { + public: + explicit SegmentSumGPUOp(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + const Tensor& input = context->input(0); + const Tensor& segment_ids = context->input(1); + + OP_REQUIRES_ASYNC( + context, TensorShapeUtils::IsVector(segment_ids.shape()), + errors::InvalidArgument("segment_ids should be a vector."), done); + + const int64 num_indices = segment_ids.NumElements(); + OP_REQUIRES_ASYNC( + context, num_indices == input.dim_size(0), + errors::InvalidArgument( + "segment_ids should be the same size as dimension 0 of" + " input."), + done); + + if (num_indices == 0) { + TensorShape output_shape = input.shape(); + output_shape.set_dim(0, 0); + + Tensor* output = nullptr; + OP_REQUIRES_OK_ASYNC( + context, context->allocate_output(0, output_shape, &output), done); + done(); + return; + } + + perftools::gputools::DeviceMemoryBase output_rows_device( + (void*)(segment_ids.template flat().data() + (num_indices - 1))); + ScratchSpace output_rows_host(context, 1, /* on_host */ true); + + auto stream = context->op_device_context()->stream(); + OP_REQUIRES_ASYNC( + context, + stream + ->ThenMemcpy(output_rows_host.mutable_data(), output_rows_device, + sizeof(Index)) + .ok(), + errors::Internal( + "SegmentSumGPUOp: failed to copy output_rows from device"), + done); + + functor::SegmentSumFunctor functor_; + auto create_and_check_output = [context, output_rows_host, &input, + &segment_ids, &functor_, done]() { + // Ensure that within the callback, the proper GPU settings are + // configured. + auto stream = context->op_device_context()->stream(); + ScopedActivateExecutorContext scoped_activation{stream->parent()}; + + Index output_rows = *output_rows_host.data(); + output_rows++; + OP_REQUIRES_ASYNC(context, output_rows > 0, + errors::InvalidArgument("segment ids must be >= 0"), + done); + + TensorShape output_shape = input.shape(); + output_shape.set_dim(0, output_rows); + + Tensor* output = nullptr; + OP_REQUIRES_OK_ASYNC( + context, context->allocate_output(0, output_shape, &output), done); + + auto output_flat = output->flat_outer_dims(); + auto data_ptr = input.template flat().data(); + auto segment_flat = segment_ids.flat(); + functor_(context, context->eigen_device(), output_rows, + segment_ids.shape(), segment_flat, input.NumElements(), data_ptr, + output_flat); + + done(); + }; + + context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute( + stream, create_and_check_output); + } +}; +#endif // GOOGLE_CUDA + #define REGISTER_CPU_KERNEL_SEGMENT(name, functor, type, index_type, \ default_value) \ REGISTER_KERNEL_BUILDER( \ @@ -227,6 +338,23 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128); #undef REGISTER_REAL_CPU_KERNELS_ALL #undef REGISTER_COMPLEX_CPU_KERNELS_ALL +#if GOOGLE_CUDA +#define REGISTER_GPU_SORTED_KERNELS(type, index_type) \ + REGISTER_KERNEL_BUILDER(Name("SegmentSum") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices"), \ + SegmentSumGPUOp) + +#define REGISTER_GPU_SORTED_KERNELS_ALL(type) \ + REGISTER_GPU_SORTED_KERNELS(type, int32); \ + REGISTER_GPU_SORTED_KERNELS(type, int64); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_SORTED_KERNELS_ALL); +#undef REGISTER_GPU_SORTED_KERNELS +#undef REGISTER_GPU_SORTED_KERNELS_ALL +#endif // GOOGLE_CUDA + namespace functor { // UnsortedSegmentSumFunctor implementation for CPUDevice. diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h index ee09c213b7cd7db9f0712e017504572ca39f1d72..412c1d601d3116b7de5ee09afe1e4f1d0253b349 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.h +++ b/tensorflow/core/kernels/segment_reduction_ops.h @@ -26,6 +26,28 @@ namespace tensorflow { class OpKernelContext; namespace functor { + +#ifdef GOOGLE_CUDA +typedef Eigen::GpuDevice GPUDevice; +// Functor for SegmentSumGPUOp. +// 'output_rows': the number of output segments (unique segment ids in +// 'segment_ids'). +// 'segment_ids_shape': shape of 'segment_ids' tensor. +// 'segment_ids': unsorted map from input to output segment ids at which to +// perform segment sum operation. +// 'data_size': size of input data tensor. +// 'data': input data tensor. +// 'output': output reshaped to {output_rows, output.size/output_rows} +template +struct SegmentSumFunctor { + void operator()(OpKernelContext* ctx, const GPUDevice& d, + const Index output_rows, const TensorShape& segment_ids_shape, + typename TTypes::ConstFlat segment_ids, + const Index data_size, const T* data, + typename TTypes::Tensor output); +}; +#endif + // BaseFunctor for definition of UnsorteSegmentReductionOp // for usage without templates. template diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc index b132b1e8f8b004ff4ad5c675488f33dcb74a6948..159fada621bd88de259e9b044491f3ecebf10b19 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc @@ -54,6 +54,77 @@ __device__ __forceinline__ void AccumulateInto( CudaAtomicAdd(dest_scalar + 1, value.imag()); } +// SortedSegmentSumFunctor kernel reduces input data just as +// UnsortedSegmentSumCustomKernel does except that input data +// is partitioned along the outer reduction dimension. This is +// because consecutive rows (elements in a row share the same +// outer dimension index) in the flattened 2D input data likely +// belong to the same segment in sorted segment sum operation. +// Therefore such partitioning strategy has two advantages over +// the UnsortedSegmentSumFunctor kernel: +// 1. Each thread reduces across multiple rows before writing +// answers to the global memory, we can therefore +// write reduction results to global memory less often. +// 2. We may know that the current thread is the only contributor +// to an output element because of the increasing nature of segment +// ids. In such cases, we do not need to use atomic operations +// to write results to global memory. +// In the flattened view of input data (with only outer and inner +// dimension), every thread processes a strip of input data of +// size OuterDimTileSize x 1. This strip runs across multiple +// rows of input data and all reduction elements share one inner +// dimension index. +template +__global__ void SortedSegmentSumCustomKernel(const Index input_outer_dim_size, + const Index inner_dim_size, + const Index output_outer_dim_size, + const Index* segment_ids, + const T* input, T* output, + const Index total_stripe_count) { + CUDA_1D_KERNEL_LOOP(stripe_index, total_stripe_count) { + const Index segment_offset = stripe_index % inner_dim_size; + const Index input_outer_dim_index_base = + stripe_index / inner_dim_size * Index(OuterDimTileSize); + + T sum = T(0); + Index first_segment_id = segment_ids[input_outer_dim_index_base]; + Index last_output_segment_id = output_outer_dim_size; + + const Index actual_stripe_height = + min(Index(OuterDimTileSize), + input_outer_dim_size - input_outer_dim_index_base); + for (Index j = 0; j < actual_stripe_height; j++) { + Index current_output_segment_id = + segment_ids[input_outer_dim_index_base + j]; + // Decide whether to write result to global memory. + // Result is only written to global memory if we move + // to another segment. Otherwise we can keep accumulating + // locally. + if (current_output_segment_id > last_output_segment_id) { + const Index output_index = + last_output_segment_id * inner_dim_size + segment_offset; + // decide whether to write result to global memory using atomic + // operations + if (last_output_segment_id == first_segment_id) { + AccumulateInto(output + output_index, sum); + } else { + *(output + output_index) = sum; + } + sum = T(0); + } + sum += ldg(input + (input_outer_dim_index_base + j) * inner_dim_size + + segment_offset); + last_output_segment_id = current_output_segment_id; + } + // For the last result in a strip, always write using atomic operations + // due to possible race conditions with threads computing + // the following strip. + const Index output_index = + last_output_segment_id * inner_dim_size + segment_offset; + AccumulateInto(output + output_index, sum); + } +} + // UnsortedSegmentSumFunctor kernel processes 'input_total_size' elements. // Each element is mapped from input to output by a combination of its // 'segment_ids' mapping and 'inner_dim_size'. @@ -80,6 +151,47 @@ __global__ void UnsortedSegmentSumCustomKernel( namespace functor { +template +void SegmentSumFunctor::operator()( + OpKernelContext* ctx, const GPUDevice& d, const Index output_rows, + const TensorShape& segment_ids_shape, + typename TTypes::ConstFlat segment_ids, const Index data_size, + const T* data, typename TTypes::Tensor output) { + if (output.size() == 0) { + return; + } + // Set 'output' to zeros. + CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d); + SetZero<<>>( + output.size(), output.data()); + if (data_size == 0 || segment_ids_shape.num_elements() == 0) { + return; + } + + // Launch kernel to compute sorted segment sum. + // Notes: + // *) 'input_total_size' is the total number of elements to process. + // *) 'segment_ids.shape' is a prefix of data's shape. + // *) 'input_outer_dim_size' is the total number of segments to process. + const Index input_total_size = data_size; + const Index input_outer_dim_size = segment_ids.dimension(0); + const Index input_inner_dim_size = input_total_size / input_outer_dim_size; + + const int OuterDimTileSize = 8; + + const Index input_outer_dim_num_stripe = + Eigen::divup(input_outer_dim_size, Index(OuterDimTileSize)); + + const Index total_stripe_count = + input_inner_dim_size * input_outer_dim_num_stripe; + + config = GetCudaLaunchConfig(total_stripe_count, d); + SortedSegmentSumCustomKernel + <<>>( + input_outer_dim_size, input_inner_dim_size, output_rows, + segment_ids.data(), data, output.data(), total_stripe_count); +}; + // UnsortedSegmentSumFunctor implementation for GPUDevice. template struct UnsortedSegmentSumFunctor: UnsortedSegmentBaseFunctor { @@ -117,6 +229,15 @@ struct UnsortedSegmentSumFunctor: UnsortedSegmentBaseFuncto } }; +#define DEFINE_SORTED_GPU_SPECS_INDEX(T, Index) \ + template struct SegmentSumFunctor + +#define DEFINE_SORTED_GPU_SPECS(T) \ + DEFINE_SORTED_GPU_SPECS_INDEX(T, int32); \ + DEFINE_SORTED_GPU_SPECS_INDEX(T, int64); + +TF_CALL_GPU_NUMBER_TYPES(DEFINE_SORTED_GPU_SPECS); + #define DEFINE_GPU_SPECS_INDEX(T, Index) \ template struct UnsortedSegmentSumFunctor diff --git a/tensorflow/core/kernels/shape_op_test.cc b/tensorflow/core/kernels/shape_op_test.cc index a305598fe2b488c854d382381e759745657758a1..96eaa4ac75b7049f963dc80bfe85ce22d1a9cb83 100644 --- a/tensorflow/core/kernels/shape_op_test.cc +++ b/tensorflow/core/kernels/shape_op_test.cc @@ -101,7 +101,7 @@ TEST_F(ShapeOpTest, Simple) { Tensor variant_tensor(DT_VARIANT, TensorShape({1})); Status s = session.Run({{input, variant_tensor}}, {shape_output}, &outputs); EXPECT_FALSE(s.ok()); - ExpectHasError(s, "Shape of non-scalar Variant not supported."); + ExpectHasError(s, "Shape of non-unary Variant not supported."); } { diff --git a/tensorflow/core/kernels/shape_ops.h b/tensorflow/core/kernels/shape_ops.h index 0c39d46aeaf053c5592bf08a2d59bb192e56cee3..ac607f4e8b8ec05e23b90b74b1dbcc8aa3f2cc2a 100644 --- a/tensorflow/core/kernels/shape_ops.h +++ b/tensorflow/core/kernels/shape_ops.h @@ -35,7 +35,7 @@ inline Status GetRegularOrVariantShape(OpKernelContext* ctx, int input_index, if (ctx->input_dtype(0) == DT_VARIANT) { if (inp.dims() != 0) { return errors::InvalidArgument( - "Shape of non-scalar Variant not supported."); + "Shape of non-unary Variant not supported."); } TF_RETURN_IF_ERROR(GetUnaryVariantShape(inp, shape)); } else { diff --git a/tensorflow/core/kernels/sloppy_interleave_dataset_op.cc b/tensorflow/core/kernels/sloppy_interleave_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d95f51f0f21094a4e25f844e5b46d6e1c1aa8b45 --- /dev/null +++ b/tensorflow/core/kernels/sloppy_interleave_dataset_op.cc @@ -0,0 +1,370 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/dataset.h" + +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/dataset_utils.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/random/random.h" + +#include "tensorflow/core/kernels/captured_function.h" + +namespace tensorflow { + +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class SloppyInterleaveDatasetOp : public UnaryDatasetOpKernel { + public: + explicit SloppyInterleaveDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + OpInputList inputs; + OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); + std::vector other_arguments; + other_arguments.reserve(inputs.size()); + for (const Tensor& t : inputs) { + other_arguments.push_back(t); + } + + int64 cycle_length; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "cycle_length", &cycle_length)); + OP_REQUIRES(ctx, cycle_length > 0, + errors::InvalidArgument("`cycle_length` must be > 0")); + + int64 block_length; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "block_length", &block_length)); + OP_REQUIRES(ctx, block_length > 0, + errors::InvalidArgument("`block_length` must be > 0")); + + std::unique_ptr captured_func; + OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, func_, graph_def_version_, + std::move(other_arguments), + &captured_func)); + + *output = new Dataset(input, std::move(captured_func), cycle_length, + block_length, output_types_, output_shapes_); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(const DatasetBase* input, + std::unique_ptr captured_func, int64 cycle_length, + int64 block_length, const DataTypeVector& output_types, + const std::vector& output_shapes) + : input_(input), + captured_func_(std::move(captured_func)), + cycle_length_(cycle_length), + block_length_(block_length), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::SloppyInterleave")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() override { + return "SloppyInterleaveDatasetOp::Dataset"; + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + input_impl_(params.dataset->input_->MakeIterator(params.prefix)), + output_elements_(params.dataset->cycle_length_) {} + + ~Iterator() override { + mutex_lock l(mu_); + cancelled_ = true; + // Notify all workers in case they are blocked. + for (int64 i = 0; i < dataset()->cycle_length_; ++i) { + output_elements_[i].cond_var.notify_all(); + } + } + + // It is implemented so that it matches the deterministic interleave + // unless we would block waiting for an element, at which point it skips + // along to the next available value. + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx)); + // Search for available items, blocking if necessary. + while (!cancelled_) { + for (size_t i = 0; i < dataset()->cycle_length_; ++i) { + size_t index = (next_index_ + i) % dataset()->cycle_length_; + if (output_elements_[index].is_produced) { + next_index_ = index; + if (i == 0) { + block_count_++; + if (block_count_ == dataset()->block_length_) { + next_index_ = (index + 1) % dataset()->cycle_length_; + block_count_ = 0; + } + } else { + block_count_ = 0; + } + // If we encounter an EoF, advance to the next iterator + if (output_elements_[index].end_of_sequence) { + output_elements_[index].is_produced = false; + output_elements_[index].cond_var.notify_one(); + next_index_ = (index + 1) % dataset()->cycle_length_; + block_count_ = 0; + i = -1; // Restart the inner loop + continue; + } + *end_of_sequence = false; + if (output_elements_[index].output_status.ok()) { + output_elements_[index].output_value.swap(*out_tensors); + } + output_elements_[index].is_produced = false; + output_elements_[index].cond_var.notify_one(); + return output_elements_[index].output_status; + } + } + + if (num_active_threads_ == 0) { + // No potential for future values. + // + // Note: this condition check must occur after checking the output + // buffer, as its possible for there to be values in the output + // buffer, even if the number of live threads is zero. + *end_of_sequence = true; + return Status::OK(); + } + // No values available; wait until woken up. + cond_var_.wait(l); + } + return errors::Cancelled( + "SloppyInterleaveDatasetOp::Dataset::Iterator::GetNext"); + } + + private: + // Internal structure to manage thread coordination. All values are + // guarded by the enclosing Iterator's mu_. + struct OutputBufferElement { + // The producer must set `is_produced` to `true` after + // `output_status` or `output_value` has been written. + bool is_produced = false; + // The producer sets `output_status` if either getting the input element + // or applying the function to it fails. + Status output_status; + // Reached end of sequence for the underlying iterator. + bool end_of_sequence = false; + // The output data element. + std::vector output_value; + // The producer thread waits on this condition variable after having + // produced an element. The reader thread notifies this condition + // variable after reading the value. + condition_variable cond_var; + }; + + Status EnsureWorkerThreadsStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (worker_threads_.empty()) { + for (int64 i = 0; i < dataset()->cycle_length_; ++i) { + // Serialize the creation of the workers and their corresponding + // input elements to ensure we match the standard interleave when + // the underlying iterators induce no delay. + std::vector args; + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &args, &end_of_input_)); + if (end_of_input_) { + LOG(WARNING) << "Input iterator exhausted after " << i + << " elements; cannot start all " + << dataset()->cycle_length_ << " worker threads."; + return Status::OK(); + } + std::unique_ptr itr; + TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( + ctx, args, i, dataset()->captured_func_.get(), prefix(), &itr)); + worker_threads_.emplace_back( + std::unique_ptr(ctx->env()->StartThread( + {}, "worker_thread", + std::bind(&Iterator::WorkerThread, this, + new IteratorContext(*ctx), i, itr.release())))); + num_active_threads_ = i + 1; + } + } + return Status::OK(); + } + + void BlockAndUpdateOutputBuffer(mutex_lock* l, const int64 thread_index, + const Status& status, + bool end_of_sequence, + std::vector* out_tensors) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // We have produced an element; push it into the output buffer + // when space is available. + while (!cancelled_ && output_elements_[thread_index].is_produced) { + output_elements_[thread_index].cond_var.wait(*l); + } + if (cancelled_) { + return; + } + output_elements_[thread_index].is_produced = true; + output_elements_[thread_index].output_status = status; + output_elements_[thread_index].end_of_sequence = end_of_sequence; + if (status.ok()) { + output_elements_[thread_index].output_value.swap(*out_tensors); + } else { + output_elements_[thread_index].output_value.clear(); + } + cond_var_.notify_one(); + } + + // Races to produce elements into the output queue buffers. + void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index, + IteratorBase* out_iterator_ptr) { + // std::function arguments are copy-constructable, so we pass raw + // pointers, and then immediately wrap them to ensure correct ownership. + std::unique_ptr ctx(ctx_ptr); + std::unique_ptr out_iterator(out_iterator_ptr); + auto cleanup = gtl::MakeCleanup([this, thread_index] { + mutex_lock l(mu_); + num_active_threads_--; + cond_var_.notify_all(); + }); + while (true) { + // Attempt to produce an element. + bool end_of_out_itr_input = false; + std::vector out_tensors; + Status element_status = out_iterator->GetNext(ctx.get(), &out_tensors, + &end_of_out_itr_input); + // Handle output. + { + mutex_lock l(mu_); + BlockAndUpdateOutputBuffer(&l, thread_index, element_status, + end_of_out_itr_input, &out_tensors); + if (end_of_out_itr_input) { + // We have exhausted our current iterator; get a new iterator; + // loop to handle errors. + while (!cancelled_) { + if (end_of_input_) { + // No more iterator inputs; we're done! + return; + } + std::vector args; + // BlockAndUpdateOutputBuffer() sequences calls to + // input_impl_->GetNext when the out_iterator doesn't cause + // slopping. + Status input_status = + input_impl_->GetNext(ctx.get(), &args, &end_of_input_); + if (end_of_input_) { + // No more elements to produce, stop the worker thread. + return; + } + if (input_status.ok()) { + input_status = dataset::MakeIteratorFromInputElement( + ctx.get(), args, thread_index, + dataset()->captured_func_.get(), prefix(), &out_iterator); + } + if (input_status.ok()) { + // Successfully have a new out_iterator; restart the outer + // loop to produce an element. + break; + } + + // We encountered an error; push the error to the output buffer. + BlockAndUpdateOutputBuffer(&l, thread_index, input_status, + /* end_of_sequence = */ false, + &out_tensors); + } + } + + // Check if we should exit. + if (cancelled_) { + return; + } + } + } + } + + // Mutex & condition variable to guard mutable iterator internals and + // coordinate among worker threads and client thread[s]. + mutex mu_; + condition_variable cond_var_; + // The iterator producing elements which are converted to datasets by + // the dataset()->captured_func_ then interleaved together. + const std::unique_ptr input_impl_ GUARDED_BY(mu_); + // Whether the input_impl_ can produce future elements. + bool end_of_input_ GUARDED_BY(mu_) = false; + // The buffer of elements to be produced. Each worker thread operates + // on a single OutputBufferElement. + std::vector output_elements_ GUARDED_BY(mu_); + // The index into output_elements_ for next element to produce. + size_t next_index_ GUARDED_BY(mu_) = 0; + // The number of items produced so far within the block + size_t block_count_ GUARDED_BY(mu_) = 0; + // Number of active threads. + size_t num_active_threads_ GUARDED_BY(mu_) = 0; + // Flag to instruct the worker threads to exit. + bool cancelled_ GUARDED_BY(mu_) = false; + // Pointers to the worker threads. This must be last to ensure the + // threads have exited before any other members are deallocated. + // TODO(b/65178177): Avoid allocating additional threads. + std::vector> worker_threads_ GUARDED_BY(mu_); + }; + + const DatasetBase* const input_; + const std::unique_ptr captured_func_; + const int64 cycle_length_; + const int64 block_length_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + }; + + const int graph_def_version_; + DataTypeVector output_types_; + std::vector output_shapes_; + const NameAttrList* func_; +}; + +REGISTER_KERNEL_BUILDER(Name("SloppyInterleaveDataset").Device(DEVICE_CPU), + SloppyInterleaveDatasetOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/platform/vmodule_benchmark_test.cc b/tensorflow/core/kernels/sql/driver_manager.cc similarity index 58% rename from tensorflow/core/platform/vmodule_benchmark_test.cc rename to tensorflow/core/kernels/sql/driver_manager.cc index 0f9e75bf9cd7b2021ccb52c2ed4b671350b721aa..9a5d5aa853c438ef4e893fac2322af17ae863fa8 100644 --- a/tensorflow/core/platform/vmodule_benchmark_test.cc +++ b/tensorflow/core/kernels/sql/driver_manager.cc @@ -12,17 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/kernels/sql/driver_manager.h" +#include "tensorflow/core/kernels/sql/sqlite_query_connection.h" namespace tensorflow { -static void BM_DisabledVlog(int iters) { - for (int i = 0; i < iters; ++i) { - VLOG(1) << "Testing VLOG(1)!"; +namespace sql { + +std::unique_ptr DriverManager::CreateQueryConnection( + const string& driver_name) { + if (driver_name == "sqlite") { + return std::unique_ptr(new SqliteQueryConnection()); + } else { // TODO(b/64276826, b/64276995) Add support for other db types. + // Change to registry pattern. + return nullptr; } } -BENCHMARK(BM_DisabledVlog); + +} // namespace sql } // namespace tensorflow diff --git a/tensorflow/core/kernels/sql/driver_manager.h b/tensorflow/core/kernels/sql/driver_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..53350268d30f4f7215eb543a28ae3fedf837ac0d --- /dev/null +++ b/tensorflow/core/kernels/sql/driver_manager.h @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SQL_DRIVER_MANAGER_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SQL_DRIVER_MANAGER_H_ + +#include "tensorflow/core/kernels/sql/query_connection.h" + +namespace tensorflow { + +namespace sql { + +// A factory class for creating `QueryConnection` instances. +class DriverManager { + public: + // A factory method for creating `QueryConnection` instances. + // + // `driver_name` is the database type (e.g. 'sqlite'). `driver_name` + // corresponds to a `QueryConnection` subclass. For example, if `driver_name` + // == `sqlite`, then `CreateQueryConnection` will create a + // `SqliteQueryConnection` instance. + static std::unique_ptr CreateQueryConnection( + const string& driver_name); +}; + +} // namespace sql + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SQL_DRIVER_MANAGER_H_ diff --git a/tensorflow/core/kernels/sql/query_connection.h b/tensorflow/core/kernels/sql/query_connection.h new file mode 100644 index 0000000000000000000000000000000000000000..f9945aee7dc6ac59df8cc9063ab5c4d9aedf4018 --- /dev/null +++ b/tensorflow/core/kernels/sql/query_connection.h @@ -0,0 +1,67 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SQL_QUERY_CONNECTION_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SQL_QUERY_CONNECTION_H_ + +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +namespace sql { +// This interface allows a user to connect to a database, execute a query, and +// iterate over the result set, putting the results into an output tensor. +// A subclass implementation is required for each type of database +// (e.g. sqlite3, mysql, etc.) +// +// Presently, a `QueryConnection` instance can only handle one query at a time. +// In a future extension, this class may be refactored so that it creates +// instances of a new class (named, say, `Statement`) which could have a +// one-to-one correspondence with queries. This would make `QueryConnection` +// more consistent with `Connection` classes of other database APIs. +// `QueryConnection` would then be renamed simply `Connection`. +// +// This class is not thread safe. Access to it is guarded by a mutex in +// `SqlDatasetOp::Dataset::Iterator`. +class QueryConnection { + public: + virtual ~QueryConnection() {} + // Opens a connection to the database named by `data_source_name`. Prepares to + // execute `query` against the database. + // + // The client must call `Close()` to release the connection resources, even + // if `Open()` fails. `Close()` must be called before making another call + // to `Open()`. + virtual Status Open(const string& data_source_name, const string& query, + const DataTypeVector& output_types) = 0; + // Closes an opened connection. + virtual Status Close() = 0; + // Retrieves the next row of the result set of the query from the most recent + // call to `Open()`. + // + // If such a row exists, then the row will be stored in `*out_tensors`, and + // `false` will be stored in `*end_of_sequence`. + // + // If there are no more rows in the result set, then instead `true` will be + // stored in `*end_of_sequence`, and the content of `*out_tensors` will be + // undefined. + virtual Status GetNext(std::vector* out_tensors, + bool* end_of_sequence) = 0; +}; + +} // namespace sql + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SQL_QUERY_CONNECTION_H_ diff --git a/tensorflow/core/kernels/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/sql/sqlite_query_connection.cc new file mode 100644 index 0000000000000000000000000000000000000000..4bcf82ae2861fa5e80ceb5d1fe6ccaa78a07d420 --- /dev/null +++ b/tensorflow/core/kernels/sql/sqlite_query_connection.cc @@ -0,0 +1,121 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/sql/sqlite_query_connection.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace tensorflow { + +namespace sql { + +// Returns a Status with the sqlite error message corresponding to the +// sqlite error number, `sqlite_err`. +static Status SqliteErrorToStatus(sqlite3* db, int sqlite_err) { + if (sqlite_err == SQLITE_OK) { + return Status::OK(); + } else { + const char* err_msg = sqlite3_errmsg(db); + // TODO(b/64276468) Be smart about the error code being returned + return errors::Unknown( + tensorflow::strings::Printf("Sqlite error: %s", err_msg)); + } +} + +SqliteQueryConnection::SqliteQueryConnection(){}; + +SqliteQueryConnection::~SqliteQueryConnection() { + Status s = Close(); + if (!s.ok()) { + LOG(WARNING) << "Failed to close query connection: " << s; + } +} + +Status SqliteQueryConnection::Open(const string& data_source_name, + const string& query, + const DataTypeVector& output_types) { + if (db_ != nullptr) { + return errors::FailedPrecondition( + "Failed to open query connection: Connection already opeend."); + } + int err = sqlite3_open(data_source_name.c_str(), &db_); + Status s = SqliteErrorToStatus(db_, err); + if (s.ok()) { + query_ = query; + output_types_ = output_types; + } + return s; +} + +Status SqliteQueryConnection::Close() { + int err = sqlite3_finalize(stmt_); + if (err != SQLITE_OK) { + return SqliteErrorToStatus(db_, err); + } + stmt_ = nullptr; + err = sqlite3_close(db_); + if (err != SQLITE_OK) { + return SqliteErrorToStatus(db_, err); + } + db_ = nullptr; + return Status::OK(); +} + +Status SqliteQueryConnection::GetNext(std::vector* out_tensors, + bool* end_of_sequence) { + if (stmt_ == nullptr) { + Status s = ExecuteQuery(); + if (!s.ok()) { + return s; + } + } + int rc = sqlite3_step(stmt_); + if (rc == SQLITE_ROW) { + for (int i = 0; i < column_count_; i++) { + // TODO(b/64276939) Support other tensorflow types. Interpret columns as + // the types that the client specifies. + Tensor tensor(cpu_allocator(), DT_STRING, {}); + string value( + reinterpret_cast(sqlite3_column_text(stmt_, i))); + tensor.scalar()() = value; + out_tensors->emplace_back(std::move(tensor)); + } + *end_of_sequence = false; + return Status::OK(); + } else if (rc == SQLITE_DONE) { + *end_of_sequence = true; + return Status::OK(); + } else { + return SqliteErrorToStatus(db_, rc); + } +} + +Status SqliteQueryConnection::ExecuteQuery() { + int err = sqlite3_prepare_v2(db_, query_.c_str(), -1, &stmt_, nullptr); + Status s = SqliteErrorToStatus(db_, err); + if (s.ok()) { + int column_count = sqlite3_column_count(stmt_); + if (column_count != output_types_.size()) { + return errors::InvalidArgument(tensorflow::strings::Printf( + "The number of columns in query (%d) must match the number of " + "elements in output_types (%zu).", + column_count, output_types_.size())); + } + column_count_ = column_count; + } + return s; +} + +} // namespace sql + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/sql/sqlite_query_connection.h b/tensorflow/core/kernels/sql/sqlite_query_connection.h new file mode 100644 index 0000000000000000000000000000000000000000..f93b203a5b75d926799cd87538317f875263a818 --- /dev/null +++ b/tensorflow/core/kernels/sql/sqlite_query_connection.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SQL_SQLITE_QUERY_CONNECTION_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SQL_SQLITE_QUERY_CONNECTION_H_ + +#include "sqlite3.h" +#include "tensorflow/core/kernels/sql/query_connection.h" + +namespace tensorflow { + +namespace sql { + +class SqliteQueryConnection : public QueryConnection { + public: + SqliteQueryConnection(); + ~SqliteQueryConnection() override; + Status Open(const string& data_source_name, const string& query, + const DataTypeVector& output_types) override; + Status Close() override; + Status GetNext(std::vector* out_tensors, + bool* end_of_sequence) override; + + private: + // Executes the query string `query_`. + Status ExecuteQuery(); + sqlite3* db_ = nullptr; + sqlite3_stmt* stmt_ = nullptr; + int column_count_ = 0; + string query_; + DataTypeVector output_types_; +}; + +} // namespace sql + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SQL_SQLITE_QUERY_CONNECTION_H_ diff --git a/tensorflow/core/kernels/sql_dataset_ops.cc b/tensorflow/core/kernels/sql_dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..d17ae53cb4ffe0da072dbfe5131cc318150cfd46 --- /dev/null +++ b/tensorflow/core/kernels/sql_dataset_ops.cc @@ -0,0 +1,154 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/dataset.h" +#include "tensorflow/core/kernels/sql/driver_manager.h" +#include "tensorflow/core/kernels/sql/query_connection.h" +#include "tensorflow/core/lib/io/inputbuffer.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace tensorflow { + +namespace { +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following ops. + +class SqlDatasetOp : public DatasetOpKernel { + public: + explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + string driver_name; + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "driver_name", &driver_name)); + + string data_source_name; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "data_source_name", + &data_source_name)); + + string query; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "query", &query)); + + // TODO(b/64276826) Change this check when we add support for other + // databases. + OP_REQUIRES(ctx, driver_name == "sqlite", + errors::InvalidArgument(tensorflow::strings::Printf( + "The database type, %s, is not supported by SqlDataset. " + "The set of supported databases is: {'sqlite'}.", + driver_name.c_str()))); + // TODO(b/64276939) Remove this check when we add support for other + // tensorflow types. + for (const DataType& dt : output_types_) { + OP_REQUIRES(ctx, dt == DataType::DT_STRING, + errors::InvalidArgument( + "Each element of `output_types_` must be DT_STRING.")); + } + for (const PartialTensorShape& pts : output_shapes_) { + OP_REQUIRES(ctx, pts.dims() == 0, + errors::InvalidArgument( + "Each element of `output_shapes_` must be a scalar.")); + } + + *output = new Dataset(driver_name, data_source_name, query, output_types_, + output_shapes_); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(const string& driver_name, const string& data_source_name, + const string& query, const DataTypeVector& output_types, + const std::vector& output_shapes) + : driver_name_(driver_name), + data_source_name_(data_source_name), + query_(query), + output_types_(output_types), + output_shapes_(output_shapes) {} + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::Sql")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() override { return "SqlDatasetOp::Dataset"; } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + ~Iterator() override { + if (query_connection_initialized_) { + Status s = query_connection_->Close(); + if (!s.ok()) { + LOG(WARNING) << "Failed to close query connection: " << s; + } + } + } + + Status GetNextInternal(IteratorContext* /*ctx*/, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (!query_connection_initialized_) { + query_connection_initialized_ = true; + query_connection_ = sql::DriverManager::CreateQueryConnection( + dataset()->driver_name_); + Status s = query_connection_->Open(dataset()->data_source_name_, + dataset()->query_, + dataset()->output_types_); + if (!s.ok()) { + LOG(WARNING) << "Failed to connect to database: " << s; + return s; + } + } + return query_connection_->GetNext(out_tensors, end_of_sequence); + } + + private: + mutex mu_; + std::unique_ptr query_connection_ GUARDED_BY(mu_); + bool query_connection_initialized_ GUARDED_BY(mu_) = false; + }; + const string driver_name_; + const string data_source_name_; + const string query_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + }; + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +REGISTER_KERNEL_BUILDER(Name("SqlDataset").Device(DEVICE_CPU), SqlDatasetOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/summary_interface.cc b/tensorflow/core/kernels/summary_interface.cc new file mode 100644 index 0000000000000000000000000000000000000000..19e0f702f9f900f3b11e84073dd72cbb39062e76 --- /dev/null +++ b/tensorflow/core/kernels/summary_interface.cc @@ -0,0 +1,432 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/summary_interface.h" +#include "tensorflow/core/lib/histogram/histogram.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/png/png_io.h" +#include "tensorflow/core/lib/wav/wav_io.h" +#include "tensorflow/core/util/event.pb.h" +#include "tensorflow/core/util/events_writer.h" + +namespace tensorflow { +namespace { +template +Status TensorValueAt(Tensor t, int index, T* out) { + switch (t.dtype()) { + case DT_FLOAT: + *out = t.flat()(index); + break; + case DT_DOUBLE: + *out = t.flat()(index); + break; + case DT_HALF: + *out = T(t.flat()(index)); + break; + case DT_INT32: + *out = t.flat()(index); + break; + case DT_UINT8: + *out = t.flat()(index); + break; + case DT_INT16: + *out = t.flat()(index); + break; + case DT_INT8: + *out = t.flat()(index); + break; + case DT_BOOL: + *out = t.flat()(index); + break; + case DT_INT64: + *out = t.flat()(index); + break; + default: + return errors::Unimplemented("Scalar summary for dtype ", + DataTypeString(t.dtype()), + " is not supported."); + } + return Status::OK(); +} + +typedef Eigen::Tensor Uint8Image; + +// Add the sequence of images specified by ith_image to the summary. +// +// Factoring this loop out into a helper function lets ith_image behave +// differently in the float and uint8 cases: the float case needs a temporary +// buffer which can be shared across calls to ith_image, but the uint8 case +// does not. +Status AddImages(const string& tag, int max_images, int batch_size, int w, + int h, int depth, + const std::function& ith_image, Summary* s) { + const int N = std::min(max_images, batch_size); + for (int i = 0; i < N; ++i) { + Summary::Value* v = s->add_value(); + // The tag depends on the number of requested images (not the number + // produced.) + // + // Note that later on avisu uses "/" to figure out a consistent naming + // convention for display, so we append "/image" to guarantee that the + // image(s) won't be displayed in the global scope with no name. + if (max_images > 1) { + v->set_tag(strings::StrCat(tag, "/image/", i)); + } else { + v->set_tag(strings::StrCat(tag, "/image")); + } + + auto image = ith_image(i); + Summary::Image* si = v->mutable_image(); + si->set_height(h); + si->set_width(w); + si->set_colorspace(depth); + const int channel_bits = 8; + const int compression = -1; // Use zlib default + if (!png::WriteImageToBuffer(image.data(), w, h, w * depth, depth, + channel_bits, compression, + si->mutable_encoded_image_string(), nullptr)) { + return errors::Internal("PNG encoding failed"); + } + } + return Status::OK(); +} + +template +void NormalizeFloatImage(int hw, int depth, + typename TTypes::ConstMatrix values, + typename TTypes::ConstVec bad_color, + Uint8Image* image) { + if (!image->size()) return; // Nothing to do for empty images + + // Rescale the image to uint8 range. + // + // We are trying to generate an RGB image from a float/half tensor. We do + // not have any info about the expected range of values in the tensor + // but the generated image needs to have all RGB values within [0, 255]. + // + // We use two different algorithms to generate these values. If the + // tensor has only positive values we scale them all by 255/max(values). + // If the tensor has both negative and positive values we scale them by + // the max of their absolute values and center them around 127. + // + // This works for most cases, but does not respect the relative dynamic + // range across different instances of the tensor. + + // Compute min and max ignoring nonfinite pixels + float image_min = std::numeric_limits::infinity(); + float image_max = -image_min; + for (int i = 0; i < hw; i++) { + bool finite = true; + for (int j = 0; j < depth; j++) { + if (!Eigen::numext::isfinite(values(i, j))) { + finite = false; + break; + } + } + if (finite) { + for (int j = 0; j < depth; j++) { + float value(values(i, j)); + image_min = std::min(image_min, value); + image_max = std::max(image_max, value); + } + } + } + + // Pick an affine transform into uint8 + const float kZeroThreshold = 1e-6; + T scale, offset; + if (image_min < 0) { + float max_val = std::max(std::abs(image_min), std::abs(image_max)); + scale = T(max_val < kZeroThreshold ? 0.0f : 127.0f / max_val); + offset = T(128.0f); + } else { + scale = T(image_max < kZeroThreshold ? 0.0f : 255.0f / image_max); + offset = T(0.0f); + } + + // Transform image, turning nonfinite values to bad_color + for (int i = 0; i < hw; i++) { + bool finite = true; + for (int j = 0; j < depth; j++) { + if (!Eigen::numext::isfinite(values(i, j))) { + finite = false; + break; + } + } + if (finite) { + image->chip<0>(i) = + (values.template chip<0>(i) * scale + offset).template cast(); + } else { + image->chip<0>(i) = bad_color; + } + } +} + +template +Status NormalizeAndAddImages(const Tensor& tensor, int max_images, int h, int w, + int hw, int depth, int batch_size, + const string& base_tag, Tensor bad_color_tensor, + Summary* s) { + // For float and half images, nans and infs are replaced with bad_color. + if (bad_color_tensor.dim_size(0) < depth) { + return errors::InvalidArgument( + "expected depth <= bad_color.size, got depth = ", depth, + ", bad_color.size = ", bad_color_tensor.dim_size(0)); + } + auto bad_color_full = bad_color_tensor.vec(); + typename TTypes::ConstVec bad_color(bad_color_full.data(), depth); + + // Float images must be scaled and translated. + Uint8Image image(hw, depth); + auto ith_image = [&tensor, &image, bad_color, batch_size, hw, depth](int i) { + auto tensor_eigen = tensor.template shaped({batch_size, hw, depth}); + typename TTypes::ConstMatrix values( + &tensor_eigen(i, 0, 0), Eigen::DSizes(hw, depth)); + NormalizeFloatImage(hw, depth, values, bad_color, &image); + return image; + }; + return AddImages(base_tag, max_images, batch_size, w, h, depth, ith_image, s); +} + +} // namespace + +class SummaryWriterImpl : public SummaryWriterInterface { + public: + SummaryWriterImpl(int max_queue, int flush_millis) + : SummaryWriterInterface(), + max_queue_(max_queue), + flush_millis_(flush_millis) {} + + Status Initialize(const string& logdir, const string& filename_suffix, + Env* env) { + Status is_dir = env->IsDirectory(logdir); + if (!is_dir.ok()) { + if (is_dir.code() != tensorflow::error::NOT_FOUND) { + return is_dir; + } + TF_RETURN_IF_ERROR(env->CreateDir(logdir)); + } + mutex_lock ml(mu_); + events_writer_ = + xla::MakeUnique(io::JoinPath(logdir, "events")); + if (!events_writer_->InitWithSuffix(filename_suffix)) { + return errors::Unknown("Could not initialize events writer."); + } + last_flush_ = Env::Default()->NowMicros(); + return Status::OK(); + } + + Status Flush() override { + mutex_lock ml(mu_); + return InternalFlush(); + } + + ~SummaryWriterImpl() override { + (void)Flush(); // Ignore errors. + } + + Status WriteTensor(int64 global_step, Tensor t, const string& tag, + const string& serialized_metadata) override { + Summary s; + Summary::Value* v = s.add_value(); + t.AsProtoTensorContent(v->mutable_tensor()); + v->set_tag(tag); + v->mutable_metadata()->ParseFromString(serialized_metadata); + return Enqueue(global_step, s); + } + + Status WriteScalar(int64 global_step, Tensor t, const string& tag) override { + Summary s; + Summary::Value* v = s.add_value(); + v->set_tag(tag); + float value; + TF_RETURN_IF_ERROR(TensorValueAt(t, 0, &value)); + v->set_simple_value(value); + return Enqueue(global_step, s); + } + + Status WriteHistogram(int64 global_step, Tensor t, + const string& tag) override { + Summary s; + Summary::Value* v = s.add_value(); + v->set_tag(tag); + histogram::Histogram histo; + for (int64 i = 0; i < t.NumElements(); i++) { + double double_val; + TF_RETURN_IF_ERROR(TensorValueAt(t, i, &double_val)); + if (Eigen::numext::isnan(double_val)) { + return errors::InvalidArgument("Nan in summary histogram for: ", tag); + } else if (Eigen::numext::isinf(double_val)) { + return errors::InvalidArgument("Infinity in summary histogram for: ", + tag); + } + histo.Add(double_val); + } + + histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */); + return Enqueue(global_step, s); + } + + Status WriteImage(int64 global_step, Tensor tensor, const string& tag, + int max_images, Tensor bad_color) override { + if (!(tensor.dims() == 4 && + (tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 || + tensor.dim_size(3) == 4))) { + return errors::InvalidArgument( + "Tensor must be 4-D with last dim 1, 3, or 4, not ", + tensor.shape().DebugString()); + } + if (!(tensor.dim_size(0) < (1LL << 31) && + tensor.dim_size(1) < (1LL << 31) && + tensor.dim_size(2) < (1LL << 31) && + (tensor.dim_size(1) * tensor.dim_size(2)) < (1LL << 29))) { + return errors::InvalidArgument("Tensor too large for summary ", + tensor.shape().DebugString()); + } + Summary s; + // The casts and h * w cannot overflow because of the limits above. + const int batch_size = static_cast(tensor.dim_size(0)); + const int h = static_cast(tensor.dim_size(1)); + const int w = static_cast(tensor.dim_size(2)); + const int hw = h * w; // Compact these two dims for simplicity + const int depth = static_cast(tensor.dim_size(3)); + if (tensor.dtype() == DT_UINT8) { + // For uint8 input, no normalization is necessary + auto ith_image = [&tensor, batch_size, hw, depth](int i) { + auto values = tensor.shaped({batch_size, hw, depth}); + return typename TTypes::ConstMatrix( + &values(i, 0, 0), Eigen::DSizes(hw, depth)); + }; + TF_RETURN_IF_ERROR( + AddImages(tag, max_images, batch_size, w, h, depth, ith_image, &s)); + } else if (tensor.dtype() == DT_HALF) { + TF_RETURN_IF_ERROR(NormalizeAndAddImages( + tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, &s)); + } else if (tensor.dtype() == DT_FLOAT) { + TF_RETURN_IF_ERROR(NormalizeAndAddImages( + tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, &s)); + } else { + return errors::InvalidArgument( + "Only DT_INT8, DT_HALF, and DT_FLOAT images are supported. Got ", + DataTypeString(tensor.dtype())); + } + + return Enqueue(global_step, s); + } + + Status WriteAudio(int64 global_step, Tensor tensor, const string& tag, + int max_outputs, float sample_rate) override { + if (sample_rate <= 0.0f) { + return errors::InvalidArgument("sample_rate must be > 0"); + } + const int batch_size = tensor.dim_size(0); + const int64 length_frames = tensor.dim_size(1); + const int64 num_channels = + tensor.dims() == 2 ? 1 : tensor.dim_size(tensor.dims() - 1); + Summary s; + const int N = std::min(max_outputs, batch_size); + for (int i = 0; i < N; ++i) { + Summary::Value* v = s.add_value(); + if (max_outputs > 1) { + v->set_tag(strings::StrCat(tag, "/audio/", i)); + } else { + v->set_tag(strings::StrCat(tag, "/audio")); + } + + Summary::Audio* sa = v->mutable_audio(); + sa->set_sample_rate(sample_rate); + sa->set_num_channels(num_channels); + sa->set_length_frames(length_frames); + sa->set_content_type("audio/wav"); + + auto values = + tensor.shaped({batch_size, length_frames, num_channels}); + auto channels_by_frames = typename TTypes::ConstMatrix( + &values(i, 0, 0), + Eigen::DSizes(length_frames, num_channels)); + size_t sample_rate_truncated = lrintf(sample_rate); + if (sample_rate_truncated == 0) { + sample_rate_truncated = 1; + } + TF_RETURN_IF_ERROR(wav::EncodeAudioAsS16LEWav( + channels_by_frames.data(), sample_rate_truncated, num_channels, + length_frames, sa->mutable_encoded_audio_string())); + } + + return Enqueue(global_step, s); + } + + string DebugString() override { return "SummaryWriterImpl"; } + + private: + Status Enqueue(int64 global_step, const Summary& summary) { + mutex_lock ml(mu_); + queue_.emplace_back(global_step, summary, Env::Default()->NowMicros()); + if (queue_.size() >= max_queue_ || + Env::Default()->NowMicros() - last_flush_ > 1000 * flush_millis_) { + return InternalFlush(); + } + return Status::OK(); + } + + Status InternalFlush() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + for (const EventInfo& e : queue_) { + Event event; + event.set_step(std::get<0>(e)); + *event.mutable_summary() = std::get<1>(e); + event.set_wall_time(std::get<2>(e)); + events_writer_->WriteEvent(event); + } + queue_.clear(); + if (!events_writer_->Flush()) { + return errors::InvalidArgument("Could not flush events file."); + } + last_flush_ = Env::Default()->NowMicros(); + return Status::OK(); + } + + const int max_queue_; + const int flush_millis_; + uint64 last_flush_; + using EventInfo = std::tuple; + mutex mu_; + std::vector queue_ GUARDED_BY(mu_); + // A pointer to allow deferred construction. + std::unique_ptr events_writer_ GUARDED_BY(mu_); + std::vector> registered_summaries_ + GUARDED_BY(mu_); +}; + +Status CreateSummaryWriter(int max_queue, int flush_millis, + const string& logdir, const string& filename_suffix, + Env* env, SummaryWriterInterface** result) { + SummaryWriterImpl* w = new SummaryWriterImpl(max_queue, flush_millis); + Status s = w->Initialize(logdir, filename_suffix, env); + if (!s.ok()) { + w->Unref(); + *result = nullptr; + return s; + } + *result = w; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/summary_interface.h b/tensorflow/core/kernels/summary_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..ae2fbb70fe3580bdd1d4f4f34a487b33f5a6a9c2 --- /dev/null +++ b/tensorflow/core/kernels/summary_interface.h @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_SUMMARY_INTERFACE_H_ +#define TENSORFLOW_CORE_KERNELS_SUMMARY_INTERFACE_H_ + + +#include "tensorflow/core/framework/resource_mgr.h" + +namespace tensorflow { + +// Main interface for the summary writer resource. +class SummaryWriterInterface : public ResourceBase { + public: + virtual ~SummaryWriterInterface() override {} + + // Flushes all unwritten messages in the queue. + virtual Status Flush() = 0; + + // These are called in the OpKernel::Compute methods for the summary ops. + virtual Status WriteTensor(int64 global_step, Tensor t, const string& tag, + const string& serialized_metadata) = 0; + + virtual Status WriteScalar(int64 global_step, Tensor t, + const string& tag) = 0; + + virtual Status WriteHistogram(int64 global_step, Tensor t, + const string& tag) = 0; + + virtual Status WriteImage(int64 global_step, Tensor t, const string& tag, + int max_images, Tensor bad_color) = 0; + + virtual Status WriteAudio(int64 global_step, Tensor t, const string& tag, + int max_outputs_, float sample_rate) = 0; +}; + +// Creates a SummaryWriterInterface instance which writes to a file. It will +// enqueue up to max_queue summaries, and flush at least every flush_millis +// milliseconds. The summaries will be written to the directory specified by +// logdir and with the filename suffixed by filename_suffix. The caller owns a +// reference to result if the returned status is ok. +Status CreateSummaryWriter(int max_queue, int flush_millis, + const string& logdir, const string& filename_suffix, + Env* env, SummaryWriterInterface** result); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SUMMARY_INTERFACE_H_ diff --git a/tensorflow/core/kernels/summary_interface_test.cc b/tensorflow/core/kernels/summary_interface_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e24e8122a0760980ffed69790c482175f4623e3 --- /dev/null +++ b/tensorflow/core/kernels/summary_interface_test.cc @@ -0,0 +1,170 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/kernels/summary_interface.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/event.pb.h" + +namespace tensorflow { +namespace { + +Status SummaryTestHelper( + const string& test_name, + std::function writer_fn, + std::function test_fn) { + static std::set* tests = new std::set(); + CHECK(tests->insert(test_name).second) << ": " << test_name; + + SummaryWriterInterface* writer; + Env* env = Env::Default(); + TF_CHECK_OK( + CreateSummaryWriter(1, 1, testing::TmpDir(), test_name, env, &writer)); + core::ScopedUnref deleter(writer); + + TF_CHECK_OK(writer_fn(writer)); + TF_CHECK_OK(writer->Flush()); + + std::vector files; + TF_CHECK_OK(env->GetChildren(testing::TmpDir(), &files)); + bool found = false; + for (const string& f : files) { + if (StringPiece(f).contains(test_name)) { + if (found) { + return errors::Unknown("Found more than one file for ", test_name); + } + found = true; + std::unique_ptr read_file; + TF_CHECK_OK(env->NewRandomAccessFile(io::JoinPath(testing::TmpDir(), f), + &read_file)); + io::RecordReader reader(read_file.get(), io::RecordReaderOptions()); + string record; + uint64 offset = 0; + TF_CHECK_OK(reader.ReadRecord(&offset, + &record)); // The first event is irrelevant + TF_CHECK_OK(reader.ReadRecord(&offset, &record)); + Event e; + e.ParseFromString(record); + test_fn(e); + } + } + if (!found) { + return errors::Unknown("Found no file for ", test_name); + } + return Status::OK(); +} + +TEST(SummaryInterfaceTest, WriteTensor) { + TF_CHECK_OK(SummaryTestHelper("tensor_test", + [](SummaryWriterInterface* writer) { + Tensor one(DT_FLOAT, TensorShape({})); + one.scalar()() = 1.0; + TF_RETURN_IF_ERROR(writer->WriteTensor( + 2, one, "name", + SummaryMetadata().SerializeAsString())); + TF_RETURN_IF_ERROR(writer->Flush()); + return Status::OK(); + }, + [](const Event& e) { + EXPECT_EQ(e.step(), 2); + CHECK_EQ(e.summary().value_size(), 1); + EXPECT_EQ(e.summary().value(0).tag(), "name"); + })); +} + +TEST(SummaryInterfaceTest, WriteScalar) { + TF_CHECK_OK(SummaryTestHelper( + "scalar_test", + [](SummaryWriterInterface* writer) { + Tensor one(DT_FLOAT, TensorShape({})); + one.scalar()() = 1.0; + TF_RETURN_IF_ERROR(writer->WriteScalar(2, one, "name")); + TF_RETURN_IF_ERROR(writer->Flush()); + return Status::OK(); + }, + [](const Event& e) { + EXPECT_EQ(e.step(), 2); + CHECK_EQ(e.summary().value_size(), 1); + EXPECT_EQ(e.summary().value(0).tag(), "name"); + EXPECT_EQ(e.summary().value(0).simple_value(), 1.0); + })); +} + +TEST(SummaryInterfaceTest, WriteHistogram) { + TF_CHECK_OK(SummaryTestHelper("hist_test", + [](SummaryWriterInterface* writer) { + Tensor one(DT_FLOAT, TensorShape({})); + one.scalar()() = 1.0; + TF_RETURN_IF_ERROR( + writer->WriteHistogram(2, one, "name")); + TF_RETURN_IF_ERROR(writer->Flush()); + return Status::OK(); + }, + [](const Event& e) { + EXPECT_EQ(e.step(), 2); + CHECK_EQ(e.summary().value_size(), 1); + EXPECT_EQ(e.summary().value(0).tag(), "name"); + EXPECT_TRUE(e.summary().value(0).has_histo()); + })); +} + +TEST(SummaryInterfaceTest, WriteImage) { + TF_CHECK_OK(SummaryTestHelper( + "image_test", + [](SummaryWriterInterface* writer) { + Tensor one(DT_UINT8, TensorShape({1, 1, 1, 1})); + one.scalar()() = 1; + TF_RETURN_IF_ERROR(writer->WriteImage(2, one, "name", 1, Tensor())); + TF_RETURN_IF_ERROR(writer->Flush()); + return Status::OK(); + }, + [](const Event& e) { + EXPECT_EQ(e.step(), 2); + CHECK_EQ(e.summary().value_size(), 1); + EXPECT_EQ(e.summary().value(0).tag(), "name/image"); + CHECK(e.summary().value(0).has_image()); + EXPECT_EQ(e.summary().value(0).image().height(), 1); + EXPECT_EQ(e.summary().value(0).image().width(), 1); + EXPECT_EQ(e.summary().value(0).image().colorspace(), 1); + })); +} + +TEST(SummaryInterfaceTest, WriteAudio) { + TF_CHECK_OK(SummaryTestHelper( + "audio_test", + [](SummaryWriterInterface* writer) { + Tensor one(DT_FLOAT, TensorShape({1, 1})); + one.scalar()() = 1.0; + TF_RETURN_IF_ERROR(writer->WriteAudio(2, one, "name", 1, 1)); + TF_RETURN_IF_ERROR(writer->Flush()); + return Status::OK(); + }, + [](const Event& e) { + EXPECT_EQ(e.step(), 2); + CHECK_EQ(e.summary().value_size(), 1); + EXPECT_EQ(e.summary().value(0).tag(), "name/audio"); + CHECK(e.summary().value(0).has_audio()); + })); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc new file mode 100644 index 0000000000000000000000000000000000000000..cfa707de715ba41ad4f5eb2ab1732324bb1c222c --- /dev/null +++ b/tensorflow/core/kernels/summary_kernels.cc @@ -0,0 +1,221 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/kernels/summary_interface.h" + +namespace tensorflow { + +REGISTER_KERNEL_BUILDER(Name("SummaryWriter").Device(DEVICE_CPU), + ResourceHandleOp); + +class CreateSummaryFileWriterOp : public OpKernel { + public: + explicit CreateSummaryFileWriterOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor* tmp; + OP_REQUIRES_OK(ctx, ctx->input("logdir", &tmp)); + const string logdir = tmp->scalar()(); + OP_REQUIRES_OK(ctx, ctx->input("max_queue", &tmp)); + const int32 max_queue = tmp->scalar()(); + OP_REQUIRES_OK(ctx, ctx->input("flush_millis", &tmp)); + const int32 flush_millis = tmp->scalar()(); + OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp)); + const string filename_suffix = tmp->scalar()(); + SummaryWriterInterface* s; + OP_REQUIRES_OK(ctx, CreateSummaryWriter(max_queue, flush_millis, logdir, + filename_suffix, ctx->env(), &s)); + OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s)); + } +}; +REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU), + CreateSummaryFileWriterOp); + +class FlushSummaryWriterOp : public OpKernel { + public: + explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + SummaryWriterInterface* s; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); + core::ScopedUnref unref(s); + OP_REQUIRES_OK(ctx, s->Flush()); + } +}; +REGISTER_KERNEL_BUILDER(Name("FlushSummaryWriter").Device(DEVICE_CPU), + FlushSummaryWriterOp); + +class CloseSummaryWriterOp : public OpKernel { + public: + explicit CloseSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + OP_REQUIRES_OK(ctx, DeleteResource( + ctx, HandleFromInput(ctx, 0))); + } +}; +REGISTER_KERNEL_BUILDER(Name("CloseSummaryWriter").Device(DEVICE_CPU), + CloseSummaryWriterOp); + +class WriteSummaryOp : public OpKernel { + public: + explicit WriteSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + SummaryWriterInterface* s; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); + core::ScopedUnref unref(s); + const Tensor* tmp; + OP_REQUIRES_OK(ctx, ctx->input("global_step", &tmp)); + const int64 global_step = tmp->scalar()(); + OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); + const string& tag = tmp->scalar()(); + OP_REQUIRES_OK(ctx, ctx->input("summary_metadata", &tmp)); + const string& serialized_metadata = tmp->scalar()(); + + const Tensor* t; + OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); + + OP_REQUIRES_OK(ctx, + s->WriteTensor(global_step, *t, tag, serialized_metadata)); + } +}; +REGISTER_KERNEL_BUILDER(Name("WriteSummary").Device(DEVICE_CPU), + WriteSummaryOp); + +class WriteScalarSummaryOp : public OpKernel { + public: + explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + SummaryWriterInterface* s; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); + core::ScopedUnref unref(s); + const Tensor* tmp; + OP_REQUIRES_OK(ctx, ctx->input("global_step", &tmp)); + const int64 global_step = tmp->scalar()(); + OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); + const string& tag = tmp->scalar()(); + + const Tensor* t; + OP_REQUIRES_OK(ctx, ctx->input("value", &t)); + + OP_REQUIRES_OK(ctx, s->WriteScalar(global_step, *t, tag)); + } +}; +REGISTER_KERNEL_BUILDER(Name("WriteScalarSummary").Device(DEVICE_CPU), + WriteScalarSummaryOp); + +class WriteHistogramSummaryOp : public OpKernel { + public: + explicit WriteHistogramSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + SummaryWriterInterface* s; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); + core::ScopedUnref unref(s); + const Tensor* tmp; + OP_REQUIRES_OK(ctx, ctx->input("global_step", &tmp)); + const int64 global_step = tmp->scalar()(); + OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); + const string& tag = tmp->scalar()(); + + const Tensor* t; + OP_REQUIRES_OK(ctx, ctx->input("values", &t)); + + OP_REQUIRES_OK(ctx, s->WriteHistogram(global_step, *t, tag)); + } +}; +REGISTER_KERNEL_BUILDER(Name("WriteHistogramSummary").Device(DEVICE_CPU), + WriteHistogramSummaryOp); + +class WriteImageSummaryOp : public OpKernel { + public: + explicit WriteImageSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + int64 max_images_tmp; + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_images", &max_images_tmp)); + OP_REQUIRES(ctx, max_images_tmp < (1LL << 31), + errors::InvalidArgument("max_images must be < 2^31")); + max_images_ = static_cast(max_images_tmp); + } + + void Compute(OpKernelContext* ctx) override { + SummaryWriterInterface* s; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); + core::ScopedUnref unref(s); + const Tensor* tmp; + OP_REQUIRES_OK(ctx, ctx->input("global_step", &tmp)); + const int64 global_step = tmp->scalar()(); + OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); + const string& tag = tmp->scalar()(); + const Tensor* bad_color; + OP_REQUIRES_OK(ctx, ctx->input("bad_color", &bad_color)); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(bad_color->shape()), + errors::InvalidArgument("bad_color must be a vector, got shape ", + bad_color->shape().DebugString())); + + const Tensor* t; + OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); + + OP_REQUIRES_OK( + ctx, s->WriteImage(global_step, *t, tag, max_images_, *bad_color)); + } + + private: + int32 max_images_; +}; +REGISTER_KERNEL_BUILDER(Name("WriteImageSummary").Device(DEVICE_CPU), + WriteImageSummaryOp); + +class WriteAudioSummaryOp : public OpKernel { + public: + explicit WriteAudioSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_outputs", &max_outputs_)); + OP_REQUIRES(ctx, max_outputs_ > 0, + errors::InvalidArgument("max_outputs must be > 0")); + } + + void Compute(OpKernelContext* ctx) override { + SummaryWriterInterface* s; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); + core::ScopedUnref unref(s); + const Tensor* tmp; + OP_REQUIRES_OK(ctx, ctx->input("global_step", &tmp)); + const int64 global_step = tmp->scalar()(); + OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp)); + const string& tag = tmp->scalar()(); + OP_REQUIRES_OK(ctx, ctx->input("sample_rate", &tmp)); + const float sample_rate = tmp->scalar()(); + + const Tensor* t; + OP_REQUIRES_OK(ctx, ctx->input("tensor", &t)); + + OP_REQUIRES_OK( + ctx, s->WriteAudio(global_step, *t, tag, max_outputs_, sample_rate)); + } + + private: + int max_outputs_; + bool has_sample_rate_attr_; + float sample_rate_attr_; +}; +REGISTER_KERNEL_BUILDER(Name("WriteAudioSummary").Device(DEVICE_CPU), + WriteAudioSummaryOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc index b14e555103946a90b801c4d3d643db1abea491f4..36b8ff09d7381a0b8bbb8b6f8d71b14e47fa4663 100644 --- a/tensorflow/core/kernels/variable_ops.cc +++ b/tensorflow/core/kernels/variable_ops.cc @@ -83,7 +83,6 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL); IsVariableInitializedOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); -TF_CALL_bool(REGISTER_GPU_KERNELS) #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA diff --git a/tensorflow/core/lib/io/buffered_inputstream.cc b/tensorflow/core/lib/io/buffered_inputstream.cc index 6f72da47131692130844c1a11d4eb7f1092dc441..b247e9c5756f85af45e163ca41e6f127f5e1be6f 100644 --- a/tensorflow/core/lib/io/buffered_inputstream.cc +++ b/tensorflow/core/lib/io/buffered_inputstream.cc @@ -41,9 +41,18 @@ BufferedInputStream::~BufferedInputStream() { } Status BufferedInputStream::FillBuffer() { + if (!file_status_.ok()) { + pos_ = 0; + limit_ = 0; + return file_status_; + } Status s = input_stream_->ReadNBytes(size_, &buf_); pos_ = 0; limit_ = buf_.size(); + if (buf_.empty()) { + DCHECK(!s.ok()); + file_status_ = s; + } return s; } @@ -82,6 +91,9 @@ Status BufferedInputStream::ReadNBytes(int64 bytes_to_read, string* result) { bytes_to_read); } result->clear(); + if (!file_status_.ok() && bytes_to_read > 0) { + return file_status_; + } result->reserve(bytes_to_read); Status s; @@ -91,6 +103,8 @@ Status BufferedInputStream::ReadNBytes(int64 bytes_to_read, string* result) { s = FillBuffer(); // If we didn't read any bytes, we're at the end of the file; break out. if (limit_ == 0) { + DCHECK(!s.ok()); + file_status_ = s; break; } } @@ -124,6 +138,9 @@ Status BufferedInputStream::SkipNBytes(int64 bytes_to_skip) { Status s = input_stream_->SkipNBytes(bytes_to_skip - (limit_ - pos_)); pos_ = 0; limit_ = 0; + if (errors::IsOutOfRange(s)) { + file_status_ = s; + } return s; } return Status::OK(); @@ -163,6 +180,7 @@ Status BufferedInputStream::ReadAll(string* result) { } if (errors::IsOutOfRange(status)) { + file_status_ = status; return Status::OK(); } return status; @@ -172,6 +190,7 @@ Status BufferedInputStream::Reset() { TF_RETURN_IF_ERROR(input_stream_->Reset()); pos_ = 0; limit_ = 0; + file_status_ = Status::OK(); return Status::OK(); } diff --git a/tensorflow/core/lib/io/buffered_inputstream.h b/tensorflow/core/lib/io/buffered_inputstream.h index b37766005a920645c604330fbf792f69df889132..2b824f35f80de47f951477a9352bedeca1290848 100644 --- a/tensorflow/core/lib/io/buffered_inputstream.h +++ b/tensorflow/core/lib/io/buffered_inputstream.h @@ -94,6 +94,9 @@ class BufferedInputStream : public InputStreamInterface { size_t pos_ = 0; // current position in buf_. size_t limit_ = 0; // just past the end of valid data in buf_. bool owns_input_stream_ = false; + // When EoF is reached, file_status_ contains the status to skip unnecessary + // buffer allocations. + Status file_status_ = Status::OK(); TF_DISALLOW_COPY_AND_ASSIGN(BufferedInputStream); }; diff --git a/tensorflow/core/lib/io/buffered_inputstream_test.cc b/tensorflow/core/lib/io/buffered_inputstream_test.cc index 7265101e1bef402a655192aeac111375aba4b51a..49b2b1a861ab2d18f23f80715f12d9182f0190c8 100644 --- a/tensorflow/core/lib/io/buffered_inputstream_test.cc +++ b/tensorflow/core/lib/io/buffered_inputstream_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/io/random_inputstream.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace tensorflow { namespace io { @@ -362,6 +363,45 @@ TEST(BufferedInputStream, ReadAll_Text) { } } +void BM_BufferedReaderSmallReads(const int iters, const int buff_size, + const int file_size) { + testing::StopTiming(); + Env* env = Env::Default(); + string fname = testing::TmpDir() + "/buffered_inputstream_test"; + + const string file_elem = "0123456789"; + std::unique_ptr write_file; + TF_ASSERT_OK(env->NewWritableFile(fname, &write_file)); + for (int i = 0; i < file_size; ++i) { + TF_ASSERT_OK(write_file->Append(file_elem)); + } + TF_ASSERT_OK(write_file->Close()); + + std::unique_ptr file; + TF_ASSERT_OK(env->NewRandomAccessFile(fname, &file)); + + string result; + testing::StartTiming(); + + for (int itr = 0; itr < iters; ++itr) { + BufferedInputStream in(file.get(), buff_size); + for (int64 i = 0; i < 10 * file_size; ++i) { + TF_ASSERT_OK(in.ReadNBytes(1, &result)) + << "i: " << i << " itr: " << itr << " buff_size: " << buff_size + << " file size: " << file_size; + } + } +} +BENCHMARK(BM_BufferedReaderSmallReads) + ->ArgPair(1, 5) + ->ArgPair(1, 1024) + ->ArgPair(10, 5) + ->ArgPair(10, 1024) + ->ArgPair(1024, 1024) + ->ArgPair(1024 * 1024, 1024) + ->ArgPair(1024 * 1024, 1024 * 1024) + ->ArgPair(256 * 1024 * 1024, 1024); + } // anonymous namespace } // namespace io } // namespace tensorflow diff --git a/tensorflow/core/lib/io/zlib_inputstream.h b/tensorflow/core/lib/io/zlib_inputstream.h index a8a4e7c83cca4c282fcfe940c04c83d0088bf569..8faa7dcb8f4139746132934813602bcb4a4e0ea9 100644 --- a/tensorflow/core/lib/io/zlib_inputstream.h +++ b/tensorflow/core/lib/io/zlib_inputstream.h @@ -37,7 +37,7 @@ namespace io { // by multiple threads class ZlibInputStream : public InputStreamInterface { public: - // Create a ZlibInputBuffer for `input_stream` with a buffer of size + // Create a ZlibInputStream for `input_stream` with a buffer of size // `input_buffer_bytes` bytes for reading contents from `input_stream` and // another buffer with size `output_buffer_bytes` for caching decompressed // contents. Does *not* take ownership of "input_stream". diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.cc b/tensorflow/core/lib/jpeg/jpeg_mem.cc index 258793aa1e665102b564ebd35310588054950d48..3c7e5ca696dc06d4946e820ec9b70210c9b5fdcd 100644 --- a/tensorflow/core/lib/jpeg/jpeg_mem.cc +++ b/tensorflow/core/lib/jpeg/jpeg_mem.cc @@ -70,13 +70,24 @@ class FewerArgsForCompiler { int stride_; }; +// Check whether the crop window is valid, assuming crop is true. +bool IsCropWindowValid(const UncompressFlags& flags, int input_image_width, + int input_image_height) { + // Crop window is valid only if it is non zero and all the window region is + // within the original image. + return flags.crop_width > 0 && flags.crop_height > 0 && flags.crop_x >= 0 && + flags.crop_y >= 0 && + flags.crop_y + flags.crop_height <= input_image_height && + flags.crop_x + flags.crop_width <= input_image_width; +} + uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) { // unpack the argball const int datasize = argball->datasize_; const auto& flags = argball->flags_; const int ratio = flags.ratio; int components = flags.components; - int stride = flags.stride; // may be 0 + int stride = flags.stride; // may be 0 int64* const nwarn = argball->pnwarn_; // may be NULL // Can't decode if the ratio is not recognized by libjpeg @@ -159,8 +170,43 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) { return nullptr; } + JDIMENSION target_output_width = cinfo.output_width; + JDIMENSION target_output_height = cinfo.output_height; + JDIMENSION skipped_scanlines = 0; +#if !defined(WIN32) + if (flags.crop) { + // Update target output height and width based on crop window. + target_output_height = flags.crop_height; + target_output_width = flags.crop_width; + + // So far, cinfo holds the original input image information. + if (!IsCropWindowValid(flags, cinfo.output_width, cinfo.output_height)) { + LOG(ERROR) << "Invalid crop window: x=" << flags.crop_x + << ", y=" << flags.crop_y << ", w=" << target_output_width + << ", h=" << target_output_height + << " for image_width: " << cinfo.output_width + << " and image_height: " << cinfo.output_height; + jpeg_destroy_decompress(&cinfo); + return nullptr; + } + + // Update cinfo.output_width. It is tricky that cinfo.output_width must + // fall on an Minimum Coded Unit (MCU) boundary; if it doesn't, then it will + // be moved left to the nearest MCU boundary, and width will be increased + // accordingly. Therefore, the final cinfo.crop_width might differ from the + // given flags.crop_width. Please see libjpeg library for details. + JDIMENSION crop_width = flags.crop_width; + JDIMENSION crop_x = flags.crop_x; + jpeg_crop_scanline(&cinfo, &crop_x, &crop_width); + + // Update cinfo.output_scanline. + skipped_scanlines = jpeg_skip_scanlines(&cinfo, flags.crop_y); + CHECK_EQ(skipped_scanlines, flags.crop_y); + } +#endif + // check for compatible stride - const int min_stride = cinfo.output_width * components * sizeof(JSAMPLE); + const int min_stride = target_output_width * components * sizeof(JSAMPLE); if (stride == 0) { stride = min_stride; } else if (stride < min_stride) { @@ -170,47 +216,88 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) { } // Remember stride and height for use in Uncompress - argball->height_ = cinfo.output_height; + argball->height_ = target_output_height; argball->stride_ = stride; - uint8* const dstdata = argball->allocate_output_( - cinfo.output_width, cinfo.output_height, components); +#if defined(WIN32) + uint8* dstdata = nullptr; + if (flags.crop) { + dstdata = new JSAMPLE[stride * target_output_height]; + } else { + dstdata = argball->allocate_output_(target_output_width, + target_output_height, components); + } +#else + uint8* dstdata = argball->allocate_output_(target_output_width, + target_output_height, components); +#endif if (dstdata == nullptr) { jpeg_destroy_decompress(&cinfo); return nullptr; } JSAMPLE* output_line = static_cast(dstdata); - // Temporary buffer used for CMYK -> RGB conversion. + // jpeg_read_scanlines requires the buffers to be allocated based on + // cinfo.output_width, but the target image width might be different if crop + // is enabled and crop_width is not MCU aligned. In this case, we need to + // realign the scanline output to achieve the exact cropping. Notably, only + // cinfo.output_width needs to fall on MCU boundary, while cinfo.output_height + // has no such constraint. + const bool need_realign_cropped_scanline = + (target_output_width != cinfo.output_width); const bool use_cmyk = (cinfo.out_color_space == JCS_CMYK); - tempdata = use_cmyk ? new JSAMPLE[cinfo.output_width * 4] : nullptr; + + if (use_cmyk) { + // Temporary buffer used for CMYK -> RGB conversion. + tempdata = new JSAMPLE[cinfo.output_width * 4]; + } else if (need_realign_cropped_scanline) { + // Temporary buffer used for MCU-aligned scanline data. + tempdata = new JSAMPLE[cinfo.output_width * components]; + } // If there is an error reading a line, this aborts the reading. // Save the fraction of the image that has been read. - argball->height_read_ = cinfo.output_height; - while (cinfo.output_scanline < cinfo.output_height) { + argball->height_read_ = target_output_height; + + // These variables are just to avoid repeated computation in the loop. + const int max_scanlines_to_read = skipped_scanlines + target_output_height; + const int mcu_align_offset = + (cinfo.output_width - target_output_width) * (use_cmyk ? 4 : components); + while (cinfo.output_scanline < max_scanlines_to_read) { int num_lines_read = 0; - if (cinfo.out_color_space == JCS_CMYK) { + if (use_cmyk) { num_lines_read = jpeg_read_scanlines(&cinfo, &tempdata, 1); - // Convert CMYK to RGB - for (size_t i = 0; i < cinfo.output_width; ++i) { - int c = tempdata[4 * i + 0]; - int m = tempdata[4 * i + 1]; - int y = tempdata[4 * i + 2]; - int k = tempdata[4 * i + 3]; - int r, g, b; - if (cinfo.saw_Adobe_marker) { - r = (k * c) / 255; - g = (k * m) / 255; - b = (k * y) / 255; - } else { - r = (255 - k) * (255 - c) / 255; - g = (255 - k) * (255 - m) / 255; - b = (255 - k) * (255 - y) / 255; + if (num_lines_read > 0) { + // Convert CMYK to RGB if scanline read succeeded. + for (size_t i = 0; i < target_output_width; ++i) { + int offset = 4 * i; + if (need_realign_cropped_scanline) { + // Align the offset for MCU boundary. + offset += mcu_align_offset; + } + const int c = tempdata[offset + 0]; + const int m = tempdata[offset + 1]; + const int y = tempdata[offset + 2]; + const int k = tempdata[offset + 3]; + int r, g, b; + if (cinfo.saw_Adobe_marker) { + r = (k * c) / 255; + g = (k * m) / 255; + b = (k * y) / 255; + } else { + r = (255 - k) * (255 - c) / 255; + g = (255 - k) * (255 - m) / 255; + b = (255 - k) * (255 - y) / 255; + } + output_line[3 * i + 0] = r; + output_line[3 * i + 1] = g; + output_line[3 * i + 2] = b; } - output_line[3 * i + 0] = r; - output_line[3 * i + 1] = g; - output_line[3 * i + 2] = b; + } + } else if (need_realign_cropped_scanline) { + num_lines_read = jpeg_read_scanlines(&cinfo, &tempdata, 1); + if (num_lines_read > 0) { + memcpy(output_line, tempdata + mcu_align_offset, min_stride); } } else { num_lines_read = jpeg_read_scanlines(&cinfo, &output_line, 1); @@ -218,12 +305,13 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) { // Handle error cases if (num_lines_read == 0) { LOG(ERROR) << "Premature end of JPEG data. Stopped at line " - << cinfo.output_scanline << "/" << cinfo.output_height; + << cinfo.output_scanline - skipped_scanlines << "/" + << target_output_height; if (!flags.try_recover_truncated_jpeg) { - argball->height_read_ = cinfo.output_scanline; + argball->height_read_ = cinfo.output_scanline - skipped_scanlines; error = JPEGERRORS_UNEXPECTED_END_OF_DATA; } else { - for (size_t line = cinfo.output_scanline; line < cinfo.output_height; + for (size_t line = cinfo.output_scanline; line < max_scanlines_to_read; ++line) { if (line == 0) { // If even the first line is missing, fill with black color @@ -235,9 +323,9 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) { output_line += stride; } argball->height_read_ = - cinfo.output_height; // consider all lines as read + target_output_height; // consider all lines as read // prevent error-on-exit in libjpeg: - cinfo.output_scanline = cinfo.output_height; + cinfo.output_scanline = max_scanlines_to_read; } break; } @@ -248,23 +336,33 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) { delete[] tempdata; tempdata = nullptr; +#if !defined(WIN32) + if (flags.crop && cinfo.output_scanline < cinfo.output_height) { + // Skip the rest of scanlines, required by jpeg_destroy_decompress. + jpeg_skip_scanlines(&cinfo, + cinfo.output_height - flags.crop_y - flags.crop_height); + // After this, cinfo.output_height must be equal to cinfo.output_height; + // otherwise, jpeg_destroy_decompress would fail. + } +#endif + // Convert the RGB data to RGBA, with alpha set to 0xFF to indicate // opacity. // RGBRGBRGB... --> RGBARGBARGBA... if (components == 4) { // Start on the last line. JSAMPLE* scanlineptr = static_cast( - dstdata + static_cast(cinfo.output_height - 1) * stride); + dstdata + static_cast(target_output_height - 1) * stride); const JSAMPLE kOpaque = -1; // All ones appropriate for JSAMPLE. - const int right_rgb = (cinfo.output_width - 1) * 3; - const int right_rgba = (cinfo.output_width - 1) * 4; + const int right_rgb = (target_output_width - 1) * 3; + const int right_rgba = (target_output_width - 1) * 4; - for (int y = cinfo.output_height; y-- > 0;) { + for (int y = target_output_height; y-- > 0;) { // We do all the transformations in place, going backwards for each row. const JSAMPLE* rgb_pixel = scanlineptr + right_rgb; JSAMPLE* rgba_pixel = scanlineptr + right_rgba; scanlineptr -= stride; - for (int x = cinfo.output_width; x-- > 0; + for (int x = target_output_width; x-- > 0; rgba_pixel -= 4, rgb_pixel -= 3) { // We copy the 3 bytes at rgb_pixel into the 4 bytes at rgba_pixel // The "a" channel is set to be opaque. @@ -319,8 +417,61 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) { LOG(ERROR) << "Unhandled case " << error; break; } - jpeg_destroy_decompress(&cinfo); +#if defined(WIN32) + // TODO(tanmingxing): delete all these code after migrating to libjpeg_turbo + // for Windows. + if (flags.crop) { + // Update target output height and width based on crop window. + target_output_height = flags.crop_height; + target_output_width = flags.crop_width; + + // cinfo holds the original input image information. + if (!IsCropWindowValid(flags, cinfo.output_width, cinfo.output_height)) { + LOG(ERROR) << "Invalid crop window: x=" << flags.crop_x + << ", y=" << flags.crop_y << ", w=" << target_output_width + << ", h=" << target_output_height + << " for image_width: " << cinfo.output_width + << " and image_height: " << cinfo.output_height; + delete[] dstdata; + jpeg_destroy_decompress(&cinfo); + return nullptr; + } + + const uint8* full_image = dstdata; + dstdata = argball->allocate_output_(target_output_width, + target_output_height, components); + if (dstdata == nullptr) { + delete[] full_image; + jpeg_destroy_decompress(&cinfo); + return nullptr; + } + + const int full_image_stride = stride; + // Update stride and hight for crop window. + const int min_stride = target_output_width * components * sizeof(JSAMPLE); + if (flags.stride == 0) { + stride = min_stride; + } + argball->height_ = target_output_height; + argball->stride_ = stride; + + if (argball->height_read_ > target_output_height) { + argball->height_read_ = target_output_height; + } + const int crop_offset = flags.crop_x * components * sizeof(JSAMPLE); + const uint8* full_image_ptr = full_image + flags.crop_y * full_image_stride; + uint8* crop_image_ptr = dstdata; + for (int i = 0; i < argball->height_read_; i++) { + memcpy(crop_image_ptr, full_image_ptr + crop_offset, min_stride); + crop_image_ptr += stride; + full_image_ptr += full_image_stride; + } + delete[] full_image; + } +#endif + + jpeg_destroy_decompress(&cinfo); return dstdata; } diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.h b/tensorflow/core/lib/jpeg/jpeg_mem.h index ac34f29f2219f0fbe23260e6f10a805defea4460..59342d28c0f411a90b68ec0590c5a6f86aaf8ca5 100644 --- a/tensorflow/core/lib/jpeg/jpeg_mem.h +++ b/tensorflow/core/lib/jpeg/jpeg_mem.h @@ -61,6 +61,17 @@ struct UncompressFlags { // // Setting this has a quality/speed trade-off implication. J_DCT_METHOD dct_method = JDCT_DEFAULT; + + // Settings of crop window before decompression. + bool crop = false; + // Vertical coordinate of the top-left corner of the result in the input. + int crop_x = 0; + // Horizontal coordinate of the top-left corner of the result in the input. + int crop_y = 0; + // Width of the output image. + int crop_width = 0; + // Height of the output image. + int crop_height = 0; }; // Uncompress some raw JPEG data given by the pointer srcdata and the length diff --git a/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc b/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc index cc8646750e1b2828a17c438c238f94fb1177f6dd..15266af1dbd877ff2023ec32e19c172dc3d00fa9 100644 --- a/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc +++ b/tensorflow/core/lib/jpeg/jpeg_mem_unittest.cc @@ -57,7 +57,7 @@ void ReadFileToStringOrDie(Env* env, const string& filename, string* output) { void TestJPEG(Env* env, const string& jpegfile) { // Read the data from the jpeg file into memory string jpeg; - ReadFileToStringOrDie(Env::Default(), jpegfile, &jpeg); + ReadFileToStringOrDie(env, jpegfile, &jpeg); const int fsize = jpeg.size(); const uint8* const temp = bit_cast(jpeg.data()); @@ -95,6 +95,194 @@ TEST(JpegMemTest, Jpeg) { TestJPEG(env, data_path + "jpeg_merge_test1_cmyk.jpg"); } +void TestCropAndDecodeJpeg(Env* env, const string& jpegfile, + const UncompressFlags& default_flags) { + // Read the data from the jpeg file into memory + string jpeg; + ReadFileToStringOrDie(env, jpegfile, &jpeg); + const int fsize = jpeg.size(); + auto temp = bit_cast(jpeg.data()); + + // Decode the whole image. + std::unique_ptr imgdata1; + int w1, h1, c1; + { + UncompressFlags flags = default_flags; + if (flags.stride == 0) { + imgdata1.reset(Uncompress(temp, fsize, flags, &w1, &h1, &c1, nullptr)); + } else { + // If stride is not zero, the default allocator would fail because it + // allocate w*h*c bytes, but the actual required bytes should be stride*h. + // Therefore, we provide a specialized allocator here. + uint8* buffer = nullptr; + imgdata1.reset(Uncompress(temp, fsize, flags, nullptr, + [&](int width, int height, int components) { + w1 = width; + h1 = height; + c1 = components; + buffer = new uint8[flags.stride * height]; + return buffer; + })); + } + ASSERT_NE(imgdata1, nullptr); + } + + auto check_crop_and_decode_func = [&](int crop_x, int crop_y, int crop_width, + int crop_height) { + std::unique_ptr imgdata2; + int w, h, c; + UncompressFlags flags = default_flags; + flags.crop = true; + flags.crop_x = crop_x; + flags.crop_y = crop_y; + flags.crop_width = crop_width; + flags.crop_height = crop_height; + if (flags.stride == 0) { + imgdata2.reset(Uncompress(temp, fsize, flags, &w, &h, &c, nullptr)); + } else { + uint8* buffer = nullptr; + imgdata2.reset(Uncompress(temp, fsize, flags, nullptr, + [&](int width, int height, int components) { + w = width; + h = height; + c = components; + buffer = new uint8[flags.stride * height]; + return buffer; + })); + } + ASSERT_NE(imgdata2, nullptr); + + ASSERT_EQ(w, crop_width); + ASSERT_EQ(h, crop_height); + ASSERT_EQ(c, c1); + + const int stride1 = (flags.stride != 0) ? flags.stride : w1 * c; + const int stride2 = (flags.stride != 0) ? flags.stride : w * c; + for (int i = 0; i < crop_height; i++) { + const uint8* p1 = &imgdata1[(i + crop_y) * stride1 + crop_x * c]; + const uint8* p2 = &imgdata2[i * stride2]; + + for (int j = 0; j < c * w; j++) { + ASSERT_EQ(p1[j], p2[j]) + << "p1 != p2 in [" << i << "][" << j / 3 << "][" << j % 3 << "]"; + } + } + }; + + // Check different crop windows. + check_crop_and_decode_func(0, 0, 5, 5); + check_crop_and_decode_func(0, 0, w1, 5); + check_crop_and_decode_func(0, 0, 5, h1); + check_crop_and_decode_func(0, 0, w1, h1); + check_crop_and_decode_func(w1 - 5, h1 - 6, 5, 6); + check_crop_and_decode_func(5, 6, 10, 15); +} + +TEST(JpegMemTest, CropAndDecodeJpeg) { + Env* env = Env::Default(); + const string data_path = kTestData; + UncompressFlags flags; + + // Test basic flags for jpeg and cmyk jpeg. + TestCropAndDecodeJpeg(env, data_path + "jpeg_merge_test1.jpg", flags); + TestCropAndDecodeJpeg(env, data_path + "jpeg_merge_test1_cmyk.jpg", flags); +} + +TEST(JpegMemTest, CropAndDecodeJpegWithRatio) { + Env* env = Env::Default(); + const string data_path = kTestData; + UncompressFlags flags; + for (int ratio : {1, 2, 4, 8}) { + flags.ratio = ratio; + TestCropAndDecodeJpeg(env, data_path + "jpeg_merge_test1.jpg", flags); + } +} + +TEST(JpegMemTest, CropAndDecodeJpegWithComponents) { + Env* env = Env::Default(); + const string data_path = kTestData; + UncompressFlags flags; + for (const int components : {0, 1, 3}) { + flags.components = components; + TestCropAndDecodeJpeg(env, data_path + "jpeg_merge_test1.jpg", flags); + } +} + +TEST(JpegMemTest, CropAndDecodeJpegWithUpScaling) { + Env* env = Env::Default(); + const string data_path = kTestData; + UncompressFlags flags; + flags.fancy_upscaling = true; + TestCropAndDecodeJpeg(env, data_path + "jpeg_merge_test1.jpg", flags); +} + +TEST(JpegMemTest, CropAndDecodeJpegWithStride) { + Env* env = Env::Default(); + const string data_path = kTestData; + + // Read the data from the jpeg file into memory + string jpeg; + ReadFileToStringOrDie(env, data_path + "jpeg_merge_test1.jpg", &jpeg); + const int fsize = jpeg.size(); + auto temp = bit_cast(jpeg.data()); + + int w, h, c; + ASSERT_TRUE(GetImageInfo(temp, fsize, &w, &h, &c)); + + // stride must be either 0 or > w*c; otherwise, uncompress would fail. + UncompressFlags flags; + flags.stride = w * c; + TestCropAndDecodeJpeg(env, data_path + "jpeg_merge_test1.jpg", flags); + flags.stride = w * c * 3; + TestCropAndDecodeJpeg(env, data_path + "jpeg_merge_test1.jpg", flags); + flags.stride = w * c + 100; + TestCropAndDecodeJpeg(env, data_path + "jpeg_merge_test1.jpg", flags); +} + +void CheckInvalidCropWindowFailed(const uint8* const temp, int fsize, int x, + int y, int w, int h) { + std::unique_ptr imgdata; + int ww, hh, cc; + UncompressFlags flags; + flags.components = 3; + flags.crop = true; + flags.crop_x = x; + flags.crop_y = y; + flags.crop_width = w; + flags.crop_height = h; + imgdata.reset(Uncompress(temp, fsize, flags, &ww, &hh, &cc, nullptr)); + CHECK(imgdata == nullptr); +} + +TEST(JpegMemTest, CropAndDecodeJpegWithInvalidCropWindow) { + Env* env = Env::Default(); + const string data_path = kTestData; + + // Read the data from the jpeg file into memory + string jpeg; + ReadFileToStringOrDie(env, data_path + "jpeg_merge_test1.jpg", &jpeg); + const int fsize = jpeg.size(); + auto temp = bit_cast(jpeg.data()); + + int w, h, c; + ASSERT_TRUE(GetImageInfo(temp, fsize, &w, &h, &c)); + + // Width and height for the crop window must be non zero. + CheckInvalidCropWindowFailed(temp, fsize, 11, 11, /*w=*/0, 11); + CheckInvalidCropWindowFailed(temp, fsize, 11, 11, 11, /*h=*/0); + + // Crop window must be non negative. + CheckInvalidCropWindowFailed(temp, fsize, /*x=*/-1, 11, 11, 11); + CheckInvalidCropWindowFailed(temp, fsize, 11, /*y=*/-1, 11, 11); + CheckInvalidCropWindowFailed(temp, fsize, 11, 11, /*w=*/-1, 11); + CheckInvalidCropWindowFailed(temp, fsize, 11, 11, 11, /*h=*/-1); + + // Invalid crop window width: x + crop_width = w + 1 > w + CheckInvalidCropWindowFailed(temp, fsize, /*x=*/w - 10, 11, 11, 11); + // Invalid crop window height: y + crop_height= h + 1 > h + CheckInvalidCropWindowFailed(temp, fsize, 11, /*y=*/h - 10, 11, 11); +} + TEST(JpegMemTest, Jpeg2) { // create known data, for size in_w x in_h const int in_w = 256; diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc index c68e14f09fbd4a89ad9cd75a8df94144d0cd2c75..8509c9a0417621f9c9550c6af92dcbf4b7075347 100644 --- a/tensorflow/core/lib/strings/str_util.cc +++ b/tensorflow/core/lib/strings/str_util.cc @@ -248,6 +248,58 @@ string Uppercase(StringPiece s) { return result; } +string ArgDefCase(StringPiece s) { + const size_t n = s.size(); + + // Compute the size of resulting string. + // Number of extra underscores we will need to add. + size_t extra_us = 0; + // Number of non-alpha chars in the beginning to skip. + size_t to_skip = 0; + for (size_t i = 0; i < n; ++i) { + // If we are skipping and current letter is non-alpha, skip it as well + if (i == to_skip && !isalpha(s[i])) { + ++to_skip; + continue; + } + + // If we are here, we are not skipping any more. + // If this letter is upper case, not the very first char in the + // resulting string, and previous letter isn't replaced with an underscore, + // we will need to insert an underscore. + if (isupper(s[i]) && i != to_skip && i > 0 && isalnum(s[i - 1])) { + ++extra_us; + } + } + + // Initialize result with all '_'s. There is no string + // constructor that does not initialize memory. + string result(n + extra_us - to_skip, '_'); + // i - index into s + // j - index into result + for (size_t i = to_skip, j = 0; i < n; ++i, ++j) { + DCHECK_LT(j, result.size()); + char c = s[i]; + // If c is not alphanumeric, we don't need to do anything + // since there is already an underscore in its place. + if (isalnum(c)) { + if (isupper(c)) { + // If current char is upper case, we might need to insert an + // underscore. + if (i != to_skip) { + DCHECK_GT(j, 0); + if (result[j - 1] != '_') ++j; + } + result[j] = tolower(c); + } else { + result[j] = c; + } + } + } + + return result; +} + void TitlecaseString(string* s, StringPiece delimiters) { bool upper = true; for (string::iterator ss = s->begin(); ss != s->end(); ++ss) { diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h index 669f0d3c5279b90fe31398410c4a95a053d16fd5..8cea0f0718652690a79891e1f5beb8d14b80c74b 100644 --- a/tensorflow/core/lib/strings/str_util.h +++ b/tensorflow/core/lib/strings/str_util.h @@ -81,6 +81,17 @@ string Lowercase(StringPiece s); // Return upper-cased version of s. string Uppercase(StringPiece s); +// Converts "^2ILoveYou!" to "i_love_you_". More specifically: +// - converts all non-alphanumeric characters to underscores +// - replaces each occurence of a capital letter (except the very +// first character and if there is already an '_' before it) with '_' +// followed by this letter in lower case +// - Skips leading non-alpha characters +// This method is useful for producing strings matching "[a-z][a-z0-9_]*" +// as required by OpDef.ArgDef.name. The resulting string is either empty or +// matches this regex. +string ArgDefCase(StringPiece s); + // Capitalize first character of each word in "*s". "delimiters" is a // set of characters that can be used as word boundaries. void TitlecaseString(string* s, StringPiece delimiters); diff --git a/tensorflow/core/lib/strings/str_util_test.cc b/tensorflow/core/lib/strings/str_util_test.cc index 040f7447e4d2d13a9f679ba92670ee74a866dae3..5c735a87a39d2b7583da208edd9af35dad33c55e 100644 --- a/tensorflow/core/lib/strings/str_util_test.cc +++ b/tensorflow/core/lib/strings/str_util_test.cc @@ -338,6 +338,38 @@ TEST(Uppercase, Basic) { EXPECT_EQ("HELLO WORLD", str_util::Uppercase("Hello World")); } +TEST(SnakeCase, Basic) { + EXPECT_EQ("", str_util::ArgDefCase("")); + EXPECT_EQ("", str_util::ArgDefCase("!")); + EXPECT_EQ("", str_util::ArgDefCase("5")); + EXPECT_EQ("", str_util::ArgDefCase("!:")); + EXPECT_EQ("", str_util::ArgDefCase("5-5")); + EXPECT_EQ("", str_util::ArgDefCase("_!")); + EXPECT_EQ("", str_util::ArgDefCase("_5")); + EXPECT_EQ("a", str_util::ArgDefCase("_a")); + EXPECT_EQ("a", str_util::ArgDefCase("_A")); + EXPECT_EQ("i", str_util::ArgDefCase("I")); + EXPECT_EQ("i", str_util::ArgDefCase("i")); + EXPECT_EQ("i_", str_util::ArgDefCase("I%")); + EXPECT_EQ("i_", str_util::ArgDefCase("i%")); + EXPECT_EQ("i", str_util::ArgDefCase("%I")); + EXPECT_EQ("i", str_util::ArgDefCase("-i")); + EXPECT_EQ("i", str_util::ArgDefCase("3i")); + EXPECT_EQ("i", str_util::ArgDefCase("32i")); + EXPECT_EQ("i3", str_util::ArgDefCase("i3")); + EXPECT_EQ("i_a3", str_util::ArgDefCase("i_A3")); + EXPECT_EQ("i_i", str_util::ArgDefCase("II")); + EXPECT_EQ("i_i", str_util::ArgDefCase("I_I")); + EXPECT_EQ("i__i", str_util::ArgDefCase("I__I")); + EXPECT_EQ("i_i_32", str_util::ArgDefCase("II-32")); + EXPECT_EQ("ii_32", str_util::ArgDefCase("Ii-32")); + EXPECT_EQ("hi_there", str_util::ArgDefCase("HiThere")); + EXPECT_EQ("hi_hi", str_util::ArgDefCase("Hi!Hi")); + EXPECT_EQ("hi_hi", str_util::ArgDefCase("HiHi")); + EXPECT_EQ("hihi", str_util::ArgDefCase("Hihi")); + EXPECT_EQ("hi_hi", str_util::ArgDefCase("Hi_Hi")); +} + TEST(TitlecaseString, Basic) { string s = "sparse_lookup"; str_util::TitlecaseString(&s, "_"); diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 651f22c6eae4c1ba83abe7d770dd9631bb27c149..62c86c771458925dc366043e6a0783c646bf215d 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -5488,24 +5488,28 @@ REGISTER_OP("BatchMatrixDiag") .Input("diagonal: T") .Output("output: T") .Attr("T: type") - .Deprecated(14, "Use MatrixDiag"); + .Deprecated(14, "Use MatrixDiag") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchMatrixSetDiag") .Input("input: T") .Input("diagonal: T") .Output("output: T") .Attr("T: type") - .Deprecated(14, "Use MatrixSetDiag"); + .Deprecated(14, "Use MatrixSetDiag") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchMatrixDiagPart") .Input("input: T") .Output("diagonal: T") .Attr("T: type") - .Deprecated(14, "Use MatrixDiagPart"); + .Deprecated(14, "Use MatrixDiagPart") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchMatrixBandPart") .Input("input: T") .Input("num_lower: int64") .Input("num_upper: int64") .Output("band: T") .Attr("T: type") - .Deprecated(14, "Use MatrixBandPart"); + .Deprecated(14, "Use MatrixBandPart") + .SetShapeFn(shape_inference::UnknownShape); } // namespace tensorflow diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index dbca823a579a014a4068fb67f603f0f046e122c6..a8338620d6909a6db759db9942a051e184f50bb7 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -8288,6 +8288,30 @@ op { } } } +op { + name: "ExtractJpegShape" + input_arg { + name: "contents" + type: DT_STRING + } + output_arg { + name: "image_shape" + type_attr: "output_type" + } + attr { + name: "output_type" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} op { name: "FFT" input_arg { @@ -10443,8 +10467,8 @@ op { type_list_attr: "Treduce_func_other_arguments" } input_arg { - name: "window_size" - type: DT_INT64 + name: "window_size_func_other_arguments" + type_list_attr: "Twindow_size_func_other_arguments" } output_arg { name: "handle" @@ -10458,6 +10482,10 @@ op { name: "reduce_func" type: "func" } + attr { + name: "window_size_func" + type: "func" + } attr { name: "Tkey_func_other_arguments" type: "list(type)" @@ -10468,6 +10496,11 @@ op { type: "list(type)" has_minimum: true } + attr { + name: "Twindow_size_func_other_arguments" + type: "list(type)" + has_minimum: true + } attr { name: "output_types" type: "list(type)" @@ -22370,6 +22403,18 @@ op { } is_stateful: true } +op { + name: "RestoreIterator" + input_arg { + name: "iterator" + type: DT_RESOURCE + } + input_arg { + name: "path" + type: DT_STRING + } + is_stateful: true +} op { name: "RestoreSlice" input_arg { @@ -23093,6 +23138,18 @@ op { } is_stateful: true } +op { + name: "SaveIterator" + input_arg { + name: "iterator" + type: DT_RESOURCE + } + input_arg { + name: "path" + type: DT_STRING + } + is_stateful: true +} op { name: "SaveSlices" input_arg { @@ -24333,6 +24390,21 @@ op { type: "type" } } +op { + name: "SerializeTensor" + input_arg { + name: "tensor" + type_attr: "T" + } + output_arg { + name: "serialized" + type: DT_STRING + } + attr { + name: "T" + type: "type" + } +} op { name: "SetSize" input_arg { @@ -24787,6 +24859,51 @@ op { } } } +op { + name: "SloppyInterleaveDataset" + input_arg { + name: "input_dataset" + type: DT_RESOURCE + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "cycle_length" + type: DT_INT64 + } + input_arg { + name: "block_length" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_RESOURCE + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "Softmax" input_arg { @@ -27670,6 +27787,38 @@ op { } } } +op { + name: "SqlDataset" + input_arg { + name: "driver_name" + type: DT_STRING + } + input_arg { + name: "data_source_name" + type: DT_STRING + } + input_arg { + name: "query" + type: DT_STRING + } + output_arg { + name: "handle" + type: DT_RESOURCE + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "Sqrt" input_arg { @@ -28728,6 +28877,40 @@ op { } } } +op { + name: "Sub" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} op { name: "Substr" input_arg { diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index f10e905361a1aa13c517af0bcccb03f70eb9a6ad..7cc8dccb95ca549097940c8e7e617513c37bf716 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -233,16 +233,46 @@ f: A function mapping elements of `input_dataset`, concatenated with `output_types` and `output_shapes`. )doc"); +REGISTER_OP("SloppyInterleaveDataset") + .Input("input_dataset: resource") + .Input("other_arguments: Targuments") + .Input("cycle_length: int64") + .Input("block_length: int64") + .Output("handle: resource") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that applies `f` to the outputs of `input_dataset`. + +The resulting dataset is similar to the `InterleaveDataset`, with the exception +that if retrieving the next value from a dataset would cause the requester to +block, it will skip that input dataset. This dataset is especially useful +when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it +allows the training step to proceed so long as some data is available. + +!! WARNING !! This dataset is not deterministic! + +f: A function mapping elements of `input_dataset`, concatenated with + `other_arguments`, to a Dataset resource that contains elements matching + `output_types` and `output_shapes`. +)doc"); + REGISTER_OP("GroupByWindowDataset") .Input("input_dataset: resource") .Input("key_func_other_arguments: Tkey_func_other_arguments") .Input("reduce_func_other_arguments: Treduce_func_other_arguments") - .Input("window_size: int64") + .Input( + "window_size_func_other_arguments: Twindow_size_func_other_arguments") .Output("handle: resource") .Attr("key_func: func") .Attr("reduce_func: func") + .Attr("window_size_func: func") .Attr("Tkey_func_other_arguments: list(type) >= 0") .Attr("Treduce_func_other_arguments: list(type) >= 0") + .Attr("Twindow_size_func_other_arguments: list(type) >= 0") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape) @@ -420,6 +450,22 @@ compression_type: A scalar containing either (i) the empty string (no buffer_size: A scalar containing the number of bytes to buffer. )doc"); +REGISTER_OP("SqlDataset") + .Input("driver_name: string") + .Input("data_source_name: string") + .Input("query: string") + .Output("handle: resource") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that executes a SQL query and emits rows of the result set. + +driver_name: The database type. Currently, the only supported type is 'sqlite'. +data_source_name: A connection string to connect to the database. +query: A SQL query to execute. +)doc"); + REGISTER_OP("FixedLengthRecordDataset") .Input("filenames: string") .Input("header_bytes: int64") @@ -483,6 +529,24 @@ This operation may be executed multiple times. Each execution will reset the iterator in `iterator` to the first element of `dataset`. )doc"); +REGISTER_OP("SaveIterator") + .Input("iterator: resource") + .Input("path: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Saves the state of the `iterator` at `path`. + +This state can be restored using "RestoreIterator". +)doc"); + +REGISTER_OP("RestoreIterator") + .Input("iterator: resource") + .Input("path: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Restores the state of the `iterator` from the checkpoint saved at `path` using "SaveIterator". +)doc"); + REGISTER_OP("OneShotIterator") .Output("handle: resource") .Attr("dataset_factory: func") diff --git a/tensorflow/core/ops/debug_ops.cc b/tensorflow/core/ops/debug_ops.cc index bd7f7c2c018000656a048c815702a90bf24f5426..5aebdca1ea5388763ef8422704e86ae55058621e 100644 --- a/tensorflow/core/ops/debug_ops.cc +++ b/tensorflow/core/ops/debug_ops.cc @@ -32,6 +32,7 @@ REGISTER_OP("Copy") .Attr("tensor_name: string = ''") .Attr("debug_ops_spec: list(string) = []") .SetAllowsUninitializedInput() + .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Copy Op. @@ -61,6 +62,7 @@ REGISTER_OP("CopyHost") .Attr("tensor_name: string = ''") .Attr("debug_ops_spec: list(string) = []") .SetAllowsUninitializedInput() + .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Copy Host Op. @@ -118,6 +120,7 @@ REGISTER_OP("DebugNanCount") .Attr("debug_urls: list(string) = []") .Attr("gated_grpc: bool = false") .SetAllowsUninitializedInput() + .SetShapeFn(shape_inference::ScalarShape) .Doc(R"doc( Debug NaN Value Counter Op @@ -148,6 +151,8 @@ REGISTER_OP("DebugNumericSummary") .Attr("mute_if_healthy: bool = false") .Attr("gated_grpc: bool = false") .SetAllowsUninitializedInput() + // Note: this could return a more specific shape if needed in future. + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Debug Numeric Summary Op. diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 1bfa37f5a7865c9f18d2870e013624dcc3a22414..8ddf3561ce1f5e3bd3915503f89cd83a642b07b2 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -457,6 +457,28 @@ xmp_metadata: If not empty, embed this XMP metadata in the image header. contents: 0-D. JPEG-encoded image. )doc"); +// -------------------------------------------------------------------------- +REGISTER_OP("ExtractJpegShape") + .Input("contents: string") + .Output("image_shape: output_type") + .Attr("output_type: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + c->set_output(0, c->Vector(3)); + return Status::OK(); + }) + .Doc(R"doc( +Extract the shape information of a JPEG-encoded image. + +This op only parses the image header, so it is much faster than DecodeJpeg. + +contents: 0-D. The JPEG-encoded image. +image_shape: 1-D. The image shape with format [height, width, channels]. +output_type: (Optional) The output type of the operation (int32 or int64). + Defaults to int32. +)doc"); + // -------------------------------------------------------------------------- REGISTER_OP("AdjustContrast") .Input("images: T") diff --git a/tensorflow/core/ops/image_ops_test.cc b/tensorflow/core/ops/image_ops_test.cc index ea202edfb37e99071196ded4c572381d07d5bfa5..c757cefddaf9443119f685ff0eae3a26e94f9260 100644 --- a/tensorflow/core/ops/image_ops_test.cc +++ b/tensorflow/core/ops/image_ops_test.cc @@ -101,6 +101,16 @@ TEST(ImageOpsTest, EncodeImage_ShapeFn) { } } +TEST(ImageOpsTest, ExtractJpegShape_ShapeFn) { + ShapeInferenceTestOp op("ExtractJpegShape"); + + // Rank check. + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1]"); + + // Only specify input data. Output must be a 1-D tensor with 3 elements. + INFER_OK(op, "?", "[3]"); +} + TEST(ImageOpsTest, Colorspace_ShapeFn) { for (const char* op_name : {"HSVToRGB", "RGBToHSV"}) { ShapeInferenceTestOp op(op_name); diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index 5b75bda1f1b01ba6a4d33d0e068d57cb3838334a..48b2362342275ae8ffbdfa8f12def59d631d4697 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" @@ -557,34 +558,39 @@ REGISTER_OP("BatchSelfAdjointEig") .Input("input: T") .Output("output: T") .Attr("T: {double, float}") - .Deprecated(11, "Use SelfAdjointEigV2 instead."); + .Deprecated(11, "Use SelfAdjointEigV2 instead.") + .SetShapeFn(shape_inference::UnknownShape); // Can all be deleted after 9mar2017. REGISTER_OP("BatchMatrixDeterminant") .Input("input: T") .Output("output: T") .Attr("T: {float, double, complex64, complex128}") - .Deprecated(13, "Use MatrixDeterminant instead."); + .Deprecated(13, "Use MatrixDeterminant instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchMatrixInverse") .Input("input: T") .Output("output: T") .Attr("adjoint: bool = False") .Attr("T: {double, float}") - .Deprecated(13, "Use MatrixInverse instead."); + .Deprecated(13, "Use MatrixInverse instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchCholesky") .Input("input: T") .Output("output: T") .Attr("T: {double, float}") - .Deprecated(13, "Use Cholesky instead."); + .Deprecated(13, "Use Cholesky instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchCholeskyGrad") .Input("l: T") .Input("grad: T") .Output("output: T") .Attr("T: {float, double}") - .Deprecated(13, "Use CholeskyGrad instead."); + .Deprecated(13, "Use CholeskyGrad instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchSelfAdjointEigV2") .Input("input: T") @@ -592,7 +598,8 @@ REGISTER_OP("BatchSelfAdjointEigV2") .Output("v: T") .Attr("compute_v: bool = True") .Attr("T: {double, float}") - .Deprecated(13, "Use SelfAdjointEigV2 instead."); + .Deprecated(13, "Use SelfAdjointEigV2 instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchMatrixSolve") .Input("matrix: T") @@ -600,7 +607,8 @@ REGISTER_OP("BatchMatrixSolve") .Output("output: T") .Attr("adjoint: bool = False") .Attr("T: {double, float}") - .Deprecated(13, "Use MatrixSolve instead."); + .Deprecated(13, "Use MatrixSolve instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchMatrixTriangularSolve") .Input("matrix: T") @@ -609,7 +617,8 @@ REGISTER_OP("BatchMatrixTriangularSolve") .Attr("lower: bool = True") .Attr("adjoint: bool = False") .Attr("T: {double, float}") - .Deprecated(13, "Use MatrixTriangularSolve instead."); + .Deprecated(13, "Use MatrixTriangularSolve instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchMatrixSolveLs") .Input("matrix: T") @@ -618,7 +627,8 @@ REGISTER_OP("BatchMatrixSolveLs") .Output("output: T") .Attr("T: {double, float}") .Attr("fast: bool = True") - .Deprecated(13, "Use MatrixSolveLs instead."); + .Deprecated(13, "Use MatrixSolveLs instead.") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("BatchSvd") .Input("input: T") @@ -628,6 +638,7 @@ REGISTER_OP("BatchSvd") .Attr("compute_uv: bool = True") .Attr("full_matrices: bool = False") .Attr("T: {double, float, complex64, complex128}") - .Deprecated(13, "Use Svd instead."); + .Deprecated(13, "Use Svd instead.") + .SetShapeFn(shape_inference::UnknownShape); } // namespace tensorflow diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index f959b52b57585a3f43f5d190d0aff9316a2fe483..c21b9a7977afd2d08627a9efeebe13d039f44e98 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -498,8 +498,39 @@ Returns x + y element-wise. [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) )doc"); +REGISTER_OP("_MklAdd") + .Input("x: T") + .Input("y: T") + .Input("mkl_x: uint8") + .Input("mkl_y: uint8") + .Output("z: T") + .Output("mkl_z: uint8") + .Attr( + "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, " + "complex128, string}") + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) + .Doc(R"doc( +Returns x + y element-wise. + +*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +)doc"); + REGISTER_OP("Sub") + .BINARY_MORE() + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) + .Doc(R"doc( +Returns x - y element-wise. + +*NOTE*: `Sub` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +)doc"); + +REGISTER_OP("_MklSub") .BINARY_FEWER() + .Input("mkl_x: uint8") + .Input("mkl_y: uint8") + .Output("mkl_z: uint8") .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( Returns x - y element-wise. @@ -519,6 +550,20 @@ Returns x * y element-wise. [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) )doc"); +REGISTER_OP("_MklMul") + .BINARY_MORE() + .Input("mkl_x: uint8") + .Input("mkl_y: uint8") + .Output("mkl_z: uint8") + .SetIsCommutative() + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) + .Doc(R"doc( +Returns x * y element-wise. + +*NOTE*: `Mul` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +)doc"); + REGISTER_OP("Div") .BINARY_MORE() .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) @@ -577,6 +622,20 @@ Returns (x - y)(x - y) element-wise. [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) )doc"); +REGISTER_OP("_MklSquaredDifference") + .BINARY_FEWER() + .Input("mkl_x: uint8") + .Input("mkl_y: uint8") + .Output("mkl_z: uint8") + .SetIsCommutative() + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) + .Doc(R"doc( +Returns (x - y)(x - y) element-wise. + +*NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +)doc"); + #undef BINARY_FEWER #undef BINARY_MORE @@ -594,6 +653,23 @@ Returns the max of x and y (i.e. x > y ? x : y) element-wise. [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) )doc"); +REGISTER_OP("_MklMaximum") + .Input("x: T") + .Input("y: T") + .Input("mkl_x: uint8") + .Input("mkl_y: uint8") + .Output("z: T") + .Output("mkl_z: uint8") + .Attr("T: {half, float, double, int32, int64}") + .SetIsCommutative() + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) + .Doc(R"doc( +Returns the max of x and y (i.e. x > y ? x : y) element-wise. + +*NOTE*: `Maximum` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +)doc"); + REGISTER_OP("Minimum") .Input("x: T") .Input("y: T") diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 0a96258dd1fac95190cf88cc768c2337951555c2..6651ad41e9a1da73fe679c75ef07b4a6bbf7a6a8 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1945,7 +1945,7 @@ Computes softsign gradients for a softsign operation. gradients: The backpropagated gradients to the corresponding softsign operation. features: The features passed as input to the corresponding softsign operation. -backprops: The gradients: `gradients / (1 + abs(-features)) ** 2`. +backprops: The gradients: `gradients / (1 + abs(features)) ** 2`. )doc"); // -------------------------------------------------------------------------- @@ -2791,7 +2791,9 @@ REGISTER_OP("_MklConv2D") .Input("mkl_input: uint8") .Input("mkl_filter: uint8") .Output("output: T") + .Output("filter_output: T") .Output("mkl_output: uint8") + .Output("mkl_filter_output: uint8") .Attr("T: {half, float, double}") .Attr("strides: list(int)") .Attr("use_cudnn_on_gpu: bool = true") @@ -2813,7 +2815,9 @@ REGISTER_OP("_MklConv2DWithBias") .Input("mkl_filter: uint8") .Input("mkl_bias: uint8") .Output("output: T") + .Output("filter_output: T") .Output("mkl_output: uint8") + .Output("mkl_filter_output: uint8") .Attr("T: {half, float, double}") .Attr("strides: list(int)") .Attr("use_cudnn_on_gpu: bool = true") @@ -3230,6 +3234,29 @@ REGISTER_OP("_MklToTf") .Doc(R"doc( MKL operator to convert a tensor from MKL layout to TensorFlow layout. +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklInputConversion") + .Input("input_0: T") + .Input("input_1: T") + .Input("mkl_input_0: uint8") + .Input("mkl_input_1: uint8") + .Output("output_0: T") + .Output("output_1: T") + .Output("mkl_output_0: uint8") + .Output("mkl_output_1: uint8") + // All datatypes supported by element-wise ops + .Attr( + "T: {half, float, double, uint8, int8, uint16, int16, int32, int64, " + "complex64, complex128}") + .Attr(GetConvnetDataFormatAttrString()) + .Doc(R"doc( +MKL operator to process the inputs to an elementwise MKL op. Both inputs +need to be either in TF or in MKL format. This op is added before every +element-wise MKL op. + NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 640ca5591e3fa1b401bdc7292bc33c6789833971..cfd3869d0597779a82dd8367962341bc6a0e677a 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -7805,6 +7805,35 @@ op { } summary: "Extract `patches` from `images` and put them in the \"depth\" output dimension." } +op { + name: "ExtractJpegShape" + input_arg { + name: "contents" + description: "0-D. The JPEG-encoded image." + type: DT_STRING + } + output_arg { + name: "image_shape" + description: "1-D. The image shape with format [height, width, channels]." + type_attr: "output_type" + } + attr { + name: "output_type" + type: "type" + default_value { + type: DT_INT32 + } + description: "(Optional) The output type of the operation (int32 or int64).\nDefaults to int32." + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + summary: "Extract the shape information of a JPEG-encoded image." + description: "This op only parses the image header, so it is much faster than DecodeJpeg." +} op { name: "FFT" input_arg { @@ -9582,8 +9611,8 @@ op { type_list_attr: "Treduce_func_other_arguments" } input_arg { - name: "window_size" - type: DT_INT64 + name: "window_size_func_other_arguments" + type_list_attr: "Twindow_size_func_other_arguments" } output_arg { name: "handle" @@ -9598,6 +9627,10 @@ op { name: "reduce_func" type: "func" } + attr { + name: "window_size_func" + type: "func" + } attr { name: "Tkey_func_other_arguments" type: "list(type)" @@ -9608,6 +9641,11 @@ op { type: "list(type)" has_minimum: true } + attr { + name: "Twindow_size_func_other_arguments" + type: "list(type)" + has_minimum: true + } attr { name: "output_types" type: "list(type)" @@ -15766,6 +15804,25 @@ op { } summary: "Transforms a serialized tensorflow.TensorProto proto into a Tensor." } +op { + name: "SerializeTensor" + input_arg { + name: "tensor" + description: "A Tensor of type `T`." + type: "T" + } + output_arg { + name: "serialized" + description: "A serialized TensorProto proto of the input tensor." + type_attr: DT_STRING + } + attr { + name: "T" + type: "type" + description: "The type of the input tensor." + } + summary: "Transforms a Tensor into a serialized TensorProto proto." +} op { name: "Placeholder" output_arg { @@ -22064,6 +22121,19 @@ op { description: "Reads a tensor stored in one or several files. If there are several files (for\ninstance because a tensor was saved as slices), `file_pattern` may contain\nwildcard symbols (`*` and `?`) in the filename portion only, not in the\ndirectory portion.\n\nIf a `file_pattern` matches several files, `preferred_shard` can be used to hint\nin which file the requested tensor is likely to be found. This op will first\nopen the file at index `preferred_shard` in the list of matching files and try\nto restore tensors from that file. Only if some tensors or tensor slices are\nnot found in that first file, then the Op opens all the files. Setting\n`preferred_shard` to match the value passed as the `shard` input\nof a matching `Save` Op may speed up Restore. This attribute only affects\nperformance, not correctness. The default value -1 means files are processed in\norder.\n\nSee also `RestoreSlice`." is_stateful: true } +op { + name: "RestoreIterator" + input_arg { + name: "iterator" + type: DT_RESOURCE + } + input_arg { + name: "path" + type: DT_STRING + } + summary: "Restores the state of the `iterator` from the checkpoint saved at `path` using \"SaveIterator\"." + is_stateful: true +} op { name: "RestoreSlice" input_arg { @@ -22624,6 +22694,20 @@ op { description: "The size of `tensor_names` must match the number of tensors in `data`. `data[i]`\nis written to `filename` with name `tensor_names[i]`.\n\nSee also `SaveSlices`." is_stateful: true } +op { + name: "SaveIterator" + input_arg { + name: "iterator" + type: DT_RESOURCE + } + input_arg { + name: "path" + type: DT_STRING + } + summary: "Saves the state of the `iterator` at `path`." + description: "This state can be restored using \"RestoreIterator\"." + is_stateful: true +} op { name: "SaveSlices" input_arg { @@ -23950,6 +24034,25 @@ op { } summary: "Serialize a `SparseTensor` into a string 3-vector (1-D `Tensor`) object." } +op { + name: "SerializeTensor" + input_arg { + name: "tensor" + description: "A Tensor of type `T`." + type_attr: "T" + } + output_arg { + name: "serialized" + description: "A serialized TensorProto proto of the input tensor." + type: DT_STRING + } + attr { + name: "T" + type: "type" + description: "The type of the input tensor." + } + summary: "Transforms a Tensor into a serialized TensorProto proto." +} op { name: "SetSize" input_arg { @@ -24451,6 +24554,54 @@ op { summary: "Return a slice from \'input\'." description: "The output tensor is a tensor with dimensions described by \'size\'\nwhose values are extracted from \'input\' starting at the offsets in\n\'begin\'.\n\n*Requirements*:\n 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n)" } +op { + name: "SloppyInterleaveDataset" + input_arg { + name: "input_dataset" + type: DT_RESOURCE + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "cycle_length" + type: DT_INT64 + } + input_arg { + name: "block_length" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_RESOURCE + } + attr { + name: "f" + type: "func" + description: "A function mapping elements of `input_dataset`, concatenated with\n`other_arguments`, to a Dataset resource that contains elements matching\n`output_types` and `output_shapes`." + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`." + description: "The resulting dataset is similar to the `InterleaveDataset`, with the exception\nthat if retrieving the next value from a dataset would cause the requester to\nblock, it will skip that input dataset. This dataset is especially useful\nwhen loading data from a variable-latency datastores (e.g. HDFS, GCS), as it\nallows the training step to proceed so long as some data is available.\n\n!! WARNING !! This dataset is not deterministic!" + is_stateful: true +} op { name: "Softmax" input_arg { @@ -24621,7 +24772,7 @@ op { } output_arg { name: "backprops" - description: "The gradients: `gradients / (1 + abs(-features)) ** 2`." + description: "The gradients: `gradients / (1 + abs(features)) ** 2`." type_attr: "T" } attr { @@ -27704,6 +27855,42 @@ op { } summary: "Splits a tensor into `num_split` tensors along one dimension." } +op { + name: "SqlDataset" + input_arg { + name: "driver_name" + description: "The database type. Currently, the only supported type is \'sqlite\'." + type: DT_STRING + } + input_arg { + name: "data_source_name" + description: "A connection string to connect to the database." + type: DT_STRING + } + input_arg { + name: "query" + description: "A SQL query to execute." + type: DT_STRING + } + output_arg { + name: "handle" + type: DT_RESOURCE + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + summary: "Creates a dataset that executes a SQL query and emits rows of the result set." + is_stateful: true +} op { name: "Sqrt" input_arg { @@ -28788,6 +28975,10 @@ op { type: DT_HALF type: DT_FLOAT type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 type: DT_INT32 type: DT_INT64 type: DT_COMPLEX64 diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc index 2e605fdffcfbb2514a58af5b3f13adce95356e72..1f7ebe91cf0409ea52b1c2bbaf17ee57a4ca58ca 100644 --- a/tensorflow/core/ops/parsing_ops.cc +++ b/tensorflow/core/ops/parsing_ops.cc @@ -292,6 +292,19 @@ out_type: The type of the serialized tensor. The provided type must match the output: A Tensor of type `out_type`. )doc"); +REGISTER_OP("SerializeTensor") + .Input("tensor: T") + .Output("serialized: string") + .Attr("T: type") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Transforms a Tensor into a serialized TensorProto proto. + +tensor: A Tensor of type `T`. +T: The type of the input tensor. +serialized: A serialized TensorProto proto of the input tensor. +)doc"); + REGISTER_OP("DecodeJSONExample") .Input("json_examples: string") .Output("binary_examples: string") diff --git a/tensorflow/core/ops/summary_ops.cc b/tensorflow/core/ops/summary_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..f778b48797263e50e132ac369e70432276b7e8fb --- /dev/null +++ b/tensorflow/core/ops/summary_ops.cc @@ -0,0 +1,218 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); + +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("SummaryWriter") + .Output("writer: resource") + .Attr("shared_name: string = ''") + .Attr("container: string = ''") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Returns a handle to be used to access a summary writer. + +The summary writer is an in-graph resource which can be used by ops to write +summaries to event files. + +writer: the summary writer resource. Scalar handle. +)doc"); + +REGISTER_OP("CreateSummaryFileWriter") + .Input("writer: resource") + .Input("logdir: string") + .Input("max_queue: int32") + .Input("flush_millis: int32") + .Input("filename_suffix: string") + .Doc(R"doc( +Creates a summary file writer accessible by the given resource handle. + +writer: A handle to the summary writer resource +logdir: Directory where the event file will be written. +max_queue: Size of the queue of pending events and summaries. +flush_millis: How often, in milliseconds, to flush the pending events and + summaries to disk. +filename_suffix: Every event file's name is suffixed with this suffix. +)doc"); + +REGISTER_OP("FlushSummaryWriter") + .Input("writer: resource") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"( +Flushes the writer's unwritten events. + +writer: A handle to the summary writer resource. +)"); + +REGISTER_OP("CloseSummaryWriter") + .Input("writer: resource") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"( +Flushes and closes the summary writer. + +Also removes it from the resource manager. To reopen, use another +CreateSummaryFileWriter op. + +writer: A handle to the summary writer resource. +)"); + +REGISTER_OP("WriteSummary") + .Input("writer: resource") + .Input("global_step: int64") + .Input("tensor: T") + .Input("tag: string") + .Input("summary_metadata: string") + .Attr("T: type") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Outputs a `Summary` protocol buffer with a tensor. + +writer: A handle to a summary writer. +global_step: The step to write the summary for. +tensor: A tensor to serialize. +tag: The summary's tag. +summary_metadata: Serialized SummaryMetadata protocol buffer containing + plugin-related metadata for this summary. +)doc"); + +REGISTER_OP("WriteScalarSummary") + .Input("writer: resource") + .Input("global_step: int64") + .Input("tag: string") + .Input("value: T") + .Attr("T: realnumbertype") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Writes a `Summary` protocol buffer with scalar values. + +The input `tag` and `value` must have the scalars. + +writer: A handle to a summary writer. +global_step: The step to write the summary for. +tag: Tag for the summary. +value: Value for the summary. +)doc"); + +REGISTER_OP("WriteHistogramSummary") + .Input("writer: resource") + .Input("global_step: int64") + .Input("tag: string") + .Input("values: T") + .Attr("T: realnumbertype = DT_FLOAT") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Writes a `Summary` protocol buffer with a histogram. + +The generated +[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) +has one summary value containing a histogram for `values`. + +This op reports an `InvalidArgument` error if any value is not finite. + +writer: A handle to a summary writer. +global_step: The step to write the summary for. +tag: Scalar. Tag to use for the `Summary.Value`. +values: Any shape. Values to use to build the histogram. +)doc"); + +REGISTER_OP("WriteImageSummary") + .Input("writer: resource") + .Input("global_step: int64") + .Input("tag: string") + .Input("tensor: T") + .Input("bad_color: uint8") + .Attr("max_images: int >= 1 = 3") + .Attr("T: {uint8, float, half} = DT_FLOAT") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Writes a `Summary` protocol buffer with images. + +The summary has up to `max_images` summary values containing images. The +images are built from `tensor` which must be 4-D with shape `[batch_size, +height, width, channels]` and where `channels` can be: + +* 1: `tensor` is interpreted as Grayscale. +* 3: `tensor` is interpreted as RGB. +* 4: `tensor` is interpreted as RGBA. + +The images have the same number of channels as the input tensor. For float +input, the values are normalized one image at a time to fit in the range +`[0, 255]`. `uint8` values are unchanged. The op uses two different +normalization algorithms: + +* If the input values are all positive, they are rescaled so the largest one + is 255. + +* If any input value is negative, the values are shifted so input value 0.0 + is at 127. They are then rescaled so that either the smallest value is 0, + or the largest one is 255. + +The `tag` argument is a scalar `Tensor` of type `string`. It is used to +build the `tag` of the summary values: + +* If `max_images` is 1, the summary value tag is '*tag*/image'. +* If `max_images` is greater than 1, the summary value tags are + generated sequentially as '*tag*/image/0', '*tag*/image/1', etc. + +The `bad_color` argument is the color to use in the generated images for +non-finite input values. It is a `unit8` 1-D tensor of length `channels`. +Each element must be in the range `[0, 255]` (It represents the value of a +pixel in the output image). Non-finite values in the input tensor are +replaced by this tensor in the output image. The default value is the color +red. + +writer: A handle to a summary writer. +global_step: The step to write the summary for. +tag: Scalar. Used to build the `tag` attribute of the summary values. +tensor: 4-D of shape `[batch_size, height, width, channels]` where + `channels` is 1, 3, or 4. +max_images: Max number of batch elements to generate images for. +bad_color: Color to use for pixels with non-finite values. +)doc"); + +REGISTER_OP("WriteAudioSummary") + .Input("writer: resource") + .Input("global_step: int64") + .Input("tag: string") + .Input("tensor: float") + .Input("sample_rate: float") + .Attr("max_outputs: int >= 1 = 3") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Writes a `Summary` protocol buffer with audio. + +The summary has up to `max_outputs` summary values containing audio. The +audio is built from `tensor` which must be 3-D with shape `[batch_size, +frames, channels]` or 2-D with shape `[batch_size, frames]`. The values are +assumed to be in the range of `[-1.0, 1.0]` with a sample rate of `sample_rate`. + +The `tag` argument is a scalar `Tensor` of type `string`. It is used to +build the `tag` of the summary values: + +* If `max_outputs` is 1, the summary value tag is '*tag*/audio'. +* If `max_outputs` is greater than 1, the summary value tags are + generated sequentially as '*tag*/audio/0', '*tag*/audio/1', etc. + +writer: A handle to a summary writer. +global_step: The step to write the summary for. +tag: Scalar. Used to build the `tag` attribute of the summary values. +tensor: 2-D of shape `[batch_size, frames]`. +sample_rate: The sample rate of the signal in hertz. +max_outputs: Max number of batch elements to generate audio for. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 126558cac38e5809b950a24a6c5e81db085ff307..e1ad66c387a221c54f18d642ea20a66452222398 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -75,6 +75,9 @@ def tf_proto_library_py(name, srcs=[], protodeps=[], deps=[], visibility=[], def tf_jspb_proto_library(**kwargs): pass +def tf_nano_proto_library(**kwargs): + pass + def tf_proto_library(name, srcs = [], has_services = None, protodeps = [], visibility = [], testonly = 0, cc_libs = [], diff --git a/tensorflow/core/platform/default/logging.cc b/tensorflow/core/platform/default/logging.cc index ac0988e70474661a867377738a68019c4723b890..ebdd4b624aa423983cdeb2d31c0bf27ff30c89e2 100644 --- a/tensorflow/core/platform/default/logging.cc +++ b/tensorflow/core/platform/default/logging.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/default/logging.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/env_time.h" #include "tensorflow/core/platform/macros.h" @@ -25,12 +24,8 @@ limitations under the License. #endif #include -#include #include -#include -#include - namespace tensorflow { namespace internal { @@ -88,11 +83,11 @@ void LogMessage::GenerateLogMessage() { const size_t time_buffer_size = 30; char time_buffer[time_buffer_size]; strftime(time_buffer, time_buffer_size, "%Y-%m-%d %H:%M:%S", - localtime(&now_seconds)); + localtime(&now_seconds)); // TODO(jeff,sanjay): Replace this with something that logs through the env. fprintf(stderr, "%s.%06d: %c %s:%d] %s\n", time_buffer, micros_remainder, - "IWEF"[severity_], fname_, line_, str().c_str()); + "IWEF"[severity_], fname_, line_, str().c_str()); } #endif @@ -129,48 +124,6 @@ int64 MinVLogLevelFromEnv() { return LogLevelStrToInt(tf_env_var_val); } -using VmoduleMap = std::unordered_map; - -// Returns a mapping from module name to VLOG level, derived from the -// TF_CPP_VMOUDLE environment variable; ownership is transferred to the caller. -VmoduleMap* VmoduleRecordsFromEnv() { - // The value of the env var is supposed to be of the form: - // "foo=1,bar=2,baz=3" - const char* tf_env_var_val = getenv("TF_CPP_VMODULE"); - auto* result = new VmoduleMap(); - if (tf_env_var_val == nullptr) return result; - while (true) { - const char* eq = strchr(tf_env_var_val, '='); - if (eq == nullptr) break; - const char* after_eq = eq + 1; - - // Comma either points at the next comma delimiter, or at a null terminator. - // We check that the integer we parse ends at this delimiter. - const char* comma = strchr(after_eq, ','); - const char* new_tf_env_var_val; - if (comma == nullptr) { - comma = strchr(after_eq, '\0'); - new_tf_env_var_val = comma; - } else { - new_tf_env_var_val = comma + 1; - } - - char* endptr = nullptr; - int level = strtol(after_eq, &endptr, 10); - if (endptr != comma) { - fprintf(stderr, - "warning: could not parse integer in vmodule specification in " - "\"%s\".\n", - after_eq); - break; - } - StringPiece module(tf_env_var_val, eq - tf_env_var_val); - tf_env_var_val = new_tf_env_var_val; - (*result)[module] = level; - } - return result; -} - } // namespace LogMessage::~LogMessage() { @@ -184,19 +137,6 @@ int64 LogMessage::MinVLogLevel() { return min_vlog_level; } -bool LogMessage::VmoduleActivated(const char* fname, int lvl) { - static VmoduleMap* vmodule_records = VmoduleRecordsFromEnv(); - const char* last_slash = strrchr(fname, '/'); - const char* module_start = last_slash == nullptr ? fname : last_slash + 1; - const char* dot_after = strchr(module_start, '.'); - const char* module_limit = - dot_after == nullptr ? strchr(fname, '\0') : dot_after; - StringPiece module(module_start, module_limit - module_start); - auto it = vmodule_records->find(module); - if (it == vmodule_records->end()) return false; - return it->second >= lvl; -} - LogMessageFatal::LogMessageFatal(const char* file, int line) : LogMessage(file, line, FATAL) {} LogMessageFatal::~LogMessageFatal() { diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h index c8c9b2da11a19e80e17aa2ddb585e0f8c15d8982..d5f7350cdd805eb71edab0fde72db8383c32addb 100644 --- a/tensorflow/core/platform/default/logging.h +++ b/tensorflow/core/platform/default/logging.h @@ -46,16 +46,6 @@ class LogMessage : public std::basic_ostringstream { // but VLOG(3) will not. Defaults to 0. static int64 MinVLogLevel(); - // Returns whether VLOG level lvl is activated for the file fname. - // - // E.g. if the environment variable TF_CPP_VMODULE contains foo=3 and fname is - // foo.cc and lvl is <= 3, this will return true. - // - // It is expected that the result of this query will be cached in the VLOG-ing - // call site to avoid repeated lookups. This routine performs a hash-map - // access against the VLOG-ing specification provided by the env var. - static bool VmoduleActivated(const char* fname, int lvl); - protected: void GenerateLogMessage(); @@ -86,38 +76,18 @@ class LogMessageFatal : public LogMessage { #define LOG(severity) _TF_LOG_##severity -#if defined(IS_MOBILE_PLATFORM) - +#ifdef IS_MOBILE_PLATFORM // Turn VLOG off when under mobile devices for considerations of binary size. -#define _VLOG_IS_ON(lvl, file) ((lvl) <= 0) - -#elif defined(PLATFORM_WINDOWS) - -// TODO(b/64279502) The _VLOG_IS_ON definition below appears to cause MSVC to -// fatal error, so we fall back to the vmodule-less implementation for now. -#define _VLOG_IS_ON(lvl, file) \ - ((lvl) <= ::tensorflow::internal::LogMessage::MinVLogLevel()) - +#define VLOG_IS_ON(lvl) ((lvl) <= 0) #else - -// Otherwise, set TF_CPP_MIN_VLOG_LEVEL environment to update minimum log level -// of VLOG, or TF_CPP_VMODULE to set the minimum log level for individual -// translation units. -#define _VLOG_IS_ON(lvl, file) \ - (([](int level, const char* fname) { \ - if (level <= ::tensorflow::internal::LogMessage::MinVLogLevel()) \ - return true; \ - static bool vmodule_activated = \ - ::tensorflow::internal::LogMessage::VmoduleActivated(fname, level); \ - return vmodule_activated; \ - })(lvl, file)) - +// Otherwise, Set TF_CPP_MIN_VLOG_LEVEL environment to update minimum log level +// of VLOG +#define VLOG_IS_ON(lvl) \ + ((lvl) <= ::tensorflow::internal::LogMessage::MinVLogLevel()) #endif -#define VLOG_IS_ON(lvl) _VLOG_IS_ON(lvl, __FILE__) - -#define VLOG(lvl) \ - if (TF_PREDICT_FALSE(_VLOG_IS_ON(lvl, __FILE__))) \ +#define VLOG(lvl) \ + if (TF_PREDICT_FALSE(VLOG_IS_ON(lvl))) \ ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::INFO) // CHECK dies with a fatal error if condition is not true. It is *not* diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc index 50dd0cd58b88a478a0c8656e98cfa5eb5458ff85..c9b362f18235f8ddec0994bc1110aaec950eef72 100644 --- a/tensorflow/core/platform/env_test.cc +++ b/tensorflow/core/platform/env_test.cc @@ -226,14 +226,28 @@ TEST_F(DefaultEnvTest, RecursivelyCreateDirSubdirsExist) { TEST_F(DefaultEnvTest, LocalFileSystem) { // Test filename with file:// syntax. + int expected_num_files = 0; + std::vector matching_paths; for (const int length : {0, 1, 1212, 2553, 4928, 8196, 9000, (1 << 20) - 1, 1 << 20, (1 << 20) + 1}) { - string filename = io::JoinPath(BaseDir(), strings::StrCat("file", length)); + string filename = io::JoinPath(BaseDir(), strings::StrCat("len", length)); filename = strings::StrCat("file://", filename); // Write a file with the given length const string input = CreateTestFile(env_, filename, length); + ++expected_num_files; + + // Ensure that GetMatchingPaths works as intended. + TF_EXPECT_OK(env_->GetMatchingPaths( + // Try it with the "file://" URI scheme. + strings::StrCat("file://", io::JoinPath(BaseDir(), "l*")), + &matching_paths)); + EXPECT_EQ(expected_num_files, matching_paths.size()); + TF_EXPECT_OK(env_->GetMatchingPaths( + // Try it without any URI scheme. + io::JoinPath(BaseDir(), "l*"), &matching_paths)); + EXPECT_EQ(expected_num_files, matching_paths.size()); // Read the file back and check equality string output; diff --git a/tensorflow/core/platform/vmodule_test.cc b/tensorflow/core/platform/vmodule_test.cc deleted file mode 100644 index 47b4b2e0e78f4710db0742981f23f16cad5cbbf8..0000000000000000000000000000000000000000 --- a/tensorflow/core/platform/vmodule_test.cc +++ /dev/null @@ -1,117 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Test that popens a child process with the VLOG-ing environment variable set -// for the logging framework, and observes VLOG_IS_ON and VLOG macro output. - -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/platform.h" -#include "tensorflow/core/platform/test.h" - -#include - -namespace tensorflow { -namespace { - -int RealMain(const char* argv0, bool do_vlog) { - if (do_vlog) { -#if !defined(PLATFORM_GOOGLE) - // Note, we only test this when !defined(PLATFORM_GOOGLE) because - // VmoduleActivated doesn't exist in that implementation. - // - // Also, we call this internal API to simulate what would happen if - // differently-named translation units attempted to VLOG, so we don't need - // to create dummy translation unit files. - bool ok = internal::LogMessage::VmoduleActivated("vmodule_test.cc", 7) && - internal::LogMessage::VmoduleActivated("shoobadooba.h", 3); - if (!ok) { - fprintf(stderr, "vmodule activated levels not as expected.\n"); - return EXIT_FAILURE; - } -#endif - - // Print info on which VLOG levels are activated. - fprintf(stderr, "VLOG_IS_ON(8)? %d\n", VLOG_IS_ON(8)); - fprintf(stderr, "VLOG_IS_ON(7)? %d\n", VLOG_IS_ON(7)); - fprintf(stderr, "VLOG_IS_ON(6)? %d\n", VLOG_IS_ON(6)); - // Do some VLOG-ing. - VLOG(8) << "VLOG(8)"; - VLOG(7) << "VLOG(7)"; - VLOG(6) << "VLOG(6)"; - LOG(INFO) << "INFO"; - return EXIT_SUCCESS; - } - - // Popen the child process. - std::string command = std::string(argv0); -#if defined(PLATFORM_GOOGLE) - command = command + " do_vlog --vmodule=vmodule_test=7 --alsologtostderr"; -#else - command = - "TF_CPP_VMODULE=vmodule_test=7,shoobadooba=3 " + command + " do_vlog"; -#endif - command += " 2>&1"; - fprintf(stderr, "Running: \"%s\"\n", command.c_str()); - FILE* f = popen(command.c_str(), "r"); - if (f == nullptr) { - fprintf(stderr, "Failed to popen child: %s\n", strerror(errno)); - return EXIT_FAILURE; - } - - // Read data from the child's stdout. - constexpr int kBufferSizeBytes = 4096; - char buffer[kBufferSizeBytes]; - size_t result = fread(buffer, sizeof(buffer[0]), kBufferSizeBytes - 1, f); - if (result == 0) { - fprintf(stderr, "Failed to read from child stdout: %zu %s\n", result, - strerror(errno)); - return EXIT_FAILURE; - } - buffer[result] = '\0'; - int status = pclose(f); - if (status == -1) { - fprintf(stderr, "Failed to close popen child: %s\n", strerror(errno)); - return EXIT_FAILURE; - } - - // Check output is as expected. - const char kExpected[] = - "VLOG_IS_ON(8)? 0\nVLOG_IS_ON(7)? 1\nVLOG_IS_ON(6)? 1\n"; - if (strstr(buffer, kExpected) == nullptr) { - fprintf(stderr, "error: unexpected output from child: \"%.*s\"\n", - kBufferSizeBytes, buffer); - return EXIT_FAILURE; - } - bool ok = strstr(buffer, "VLOG(7)\n") != nullptr && - strstr(buffer, "VLOG(6)\n") != nullptr && - strstr(buffer, "VLOG(8)\n") == nullptr; - if (!ok) { - fprintf(stderr, "error: VLOG output not as expected: \"%.*s\"\n", - kBufferSizeBytes, buffer); - return EXIT_FAILURE; - } - - // Success! - return EXIT_SUCCESS; -} - -} // namespace -} // namespace tensorflow - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - bool do_vlog = argc >= 2 && strcmp(argv[1], "do_vlog") == 0; - return tensorflow::RealMain(argv[0], do_vlog); -} diff --git a/tensorflow/core/profiler/README.md b/tensorflow/core/profiler/README.md index 5c50a86c88ff597326db699d826b0997104431d5..f0d4dafd3eafb6b3bdd47cd6925ebaca20400b05 100644 --- a/tensorflow/core/profiler/README.md +++ b/tensorflow/core/profiler/README.md @@ -56,7 +56,7 @@ with tf.contrib.tfprof.ProfileContext() as pctx: ```shell # Profiling from Python API is not interactive. -# Dump the profiles to files and profile with interactive command line. +# Dump the profiles to files and profile with interactive command line or web UI. with tf.contrib.tfprof.ProfileContext() as pctx: pctx.add_auto_profile_dump('/tmp/profiles', [100]) train_loop() @@ -66,7 +66,15 @@ bazel-bin/tensorflow/core/profiler/profiler \ --run_meta_path=/tmp/profiles/run_meta \ --op_log_path=/tmp/profiles/tfprof_log \ tfprof> op -select micros,bytes,occurrence -order_by micros + + +# To be open sourced... +bazel-bin/third_party/tensorflow/python/profiler/profiler_ui \ + --graph_path=/tmp/profiles/graph.pbtxt \ + --run_meta_path=/tmp/profiles/run_meta \ + --op_log_path=/tmp/profiles/tfprof_log \ ``` +![ProfilerUI](g3doc/profiler_ui.jpg) Detail Tutorials @@ -239,5 +247,6 @@ bug fix. `OpLogProto` is a good plus if it is used. #### Teams * Xin Pan (xpan@google.com, github: panyx0718) +* Chris Antaki * Yao Zhang * Jon Shlens diff --git a/tensorflow/core/profiler/g3doc/advise.md b/tensorflow/core/profiler/g3doc/advise.md index d87b0d8603d12546dfe12ab459ebe407e9810f88..d0de8317f6950a89567b6d3c5705c42fcc8f4653 100644 --- a/tensorflow/core/profiler/g3doc/advise.md +++ b/tensorflow/core/profiler/g3doc/advise.md @@ -86,7 +86,7 @@ For example: * Checks RecvTensor RPC latency and bandwidth. * Checks CPU/Memory utilization of the job. -####AcceleratorUtilization Checker +#### AcceleratorUtilization Checker * Checks what percentage of time the accelerator spends on computation. #### OperationChecker @@ -100,7 +100,7 @@ For example: * Checks the most expensive graph nodes. * Checks the most expensive graph-building Python codes. -####Contribute Your Checker +#### Contribute Your Checker Follow examples of accelerator_utilization_checker.h diff --git a/tensorflow/core/profiler/g3doc/command_line.md b/tensorflow/core/profiler/g3doc/command_line.md index 857b5e64590db193baa6e7f836634745f35eb5dc..fb4207c784170ff998c5cd94a84e7d1c625a0ac5 100644 --- a/tensorflow/core/profiler/g3doc/command_line.md +++ b/tensorflow/core/profiler/g3doc/command_line.md @@ -51,13 +51,13 @@ It defines _checkpoint_variable op type. It also provides checkpointed tensors' Note: this feature is not well maintained now. -###Start `tfprof` +### Start `tfprof` #### Build `tfprof` ```shell # Build the tool. -bazel build --config opt third_party/tensorflow/core/profiler/... +bazel build --config opt tensorflow/core/profiler:profiler # Help information, including detail 'option' instructions. bazel-bin/tensorflow/core/profiler/profiler help @@ -140,9 +140,9 @@ tfprof> -output ``` -###Examples +### Examples -####Profile Python Time +#### Profile Python Time ```shell # Requires --graph_path --op_log_path tfprof> code -max_depth 1000 -show_name_regexes .*model_analyzer.*py.* -select micros -account_type_regexes .* -order_by micros diff --git a/tensorflow/core/profiler/g3doc/options.md b/tensorflow/core/profiler/g3doc/options.md index 15712d04c25bdf9fdb012a4129c706c2f1124483..ddee63ad42ade86a6951045e23a290e14756deec 100644 --- a/tensorflow/core/profiler/g3doc/options.md +++ b/tensorflow/core/profiler/g3doc/options.md @@ -1,6 +1,6 @@ -##Options +## Options -###Overview +### Overview For all tfprof views, the profiles are processed with the following procedures @@ -35,7 +35,7 @@ For all tfprof views, the profiles are processed with the following procedures 4) Finally, the filtered data structure is output in a format depending on the `-output` option. -####Option Semantics In Different View +#### Option Semantics In Different View options usually have the same semantics in different views. However, some can vary. For example `-max_depth` in scope view means the depth of name scope tree. In op view, it means the length of operation list. @@ -68,7 +68,7 @@ output_bytes: The memory output by the operation. It's not necessarily requested by the current operation. For example, it can be a tensor forwarded from input to output, with in-place mutation. -###Docs +### Docs `-max_depth`: Show nodes that are at most this number of hops from starting node in the data structure. diff --git a/tensorflow/core/profiler/g3doc/profile_memory.md b/tensorflow/core/profiler/g3doc/profile_memory.md index a00683d0626759cbaf707c2d6465f0fb7885d082..6eda5abdd973ece435855b0952a5edd4a86b8217 100644 --- a/tensorflow/core/profiler/g3doc/profile_memory.md +++ b/tensorflow/core/profiler/g3doc/profile_memory.md @@ -1,4 +1,4 @@ -##Profile Memory +## Profile Memory It is generally a good idea to visualize the memory usage in timeline. It allows you to see the memory consumption of each GPU over time. diff --git a/tensorflow/core/profiler/g3doc/profile_model_architecture.md b/tensorflow/core/profiler/g3doc/profile_model_architecture.md index a42b2e918da0c40e3c8caff1400b25927b46d9c9..61bb66bd21b336074475142ee564414ee154cafc 100644 --- a/tensorflow/core/profiler/g3doc/profile_model_architecture.md +++ b/tensorflow/core/profiler/g3doc/profile_model_architecture.md @@ -1,9 +1,9 @@ -##Profile Model Architecture +## Profile Model Architecture * [Profile Model Parameters](#profile-model-parameters) * [Profile Model Float Operations](#profile-model-float-operations) -###Profile Model Parameters +### Profile Model Parameters Notes: `VariableV2` operation type might contain variables created by TensorFlow @@ -39,9 +39,9 @@ param_stats = tf.profiler.profile( sys.stdout.write('total_params: %d\n' % param_stats.total_parameters) ``` -###Profile Model Float Operations +### Profile Model Float Operations -####Caveats +#### Caveats For an operation to have float operation statistics: diff --git a/tensorflow/core/profiler/g3doc/profile_time.md b/tensorflow/core/profiler/g3doc/profile_time.md index e11a75553b23ccb0c5e698e25e4b22afd1e364d7..4aafc697a9b3f437d0d555d14c99aec7afcd836d 100644 --- a/tensorflow/core/profiler/g3doc/profile_time.md +++ b/tensorflow/core/profiler/g3doc/profile_time.md @@ -1,4 +1,4 @@ -##Profile Time +## Profile Time * [Times in TensorFlow and tfprof](#times-in-tensorflow-and-tfprof) * [Profile by Python Code](#profile-by-python-code) @@ -7,7 +7,7 @@ * [Profile by Name Scope](#profile-by-name-scope) -###Times in TensorFlow and tfprof +### Times in TensorFlow and tfprof When we run a model, Tensorflow schedules and runs the nodes (operations) in the graph. An operation can be placed on an accelerator or on CPU. @@ -37,7 +37,7 @@ When an operation is placed on CPU, it will completely run on CPU. Hence, should be 0. -###Profile by Python Code +### Profile by Python Code ```python # In code view, the time of each line of Python code is the aggregated # times of all operations created by that line. @@ -112,7 +112,7 @@ Set ```-output timeline:outfile=``` to generate timeline instead of st -###Profile by Operation Type +### Profile by Operation Type ```python # In op view, you can view the aggregated time of each operation type. tfprof> op -select micros,occurrence -order_by micros @@ -138,7 +138,7 @@ MatMul 618.97ms (63.56%, 16.51%), |/job:worker/replica:0/ ``` -###Profile by Graph +### Profile by Graph Usually, use graph view to generate a timeline to visualize the result. @@ -163,7 +163,7 @@ Open a Chrome browser, enter URL chrome://tracing and load the timeline file. ****************************************************** ``` -###Profile by Name Scope +### Profile by Name Scope Usually scope view allows you to pin point the problematic places if you have properly named your operations with tf.name_scope or tf.variable_scope. diff --git a/tensorflow/core/profiler/g3doc/profiler_ui.jpg b/tensorflow/core/profiler/g3doc/profiler_ui.jpg new file mode 100644 index 0000000000000000000000000000000000000000..36aa94502a8c3de7915fb0e388c861cd706c3af8 Binary files /dev/null and b/tensorflow/core/profiler/g3doc/profiler_ui.jpg differ diff --git a/tensorflow/core/profiler/internal/print_model_analysis.cc b/tensorflow/core/profiler/internal/print_model_analysis.cc index 65b54f01aa11b1b4f7e61f60b74e83ffd43e6be5..fd46b957e85257ec4107fcf4ed4a656e48ea1d7e 100644 --- a/tensorflow/core/profiler/internal/print_model_analysis.cc +++ b/tensorflow/core/profiler/internal/print_model_analysis.cc @@ -87,7 +87,11 @@ bool NewProfiler(const string* graph, const string* op_log) { CHECK(!tf_stat) << "Currently only 1 living tfprof profiler is allowed"; CHECK(graph) << "graph mustn't be null"; std::unique_ptr graph_ptr(new GraphDef()); - graph_ptr->ParseFromString(*graph); + if (!graph_ptr->ParseFromString(*graph)) { + if (!protobuf::TextFormat::ParseFromString(*graph, graph_ptr.get())) { + fprintf(stderr, "Failed to parse graph\n"); + } + } std::unique_ptr op_log_ptr; if (op_log && !op_log->empty()) { diff --git a/tensorflow/core/profiler/internal/tfprof_code.cc b/tensorflow/core/profiler/internal/tfprof_code.cc index 7f4d682cdaba49556cecf281bb4a445d1e51ae08..c9c0baa908f60a55d93626c3c24ef8bd67c9bc84 100644 --- a/tensorflow/core/profiler/internal/tfprof_code.cc +++ b/tensorflow/core/profiler/internal/tfprof_code.cc @@ -44,11 +44,6 @@ string GetTraceString(const CodeDef::Trace& trace) { } else { ntrace += ":" + trace.function().substr(0, 17) + "..."; } - if (trace.line().length() < 20) { - ntrace += ":" + trace.line(); - } else { - ntrace += ":" + trace.line().substr(0, 17) + "..."; - } return ntrace; } diff --git a/tensorflow/core/profiler/internal/tfprof_stats.cc b/tensorflow/core/profiler/internal/tfprof_stats.cc index 012c49525a5631156da86344dc269c850936946f..81db76679691715ee711b069e3ee7438d07e7c7d 100644 --- a/tensorflow/core/profiler/internal/tfprof_stats.cc +++ b/tensorflow/core/profiler/internal/tfprof_stats.cc @@ -125,7 +125,7 @@ const MultiGraphNodeProto& TFStats::ShowMultiGraphNode( if (!Validate(opts)) { return empty_multi_graph_node_; } - if (cmd == kCmds[2]) { + if (cmd == kCmds[2] && has_code_traces()) { return code_view_->Show(opts); } else if (cmd == kCmds[3]) { return op_view_->Show(opts); diff --git a/tensorflow/core/util/activation_mode.cc b/tensorflow/core/util/activation_mode.cc index 4bf947a0a9abd12fa73898591fec9b066f1d5a8a..efb5ab146aa4cdc0114f5d18a5d0542b94c5abc4 100644 --- a/tensorflow/core/util/activation_mode.cc +++ b/tensorflow/core/util/activation_mode.cc @@ -22,7 +22,9 @@ namespace tensorflow { Status GetActivationModeFromString(const string& str_value, ActivationMode* value) { - if (str_value == "Sigmoid") { + if (str_value == "None") { + *value = NONE; + } else if (str_value == "Sigmoid") { *value = SIGMOID; } else if (str_value == "Relu") { *value = RELU; diff --git a/tensorflow/core/util/activation_mode.h b/tensorflow/core/util/activation_mode.h index 2a8564847dd5b6020ca0f779c17974b07ee0d51b..2e03ccd5c85d16d058d34dac7d6217167c08f7ba 100644 --- a/tensorflow/core/util/activation_mode.h +++ b/tensorflow/core/util/activation_mode.h @@ -28,6 +28,7 @@ namespace tensorflow { // ActivationMode: the activation function we apply to the input tensor: enum ActivationMode { + NONE = 0, SIGMOID = 1, RELU = 2, RELU6 = 3, diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc index 6139c8e7bcd1015e17b796896404ccf33064123f..ad8c824461a1b36598e3cff20e5b69517cde3fc7 100644 --- a/tensorflow/core/util/command_line_flags_test.cc +++ b/tensorflow/core/util/command_line_flags_test.cc @@ -224,6 +224,25 @@ TEST(CommandLineFlagsTest, FailedStringHook) { EXPECT_EQ(argc, 1); } +TEST(CommandLineFlagsTest, RepeatedStringHook) { + int argc = 3; + std::vector argv_strings = {"program_name", "--some_name=this", + "--some_name=that"}; + std::vector argv_array = CharPointerVectorFromStrings(argv_strings); + int call_count = 0; + bool parsed_ok = Flags::Parse(&argc, argv_array.data(), + {Flag("some_name", + [&call_count](string value) { + call_count++; + return true; + }, + "", "some name")}); + + EXPECT_EQ(true, parsed_ok); + EXPECT_EQ(argc, 1); + EXPECT_EQ(call_count, 2); +} + // Return whether str==pat, but allowing any whitespace in pat // to match zero or more whitespace characters in str. static bool MatchWithAnyWhitespace(const string &str, const string &pat) { diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index cb22a50e8f1151519c35ed55cef76183348c2dbd..f4bec9524adb72997bcfb2b776c6ab7fe30daf33 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -65,6 +65,8 @@ class MklShape { void SetDimensions(const size_t dimension) { dimension_ = dimension; } + void SetMklLayout(dnnLayout_t mklLayout) { mklLayout_ = mklLayout; } + void SetMklLayout(const void* primitive, size_t resourceType) { CHECK_EQ( dnnLayoutCreateFromPrimitive_F32(&mklLayout_, (dnnPrimitive_t)primitive, @@ -135,6 +137,7 @@ class MklShape { size_t GetDimension() const { return dimension_; } const size_t* GetSizes() const { return sizes_; } int64 dim_size(int index) const { return sizes_[index]; } + int64 tf_dim_size(int index) const { return sizes_[tf_to_mkl_dim_map_[index]]; } const size_t* GetStrides() const { return strides_; } const size_t* GetTfToMklDimMap() const { return tf_to_mkl_dim_map_; } size_t tf_dim_idx(int index) const { return tf_to_mkl_dim_map_[index]; } @@ -581,7 +584,7 @@ inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, context->set_output(idx_data_out, output); } -inline void FowardTfTensorInToOut(OpKernelContext* context, +inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in, int idx_out) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); @@ -598,7 +601,7 @@ inline void FowardTfTensorInToOut(OpKernelContext* context, } } -inline void ForwarMklTensorInToOut(OpKernelContext* context, +inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in, int idx_out) { int num_inputs = context->num_inputs(); int num_outputs = context->num_outputs(); @@ -616,6 +619,98 @@ inline void ForwarMklTensorInToOut(OpKernelContext* context, } } +// Forward the MKL shape ONLY (used in elementwise and other ops where +// we call the eigen implementation and MKL shape is not used) +inline void ForwardMklMetaDataInToOut(OpKernelContext* context, + uint idx_data_in, uint idx_data_out) { + uint idx_meta_in = GetTensorMetaDataIndex(idx_data_in, context->num_inputs()); + uint idx_meta_out = + GetTensorMetaDataIndex(idx_data_out, context->num_outputs()); + + if (IsRefType(context->input_dtype(idx_data_in))) { + context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out); + } else { + context->set_output(idx_meta_out, context->input(idx_meta_in)); + } +} + +// Set a dummy MKL shape (called when the output is in TF format) +inline void SetDummyMklShapeOutput(OpKernelContext* context, + uint idx_data_out) { + MklShape mkl_shape_output; + mkl_shape_output.SetMklTensor(false); + AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output); +} + +// Checks if the TF shape for both MKL tensors is the same or not +// Returns: true if both TF shapes are the same, false otherwise +inline bool MklCompareShapes(const MklShape* input_shape_0, + const MklShape* input_shape_1) { + // Check for number of dimensions + if (input_shape_0->GetDimension() != input_shape_1->GetDimension()) { + return false; + } + + // Check size of each dimension + size_t ndims = input_shape_0->GetDimension(); + for (size_t i = 0; i < ndims; i++) { + if (input_shape_0->dim_size(i) != input_shape_1->dim_size(i)) { + return false; + } + } + + return true; +} + +// Checks if the TF shape for both tensors is the same or not +// Returns: true if TF shapes for both are the same, false otherwise +inline bool MklCompareShapes(const MklShape* input_shape_0, + const TensorShape* input_shape_1) { + // Check for number of dimensions + if (input_shape_0->GetDimension() != input_shape_1->dims()) { + return false; + } + + // Check size of each dimension + size_t ndims = input_shape_0->GetDimension(); + for (size_t i = 0; i < ndims; i++) { + if (input_shape_0->tf_dim_size(i) != input_shape_1->dim_size(i)) { + return false; + } + } + + return true; +} + +// Checks if the TF shape for both tensors is the same or not +// Returns: true if TF shapes for both are the same, false otherwise +inline bool MklCompareShapes(const TensorShape* input_shape_0, + const MklShape* input_shape_1) { + return MklCompareShapes(input_shape_1, input_shape_0); +} + +// Checks if the TF shape for both tensors is the same or not +// Returns: true if TF shapes for both are the same, false otherwise +inline bool MklCompareShapes(const TensorShape* input_shape_0, + const TensorShape* input_shape_1) { + // Check for number of dimensions + if (input_shape_0->dims() != input_shape_1->dims()) { + return false; + } + + // Check size of each dimension + size_t ndims = input_shape_0->dims(); + for (size_t i = 0; i < ndims; i++) { + if (input_shape_0->dim_size(i) != input_shape_1->dim_size(i)) { + return false; + } + } + + return true; +} + +// TODO(intel_tf): Remove this routine when faster MKL layout conversion is +// out. inline void MklNHWCToNCHW(const Tensor& input, Tensor** output) { const float* buf_in = input.flat().data(); float* buf_out = (*output)->flat().data(); @@ -652,11 +747,19 @@ namespace mkl_op_registry { static const char* kMklOpLabel = "MklOp"; static const char* kMklOpLabelPattern = "label='MklOp'"; +// Get the name of Mkl op from original TensorFlow op +// We prefix 'Mkl' to the original op to get Mkl op. +inline string GetMklOpName(const string& name) { + // Prefix that we add to Tensorflow op name to construct Mkl op name. + const char* const kMklOpPrefix = "_Mkl"; + return string(kMklOpPrefix) + name; +} + // Check whether opname with type T is registered as MKL-compliant. // // @input: name of the op // @input: T datatype to be used for checking op -// @return: true if opname is registered as Mkl op +// @return: true if opname is registered as Mkl op; false otherwise static inline bool IsMklOp(const std::string& op_name, DataType T) { string kernel = KernelsRegisteredForOp(op_name); bool result = @@ -667,6 +770,28 @@ static inline bool IsMklOp(const std::string& op_name, DataType T) { return result; } +// Check whether opname with type T is registered as MKL-compliant and +// is element-wise. +// +// @input: name of the op +// @input: T datatype to be used for checking op +// @return: true if opname is registered as element-wise Mkl op; false otherwise +static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) { + if (!IsMklOp(op_name, T)) { + return false; + } + + bool result = (0 == op_name.compare(GetMklOpName("Add")) || + 0 == op_name.compare(GetMklOpName("Sub")) || + 0 == op_name.compare(GetMklOpName("Mul")) || + 0 == op_name.compare(GetMklOpName("Maximum")) || + 0 == op_name.compare(GetMklOpName("SquaredDifference"))); + + VLOG(1) << "mkl_op_registry::" << op_name + << " is elementwise MKL op: " << result; + return result; +} + } // namespace mkl_op_registry } // namespace tensorflow diff --git a/tensorflow/core/util/permutation_input_iterator.h b/tensorflow/core/util/permutation_input_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..f6375b25157644cda97aa195958b60ac27b8a4d6 --- /dev/null +++ b/tensorflow/core/util/permutation_input_iterator.h @@ -0,0 +1,134 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_UTIL_PERMUTATION_INPUT_ITERATOR_H_ +#define TENSORFLOW_UTIL_PERMUTATION_INPUT_ITERATOR_H_ + +#include +#include + +namespace tensorflow { + +template +class PermutationInputIterator { + public: + // Required iterator traits + typedef PermutationInputIterator self_type; ///< My own type + typedef OffsetT difference_type; ///< Type to express the result of + ///< subtracting one iterator from another + typedef ValueType + value_type; ///< The type of the element the iterator can point to + typedef ValueType* pointer; ///< The type of a pointer to an element the + ///< iterator can point to + typedef ValueType reference; ///< The type of a reference to an element the + ///< iterator can point to + + typedef std::random_access_iterator_tag + iterator_category; ///< The iterator category + + private: + InputIteratorT input_itr; + IndexIteratorT index_itr; + + public: + /// Constructor + __host__ __device__ __forceinline__ PermutationInputIterator( + InputIteratorT input_itr, ///< Input iterator to wrap + IndexIteratorT index_itr) ///< Conversion functor to wrap + : input_itr(input_itr), index_itr(index_itr) {} + + /// Postfix increment + __host__ __device__ __forceinline__ self_type operator++(int) { + self_type retval = *this; + index_itr++; + return retval; + } + + /// Prefix increment + __host__ __device__ __forceinline__ self_type operator++() { + index_itr++; + return *this; + } + + /// Indirection + __host__ __device__ __forceinline__ reference operator*() const { + return input_itr[*index_itr]; + } + + /// Addition + template + __host__ __device__ __forceinline__ self_type operator+(Distance n) const { + self_type retval(input_itr, index_itr + n); + return retval; + } + + /// Addition assignment + template + __host__ __device__ __forceinline__ self_type& operator+=(Distance n) { + index_itr += n; + return *this; + } + + /// Subtraction + template + __host__ __device__ __forceinline__ self_type operator-(Distance n) const { + self_type retval(input_itr, index_itr - n); + return retval; + } + + /// Subtraction assignment + template + __host__ __device__ __forceinline__ self_type& operator-=(Distance n) { + index_itr -= n; + return *this; + } + + /// Distance + __host__ __device__ __forceinline__ difference_type + operator-(self_type other) const { + return index_itr - other.index_itr; + } + + /// Array subscript + template + __host__ __device__ __forceinline__ reference operator[](Distance n) const { + return input_itr[index_itr[n]]; + } + + /// Structure dereference + __host__ __device__ __forceinline__ pointer operator->() { + return input_itr + *index_itr; + } + + /// Equal to + __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) { + return (index_itr == rhs.index_itr && input_itr == rhs.input_itr); + } + + /// Not equal to + __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) { + return !(*this == rhs); + } + + /// ostream operator + friend std::ostream& operator<<(std::ostream& os, const self_type& itr) { + return os; + } +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_UTIL_PERMUTATION_INPUT_ITERATOR_H_ diff --git a/tensorflow/core/util/tensor_slice_reader.h b/tensorflow/core/util/tensor_slice_reader.h index eeb31295737dc262c4eef425060c6fd30cd64be0..5932d59a159f8517c4449067606798177939bc59 100644 --- a/tensorflow/core/util/tensor_slice_reader.h +++ b/tensorflow/core/util/tensor_slice_reader.h @@ -165,13 +165,18 @@ bool TensorSliceReader::CopySliceData(const string& name, CHECK_GE(idx, 0) << "Failed to find the index for filename " << fname; // We read a record in the corresponding sstable const string key = EncodeTensorNameSlice(name, slice_s); - CHECK(sss_[idx]->Get(key, &value)) - << "Failed to seek to the record for tensor " << name << ", slice " - << slice_s.DebugString() << ": computed key = " << key; + if (!sss_[idx]->Get(key, &value)) { + VLOG(1) << "Failed to seek to the record for tensor " << name + << ", slice " << slice_s.DebugString() + << ": computed key = " << key; + return false; + } SavedTensorSlices sts; - CHECK(ParseProtoUnlimited(&sts, value)) - << "Failed to parse the record for tensor " << name << ", slice " - << slice_s.DebugString() << ": computed key = " << key; + if (!ParseProtoUnlimited(&sts, value)) { + VLOG(1) << "Failed to parse the record for tensor " << name << ", slice " + << slice_s.DebugString() << ": computed key = " << key; + return false; + } CopyDataFromTensorSliceToTensorSlice( tss->shape(), slice_s, slice, checkpoint::TensorProtoData(sts.data().data()), data); diff --git a/tensorflow/core/util/transform_output_iterator.h b/tensorflow/core/util/transform_output_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..1640791ad1729a57283ab5f2b91b7734c9447d8f --- /dev/null +++ b/tensorflow/core/util/transform_output_iterator.h @@ -0,0 +1,149 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_UTIL_TRANSFORM_OUTPUT_ITERATOR_H_ +#define TENSORFLOW_UTIL_TRANSFORM_OUTPUT_ITERATOR_H_ + +#include +#include + +namespace tensorflow { + +template +class TransformOutputIterator { + private: + // Proxy object + struct Reference { + StoreType* ptr; + ConversionOp conversion_op; + + /// Constructor + __host__ __device__ __forceinline__ Reference(StoreType* ptr, + ConversionOp conversion_op) + : ptr(ptr), conversion_op(conversion_op) {} + + /// Assignment + __host__ __device__ __forceinline__ InputType operator=(InputType val) { + *ptr = conversion_op(val); + return val; + } + }; + + public: + // Required iterator traits + typedef TransformOutputIterator self_type; ///< My own type + typedef OffsetT difference_type; ///< Type to express the result of + ///< subtracting one iterator from another + typedef void + value_type; ///< The type of the element the iterator can point to + typedef void pointer; ///< The type of a pointer to an element the iterator + ///< can point to + typedef Reference reference; ///< The type of a reference to an element the + ///< iterator can point to + + typedef std::random_access_iterator_tag + iterator_category; ///< The iterator category + + /*private:*/ + + StoreType* ptr; + ConversionOp conversion_op; + + public: + /// Constructor + template + __host__ __device__ __forceinline__ TransformOutputIterator( + QualifiedStoreType* ptr, + ConversionOp conversionOp) ///< Native pointer to wrap + : ptr(ptr), conversion_op(conversionOp) {} + + /// Postfix increment + __host__ __device__ __forceinline__ self_type operator++(int) { + self_type retval = *this; + ptr++; + return retval; + } + + /// Prefix increment + __host__ __device__ __forceinline__ self_type operator++() { + ptr++; + return *this; + } + + /// Indirection + __host__ __device__ __forceinline__ reference operator*() const { + return Reference(ptr, conversion_op); + } + + /// Addition + template + __host__ __device__ __forceinline__ self_type operator+(Distance n) const { + self_type retval(ptr + n, conversion_op); + return retval; + } + + /// Addition assignment + template + __host__ __device__ __forceinline__ self_type& operator+=(Distance n) { + ptr += n; + return *this; + } + + /// Subtraction + template + __host__ __device__ __forceinline__ self_type operator-(Distance n) const { + self_type retval(ptr - n, conversion_op); + return retval; + } + + /// Subtraction assignment + template + __host__ __device__ __forceinline__ self_type& operator-=(Distance n) { + ptr -= n; + return *this; + } + + /// Distance + __host__ __device__ __forceinline__ difference_type + operator-(self_type other) const { + return ptr - other.ptr; + } + + /// Array subscript + template + __host__ __device__ __forceinline__ reference operator[](Distance n) const { + return Reference(ptr + n, conversion_op); + } + + /// Equal to + __host__ __device__ __forceinline__ bool operator==(const self_type& rhs) { + return (ptr == rhs.ptr); + } + + /// Not equal to + __host__ __device__ __forceinline__ bool operator!=(const self_type& rhs) { + return (ptr != rhs.ptr); + } + + /// ostream operator + friend std::ostream& operator<<(std::ostream& os, const self_type& itr) { + return os; + } +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_UTIL_TRANSFORM_OUTPUT_ITERATOR_H_ diff --git a/tensorflow/docs_src/about/bib.md b/tensorflow/docs_src/about/bib.md index 0c0e88c1fed98dd63b494c9420cc58ca7c3f7f77..c9f0c532c62791a9fcf854f11fd2f330955ee7d6 100644 --- a/tensorflow/docs_src/about/bib.md +++ b/tensorflow/docs_src/about/bib.md @@ -37,7 +37,7 @@ system, we suggest you cite this whitepaper.

 @misc{tensorflow2015-whitepaper,
 title={ {TensorFlow}: Large-Scale Machine Learning on Heterogeneous Systems},
-url={http://tensorflow.org/},
+url={https://www.tensorflow.org/},
 note={Software available from tensorflow.org},
 author={
     Mart\'{\i}n~Abadi and
diff --git a/tensorflow/docs_src/community/welcome.md b/tensorflow/docs_src/community/welcome.md
index 194649a304d236147176201ebe3e99a2ad3b31c5..4991783a53a5a5fd5168aca14e4cf7db6847e665 100644
--- a/tensorflow/docs_src/community/welcome.md
+++ b/tensorflow/docs_src/community/welcome.md
@@ -37,6 +37,7 @@ Asia:
 * [TensorFlow Korea (TF-KR) User Group](https://www.facebook.com/groups/TensorFlowKR/) _(Korean language)_
 * [TensorFlow User Group Tokyo](https://tfug-tokyo.connpass.com/) _(Japanese Language)_
 * [Soleil Data Dojo](https://soleildatadojo.connpass.com/) _(Japanese language)_
+* [TensorFlow User Group Utsunomiya](https://tfug-utsunomiya.connpass.com/)
 
 
 Europe:
diff --git a/tensorflow/docs_src/get_started/estimator.md b/tensorflow/docs_src/get_started/estimator.md
index a55454f8af362cd97d1ef18ab750e2ee95291bd0..4f3a438d17d20a6a7698e3a767b9f1e63417a953 100644
--- a/tensorflow/docs_src/get_started/estimator.md
+++ b/tensorflow/docs_src/get_started/estimator.md
@@ -273,9 +273,7 @@ Then, the code creates a `DNNClassifier` model using the following arguments:
     containing 10, 20, and 10 neurons, respectively.
 *   `n_classes=3`. Three target classes, representing the three Iris species.
 *   `model_dir=/tmp/iris_model`. The directory in which TensorFlow will save
-    checkpoint data during model training. For more on logging and monitoring
-    with TensorFlow, see
-    @{$monitors$Logging and Monitoring Basics with tf.estimator}.
+    checkpoint data and TensorBoard summaries during model training.
 
 ## Describe the training input pipeline {#train-input}
 
@@ -315,9 +313,7 @@ classifier.train(input_fn=train_input_fn, steps=1000)
 
 However, if you're looking to track the model while it trains, you'll likely
 want to instead use a TensorFlow @{tf.train.SessionRunHook$`SessionRunHook`}
-to perform logging operations. See the tutorial
-@{$monitors$Logging and Monitoring Basics with tf.estimator}
-for more on this topic.
+to perform logging operations.
 
 ## Evaluate Model Accuracy {#evaluate-accuracy}
 
diff --git a/tensorflow/docs_src/get_started/index.md b/tensorflow/docs_src/get_started/index.md
index 3e700daa30417a023d06bf6db11ec80c610e7af8..003fac1a287688e1d1d343b1dcc834500fd20856 100644
--- a/tensorflow/docs_src/get_started/index.md
+++ b/tensorflow/docs_src/get_started/index.md
@@ -24,8 +24,6 @@ To learn about the high-level API, read the following guides:
     API.
   * @{$get_started/input_fn$Building Input Functions},
     which takes you into a somewhat more sophisticated use of this API.
-  * @{$get_started/monitors$Logging and Monitoring Basics with tf.contrib.learn},
-    which explains how to audit the progress of model training.
 
 TensorBoard is a utility to visualize different aspects of machine learning.
 The following guides explain how to use TensorBoard:
diff --git a/tensorflow/docs_src/get_started/input_fn.md b/tensorflow/docs_src/get_started/input_fn.md
index 422f45c586aa587c22f9d72eab23833f85b5a2eb..7706c07b1d940f98acf89ecf63df5e9f7af31366 100644
--- a/tensorflow/docs_src/get_started/input_fn.md
+++ b/tensorflow/docs_src/get_started/input_fn.md
@@ -249,7 +249,7 @@ here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/input_fn/bos
 
 ### Importing the Housing Data
 
-To start, set up your imports (including `pandas` and `tensorflow`) and @{$monitors#enabling-logging-with-tensorflow$set logging verbosity} to
+To start, set up your imports (including `pandas` and `tensorflow`) and set logging verbosity to
 `INFO` for more detailed log output:
 
 ```python
diff --git a/tensorflow/docs_src/get_started/leftnav_files b/tensorflow/docs_src/get_started/leftnav_files
index b656033f7e8df1e68292f9fe30d9fca5c406dfd8..bb67eaddda369c0271c4fdb17a686016ffa80a2e 100644
--- a/tensorflow/docs_src/get_started/leftnav_files
+++ b/tensorflow/docs_src/get_started/leftnav_files
@@ -5,7 +5,6 @@ mnist/pros.md
 mnist/mechanics.md
 estimator.md
 input_fn.md
-monitors.md
 summaries_and_tensorboard.md
 graph_viz.md
 tensorboard_histograms.md
diff --git a/tensorflow/docs_src/get_started/monitors.md b/tensorflow/docs_src/get_started/monitors.md
deleted file mode 100644
index 5606e95365812a7287b844a86172287c1aafa766..0000000000000000000000000000000000000000
--- a/tensorflow/docs_src/get_started/monitors.md
+++ /dev/null
@@ -1,406 +0,0 @@
-# Logging and Monitoring Basics with tf.contrib.learn
-
-When training a model, it’s often valuable to track and evaluate progress in
-real time. In this tutorial, you’ll learn how to use TensorFlow’s logging
-capabilities and the `Monitor` API to audit the in-progress training of a neural
-network classifier for categorizing irises. This tutorial builds on the code
-developed in @{$estimator$tf.estimator Quickstart} so if you
-haven't yet completed that tutorial, you may want to explore it first,
-especially if you're looking for an intro/refresher on tf.contrib.learn basics.
-
-## Setup {#setup}
-
-For this tutorial, you'll be building upon the following code from
-@{$estimator$tf.estimator Quickstart}:
-
-```python
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-import numpy as np
-import tensorflow as tf
-
-# Data sets
-IRIS_TRAINING = os.path.join(os.path.dirname(__file__), "iris_training.csv")
-IRIS_TEST = os.path.join(os.path.dirname(__file__), "iris_test.csv")
-
-def main(unused_argv):
-    # Load datasets.
-    training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
-        filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32)
-    test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
-        filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float32)
-
-    # Specify that all features have real-value data
-    feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
-
-    # Build 3 layer DNN with 10, 20, 10 units respectively.
-    classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
-                                                hidden_units=[10, 20, 10],
-                                                n_classes=3,
-                                                model_dir="/tmp/iris_model")
-
-    # Fit model.
-    classifier.fit(x=training_set.data,
-                   y=training_set.target,
-                   steps=2000)
-
-    # Evaluate accuracy.
-    accuracy_score = classifier.evaluate(x=test_set.data,
-                                         y=test_set.target)["accuracy"]
-    print('Accuracy: {0:f}'.format(accuracy_score))
-
-    # Classify two new flower samples.
-    new_samples = np.array(
-        [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
-    y = list(classifier.predict(new_samples, as_iterable=True))
-    print('Predictions: {}'.format(str(y)))
-
-if __name__ == "__main__":
-  tf.app.run()
-```
-
-Copy the above code into a file, and download the corresponding
-[training](http://download.tensorflow.org/data/iris_training.csv) and
-[test](http://download.tensorflow.org/data/iris_test.csv) data sets to the same
-directory.
-
-In the following sections, you'll progressively make updates to the above code
-to add logging and monitoring capabilities. Final code incorporating all updates
-is [available for download
-here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/monitors/iris_monitors.py).
-
-## Overview
-
-The @{$estimator$tf.estimator Quickstart tutorial} walked through
-how to implement a neural net classifier to categorize iris examples into one of
-three species.
-
-But when [the code](#setup) from this tutorial is run, the output contains no
-logging tracking how model training is progressing—only the results of the
-`print` statements that were included:
-
-```none
-Accuracy: 0.933333
-Predictions: [1 2]
-```
-
-Without any logging, model training feels like a bit of a black box; you can't
-see what's happening as TensorFlow steps through gradient descent, get a sense
-of whether the model is converging appropriately, or audit to determine whether
-[early stopping](https://en.wikipedia.org/wiki/Early_stopping) might be
-appropriate.
-
-One way to address this problem would be to split model training into multiple
-`fit` calls with smaller numbers of steps in order to evaluate accuracy more
-progressively. However, this is not recommended practice, as it greatly slows
-down model training. Fortunately, tf.contrib.learn offers another solution: a
-@{tf.contrib.learn.monitors$Monitor API} designed to help
-you log metrics and evaluate your model while training is in progress. In the
-following sections, you'll learn how to enable logging in TensorFlow, set up a
-ValidationMonitor to do streaming evaluations, and visualize your metrics using
-TensorBoard.
-
-## Enabling Logging with TensorFlow
-
-TensorFlow uses five different levels for log messages. In order of ascending
-severity, they are `DEBUG`, `INFO`, `WARN`, `ERROR`, and `FATAL`. When you
-configure logging at any of these levels, TensorFlow will output all log
-messages corresponding to that level and all levels of higher severity. For
-example, if you set a logging level of `ERROR`, you'll get log output containing
-`ERROR` and `FATAL` messages, and if you set a level of `DEBUG`, you'll get log
-messages from all five levels.
-
-By default, TensorFlow is configured at a logging level of `WARN`, but when
-tracking model training, you'll want to adjust the level to `INFO`, which will
-provide additional feedback as `fit` operations are in progress.
-
-Add the following line to the beginning of your code (right after your
-`import`s):
-
-```python
-tf.logging.set_verbosity(tf.logging.INFO)
-```
-
-Now when you run the code, you'll see additional log output like the following:
-
-```none
-INFO:tensorflow:loss = 1.18812, step = 1
-INFO:tensorflow:loss = 0.210323, step = 101
-INFO:tensorflow:loss = 0.109025, step = 201
-```
-
-With `INFO`-level logging, tf.contrib.learn automatically outputs [training-loss
-metrics](https://en.wikipedia.org/wiki/Loss_function) to stderr after every 100
-steps.
-
-## Configuring a ValidationMonitor for Streaming Evaluation
-
-Logging training loss is helpful to get a sense whether your model is
-converging, but what if you want further insight into what's happening during
-training? tf.contrib.learn provides several high-level `Monitor`s you can attach
-to your `fit` operations to further track metrics and/or debug lower-level
-TensorFlow operations during model training, including:
-
-Monitor             | Description
-------------------- | -----------
-`CaptureVariable`   | Saves a specified variable's values into a collection at every _n_ steps of training
-`PrintTensor`       | Logs a specified tensor's values at every _n_ steps of training
-`SummarySaver`      | Saves @{tf.Summary} [protocol buffers](https://developers.google.com/protocol-buffers/) for a given tensor using a @{tf.summary.FileWriter} at every _n_ steps of training
-`ValidationMonitor` | Logs a specified set of evaluation metrics at every _n_ steps of training, and, if desired, implements early stopping under certain conditions
-
-### Evaluating Every *N* Steps
-
-For the iris neural network classifier, while logging training loss, you might
-also want to simultaneously evaluate against test data to see how well the model
-is generalizing. You can accomplish this by configuring a `ValidationMonitor`
-with the test data (`test_set.data` and `test_set.target`), and setting how
-often to evaluate with `every_n_steps`. The default value of `every_n_steps` is
-`100`; here, set `every_n_steps` to `50` to evaluate after every 50 steps of
-model training:
-
-```python
-validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
-    test_set.data,
-    test_set.target,
-    every_n_steps=50)
-```
-
-Place this code right before the line instantiating the `classifier`.
-
-`ValidationMonitor`s rely on saved checkpoints to perform evaluation operations,
-so you'll want to modify instantiation of the `classifier` to add a
-@{tf.contrib.learn.RunConfig} that includes
-`save_checkpoints_secs`, which specifies how many seconds should elapse between
-checkpoint saves during training. Because the iris data set is quite small, and
-thus trains quickly, it makes sense to set `save_checkpoints_secs` to 1 (saving
-a checkpoint every second) to ensure a sufficient number of checkpoints:
-
-```python
-classifier = tf.contrib.learn.DNNClassifier(
-    feature_columns=feature_columns,
-    hidden_units=[10, 20, 10],
-    n_classes=3,
-    model_dir="/tmp/iris_model",
-    config=tf.contrib.learn.RunConfig(save_checkpoints_secs=1))
-```
-
-NOTE: The `model_dir` parameter specifies an explicit directory
-(`/tmp/iris_model`) for model data to be stored; this directory path will be
-easier to reference later on than an autogenerated one. Each time you run the
-code, any existing data in `/tmp/iris_model` will be loaded, and model training
-will continue where it left off in the last run (e.g., running the script twice
-in succession will execute 4000 steps during training—2000 during each
-`fit` operation). To start over model training from scratch, delete
-`/tmp/iris_model` before running the code.
-
-Finally, to attach your `validation_monitor`, update the `fit` call to include a
-`monitors` param, which takes a list of all monitors to run during model
-training:
-
-```python
-classifier.fit(x=training_set.data,
-               y=training_set.target,
-               steps=2000,
-               monitors=[validation_monitor])
-```
-
-Now, when you rerun the code, you should see validation metrics in your log
-output, e.g.:
-
-```none
-INFO:tensorflow:Validation (step 50): loss = 1.71139, global_step = 0, accuracy = 0.266667
-...
-INFO:tensorflow:Validation (step 300): loss = 0.0714158, global_step = 268, accuracy = 0.966667
-...
-INFO:tensorflow:Validation (step 1750): loss = 0.0574449, global_step = 1729, accuracy = 0.966667
-```
-
-### Customizing the Evaluation Metrics with MetricSpec
-
-By default, if no evaluation metrics are specified, `ValidationMonitor` will log
-both [loss](https://en.wikipedia.org/wiki/Loss_function) and accuracy, but you
-can customize the list of metrics that will be run every 50 steps. To specify
-the exact metrics you'd like to run in each evaluation pass, you can add a
-`metrics` param to the `ValidationMonitor` constructor. `metrics` takes a dict
-of key/value pairs, where each key is the name you'd like logged for the metric,
-and the corresponding value is a
-[`MetricSpec`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/metric_spec.py)
-object.
-
-The `MetricSpec` constructor accepts four parameters:
-
-*   `metric_fn`. The function that calculates and returns the value of a metric.
-    This can be a predefined function available in the
-    @{tf.contrib.metrics} module, such as
-    @{tf.contrib.metrics.streaming_precision} or
-    @{tf.contrib.metrics.streaming_recall}.
-
-    Alternatively, you can define your own custom metric function, which must
-    take `predictions` and `labels` tensors as arguments (a `weights` argument
-    can also optionally be supplied). The function must return the value of the
-    metric in one of two formats:
-
-    *   A single tensor
-    *   A pair of ops `(value_op, update_op)`, where `value_op` returns the
-        metric value and `update_op` performs a corresponding operation to
-        update internal model state.
-
-*   `prediction_key`. The key of the tensor containing the predictions returned
-    by the model. This argument may be omitted if the model returns either a
-    single tensor or a dict with a single entry. For a `DNNClassifier` model,
-    class predictions will be returned in a tensor with the key
-    @{tf.contrib.learn.PredictionKey.CLASSES}.
-
-*   `label_key`. The key of the tensor containing the labels returned by the
-    model, as specified by the model's @{$input_fn$`input_fn`}. As
-    with `prediction_key`, this argument may be omitted if the `input_fn`
-    returns either a single tensor or a dict with a single entry. In the iris
-    example in this tutorial, the `DNNClassifier` does not have an `input_fn`
-    (`x`,`y` data is passed directly to `fit`), so it's not necessary to provide
-    a `label_key`.
-
-*   `weights_key`. *Optional*. The key of the tensor (returned by the
-    @{$input_fn$`input_fn`}) containing weights inputs for the
-    `metric_fn`.
-
-The following code creates a `validation_metrics` dict that defines three
-metrics to log during model evaluation:
-
-*   `"accuracy"`, using @{tf.contrib.metrics.streaming_accuracy}
-    as the `metric_fn`
-*   `"precision"`, using @{tf.contrib.metrics.streaming_precision}
-    as the `metric_fn`
-*   `"recall"`, using @{tf.contrib.metrics.streaming_recall}
-    as the `metric_fn`
-
-```python
-validation_metrics = {
-    "accuracy":
-        tf.contrib.learn.MetricSpec(
-            metric_fn=tf.contrib.metrics.streaming_accuracy,
-            prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
-    "precision":
-        tf.contrib.learn.MetricSpec(
-            metric_fn=tf.contrib.metrics.streaming_precision,
-            prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
-    "recall":
-        tf.contrib.learn.MetricSpec(
-            metric_fn=tf.contrib.metrics.streaming_recall,
-            prediction_key=tf.contrib.learn.PredictionKey.CLASSES)
-}
-```
-
-Add the above code before the `ValidationMonitor` constructor. Then revise the
-`ValidationMonitor` constructor as follows to add a `metrics` parameter to log
-the accuracy, precision, and recall metrics specified in `validation_metrics`
-(loss is always logged, and doesn't need to be explicitly specified):
-
-```python
-validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
-    test_set.data,
-    test_set.target,
-    every_n_steps=50,
-    metrics=validation_metrics)
-```
-
-Rerun the code, and you should see precision and recall included in your log
-output, e.g.:
-
-```none
-INFO:tensorflow:Validation (step 50): recall = 0.0, loss = 1.20626, global_step = 1, precision = 0.0, accuracy = 0.266667
-...
-INFO:tensorflow:Validation (step 600): recall = 1.0, loss = 0.0530696, global_step = 571, precision = 1.0, accuracy = 0.966667
-...
-INFO:tensorflow:Validation (step 1500): recall = 1.0, loss = 0.0617403, global_step = 1452, precision = 1.0, accuracy = 0.966667
-```
-
-### Early Stopping with ValidationMonitor
-
-Note that in the above log output, by step 600, the model has already achieved
-precision and recall rates of 1.0. This raises the question as to whether model
-training could benefit from
-[early stopping](https://en.wikipedia.org/wiki/Early_stopping).
-
-In addition to logging eval metrics, `ValidationMonitor`s make it easy to
-implement early stopping when specified conditions are met, via three params:
-
-| Param                            | Description                               |
-| -------------------------------- | ----------------------------------------- |
-| `early_stopping_metric`          | Metric that triggers early stopping       |
-:                                  : (e.g., loss or accuracy) under conditions :
-:                                  : specified in `early_stopping_rounds` and  :
-:                                  : `early_stopping_metric_minimize`. Default :
-:                                  : is `"loss"`.                              :
-| `early_stopping_metric_minimize` | `True` if desired model behavior is to    |
-:                                  : minimize the value of                     :
-:                                  : `early_stopping_metric`; `False` if       :
-:                                  : desired model behavior is to maximize the :
-:                                  : value of `early_stopping_metric`. Default :
-:                                  : is `True`.                                :
-| `early_stopping_rounds`          | Sets a number of steps during which if    |
-:                                  : the `early_stopping_metric` does not      :
-:                                  : decrease (if                              :
-:                                  : `early_stopping_metric_minimize` is       :
-:                                  : `True`) or increase (if                   :
-:                                  : `early_stopping_metric_minimize` is       :
-:                                  : `False`), training will be stopped.       :
-:                                  : Default is `None`, which means early      :
-:                                  : stopping will never occur.                :
-
-Make the following revision to the `ValidationMonitor` constructor, which
-specifies that if loss (`early_stopping_metric="loss"`) does not decrease
-(`early_stopping_metric_minimize=True`) over a period of 200 steps
-(`early_stopping_rounds=200`), model training will stop immediately at that
-point, and not complete the full 2000 steps specified in `fit`:
-
-```python
-validation_monitor = tf.contrib.learn.monitors.ValidationMonitor(
-    test_set.data,
-    test_set.target,
-    every_n_steps=50,
-    metrics=validation_metrics,
-    early_stopping_metric="loss",
-    early_stopping_metric_minimize=True,
-    early_stopping_rounds=200)
-```
-
-Rerun the code to see if model training stops early:
-
-```none
-...
-INFO:tensorflow:Validation (step 1150): recall = 1.0, loss = 0.056436, global_step = 1119, precision = 1.0, accuracy = 0.966667
-INFO:tensorflow:Stopping. Best step: 800 with loss = 0.048313818872.
-```
-
-Indeed, here training stops at step 1150, indicating that for the past 200
-steps, loss did not decrease, and that overall, step 800 produced the smallest
-loss value against the test data set. This suggests that additional calibration
-of hyperparameters by decreasing the step count might further improve the model.
-
-## Visualizing Log Data with TensorBoard
-
-Reading through the log produced by `ValidationMonitor` provides plenty of raw
-data on model performance during training, but it may also be helpful to see
-visualizations of this data to get further insight into trends—for
-example, how accuracy is changing over step count. You can use TensorBoard (a
-separate program packaged with TensorFlow) to plot graphs like this by setting
-the `logdir` command-line argument to the directory where you saved your model
-training data (here, `/tmp/iris_model`). Run the following on your command line:
-
-
$ tensorboard --logdir=/tmp/iris_model/
-Starting TensorBoard 39 on port 6006
- -Then navigate to `http://0.0.0.0:`*``* in your browser, where -*``* is the port specified in the command-line output (here, -`6006`). - -If you click on the accuracy field, you'll see an image like the following, -which shows accuracy plotted against step count: - -![Accuracy over step count in TensorBoard](https://www.tensorflow.org/images/validation_monitor_tensorboard_accuracy.png "Accuracy over step count in TensorBoard") - -For more on using TensorBoard, see @{$summaries_and_tensorboard$TensorBoard: Visualizing Learning} and @{$graph_viz$TensorBoard: Graph Visualization}. diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md index 43e09906f7360933cc697a24a9b83a4c605c84f5..d5e481520c43fc988c645a03146066bafa8f9662 100644 --- a/tensorflow/docs_src/install/install_linux.md +++ b/tensorflow/docs_src/install/install_linux.md @@ -151,10 +151,10 @@ Take the following steps to install TensorFlow with Virtualenv: (tensorflow)$ pip install --upgrade tensorflow-gpu # for Python 2.7 and GPU (tensorflow)$ pip3 install --upgrade tensorflow-gpu # for Python 3.n and GPU
- If the preceding command succeeds, skip Step 5. If the preceding - command fails, perform Step 5. + If the preceding command succeeds, skip Step 6. If the preceding + command fails, perform Step 6. - 5. (Optional) If Step 4 failed (typically because you invoked a pip version + 6. (Optional) If Step 5 failed (typically because you invoked a pip version lower than 8.1), install TensorFlow in the active virtualenv environment by issuing a command of the following format: diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md index be6a490ff9bc7e1f5c5df81588cf5c255ab68d13..3025c9971abdbbe7e3aee2d835e1d842c64e01ca 100644 --- a/tensorflow/docs_src/install/install_windows.md +++ b/tensorflow/docs_src/install/install_windows.md @@ -71,12 +71,14 @@ Use that package at your own risk. ## Installing with native pip -If the following version of Python is not installed on your machine, +If one of the following versions of Python is not installed on your machine, install it now: * [Python 3.5.x 64-bit from python.org](https://www.python.org/downloads/release/python-352/) + * [Python 3.6.x 64-bit from python.org](https://www.python.org/downloads/release/python-362/) -Note that Python 3.5.x comes with the pip3 package manager, which is the +-TensorFlow supports Python 3.5.x and 3.6.x on Windows. +Note that Python 3 comes with the pip3 package manager, which is the program you'll use to install TensorFlow. To install TensorFlow, start a terminal. Then issue the appropriate diff --git a/tensorflow/docs_src/programmers_guide/datasets.md b/tensorflow/docs_src/programmers_guide/datasets.md index bf3cb5bf196916c834b19d37ffa16a02882cc9c1..aaebabfddf9f1d9e0d4b98a7abe6fab2ff0dc9bd 100644 --- a/tensorflow/docs_src/programmers_guide/datasets.md +++ b/tensorflow/docs_src/programmers_guide/datasets.md @@ -1,4 +1,4 @@ -# Using the `Dataset` API for TensorFlow Input Pipelines +# Importing Data The `Dataset` API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might @@ -146,6 +146,9 @@ for i in range(100): assert i == value ``` +Note: Currently, one-shot iterators are the only type that is easily usable +with an `Estimator`. + An **initializable** iterator requires you to run an explicit `iterator.initializer` operation before using it. In exchange for this inconvenience, it enables you to *parameterize* the definition of the dataset, @@ -452,6 +455,9 @@ dataset = dataset.flat_map( .filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#")))) ``` +For a full example of parsing a CSV file using datasets, see [`imports85.py`](https://www.tensorflow.org/code/tensorflow/examples/get_started/regression/imports85.py) +in @{$get_started/linear_regression}. + - - - ossrh - https://oss.sonatype.org/content/repositories/snapshots - - + + + ossrh - https://oss.sonatype.org/service/local/staging/deploy/maven2/ - - + + + + ossrh + https://oss.sonatype.org/content/repositories/snapshots + + + ossrh + https://oss.sonatype.org/service/local/staging/deploy/maven2/ + + + + + bintray + + + + bintray + https://api.bintray.com/maven/google/tensorflow/tensorflow/;publish=0 + + + + @@ -55,19 +72,6 @@ - - - org.sonatype.plugins - nexus-staging-maven-plugin - 1.6.7 - true - - ossrh - https://oss.sonatype.org/ - - false - - org.apache.maven.plugins diff --git a/tensorflow/java/maven/release.sh b/tensorflow/java/maven/release.sh index b95a4d4674e7dba785a1f4d3bc90b01e6c03d94d..9012ea14ea6c3322ceddd0951b4245733e8d1659 100755 --- a/tensorflow/java/maven/release.sh +++ b/tensorflow/java/maven/release.sh @@ -49,6 +49,8 @@ fi set -ex docker run \ -e TF_VERSION="${TF_VERSION}" \ + -e DEPLOY_OSSRH="${DEPLOY_OSSRH:-true}" \ + -e DEPLOY_BINTRAY="${DEPLOY_BINTRAY:-true}" \ -v ${PWD}:/tensorflow \ -v "${SETTINGS_XML}":/root/.m2/settings.xml \ -v ${HOME}/.gnupg:/root/.gnupg \ diff --git a/tensorflow/java/maven/run_inside_container.sh b/tensorflow/java/maven/run_inside_container.sh index 6b4d5d70327a595805f160332a5837c07949a4e1..a2ce097195450eff566f0be48ca4f1a6b99401cc 100644 --- a/tensorflow/java/maven/run_inside_container.sh +++ b/tensorflow/java/maven/run_inside_container.sh @@ -19,11 +19,23 @@ RELEASE_URL_PREFIX="https://storage.googleapis.com/tensorflow/libtensorflow" + +# By default we deploy to both ossrh and bintray. These two +# environment variables can be set to skip either repository. +DEPLOY_BINTRAY="${DEPLOY_BINTRAY:-true}" +DEPLOY_OSSRH="${DEPLOY_OSSRH:-true}" + IS_SNAPSHOT="false" if [[ "${TF_VERSION}" == *"-SNAPSHOT" ]]; then IS_SNAPSHOT="true" + # Bintray does not allow snapshots. + DEPLOY_BINTRAY="false" fi PROTOC_RELEASE_URL="https://github.com/google/protobuf/releases/download/v3.3.0/protoc-3.3.0-linux-x86_64.zip" +if [[ "${DEPLOY_BINTRAY}" != "true" && "${DEPLOY_OSSRH}" != "true" ]]; then + echo "Must deploy to at least one of Bintray or OSSRH" >&2 + exit 2 +fi set -ex @@ -39,6 +51,20 @@ update_version_in_pom() { mvn versions:set -DnewVersion="${TF_VERSION}" } +# Fetch a property from pom files for a given profile. +# Arguments: +# profile - name of the selected profile. +# property - name of the property to be retrieved. +# Output: +# Echo property value to stdout +mvn_property() { + local profile="$1" + local prop="$2" + mvn -q --non-recursive exec:exec -P "${profile}" \ + -Dexec.executable='echo' \ + -Dexec.args="\${${prop}}" +} + download_libtensorflow() { if [[ "${IS_SNAPSHOT}" == "true" ]]; then URL="http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=cpu-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow-src.jar" @@ -137,29 +163,50 @@ generate_java_protos() { rm -rf "${DIR}/proto/tmp" } +# Deploy artifacts using a specific profile. +# Arguments: +# profile - name of selected profile. +# Outputs: +# n/a +deploy_profile() { + local profile="$1" + # Deploy the non-android pieces. + mvn deploy -P"${profile}" + # Determine the correct pom file property to use + # for the repository url. + local rtype + if [[ "${IS_SNAPSHOT}" == "true" ]]; then + rtype='snapshotRepository' + else + rtype='repository' + fi + local url=$(mvn_property "${profile}" "project.distributionManagement.${rtype}.url") + local repositoryId=$(mvn_property "${profile}" "project.distributionManagement.${rtype}.id") + mvn gpg:sign-and-deploy-file \ + -Dfile="${DIR}/tensorflow-android/target/tensorflow.aar" \ + -DpomFile="${DIR}/tensorflow-android/target/pom-android.xml" \ + -Durl="${url}" \ + -DrepositoryId="${repositoryId}" +} + # If successfully built, try to deploy. # If successfully deployed, clean. # If deployment fails, debug with # ./release.sh ${TF_VERSION} ${SETTINGS_XML} bash # To get a shell to poke around the maven artifacts with. deploy_artifacts() { - # This deploys the non-android pieces - mvn deploy - - # Sign and deploy the previously downloaded aar file as a single - # maven artifact. - if [[ "${IS_SNAPSHOT}" == "true" ]]; then - REPO="https://oss.sonatype.org/content/repositories/snapshots" - else - REPO="https://oss.sonatype.org/service/local/staging/deploy/maven2/" + # Deploy artifacts to ossrh if requested. + if [[ "${DEPLOY_OSSRH}" == "true" ]]; then + deploy_profile 'ossrh' + fi + # Deploy artifacts to bintray if requested. + if [[ "${DEPLOY_BINTRAY}" == "true" ]]; then + deploy_profile 'bintray' fi - mvn gpg:sign-and-deploy-file -Dfile="${DIR}/tensorflow-android/target/tensorflow.aar" -DpomFile="${DIR}/tensorflow-android/target/pom-android.xml" -Durl=${REPO} -DrepositoryId=ossrh - # Clean up when everything works clean } - if [ -z "${TF_VERSION}" ] then echo "Must set the TF_VERSION environment variable" @@ -189,8 +236,14 @@ set +ex if [[ "${IS_SNAPSHOT}" == "false" ]]; then echo "Uploaded to the staging repository" echo "After validating the release: " - echo "1. Login to https://oss.sonatype.org/#stagingRepositories" - echo "2. Find the 'org.tensorflow' staging release and click either 'Release' to release or 'Drop' to abort" + if [[ "${DEPLOY_OSSRH}" == "true" ]]; then + echo "* Login to https://oss.sonatype.org/#stagingRepositories" + echo "* Find the 'org.tensorflow' staging release and click either 'Release' to release or 'Drop' to abort" + fi + if [[ "${DEPLOY_BINTRAY}" == "true" ]]; then + echo "* Login to https://bintray.com/google/tensorflow/tensorflow" + echo "* Either 'Publish' unpublished items to release, or 'Discard' to abort" + fi else echo "Uploaded to the snapshot repository" fi diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..a7c66dda893a3109e0e0bfe76f5becef766afb0e --- /dev/null +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/java/src/gen/cc/op_generator.h" + +namespace tensorflow { +namespace op_gen { + +const char kUsageHeader[] = + "\n\nGenerator of operation wrappers in Java.\n\n" + "This executable generates wrappers for all registered operations it has " + "been compiled with. A wrapper exposes an intuitive and strongly-typed\n" + "interface for building its underlying operation and linking it into a " + "graph.\n\n" + "Operation wrappers are generated under the path specified by the " + "'--output_dir' argument. This path can be absolute or relative to the\n" + "current working directory and will be created if it does not exists.\n\n" + "The '--lib_name' argument is used to classify the set of operations. If " + "the chosen name contains more than one word, it must be provided in \n" + "snake_case. This value is declined into other meaningful names, such as " + "the group and package of the generated operations. For example,\n" + "'--lib_name=my_lib' generates the operations under the " + "'org.tensorflow.op.mylib' package and add them to the 'myLib()' operator\n" + "group.\n\n" + "Note that the operator group assigned to the generated wrappers is just " + "an annotation tag at this stage. Operations will not be available " + "through\n" + "the 'org.tensorflow.op.Ops' API as a group until the generated classes " + "are compiled using an appropriate annotation processor.\n\n" + "Finally, the '--base_package' overrides the default parent package " + "under which the generated subpackage and classes are to be located.\n\n"; + +} // namespace op_gen +} // namespace tensorflow + +int main(int argc, char* argv[]) { + tensorflow::string lib_name; + tensorflow::string output_dir; + tensorflow::string base_package = "org.tensorflow.op"; + std::vector flag_list = { + tensorflow::Flag("output_dir", &output_dir, + "Root directory into which output files are generated"), + tensorflow::Flag( + "lib_name", &lib_name, + "A name, in snake_case, used to classify this set of operations"), + tensorflow::Flag( + "base_package", &base_package, + "Package parent to the generated subpackage and classes")}; + tensorflow::string usage = tensorflow::op_gen::kUsageHeader; + usage += tensorflow::Flags::Usage(argv[0], flag_list); + bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage; + + tensorflow::OpGenerator generator; + tensorflow::OpList ops; + tensorflow::OpRegistry::Global()->Export(true, &ops); + tensorflow::Status status = + generator.Run(ops, lib_name, base_package, output_dir); + TF_QCHECK_OK(status); + + return 0; +} diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc new file mode 100644 index 0000000000000000000000000000000000000000..df130c32e6afcba157da282026280756b778f3ad --- /dev/null +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/java/src/gen/cc/op_generator.h" + +namespace tensorflow { +namespace { + +string CamelCase(const string& str, char delimiter, bool upper) { + string result; + bool cap = upper; + for (string::const_iterator it = str.begin(); it != str.end(); ++it) { + const char c = *it; + if (c == delimiter) { + cap = true; + } else if (cap) { + result += toupper(c); + cap = false; + } else { + result += c; + } + } + return result; +} + +} // namespace + +OpGenerator::OpGenerator() : env(Env::Default()) {} + +OpGenerator::~OpGenerator() {} + +Status OpGenerator::Run(const OpList& ops, const string& lib_name, + const string& base_package, const string& output_dir) { + const string package = + base_package + '.' + str_util::StringReplace(lib_name, "_", "", true); + const string package_path = + output_dir + '/' + str_util::StringReplace(package, ".", "/", true); + const string group = CamelCase(lib_name, '_', false); + + if (!env->FileExists(package_path).ok()) { + TF_CHECK_OK(env->RecursivelyCreateDir(package_path)); + } + + LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations"; + // TODO(karllessard) generate wrappers from list of ops + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..eec1082b5162298e68fbd05d82d5563777e865db --- /dev/null +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_ +#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_ + +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +/// \brief A generator of Java operation wrappers. +/// +/// Such generator is normally ran only once per executable, outputting +/// wrappers for the all registered operations it has been compiled with. +/// Nonetheless, it is designed to support multiple runs, giving a different +/// list of operations on each cycle. +class OpGenerator { + public: + OpGenerator(); + virtual ~OpGenerator(); + + /// \brief Generates wrappers for the given list of 'ops'. + /// + /// Output files are generated in //, + /// where 'lib_package' is derived from 'lib_name'. + Status Run(const OpList& ops, const string& lib_name, + const string& base_package, const string& output_dir); + + private: + Env* env; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_ diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl new file mode 100644 index 0000000000000000000000000000000000000000..e3710c49d02a5f2e979b1ce4814345d55a14db78 --- /dev/null +++ b/tensorflow/java/src/gen/gen_ops.bzl @@ -0,0 +1,59 @@ +# -*- Python -*- + +load("//tensorflow:tensorflow.bzl", "tf_copts") + +# Given a list of "ops_libs" (a list of files in the core/ops directory +# without their .cc extensions), generate Java wrapper code for all operations +# found in the ops files. +# Then, combine all those source files into a single archive (.srcjar). +# +# For example: +# tf_java_op_gen_srcjar("gen_sources", "gen_tool", "my.package", [ "array_ops", "math_ops" ]) +# +# will create a genrule named "gen_sources" that first generate source files: +# ops/src/main/java/my/package/array/*.java +# ops/src/main/java/my/package/math/*.java +# +# and then archive those source files in: +# ops/gen_sources.srcjar +# +def tf_java_op_gen_srcjar(name, + gen_tool, + gen_base_package, + ops_libs=[], + ops_libs_pkg="//tensorflow/core", + out_dir="ops/", + out_src_dir="src/main/java/", + visibility=["//tensorflow/java:__pkg__"]): + + gen_tools = [] + gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files + + # Construct an op generator binary for each ops library. + for ops_lib in ops_libs: + gen_lib = ops_lib[:ops_lib.rfind("_")] + out_gen_tool = out_dir + ops_lib + "_gen_tool" + + native.cc_binary( + name=out_gen_tool, + copts=tf_copts(), + linkopts=["-lm"], + linkstatic=1, # Faster to link this one-time-use binary dynamically + deps=[gen_tool, ops_libs_pkg + ":" + ops_lib + "_op_lib"]) + + gen_tools += [":" + out_gen_tool] + gen_cmds += ["$(location :" + out_gen_tool + ")" + + " --output_dir=$(@D)/" + out_src_dir + + " --lib_name=" + gen_lib + + " --base_package=" + gen_base_package] + + # Generate a source archive containing generated code for these ops. + gen_srcjar = out_dir + name + ".srcjar" + gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) ."] + gen_tools += ["@local_jdk//:jar"] + + native.genrule( + name=name, + outs=[gen_srcjar], + tools=gen_tools, + cmd="&&".join(gen_cmds)) diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java index f4f853f716caeaa538aba6c944c2e21b68d8d683..442410039073a1410de7b6959f9e9bca9034d396 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java @@ -71,6 +71,27 @@ public final class Tensor implements AutoCloseable { * Tensor x = Tensor.create(twoD); * } * + * {@link DataType#STRING} typed Tensors are multi-dimensionary arrays of arbitrary byte sequences + * and thus have {@code byte[]} and not {@code String}-valued elements. For example: + * + *
{@code
+   * // Valid: A DataType.STRING tensor.
+   * Tensor s = Tensor.create(new byte[]{1, 2, 3});
+   *
+   * // Java Strings will need to be encoded into a byte-sequence.
+   * String mystring = "foo";
+   * Tensor s = Tensor.create(mystring.getBytes("UTF-8"));
+   *
+   * // Valid: Matrix of DataType.STRING tensors.
+   * // Each element might have a different length.
+   * byte[][][] matrix = new byte[2][2][];
+   * matrix[0][0] = "this".getBytes("UTF-8");
+   * matrix[0][1] = "is".getBytes("UTF-8");
+   * matrix[1][0] = "a".getBytes("UTF-8");
+   * matrix[1][1] = "matrix".getBytes("UTF-8");
+   * Tensor m = Tensor.create(matrix);
+   * }
+ * * @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type * system, or if obj does not disambiguate between multiple DataTypes. In that case, consider * using {@link #create(DataType, long[], ByteBuffer)} instead. @@ -85,10 +106,7 @@ public final class Tensor implements AutoCloseable { t.nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize); setValue(t.nativeHandle, obj); } else if (t.shapeCopy.length != 0) { - throw new UnsupportedOperationException( - String.format( - "non-scalar DataType.STRING tensors are not supported yet (version %s). Please file a feature request at https://github.com/tensorflow/tensorflow/issues/new", - TensorFlow.version())); + t.nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj); } else { t.nativeHandle = allocateScalarBytes((byte[]) obj); } @@ -172,7 +190,8 @@ public final class Tensor implements AutoCloseable { * *

Creates a Tensor with the provided shape of any type where the tensor's data has been * encoded into {@code data} as per the specification of the TensorFlow C API. + * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C + * API. * * @param dataType the tensor datatype. * @param shape the tensor shape. @@ -328,9 +347,9 @@ public final class Tensor implements AutoCloseable { * Copies the contents of the tensor to {@code dst} and returns {@code dst}. * *

For non-scalar tensors, this method copies the contents of the underlying tensor to a Java - * array. For scalar tensors, use one of {@link #floatValue()}, {@link #doubleValue()}, {@link - * #intValue()}, {@link #longValue()} or {@link #booleanValue()} instead. The type and shape of - * {@code dst} must be compatible with the tensor. For example: + * array. For scalar tensors, use one of {@link #bytesValue()}, {@link #floatValue()}, {@link + * #doubleValue()}, {@link #intValue()}, {@link #longValue()} or {@link #booleanValue()} instead. + * The type and shape of {@code dst} must be compatible with the tensor. For example: * *

{@code
    * int matrix[2][2] = {{1,2},{3,4}};
@@ -496,8 +515,6 @@ public final class Tensor implements AutoCloseable {
 
   private static int elemByteSize(DataType dataType) {
     switch (dataType) {
-      case UINT8:
-        return 1;
       case FLOAT:
       case INT32:
         return 4;
@@ -505,6 +522,7 @@ public final class Tensor implements AutoCloseable {
       case INT64:
         return 8;
       case BOOL:
+      case UINT8:
         return 1;
       case STRING:
         throw new IllegalArgumentException("STRING tensors do not have a fixed element size");
@@ -512,6 +530,13 @@ public final class Tensor implements AutoCloseable {
     throw new IllegalArgumentException("DataType " + dataType + " is not supported yet");
   }
 
+  private static void throwExceptionIfNotByteOfByteArrays(Object array) {
+    if (!array.getClass().getName().equals("[[B")) {
+      throw new IllegalArgumentException(
+          "object cannot be converted to a Tensor as it includes an array with null elements");
+    }
+  }
+
   private static DataType dataTypeOf(Object o) {
     if (o.getClass().isArray()) {
       if (Array.getLength(o) == 0) {
@@ -519,6 +544,10 @@ public final class Tensor implements AutoCloseable {
       }
       // byte[] is a DataType.STRING scalar.
       Object e = Array.get(o, 0);
+      if (e == null) {
+        throwExceptionIfNotByteOfByteArrays(o);
+        return DataType.STRING;
+      }
       if (Byte.class.isInstance(e) || byte.class.isInstance(e)) {
         return DataType.STRING;
       }
@@ -541,9 +570,11 @@ public final class Tensor implements AutoCloseable {
 
   private static int numDimensions(Object o) {
     if (o.getClass().isArray()) {
-      // byte[] is a DataType.STRING scalar.
       Object e = Array.get(o, 0);
-      if (Byte.class.isInstance(e) || byte.class.isInstance(e)) {
+      if (e == null) {
+        throwExceptionIfNotByteOfByteArrays(o);
+        return 1;
+      } else if (Byte.class.isInstance(e) || byte.class.isInstance(e)) {
         return 0;
       }
       return 1 + numDimensions(e);
@@ -568,11 +599,12 @@ public final class Tensor implements AutoCloseable {
   }
 
   private void throwExceptionIfTypeIsIncompatible(Object o) {
-    if (numDimensions(o) != numDimensions()) {
+    final int rank = numDimensions();
+    final int oRank = numDimensions(o);
+    if (oRank != rank) {
       throw new IllegalArgumentException(
           String.format(
-              "cannot copy Tensor with %d dimensions into an object with %d",
-              numDimensions(), numDimensions(o)));
+              "cannot copy Tensor with %d dimensions into an object with %d", rank, oRank));
     }
     if (dataTypeOf(o) != dtype) {
       throw new IllegalArgumentException(
@@ -580,7 +612,7 @@ public final class Tensor implements AutoCloseable {
               "cannot copy Tensor with DataType %s into an object of type %s",
               dtype.toString(), o.getClass().getName()));
     }
-    long[] oShape = new long[numDimensions()];
+    long[] oShape = new long[rank];
     fillShape(o, 0, oShape);
     for (int i = 0; i < oShape.length; ++i) {
       if (oShape[i] != shape()[i]) {
@@ -596,6 +628,8 @@ public final class Tensor implements AutoCloseable {
 
   private static native long allocateScalarBytes(byte[] value);
 
+  private static native long allocateNonScalarBytes(long[] shape, Object[] value);
+
   private static native void delete(long handle);
 
   private static native ByteBuffer buffer(long handle);
diff --git a/tensorflow/java/src/main/native/operation_builder_jni.cc b/tensorflow/java/src/main/native/operation_builder_jni.cc
index 37f01a943a24c4164d3abd1cf2e5ed73b1d05162..e03be7b1103d5507310c3423e537b6809083e6c3 100644
--- a/tensorflow/java/src/main/native/operation_builder_jni.cc
+++ b/tensorflow/java/src/main/native/operation_builder_jni.cc
@@ -75,8 +75,10 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_OperationBuilder_finish(
   TF_Status* status = TF_NewStatus();
   TF_Operation* op = TF_FinishOperation(d, status);
   if (throwExceptionIfNotOK(env, status)) {
+    TF_DeleteStatus(status);
     return reinterpret_cast(op);
   }
+  TF_DeleteStatus(status);
   return 0;
 }
 
@@ -211,6 +213,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrTensor(
   TF_Status* status = TF_NewStatus();
   TF_SetAttrTensor(d, cname, t, status);
   throwExceptionIfNotOK(env, status);
+  TF_DeleteStatus(status);
   env->ReleaseStringUTFChars(name, cname);
 }
 
@@ -234,6 +237,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrTensorList(
   TF_Status* status = TF_NewStatus();
   TF_SetAttrTensorList(d, cname, tensors.get(), n, status);
   throwExceptionIfNotOK(env, status);
+  TF_DeleteStatus(status);
   env->ReleaseStringUTFChars(name, cname);
 }
 
@@ -259,7 +263,8 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShape(
 }
 
 JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrStringList(
-        JNIEnv* env, jclass object, jlong handle, jstring name, jobjectArray values) {
+    JNIEnv* env, jclass object, jlong handle, jstring name,
+    jobjectArray values) {
   TF_OperationDescription* d = requireHandle(env, handle);
   if (d == nullptr) return;
   const char* cname = env->GetStringUTFChars(name, nullptr);
@@ -267,12 +272,13 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrStringList(
   static_assert(sizeof(jbyte) == 1,
                 "Require Java byte to be represented as a single byte");
   std::unique_ptr jarrays(new jbyteArray[num_values]);
-  std::unique_ptr jvalues(new jbyte*[num_values]);
-  std::unique_ptr cvalues(new void*[num_values]);
+  std::unique_ptr jvalues(new jbyte*[num_values]);
+  std::unique_ptr cvalues(new void*[num_values]);
   std::unique_ptr lengths(new size_t[num_values]);
 
   for (int i = 0; i < num_values; ++i) {
-    jbyteArray v = static_cast(env->GetObjectArrayElement(values, i));
+    jbyteArray v =
+        static_cast(env->GetObjectArrayElement(values, i));
     jarrays[i] = v;
     jvalues[i] = env->GetByteArrayElements(v, nullptr);
     cvalues[i] = jvalues[i];
diff --git a/tensorflow/java/src/main/native/tensor_jni.cc b/tensorflow/java/src/main/native/tensor_jni.cc
index dfdca357f78a38c5e09b370ac014d17ec1190198..745abec244d1528e918464473e5d3fb19ad5082c 100644
--- a/tensorflow/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/java/src/main/native/tensor_jni.cc
@@ -41,8 +41,11 @@ size_t elemByteSize(TF_DataType dtype) {
   // have the same byte sizes. Validate that:
   switch (dtype) {
     case TF_BOOL:
+    case TF_UINT8:
       static_assert(sizeof(jboolean) == 1,
                     "Java boolean not compatible with TF_BOOL");
+      static_assert(sizeof(jbyte) == 1,
+                    "Java byte not compatible with TF_UINT8");
       return 1;
     case TF_FLOAT:
     case TF_INT32:
@@ -90,6 +93,7 @@ void writeScalar(JNIEnv* env, jobject src, TF_DataType dtype, void* dst,
     CASE(TF_DOUBLE, jdouble, "doubleValue", "()D", Double);
     CASE(TF_INT32, jint, "intValue", "()I", Int);
     CASE(TF_INT64, jlong, "longValue", "()J", Long);
+    CASE(TF_UINT8, jbyte, "byteValue", "()B", Byte);
 #undef CASE
     case TF_BOOL: {
       jclass clazz = env->FindClass("java/lang/Boolean");
@@ -134,6 +138,7 @@ size_t write1DArray(JNIEnv* env, jarray array, TF_DataType dtype, void* dst,
     CASE(TF_INT32, jint, Int);
     CASE(TF_INT64, jlong, Long);
     CASE(TF_BOOL, jboolean, Boolean);
+    CASE(TF_UINT8, jbyte, Byte);
 #undef CASE
     default:
       throwException(env, kIllegalStateException, "invalid DataType(%d)",
@@ -168,6 +173,7 @@ size_t read1DArray(JNIEnv* env, TF_DataType dtype, const void* src,
     CASE(TF_INT32, jint, Int);
     CASE(TF_INT64, jlong, Long);
     CASE(TF_BOOL, jboolean, Boolean);
+    CASE(TF_UINT8, jbyte, Byte);
 #undef CASE
     default:
       throwException(env, kIllegalStateException, "invalid DataType(%d)",
@@ -213,6 +219,108 @@ size_t readNDArray(JNIEnv* env, TF_DataType dtype, const char* src,
     return sz;
   }
 }
+
+jbyteArray TF_StringDecodeTojbyteArray(JNIEnv* env, const char* src,
+                                       size_t src_len, TF_Status* status) {
+  const char* dst = nullptr;
+  size_t dst_len = 0;
+  TF_StringDecode(src, src_len, &dst, &dst_len, status);
+  if (TF_GetCode(status) != TF_OK) {
+    return nullptr;
+  }
+  jbyteArray ret = env->NewByteArray(dst_len);
+  jbyte* cpy = env->GetByteArrayElements(ret, nullptr);
+  memcpy(cpy, dst, dst_len);
+  env->ReleaseByteArrayElements(ret, cpy, 0);
+  return ret;
+}
+
+class StringTensorWriter {
+ public:
+  StringTensorWriter(TF_Tensor* t, int num_elements)
+      : offset_(0),
+        poffsets_(static_cast(TF_TensorData(t))),
+        pdata_(poffsets_ + 8 * num_elements),
+        plimit_(poffsets_ + TF_TensorByteSize(t)) {}
+
+  void Add(const char* src, size_t len, TF_Status* status) {
+    if (TF_GetCode(status) != TF_OK) return;
+    if (plimit_ - poffsets_ < sizeof(offset_)) {
+      TF_SetStatus(status, TF_OUT_OF_RANGE,
+                   "TF_STRING tensor encoding ran out of space for offsets, "
+                   "this is likely a bug, please file an issue at "
+                   "https://github.com/tensorflow/tensorflow/issues/new");
+      return;
+    }
+    memcpy(poffsets_, &offset_, sizeof(offset_));
+    size_t written =
+        TF_StringEncode(src, len, pdata_, (plimit_ - pdata_), status);
+    offset_ += written;
+    poffsets_ += 8;
+    pdata_ += written;
+  }
+
+ private:
+  uint64_t offset_;
+  char* poffsets_;
+  char* pdata_;
+  const char* plimit_;
+};
+
+class StringTensorReader {
+ public:
+  StringTensorReader(const TF_Tensor* t, int num_elements)
+      : index_(0),
+        offsets_(static_cast(TF_TensorData(t))),
+        data_(offsets_ + 8 * num_elements),
+        limit_(offsets_ + TF_TensorByteSize(t)) {}
+
+  jbyteArray Next(JNIEnv* env, TF_Status* status) {
+    if (TF_GetCode(status) != TF_OK) return nullptr;
+    uint64_t offset = 0;
+    const char* poffset = offsets_ + sizeof(offset) * index_;
+    if (poffset >= limit_) {
+      TF_SetStatus(
+          status, TF_INTERNAL,
+          "Invalid TF_STRING tensor, offsets table seems to be too small");
+      return nullptr;
+    }
+    memcpy(&offset, poffset, sizeof(offset));
+    const char* pdata = data_ + offset;
+    if (pdata >= limit_) {
+      TF_SetStatus(status, TF_INTERNAL,
+                   "Invalid TF_STRING tensor, invalid entry in offset table");
+      return nullptr;
+    }
+    ++index_;
+    return TF_StringDecodeTojbyteArray(env, pdata, (limit_ - pdata), status);
+  }
+
+ private:
+  int index_;
+  const char* offsets_;
+  const char* data_;
+  const char* limit_;
+};
+
+void readNDStringArray(JNIEnv* env, StringTensorReader* reader, int dims_left,
+                       jobjectArray dst, TF_Status* status) {
+  jsize len = env->GetArrayLength(dst);
+  if (dims_left == 1) {
+    for (jsize i = 0; i < len; ++i) {
+      jbyteArray elem = reader->Next(env, status);
+      if (TF_GetCode(status) != TF_OK) return;
+      env->SetObjectArrayElement(dst, i, elem);
+    }
+    return;
+  }
+  for (jsize i = 0; i < len; ++i) {
+    jobjectArray arr =
+        static_cast(env->GetObjectArrayElement(dst, i));
+    readNDStringArray(env, reader, dims_left - 1, arr, status);
+    if (TF_GetCode(status) != TF_OK) return;
+  }
+}
 }  // namespace
 
 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env,
@@ -264,18 +372,13 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateScalarBytes(
   char* dst = static_cast(TF_TensorData(t));
   memset(dst, 0, 8);  // The offset table
 
-  // jbyte is a signed char, while the C standard doesn't require char and
-  // signed char to be the same. As a result, static_cast(src) will
-  // complain. Copy the string instead. sigh!
+  TF_Status* status = TF_NewStatus();
   jbyte* jsrc = env->GetByteArrayElements(value, nullptr);
-  std::unique_ptr src(new char[src_len]);
-  static_assert(sizeof(jbyte) == sizeof(char),
-                "Cannot convert Java byte to a C char");
-  memcpy(src.get(), jsrc, src_len);
+  // jsrc is an unsigned byte*, TF_StringEncode requires a char*.
+  // reinterpret_cast<> for this conversion should be safe.
+  TF_StringEncode(reinterpret_cast(jsrc), src_len, dst + 8,
+                  dst_len, status);
   env->ReleaseByteArrayElements(value, jsrc, JNI_ABORT);
-
-  TF_Status* status = TF_NewStatus();
-  TF_StringEncode(src.get(), src_len, dst + 8, dst_len, status);
   if (!throwExceptionIfNotOK(env, status)) {
     TF_DeleteStatus(status);
     return 0;
@@ -284,6 +387,85 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateScalarBytes(
   return reinterpret_cast(t);
 }
 
+namespace {
+size_t nonScalarTF_STRINGTensorSize(JNIEnv* env, jarray value, int num_dims) {
+  if (num_dims == 0) {
+    // This is the last dimension, i.e., value should correspond to a jbyteArray
+    // encoding the string.
+    return TF_StringEncodedSize(
+        static_cast(env->GetArrayLength(value)));
+  }
+  jsize len = env->GetArrayLength(value);
+  size_t ret = 0;
+  for (jsize i = 0; i < len; ++i) {
+    jarray elem = static_cast(
+        env->GetObjectArrayElement(static_cast(value), i));
+    ret += nonScalarTF_STRINGTensorSize(env, elem, num_dims - 1);
+  }
+  return ret;
+}
+
+void fillNonScalarTF_STRINGTensorData(JNIEnv* env, jarray value, int num_dims,
+                                      StringTensorWriter* writer,
+                                      TF_Status* status) {
+  if (num_dims == 0) {
+    jbyte* jsrc =
+        env->GetByteArrayElements(static_cast(value), nullptr);
+    writer->Add(reinterpret_cast(jsrc), env->GetArrayLength(value),
+                status);
+    env->ReleaseByteArrayElements(static_cast(value), jsrc,
+                                  JNI_ABORT);
+    return;
+  }
+  jsize len = env->GetArrayLength(value);
+  for (jsize i = 0; i < len; ++i) {
+    jarray elem = static_cast(
+        env->GetObjectArrayElement(static_cast(value), i));
+    if (TF_GetCode(status) != TF_OK) return;
+    fillNonScalarTF_STRINGTensorData(env, elem, num_dims - 1, writer, status);
+  }
+}
+}  // namespace
+
+JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateNonScalarBytes(
+    JNIEnv* env, jclass clazz, jlongArray shape, jobjectArray value) {
+  // TF_STRING tensors are encoded with a table of 8-byte offsets following by
+  // TF_StringEncode-encoded bytes.
+  const int num_dims = static_cast(env->GetArrayLength(shape));
+  int64_t* dims = new int64_t[num_dims];
+  int64_t num_elements = 1;
+  {
+    jlong* jdims = env->GetLongArrayElements(shape, nullptr);
+    for (int i = 0; i < num_dims; ++i) {
+      dims[i] = static_cast(jdims[i]);
+      num_elements *= dims[i];
+    }
+    env->ReleaseLongArrayElements(shape, jdims, JNI_ABORT);
+  }
+  const size_t encoded_size =
+      nonScalarTF_STRINGTensorSize(env, value, num_dims);
+  TF_Tensor* t = TF_AllocateTensor(TF_STRING, dims, num_dims,
+                                   8 * num_elements + encoded_size);
+  if (t == nullptr) {
+    delete[] dims;
+    throwException(env, kNullPointerException,
+                   "unable to allocate memory for the Tensor");
+    return 0;
+  }
+  TF_Status* status = TF_NewStatus();
+  StringTensorWriter writer(t, num_elements);
+  fillNonScalarTF_STRINGTensorData(env, value, num_dims, &writer, status);
+  delete[] dims;
+  jlong ret = 0;
+  if (!throwExceptionIfNotOK(env, status)) {
+    TF_DeleteTensor(t);
+  } else {
+    ret = reinterpret_cast(t);
+  }
+  TF_DeleteStatus(status);
+  return ret;
+}
+
 JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv* env,
                                                          jclass clazz,
                                                          jlong handle) {
@@ -292,8 +474,8 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv* env,
 }
 
 JNIEXPORT jobject JNICALL Java_org_tensorflow_Tensor_buffer(JNIEnv* env,
-                                                              jclass clazz,
-                                                              jlong handle) {
+                                                            jclass clazz,
+                                                            jlong handle) {
   TF_Tensor* t = requireHandle(env, handle);
   if (t == nullptr) return nullptr;
   void* data = TF_TensorData(t);
@@ -393,17 +575,9 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Tensor_scalarBytes(
                    "invalid tensor encoding: bad offsets");
     return nullptr;
   }
-  jbyteArray ret = nullptr;
-  const char* dst = nullptr;
-  size_t dst_len = 0;
   TF_Status* status = TF_NewStatus();
-  TF_StringDecode(src, src_len, &dst, &dst_len, status);
-  if (throwExceptionIfNotOK(env, status)) {
-    ret = env->NewByteArray(dst_len);
-    jbyte* cpy = env->GetByteArrayElements(ret, nullptr);
-    memcpy(cpy, dst, dst_len);
-    env->ReleaseByteArrayElements(ret, cpy, 0);
-  }
+  jbyteArray ret = TF_StringDecodeTojbyteArray(env, src, src_len, status);
+  throwExceptionIfNotOK(env, status);
   TF_DeleteStatus(status);
   return ret;
 }
@@ -424,6 +598,19 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_readNDArray(JNIEnv* env,
                    "accessor (floatValue(), intValue() etc.) instead");
     return;
   }
+  if (dtype == TF_STRING) {
+    int64_t num_elements = 1;
+    for (int i = 0; i < num_dims; ++i) {
+      num_elements *= TF_Dim(t, i);
+    }
+    StringTensorReader reader(t, num_elements);
+    TF_Status* status = TF_NewStatus();
+    readNDStringArray(env, &reader, num_dims, static_cast(value),
+                      status);
+    throwExceptionIfNotOK(env, status);
+    TF_DeleteStatus(status);
+    return;
+  }
   readNDArray(env, dtype, static_cast(data), sz, num_dims,
               static_cast(value));
 }
diff --git a/tensorflow/java/src/main/native/tensor_jni.h b/tensorflow/java/src/main/native/tensor_jni.h
index 70850d250b8b3921e444a64eb1cae28dd9fe3720..a300936884c0bf25a6d92aa7e2b7b36abd85d646 100644
--- a/tensorflow/java/src/main/native/tensor_jni.h
+++ b/tensorflow/java/src/main/native/tensor_jni.h
@@ -28,7 +28,8 @@ extern "C" {
  * Signature: (I[JJ)J
  */
 JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv *, jclass,
-                                                            jint, jlongArray, jlong);
+                                                            jint, jlongArray,
+                                                            jlong);
 
 /*
  * Class:     org_tensorflow_Tensor
@@ -38,6 +39,14 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv *, jclass,
 JNIEXPORT jlong JNICALL
 Java_org_tensorflow_Tensor_allocateScalarBytes(JNIEnv *, jclass, jbyteArray);
 
+/*
+ * Class:     org_tensorflow_Tensor
+ * Method:    allocateNonScalarBytes
+ * Signature: ([J[Ljava/lang/Object;)J
+ */
+JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateNonScalarBytes(
+    JNIEnv *, jclass, jlongArray, jobjectArray);
+
 /*
  * Class:     org_tensorflow_Tensor
  * Method:    delete
@@ -52,7 +61,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv *, jclass,
  * Signature: (J)Ljava/nio/ByteBuffer;
  */
 JNIEXPORT jobject JNICALL Java_org_tensorflow_Tensor_buffer(JNIEnv *, jclass,
-                                                              jlong);
+                                                            jlong);
 
 /*
  * Class:     org_tensorflow_Tensor
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
index 3ff59e71b22651fc13256714d9b024386162d75b..bb5f9a0708564d8c50ed6284afd619f502425b5b 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java
@@ -386,6 +386,30 @@ public class TensorTest {
     }
   }
 
+  @Test
+  public void testNDimensionalStringTensor() {
+    byte[][][] matrix = new byte[4][3][];
+    for (int i = 0; i < 4; ++i) {
+      for (int j = 0; j < 3; ++j) {
+        matrix[i][j] = String.format("(%d, %d) = %d", i, j, i << j).getBytes(UTF_8);
+      }
+    }
+    try (Tensor t = Tensor.create(matrix)) {
+      assertEquals(DataType.STRING, t.dataType());
+      assertEquals(2, t.numDimensions());
+      assertArrayEquals(new long[] {4, 3}, t.shape());
+
+      byte[][][] got = t.copyTo(new byte[4][3][]);
+      assertEquals(4, got.length);
+      for (int i = 0; i < 4; ++i) {
+        assertEquals(String.format("%d", i), 3, got[i].length);
+        for (int j = 0; j < 3; ++j) {
+          assertArrayEquals(String.format("(%d, %d)", i, j), matrix[i][j], got[i][j]);
+        }
+      }
+    }
+  }
+
   @Test
   public void failCreateOnMismatchedDimensions() {
     int[][][] invalid = new int[3][1][];
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index fa2bce843a676286452207fa575a0834d3889aa5..c1e63c0d856b9ee77916a6761bf5181dbc263809 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -94,6 +94,7 @@ py_library(
         "//tensorflow/python/ops/distributions",
         "//tensorflow/python/profiler",
         "//tensorflow/python/saved_model",
+        "//tensorflow/python/keras",
     ] + if_not_windows([
         "//tensorflow/contrib:contrib_py",
     ]),
@@ -1401,9 +1402,10 @@ py_library(
     deps = [
         ":array_ops",
         ":control_flow_ops",
-        ":framework",
         ":framework_for_generated_wrappers",
         ":math_ops",
+        ":sparse_tensor",
+        ":tensor_util",
         ":util",
         "//third_party/py/numpy",
     ],
@@ -1492,9 +1494,11 @@ py_library(
         ":array_ops",
         ":control_flow_ops",
         ":data_flow_ops_gen",
-        ":framework",
         ":framework_for_generated_wrappers",
         ":math_ops",
+        ":random_seed",
+        ":tensor_util",
+        "//tensorflow/python/eager:context",
         "@six_archive//:six",
     ],
 )
@@ -1632,7 +1636,10 @@ py_library(
 
 py_library(
     name = "linalg_ns",
-    srcs = ["ops/linalg_ns.py"],
+    srcs = [
+        "ops/linalg_impl.py",
+        "ops/linalg_ns.py",
+    ],
     srcs_version = "PY2AND3",
     deps = [
         ":array_ops",
@@ -1687,10 +1694,14 @@ py_library(
     deps = [
         ":array_ops",
         ":constant_op",
-        ":framework",
+        ":control_flow_ops",
         ":framework_for_generated_wrappers",
         ":lookup_ops_gen",
         ":math_ops",
+        ":sparse_tensor",
+        ":string_ops",
+        ":util",
+        "//tensorflow/python/eager:context",
         "@six_archive//:six",
     ],
 )
@@ -1702,10 +1713,11 @@ py_library(
     deps = [
         ":array_ops",
         ":array_ops_gen",
-        ":framework",
         ":framework_for_generated_wrappers",
         ":math_ops",
         ":math_ops_gen",
+        ":tensor_util",
+        "//tensorflow/python/eager:context",
         "//third_party/py/numpy",
     ],
 )
@@ -1732,6 +1744,7 @@ py_library(
         ":state_ops_gen",
         ":tensor_shape",
         ":util",
+        "//tensorflow/python/eager:context",
         "//third_party/py/numpy",
     ],
 )
@@ -1754,6 +1767,8 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         ":array_ops",
+        ":array_ops_gen",
+        ":dtypes",
         ":framework_ops",
         ":resource_variable_ops_gen",
         ":tensor_shape",
@@ -1762,7 +1777,8 @@ py_library(
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:custom_gradient",
-        "//tensorflow/python/eager:tensor",
+        "//tensorflow/python/eager:tape",
+        "//tensorflow/python/eager:tensor_node",
     ],
 )
 
@@ -1802,6 +1818,8 @@ py_library(
         ":nn_ops",
         ":nn_ops_gen",
         ":sparse_ops",
+        ":tensor_util",
+        "//tensorflow/python/eager:context",
     ],
 )
 
@@ -2149,6 +2167,7 @@ py_library(
         ":resource_variable_ops_gen",
         ":state_ops_gen",
         ":tensor_shape",
+        "//tensorflow/python/eager:context",
     ],
 )
 
@@ -2210,6 +2229,7 @@ py_library(
         ":tensor_shape",
         ":tensor_util",
         ":util",
+        "//tensorflow/python/eager:context",
     ],
 )
 
@@ -2218,12 +2238,14 @@ py_library(
     srcs = ["ops/variable_scope.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":array_ops",
         ":dtypes",
         ":framework_ops",
         ":init_ops",
         ":platform",
         ":resource_variable_ops",
         ":tensor_shape",
+        ":util",
         ":variables",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/estimator:util",
@@ -2328,7 +2350,7 @@ cuda_py_test(
 
 cuda_py_test(
     name = "gradients_test",
-    size = "small",
+    size = "medium",
     srcs = ["ops/gradients_test.py"],
     additional_deps = [
         ":array_grad",
@@ -2531,6 +2553,7 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         ":array_ops",
+        ":checkpoint_ops_gen",
         ":client",
         ":control_flow_ops",
         ":data_flow_ops",
@@ -3555,23 +3578,47 @@ py_test(
 )
 
 py_test(
-    name = "monitored_session_test",
+    name = "checkpoint_ops_test",
     size = "small",
+    srcs = ["training/checkpoint_ops_test.py"],
+    srcs_version = "PY2AND3",
+    tags = ["no_windows"],
+    deps = [
+        ":checkpoint_ops_gen",
+        ":client",
+        ":client_testlib",
+        ":framework_for_generated_wrappers",
+        ":io_ops",
+        ":partitioned_variables",
+        ":platform",
+        ":pywrap_tensorflow",
+        ":state_ops",
+        ":training",
+        ":variable_scope",
+        ":variables",
+    ],
+)
+
+py_test(
+    name = "monitored_session_test",
+    size = "medium",
     srcs = ["training/monitored_session_test.py"],
     srcs_version = "PY2AND3",
     tags = ["no_windows"],
     deps = [
         ":array_ops",
-        ":client",
         ":client_testlib",
+        ":control_flow_ops",
         ":errors",
         ":framework_for_generated_wrappers",
+        ":session",
         ":state_ops",
         ":summary",
         ":training",
         ":variables",
         "//tensorflow/contrib/framework:framework_py",
         "//tensorflow/contrib/testing:testing_py",
+        "//tensorflow/core:protos_all_py",
     ],
 )
 
@@ -3716,6 +3763,7 @@ py_library(
         ":util",
         ":variable_scope",
         ":variables",
+        "//tensorflow/python/eager:context",
         "//third_party/py/numpy",
         "@six_archive//:six",
     ],
@@ -3730,6 +3778,7 @@ py_test(
     deps = [
         ":client_testlib",
         ":framework_for_generated_wrappers",
+        ":framework_test_lib",
         ":init_ops",
         ":layers",
         ":math_ops",
@@ -3748,6 +3797,7 @@ py_test(
         ":array_ops",
         ":client_testlib",
         ":framework_for_generated_wrappers",
+        ":framework_test_lib",
         ":layers",
         ":math_ops",
         ":nn_ops",
@@ -3767,6 +3817,7 @@ py_test(
     deps = [
         ":client_testlib",
         ":framework_for_generated_wrappers",
+        ":framework_test_lib",
         ":layers",
         ":math_ops",
         ":nn_ops",
@@ -3794,6 +3845,7 @@ py_test(
     srcs_version = "PY2AND3",
     deps = [
         ":client_testlib",
+        ":framework_test_lib",
         ":layers",
         ":random_ops",
     ],
@@ -3807,6 +3859,7 @@ cuda_py_test(
         ":array_ops",
         ":client_testlib",
         ":framework_for_generated_wrappers",
+        ":framework_test_lib",
         ":layers",
         ":math_ops",
         ":random_ops",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index acda11bd18ddb86e9c09a3c1894c2c9832b5709d..18603c21812e180afe42edbcca85ccbd8f52b424 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -80,6 +80,7 @@ from tensorflow.python.ops import linalg_ns as linalg
 # Bring in subpackages.
 from tensorflow.python.estimator import estimator_lib as estimator
 from tensorflow.python.feature_column import feature_column_lib as feature_column
+from tensorflow.python import keras
 from tensorflow.python.layers import layers
 from tensorflow.python.ops import bitwise_ops as bitwise
 from tensorflow.python.ops import image_ops as image
@@ -248,6 +249,7 @@ _allowed_symbols.extend([
     'user_ops',
     'layers',
     'profiler',
+    'keras',
 ])
 
 # Variables framework.versions:
@@ -265,7 +267,7 @@ remove_undocumented(__name__, _allowed_symbols, [
     functional_ops, histogram_ops, io_ops,
     losses, math_ops, metrics, nn, resource_loader, sets, script_ops,
     session_ops, sparse_ops, state_ops, string_ops, summary, tensor_array_ops,
-    train, layers, profiler
+    train, layers, profiler, keras
 ])
 
 # Special dunders that we choose to export:
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 08dd3922dbe8ed650c1178576f7cb6e3a0230c81..fa49e66e87bba4921137b4575e00f90077925571 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -373,6 +373,33 @@ def TF_Reset(target, containers=None, config=None):
     TF_DeleteSessionOptions(opts)
 %}
 
+// We use TF_GraphToFunction_wrapper instead of TF_GraphToFunction
+%ignore TF_GraphToFunction;
+// TF_GraphToFunction_wrapper does not use any Python methods and
+// does not require GIL to be held.
+%unignore TF_GraphToFunction_wrapper;
+
+// $input is a Python list of wrapped TF_Operations
+%typemap(in) (const std::vector* opers)
+    (std::vector opers) {
+  if ($input != Py_None) {
+    if (!PyList_Check($input)) {
+      SWIG_exception_fail(SWIG_TypeError, "$symname: expected list");
+    }
+    size_t size = PyList_Size($input);
+    for (int i = 0; i < size; ++i) {
+      PyObject* item = PyList_GetItem($input, i);
+      TF_Operation* oper_ptr;
+      SWIG_ConvertPtr(item, reinterpret_cast(&oper_ptr),
+                      $descriptor(TF_Operation*), 0);
+      opers.push_back(oper_ptr);
+    }
+    $1 = &opers;
+  } else {
+    $1 = nullptr;
+  }
+}
+
 %include "tensorflow/python/client/tf_session_helper.h"
 
 %unignoreall
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index 60a589fa8bbdf059f12585ec9dd256783f408c03..72f560fa878b8d27afab82dd0238fe98ed4c4ebf 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -337,4 +337,38 @@ std::vector TF_OperationGetControlInputs_wrapper(
   return control_inputs;
 }
 
+TF_Function* TF_GraphToFunction_wrapper(const TF_Graph* fn_body,
+                                        const char* fn_name,
+                                        const std::vector* opers,
+                                        const std::vector& inputs,
+                                        const std::vector& outputs,
+                                        const NameVector& output_names,
+                                        const TF_FunctionOptions* opts,
+                                        TF_Status* out_status) {
+  if (!output_names.empty() && output_names.size() != outputs.size()) {
+    Set_TF_Status_from_Status(
+        out_status,
+        errors::InvalidArgument(
+            "output names must be either empty or equal in size to outputs. ",
+            "output names size = ", output_names.size(),
+            " outputs size = ", outputs.size()));
+    return nullptr;
+  }
+
+  int nopers = -1;
+  const TF_Operation* const* opers_array = nullptr;
+  if (opers != nullptr) {
+    nopers = opers->size();
+    opers_array = opers->data();
+  }
+
+  const char** output_names_ptr =
+      output_names.empty() ? nullptr
+                           : const_cast(output_names.data());
+
+  return TF_GraphToFunction(fn_body, fn_name, nopers, opers_array,
+                            inputs.size(), inputs.data(), outputs.size(),
+                            outputs.data(), output_names_ptr, opts, out_status);
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index 3bc63f822fe078d56e835c71e3d3ee108271aa07..8fae6206c07f638e1d514bf720dbaf6da79cdc7b 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -148,6 +148,16 @@ void TF_SessionPRun_wrapper(TF_Session* session, const char* handle,
 std::vector TF_OperationGetControlInputs_wrapper(
     TF_Operation* oper);
 
+// `opers` equaling NULL are converted to `nopers = -1`.
+// `output_names` must be empty or have the same length as `outputs`.
+TF_Function* TF_GraphToFunction_wrapper(const TF_Graph* fn_body,
+                                        const char* fn_name,
+                                        const std::vector* opers,
+                                        const std::vector& inputs,
+                                        const std::vector& outputs,
+                                        const NameVector& output_names,
+                                        const TF_FunctionOptions* opts,
+                                        TF_Status* out_status);
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 8eb2212069142f28ce61dd4ad20d473072836741..c0926169995efbc9c7f1e0833d6dffa98a1deb83 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -49,11 +49,26 @@ py_library(
     ]),
 )
 
+py_library(
+    name = "debug_graphs",
+    srcs = ["lib/debug_graphs.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow/core:protos_all_py",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:op_def_registry",
+        "//tensorflow/python:platform",
+        "//tensorflow/python:tensor_util",
+        "@six_archive//:six",
+    ],
+)
+
 py_library(
     name = "debug_data",
     srcs = ["lib/debug_data.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":debug_graphs",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:framework",
         "//tensorflow/python:op_def_registry",
@@ -70,6 +85,7 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         ":debug_data",
+        ":debug_graphs",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:framework",
         "//tensorflow/python:platform",
@@ -99,6 +115,7 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         ":debug_data",
+        ":debug_graphs",
         ":debug_utils",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:framework_for_generated_wrappers",
@@ -181,7 +198,7 @@ py_library(
     deps = [
         ":cli_shared",
         ":command_parser",
-        ":debug_data",
+        ":debug_graphs",
         ":debugger_cli_common",
         ":evaluator",
         ":source_utils",
@@ -400,6 +417,18 @@ py_binary(
     ],
 )
 
+py_test(
+    name = "debug_graphs_test",
+    size = "small",
+    srcs = ["lib/debug_graphs_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":debug_graphs",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_test_lib",
+    ],
+)
+
 py_test(
     name = "debug_data_test",
     size = "small",
@@ -569,6 +598,7 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         ":debug_data",
+        ":debug_graphs",
         ":debug_utils",
         "//tensorflow/core:protos_all_py",
         "//tensorflow/python:array_ops",
@@ -608,7 +638,7 @@ py_library(
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
     deps = [
-        ":debug_data",
+        ":debug_graphs",
         ":debug_service_pb2_grpc",
         "//tensorflow/core/debug:debug_service_proto_py",
         "@six_archive//:six",
diff --git a/tensorflow/python/debug/cli/analyzer_cli.py b/tensorflow/python/debug/cli/analyzer_cli.py
index 22e451e38cff722fc7f7d98a595430ee6a4a2f3b..50850bbc0dc4a6d9be5ae1596d9e3cfacb81000c 100644
--- a/tensorflow/python/debug/cli/analyzer_cli.py
+++ b/tensorflow/python/debug/cli/analyzer_cli.py
@@ -34,7 +34,7 @@ from tensorflow.python.debug.cli import command_parser
 from tensorflow.python.debug.cli import debugger_cli_common
 from tensorflow.python.debug.cli import evaluator
 from tensorflow.python.debug.cli import ui_factory
-from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.lib import debug_graphs
 from tensorflow.python.debug.lib import source_utils
 
 RL = debugger_cli_common.RichLine
@@ -716,7 +716,7 @@ class DebugAnalyzer(object):
 
     # Get a node name, regardless of whether the input is a node name (without
     # output slot attached) or a tensor name (with output slot attached).
-    node_name, unused_slot = debug_data.parse_node_or_tensor_name(
+    node_name, unused_slot = debug_graphs.parse_node_or_tensor_name(
         parsed.node_name)
 
     if not self._debug_dump.node_exists(node_name):
@@ -840,7 +840,7 @@ class DebugAnalyzer(object):
         parsed.op_type,
         do_outputs=False)
 
-    node_name = debug_data.get_node_name(parsed.node_name)
+    node_name = debug_graphs.get_node_name(parsed.node_name)
     _add_main_menu(output, node_name=node_name, enable_list_inputs=False)
 
     return output
@@ -871,7 +871,7 @@ class DebugAnalyzer(object):
     tensor_name, tensor_slicing = (
         command_parser.parse_tensor_name_with_slicing(parsed.tensor_name))
 
-    node_name, output_slot = debug_data.parse_node_or_tensor_name(tensor_name)
+    node_name, output_slot = debug_graphs.parse_node_or_tensor_name(tensor_name)
     if (self._debug_dump.loaded_partition_graphs() and
         not self._debug_dump.node_exists(node_name)):
       output = cli_shared.error(
@@ -1016,7 +1016,7 @@ class DebugAnalyzer(object):
         parsed.op_type,
         do_outputs=True)
 
-    node_name = debug_data.get_node_name(parsed.node_name)
+    node_name = debug_graphs.get_node_name(parsed.node_name)
     _add_main_menu(output, node_name=node_name, enable_list_outputs=False)
 
     return output
@@ -1087,7 +1087,7 @@ class DebugAnalyzer(object):
 
           label = RL(" " * 4)
           if self._debug_dump.debug_watch_keys(
-              debug_data.get_node_name(element)):
+              debug_graphs.get_node_name(element)):
             attribute = debugger_cli_common.MenuItem("", "pt %s" % element)
           else:
             attribute = cli_shared.COLOR_BLUE
@@ -1246,7 +1246,7 @@ class DebugAnalyzer(object):
     font_attr_segs = {}
 
     # Check if this is a tensor name, instead of a node name.
-    node_name, _ = debug_data.parse_node_or_tensor_name(node_name)
+    node_name, _ = debug_graphs.parse_node_or_tensor_name(node_name)
 
     # Check if node exists.
     if not self._debug_dump.node_exists(node_name):
@@ -1395,7 +1395,7 @@ class DebugAnalyzer(object):
       # Recursive call.
       # The input's/output's name can be a tensor name, in the case of node
       # with >1 output slots.
-      inp_node_name, _ = debug_data.parse_node_or_tensor_name(inp)
+      inp_node_name, _ = debug_graphs.parse_node_or_tensor_name(inp)
       self._dfs_from_node(
           lines,
           attr_segs,
diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py
index b2b3ec5d47036f6f83853bfd370a0dcf9d30fdd8..9ea279c004972013c9446cfa0db2a4b79d8cb664 100644
--- a/tensorflow/python/debug/lib/debug_data.py
+++ b/tensorflow/python/debug/lib/debug_data.py
@@ -26,14 +26,14 @@ import platform
 
 import numpy as np
 import six
-from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.core.framework import graph_pb2
 from tensorflow.core.framework import types_pb2
 from tensorflow.core.util import event_pb2
-from tensorflow.python.framework import op_def_registry
+from tensorflow.python.debug.lib import debug_graphs
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import compat
 
 
@@ -155,30 +155,6 @@ def _load_log_message_from_event_file(event_file_path):
   return event.log_message.message
 
 
-def parse_node_or_tensor_name(name):
-  """Get the node name from a string that can be node or tensor name.
-
-  Args:
-    name: An input node name (e.g., "node_a") or tensor name (e.g.,
-      "node_a:0"), as a str.
-
-  Returns:
-    1) The node name, as a str. If the input name is a tensor name, i.e.,
-      consists of a colon, the final colon and the following output slot
-      will be stripped.
-    2) If the input name is a tensor name, the output slot, as an int. If
-      the input name is not a tensor name, None.
-  """
-
-  if ":" in name and not name.endswith(":"):
-    node_name = name[:name.rfind(":")]
-    output_slot = int(name[name.rfind(":") + 1:])
-
-    return node_name, output_slot
-  else:
-    return name, None
-
-
 def _is_graph_file(file_name):
   return file_name.startswith(METADATA_FILE_PREFIX + GRAPH_FILE_TAG)
 
@@ -191,25 +167,6 @@ def _is_run_feed_keys_info_file(file_name):
   return file_name == METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG
 
 
-def get_node_name(element_name):
-  return element_name.split(":")[0] if ":" in element_name else element_name
-
-
-def get_output_slot(element_name):
-  """Get the output slot number from the name of a graph element.
-
-  If element_name is a node name without output slot at the end, 0 will be
-  assumed.
-
-  Args:
-    element_name: (`str`) name of the graph element in question.
-
-  Returns:
-    (`int`) output slot number.
-  """
-  return int(element_name.split(":")[-1]) if ":" in element_name else 0
-
-
 def _get_tensor_name(node_name, output_slot):
   """Get tensor name given node name and output slot index.
 
@@ -241,78 +198,6 @@ def _get_tensor_watch_key(node_name, output_slot, debug_op):
   return "%s:%s" % (_get_tensor_name(node_name, output_slot), debug_op)
 
 
-def is_copy_node(node_name):
-  """Determine whether a node name is that of a debug Copy node.
-
-  Such nodes are inserted by TensorFlow core upon request in
-  RunOptions.debug_options.debug_tensor_watch_opts.
-
-  Args:
-    node_name: Name of the node.
-
-  Returns:
-    A bool indicating whether the input argument is the name of a debug Copy
-    node.
-  """
-  return node_name.startswith("__copy_")
-
-
-def is_debug_node(node_name):
-  """Determine whether a node name is that of a debug node.
-
-  Such nodes are inserted by TensorFlow core upon request in
-  RunOptions.debug_options.debug_tensor_watch_opts.
-
-  Args:
-    node_name: Name of the node.
-
-  Returns:
-    A bool indicating whether the input argument is the name of a debug node.
-  """
-  return node_name.startswith("__dbg_")
-
-
-def parse_debug_node_name(node_name):
-  """Parse the name of a debug node.
-
-  Args:
-    node_name: Name of the debug node.
-
-  Returns:
-    1. Name of the watched node, as a str.
-    2. Output slot index of the watched tensor, as an int.
-    3. Index of the debug node, as an int.
-    4. Name of the debug op, as a str, e.g, "DebugIdentity".
-
-  Raises:
-    ValueError: If the input node name is not a valid debug node name.
-  """
-  prefix = "__dbg_"
-
-  name = node_name
-  if not name.startswith(prefix):
-    raise ValueError("Invalid prefix in debug node name: '%s'" % node_name)
-
-  name = name[len(prefix):]
-
-  if name.count("_") < 2:
-    raise ValueError("Invalid debug node name: '%s'" % node_name)
-
-  debug_op = name[name.rindex("_") + 1:]
-  name = name[:name.rindex("_")]
-
-  debug_op_index = int(name[name.rindex("_") + 1:])
-  name = name[:name.rindex("_")]
-
-  if name.count(":") != 1:
-    raise ValueError("Invalid tensor name in debug node name: '%s'" % node_name)
-
-  watched_node_name = name[:name.index(":")]
-  watched_output_slot = int(name[name.index(":") + 1:])
-
-  return watched_node_name, watched_output_slot, debug_op_index, debug_op
-
-
 def has_inf_or_nan(datum, tensor):
   """A predicate for whether a tensor consists of any bad numerical values.
 
@@ -573,88 +458,6 @@ class WatchKeyDoesNotExistInDebugDumpDirError(ValueError):
   pass
 
 
-class _GraphTracingReachedDestination(Exception):
-  pass
-
-
-class _DFSGraphTracer(object):
-  """Graph input tracer using depth-first search."""
-
-  def __init__(self,
-               input_lists,
-               skip_node_names=None,
-               destination_node_name=None):
-    """Constructor of _DFSGraphTracer.
-
-    Args:
-      input_lists: A list of dicts. Each dict is an adjacency (input) map from
-        the recipient node name as the key and the list of input node names
-        as the value.
-      skip_node_names: Optional: a list of node names to skip tracing.
-      destination_node_name: Optional: destination node name. If not `None`, it
-        should be the name of a destination not as a str and the graph tracing
-        will raise GraphTracingReachedDestination as soon as the node has been
-        reached.
-
-    Raises:
-      _GraphTracingReachedDestination: if stop_at_node_name is not None and
-        the specified node is reached.
-    """
-
-    self._input_lists = input_lists
-    self._skip_node_names = skip_node_names
-
-    self._inputs = []
-    self._visited_nodes = []
-    self._depth_count = 0
-    self._depth_list = []
-
-    self._destination_node_name = destination_node_name
-
-  def trace(self, graph_element_name):
-    """Trace inputs.
-
-    Args:
-      graph_element_name: Name of the node or an output tensor of the node, as a
-        str.
-
-    Raises:
-      _GraphTracingReachedDestination: if destination_node_name of this tracer
-        object is not None and the specified node is reached.
-    """
-    self._depth_count += 1
-
-    node_name = get_node_name(graph_element_name)
-
-    if node_name == self._destination_node_name:
-      raise _GraphTracingReachedDestination()
-
-    if node_name in self._skip_node_names:
-      return
-    if node_name in self._visited_nodes:
-      return
-
-    self._visited_nodes.append(node_name)
-
-    for input_list in self._input_lists:
-      for inp in input_list[node_name]:
-        if get_node_name(inp) in self._visited_nodes:
-          continue
-        self._inputs.append(inp)
-        self._depth_list.append(self._depth_count)
-        self.trace(inp)
-
-    self._depth_count -= 1
-
-  def inputs(self):
-    return self._inputs
-
-  def depth_list(self):
-    return self._depth_list
-
-
-# TODO(cais): This class is getting too large in line count. Refactor to make it
-# smaller and easier to maintain.
 class DebugDumpDir(object):
   """Data set from a debug-dump directory on filesystem.
 
@@ -963,52 +766,36 @@ class DebugDumpDir(object):
       ValueError: If the partition GraphDef of one or more devices fail to be
         loaded.
     """
-
-    self._node_attributes = {}
-    self._node_inputs = {}
-    self._node_reversed_ref_inputs = {}
-    self._node_ctrl_inputs = {}
-    self._node_recipients = {}
-    self._node_ctrl_recipients = {}
+    self._debug_graphs = {}
     self._node_devices = {}
-    self._node_op_types = {}
-    self._copy_send_nodes = {}
-    self._ref_args = {}
-
-    self._partition_graphs = {}
-    for device_name in self._device_names:
-      partition_graph = None
-      if device_name in self._dump_graph_file_paths:
-        partition_graph = _load_graph_def_from_event_file(
-            self._dump_graph_file_paths[device_name])
-      else:
-        partition_graph = self._find_partition_graph(partition_graphs,
-                                                     device_name)
-
-      if partition_graph:
-        self._partition_graphs[device_name] = partition_graph
 
-      self._node_attributes[device_name] = {}
-      self._node_inputs[device_name] = {}
-      self._node_reversed_ref_inputs[device_name] = {}
-      self._node_ctrl_inputs[device_name] = {}
-      self._node_recipients[device_name] = {}
-      self._node_ctrl_recipients[device_name] = {}
-      self._node_op_types[device_name] = {}
-      self._copy_send_nodes[device_name] = []
-      self._ref_args[device_name] = []
-
-      if partition_graph:
-        for node in partition_graph.node:
-          self._process_partition_graph_node(device_name, node)
-
-      self._prune_non_control_edges_of_debug_ops(device_name)
-      self._prune_control_edges_of_debug_ops(device_name)
+    if partition_graphs:
+      partition_graphs_and_device_names = [
+          (partition_graph, None) for partition_graph in partition_graphs]
+    else:
+      partition_graphs_and_device_names = []
+      for device_name in self._device_names:
+        partition_graph = None
+        if device_name in self._dump_graph_file_paths:
+          partition_graph = _load_graph_def_from_event_file(
+              self._dump_graph_file_paths[device_name])
+        else:
+          partition_graph = self._find_partition_graph(partition_graphs,
+                                                       device_name)
+        if partition_graph:
+          partition_graphs_and_device_names.append((partition_graph,
+                                                    device_name))
+        else:
+          logging.warn("Failed to load partition graphs from disk.")
 
-      self._populate_recipient_maps(device_name)
+    for partition_graph, maybe_device_name in partition_graphs_and_device_names:
+      debug_graph = debug_graphs.DebugGraph(partition_graph,
+                                            device_name=maybe_device_name)
+      self._debug_graphs[debug_graph.device_name] = debug_graph
+      self._collect_node_devices(debug_graph)
 
-      if device_name in self._partition_graphs and validate:
-        self._validate_dump_with_graphs(device_name)
+      if validate and debug_graph.device_name in self._dump_tensor_data:
+        self._validate_dump_with_graphs(debug_graph.device_name)
 
   def _find_partition_graph(self, partition_graphs, device_name):
     if partition_graphs is None:
@@ -1020,167 +807,13 @@ class DebugDumpDir(object):
             return graph_def
       return None
 
-  def _get_ref_args(self, node):
-    """Determine whether an input of an op is ref-type.
-
-    Args:
-      node: A `NodeDef`.
-
-    Returns:
-      A list of the arg names (as strs) that are ref-type.
-    """
-
-    op_def = op_def_registry.get_registered_ops().get(node.op)
-    ref_args = []
-    if op_def:
-      for i, output_arg in enumerate(op_def.output_arg):
-        if output_arg.is_ref:
-          arg_name = node.name if i == 0 else (node.name + ":%d" % i)
-          ref_args.append(arg_name)
-    return ref_args
-
-  def _process_partition_graph_node(self, device_name, node):
-    """Process a node from the partition graphs.
-
-    Args:
-      device_name: (str) device name.
-      node: (NodeDef) A partition-graph node to be processed.
-
-    Raises:
-      ValueError: If duplicate node names are encountered.
-    """
-
-    if is_debug_node(node.name):
-      # This is a debug node. Parse the node name and retrieve the
-      # information about debug watches on tensors. But do not include
-      # the node in the graph.
-      (watched_node_name, watched_output_slot, _,
-       debug_op) = parse_debug_node_name(node.name)
-
-      self._debug_watches[device_name][watched_node_name][
-          watched_output_slot].add(debug_op)
-
-      return
-
-    if node.name in self._node_inputs[device_name]:
-      raise ValueError("Duplicate node name on device %s: '%s'" %
-                       (device_name, node.name))
-
-    self._node_attributes[device_name][node.name] = node.attr
-
-    self._node_inputs[device_name][node.name] = []
-    self._node_ctrl_inputs[device_name][node.name] = []
-    self._node_recipients[device_name][node.name] = []
-    self._node_ctrl_recipients[device_name][node.name] = []
-
-    if node.name not in self._node_devices:
-      self._node_devices[node.name] = set()
-    self._node_devices[node.name].add(node.device)
-    self._node_op_types[device_name][node.name] = node.op
-    self._ref_args[device_name].extend(self._get_ref_args(node))
-
-    for inp in node.input:
-      if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"):
-        self._copy_send_nodes[device_name].append(node.name)
-
-      if inp.startswith("^"):
-        cinp = inp[1:]
-        self._node_ctrl_inputs[device_name][node.name].append(cinp)
+  def _collect_node_devices(self, debug_graph):
+    for node_name in debug_graph.node_devices:
+      if node_name in self._node_devices:
+        self._node_devices[node_name] = self._node_devices[node_name].union(
+            debug_graph.node_devices[node_name])
       else:
-        self._node_inputs[device_name][node.name].append(inp)
-
-  def _prune_nodes_from_input_and_recipient_maps(self,
-                                                 device_name,
-                                                 nodes_to_prune):
-    """Prune nodes out of input and recipient maps.
-
-    Args:
-      device_name: (`str`) device name.
-      nodes_to_prune: (`list` of `str`) Names of the nodes to be pruned.
-    """
-
-    for node in nodes_to_prune:
-      del self._node_inputs[device_name][node]
-      del self._node_ctrl_inputs[device_name][node]
-      del self._node_recipients[device_name][node]
-      del self._node_ctrl_recipients[device_name][node]
-
-  def _prune_non_control_edges_of_debug_ops(self, device_name):
-    """Prune (non-control) edges related to debug ops.
-
-    Prune the Copy ops and associated _Send ops inserted by the debugger out
-    from the non-control inputs and output recipients map. Replace the inputs
-    and recipients with original ones.
-
-    Args:
-      device_name: (`str`) device name.
-    """
-
-    copy_nodes = []
-    for node in self._node_inputs[device_name]:
-      if node in self._copy_send_nodes[device_name]:
-        continue
-
-      if is_copy_node(node):
-        copy_nodes.append(node)
-
-      inputs = self._node_inputs[device_name][node]
-
-      for i in xrange(len(inputs)):
-        inp = inputs[i]
-        if is_copy_node(inp):
-          # Find the input to the Copy node, which should be the original
-          # input to the node.
-          orig_inp = self._node_inputs[device_name][inp][0]
-          inputs[i] = orig_inp
-
-    self._prune_nodes_from_input_and_recipient_maps(device_name, copy_nodes)
-    self._prune_nodes_from_input_and_recipient_maps(
-        device_name, self._copy_send_nodes[device_name])
-
-  def _prune_control_edges_of_debug_ops(self, device_name):
-    """Prune control edges related to the debug ops."""
-
-    for node in self._node_ctrl_inputs[device_name]:
-      ctrl_inputs = self._node_ctrl_inputs[device_name][node]
-      debug_op_inputs = []
-      for ctrl_inp in ctrl_inputs:
-        if is_debug_node(ctrl_inp):
-          debug_op_inputs.append(ctrl_inp)
-      for debug_op_inp in debug_op_inputs:
-        ctrl_inputs.remove(debug_op_inp)
-
-  def _populate_recipient_maps(self, device_name):
-    """Populate the map from node name to recipient(s) of its output(s).
-
-    This method also populates the input map based on reversed ref edges.
-
-    Args:
-      device_name: name of device.
-    """
-
-    for node in self._node_inputs[device_name]:
-      inputs = self._node_inputs[device_name][node]
-      for inp in inputs:
-        inp = get_node_name(inp)
-        if inp not in self._node_recipients[device_name]:
-          self._node_recipients[device_name][inp] = []
-        self._node_recipients[device_name][inp].append(node)
-
-        if inp in self._ref_args[device_name]:
-          if inp not in self._node_reversed_ref_inputs[device_name]:
-            self._node_reversed_ref_inputs[device_name][inp] = []
-          self._node_reversed_ref_inputs[device_name][inp].append(node)
-
-    for node in self._node_ctrl_inputs[device_name]:
-      ctrl_inputs = self._node_ctrl_inputs[device_name][node]
-      for ctrl_inp in ctrl_inputs:
-        if ctrl_inp in self._copy_send_nodes[device_name]:
-          continue
-
-        if ctrl_inp not in self._node_ctrl_recipients[device_name]:
-          self._node_ctrl_recipients[device_name][ctrl_inp] = []
-        self._node_ctrl_recipients[device_name][ctrl_inp].append(node)
+        self._node_devices[node_name] = debug_graph.node_devices[node_name]
 
   def _validate_dump_with_graphs(self, device_name):
     """Validate the dumped tensor data against the partition graphs.
@@ -1197,31 +830,31 @@ class DebugDumpDir(object):
         Or if the temporal order of the dump's timestamps violate the
         input relations on the partition graphs.
     """
-
-    if not self._partition_graphs[device_name]:
+    if not self._debug_graphs:
       raise LookupError(
           "No partition graphs loaded for device %s" % device_name)
+    debug_graph = self._debug_graphs[device_name]
 
     # Verify that the node names in the dump data are all present in the
     # partition graphs.
     for datum in self._dump_tensor_data[device_name]:
-      if datum.node_name not in self._node_inputs[device_name]:
+      if datum.node_name not in debug_graph.node_inputs:
         raise ValueError("Node name '%s' is not found in partition graphs of "
                          "device %s." % (datum.node_name, device_name))
 
     pending_inputs = {}
-    for node in self._node_inputs[device_name]:
+    for node in debug_graph.node_inputs:
       pending_inputs[node] = []
-      inputs = self._node_inputs[device_name][node]
+      inputs = debug_graph.node_inputs[node]
       for inp in inputs:
-        inp_node = get_node_name(inp)
-        inp_output_slot = get_output_slot(inp)
+        inp_node = debug_graphs.get_node_name(inp)
+        inp_output_slot = debug_graphs.get_output_slot(inp)
         # Inputs from Enter and NextIteration nodes are not validated because
         # DebugNodeInserter::InsertNodes() in the debugger core skips creating
         # control edges from debug ops watching these types of nodes.
         if (inp_node in self._debug_watches[device_name] and
             inp_output_slot in self._debug_watches[device_name][inp_node] and
-            self._node_op_types[device_name].get(inp) not in (
+            debug_graph.node_op_types.get(inp) not in (
                 "Enter", "NextIteration") and
             (inp_node, inp_output_slot) not in pending_inputs[node]):
           pending_inputs[node].append((inp_node, inp_output_slot))
@@ -1240,7 +873,7 @@ class DebugDumpDir(object):
                          "these input(s) are not satisfied: %s" %
                          (node, datum.timestamp, repr(pending_inputs[node])))
 
-      recipients = self._node_recipients[device_name][node]
+      recipients = debug_graph.node_recipients[node]
       for recipient in recipients:
         recipient_pending_inputs = pending_inputs[recipient]
         if (node, slot) in recipient_pending_inputs:
@@ -1285,7 +918,7 @@ class DebugDumpDir(object):
 
   def loaded_partition_graphs(self):
     """Test whether partition graphs have been loaded."""
-    return self._partition_graphs is not None
+    return bool(self._debug_graphs)
 
   def partition_graphs(self):
     """Get the partition graphs.
@@ -1296,11 +929,10 @@ class DebugDumpDir(object):
     Raises:
       LookupError: If no partition graphs have been loaded.
     """
-
-    if self._partition_graphs is None:
+    if not self._debug_graphs:
       raise LookupError("No partition graphs have been loaded.")
-
-    return self._partition_graphs.values()
+    return [self._debug_graphs[key].debug_graph_def
+            for key in self._debug_graphs]
 
   @property
   def run_fetches_info(self):
@@ -1380,17 +1012,17 @@ class DebugDumpDir(object):
       LookupError: If no partition graphs have been loaded.
       ValueError: If specified node name does not exist.
     """
-    if self._partition_graphs is None:
+    if not self._debug_graphs:
       raise LookupError("No partition graphs have been loaded.")
     if device_name is None:
       nodes = []
-      for device_name in self._node_inputs:
-        nodes.extend(self._node_inputs[device_name].keys())
+      for device_name in self._debug_graphs:
+        nodes.extend(self._debug_graphs[device_name].node_inputs.keys())
       return nodes
     else:
-      if device_name not in self._node_inputs:
+      if device_name not in self._debug_graphs:
         raise ValueError("Invalid device name: %s" % device_name)
-      return self._node_inputs[device_name].keys()
+      return self._debug_graphs[device_name].node_inputs.keys()
 
   def node_attributes(self, node_name, device_name=None):
     """Get the attributes of a node.
@@ -1406,11 +1038,11 @@ class DebugDumpDir(object):
     Raises:
       LookupError: If no partition graphs have been loaded.
     """
-    if self._partition_graphs is None:
+    if not self._debug_graphs:
       raise LookupError("No partition graphs have been loaded.")
 
     device_name = self._infer_device_name(device_name, node_name)
-    return self._node_attributes[device_name][node_name]
+    return self._debug_graphs[device_name].node_attributes[node_name]
 
   def node_inputs(self, node_name, is_control=False, device_name=None):
     """Get the inputs of given node according to partition graphs.
@@ -1429,16 +1061,15 @@ class DebugDumpDir(object):
       LookupError: If node inputs and control inputs have not been loaded
          from partition graphs yet.
     """
-
-    if self._partition_graphs is None:
+    if not self._debug_graphs:
       raise LookupError(
           "Node inputs are not loaded from partition graphs yet.")
 
     device_name = self._infer_device_name(device_name, node_name)
     if is_control:
-      return self._node_ctrl_inputs[device_name][node_name]
+      return self._debug_graphs[device_name].node_ctrl_inputs[node_name]
     else:
-      return self._node_inputs[device_name][node_name]
+      return self._debug_graphs[device_name].node_inputs[node_name]
 
   def transitive_inputs(self,
                         node_name,
@@ -1466,19 +1097,19 @@ class DebugDumpDir(object):
       LookupError: If node inputs and control inputs have not been loaded
          from partition graphs yet.
     """
-
-    if self._partition_graphs is None:
+    if not self._debug_graphs:
       raise LookupError(
           "Node inputs are not loaded from partition graphs yet.")
 
     device_name = self._infer_device_name(device_name, node_name)
 
-    input_lists = [self._node_inputs[device_name]]
+    input_lists = [self._debug_graphs[device_name].node_inputs]
     if include_control:
-      input_lists.append(self._node_ctrl_inputs[device_name])
+      input_lists.append(self._debug_graphs[device_name].node_ctrl_inputs)
     if include_reversed_ref:
-      input_lists.append(self._node_reversed_ref_inputs[device_name])
-    tracer = _DFSGraphTracer(
+      input_lists.append(
+          self._debug_graphs[device_name].node_reversed_ref_inputs)
+    tracer = debug_graphs.DFSGraphTracer(
         input_lists,
         skip_node_names=self._get_merge_node_names(device_name))
     tracer.trace(node_name)
@@ -1492,9 +1123,10 @@ class DebugDumpDir(object):
     if not hasattr(self, "_merge_node_names"):
       self._merge_node_names = {}
     if device_name not in self._merge_node_names:
+      debug_graph = self._debug_graphs[device_name]
       self._merge_node_names[device_name] = [
-          node for node in self._node_op_types[device_name]
-          if self._node_op_types[device_name][node] == "Merge"]
+          node for node in debug_graph.node_op_types
+          if debug_graph.node_op_types[node] == "Merge"]
     return self._merge_node_names[device_name]
 
   def find_some_path(self,
@@ -1546,12 +1178,13 @@ class DebugDumpDir(object):
           "%s vs. %s" % (src_node_name, dst_node_name, src_device_name,
                          dst_device_name))
 
-    input_lists = [self._node_inputs[dst_device_name]]
+    input_lists = [self._debug_graphs[dst_device_name].node_inputs]
+    debug_graph = self._debug_graphs[dst_device_name]
     if include_control:
-      input_lists.append(self._node_ctrl_inputs[dst_device_name])
+      input_lists.append(debug_graph.node_ctrl_inputs)
     if include_reversed_ref:
-      input_lists.append(self._node_reversed_ref_inputs[dst_device_name])
-    tracer = _DFSGraphTracer(
+      input_lists.append(debug_graph.node_reversed_ref_inputs)
+    tracer = debug_graphs.DFSGraphTracer(
         input_lists,
         skip_node_names=self._get_merge_node_names(dst_device_name),
         destination_node_name=src_node_name)
@@ -1561,7 +1194,7 @@ class DebugDumpDir(object):
 
     try:
       tracer.trace(dst_node_name)
-    except _GraphTracingReachedDestination:
+    except debug_graphs.GraphTracingReachedDestination:
       # Prune nodes not on the path.
       inputs = [dst_node_name] + tracer.inputs()
       depth_list = [0] + tracer.depth_list()
@@ -1592,15 +1225,16 @@ class DebugDumpDir(object):
          from partition graphs yet.
     """
 
-    if self._partition_graphs is None:
+    if not self._debug_graphs:
       raise LookupError(
           "Node recipients are not loaded from partition graphs yet.")
 
     device_name = self._infer_device_name(device_name, node_name)
+    debug_graph = self._debug_graphs[device_name]
     if is_control:
-      return self._node_ctrl_recipients[device_name][node_name]
+      return debug_graph.node_ctrl_recipients[node_name]
     else:
-      return self._node_recipients[device_name][node_name]
+      return debug_graph.node_recipients[node_name]
 
   def devices(self):
     """Get the list of device names.
@@ -1608,7 +1242,6 @@ class DebugDumpDir(object):
     Returns:
       (`list` of `str`) names of the devices.
     """
-
     return self._device_names
 
   def node_exists(self, node_name, device_name=None):
@@ -1627,20 +1260,18 @@ class DebugDumpDir(object):
       LookupError: If no partition graphs have been loaded yet.
       ValueError: If device_name is specified but cannot be found.
     """
-
-    if self._node_inputs is None:
+    if not self._debug_graphs:
       raise LookupError(
           "Nodes have not been loaded from partition graphs yet.")
 
-    if (device_name is not None) and device_name not in self._node_inputs:
+    if (device_name is not None) and device_name not in self._debug_graphs:
       raise ValueError(
           "The specified device_name '%s' cannot be found." % device_name)
 
-    node_inputs_all_devices = (self._node_inputs if device_name is None
-                               else (self._node_inputs[device_name],))
-
-    return any(node_name in node_inputs_all_devices[dev_name]
-               for dev_name in node_inputs_all_devices)
+    for _, debug_graph in self._debug_graphs.items():
+      if node_name in debug_graph.node_inputs:
+        return True
+    return False
 
   def node_device(self, node_name):
     """Get the names of the devices that has nodes of the specified name.
@@ -1658,8 +1289,7 @@ class DebugDumpDir(object):
          from partition graphs yet.
       ValueError: If the node does not exist in partition graphs.
     """
-
-    if self._partition_graphs is None:
+    if not self._debug_graphs:
       raise LookupError(
           "Node devices are not loaded from partition graphs yet.")
 
@@ -1685,13 +1315,12 @@ class DebugDumpDir(object):
       LookupError: If node op types have not been loaded
          from partition graphs yet.
     """
-
-    if self._partition_graphs is None:
+    if not self._debug_graphs:
       raise LookupError(
           "Node op types are not loaded from partition graphs yet.")
 
     device_name = self._infer_device_name(device_name, node_name)
-    return self._node_op_types[device_name][node_name]
+    return self._debug_graphs[device_name].node_op_types[node_name]
 
   def debug_watch_keys(self, node_name, device_name=None):
     """Get all tensor watch keys of given node according to partition graphs.
@@ -1957,7 +1586,7 @@ class DebugDumpDir(object):
     if self._python_graph is None:
       raise LookupError("Python graph is not available for traceback lookup")
 
-    node_name = get_node_name(element_name)
+    node_name = debug_graphs.get_node_name(element_name)
     if node_name not in self._node_traceback:
       raise KeyError("Cannot find node \"%s\" in Python graph" % node_name)
 
diff --git a/tensorflow/python/debug/lib/debug_data_test.py b/tensorflow/python/debug/lib/debug_data_test.py
index 694010a23cde73f90f78c200c93657a9a15b71bd..7ce7ef6a979a1298fafc70c719fdd050ee2c2540 100644
--- a/tensorflow/python/debug/lib/debug_data_test.py
+++ b/tensorflow/python/debug/lib/debug_data_test.py
@@ -49,77 +49,6 @@ class DeviceNamePathConversionTest(test_util.TensorFlowTestCase):
             ",job_ps,replica_1,task_2,cpu_0"))
 
 
-class ParseNodeOrTensorNameTest(test_util.TensorFlowTestCase):
-
-  def testParseNodeName(self):
-    node_name, slot = debug_data.parse_node_or_tensor_name("namespace1/node_1")
-
-    self.assertEqual("namespace1/node_1", node_name)
-    self.assertIsNone(slot)
-
-  def testParseTensorName(self):
-    node_name, slot = debug_data.parse_node_or_tensor_name(
-        "namespace1/node_2:3")
-
-    self.assertEqual("namespace1/node_2", node_name)
-    self.assertEqual(3, slot)
-
-
-class NodeNameChecksTest(test_util.TensorFlowTestCase):
-
-  def testIsCopyNode(self):
-    self.assertTrue(debug_data.is_copy_node("__copy_ns1/ns2/node3_0"))
-
-    self.assertFalse(debug_data.is_copy_node("copy_ns1/ns2/node3_0"))
-    self.assertFalse(debug_data.is_copy_node("_copy_ns1/ns2/node3_0"))
-    self.assertFalse(debug_data.is_copy_node("_copyns1/ns2/node3_0"))
-    self.assertFalse(debug_data.is_copy_node("__dbg_ns1/ns2/node3_0"))
-
-  def testIsDebugNode(self):
-    self.assertTrue(
-        debug_data.is_debug_node("__dbg_ns1/ns2/node3:0_0_DebugIdentity"))
-
-    self.assertFalse(
-        debug_data.is_debug_node("dbg_ns1/ns2/node3:0_0_DebugIdentity"))
-    self.assertFalse(
-        debug_data.is_debug_node("_dbg_ns1/ns2/node3:0_0_DebugIdentity"))
-    self.assertFalse(
-        debug_data.is_debug_node("_dbgns1/ns2/node3:0_0_DebugIdentity"))
-    self.assertFalse(debug_data.is_debug_node("__copy_ns1/ns2/node3_0"))
-
-
-class ParseDebugNodeNameTest(test_util.TensorFlowTestCase):
-
-  def testParseDebugNodeName_valid(self):
-    debug_node_name_1 = "__dbg_ns_a/ns_b/node_c:1_0_DebugIdentity"
-    (watched_node, watched_output_slot, debug_op_index,
-     debug_op) = debug_data.parse_debug_node_name(debug_node_name_1)
-
-    self.assertEqual("ns_a/ns_b/node_c", watched_node)
-    self.assertEqual(1, watched_output_slot)
-    self.assertEqual(0, debug_op_index)
-    self.assertEqual("DebugIdentity", debug_op)
-
-  def testParseDebugNodeName_invalidPrefix(self):
-    invalid_debug_node_name_1 = "__copy_ns_a/ns_b/node_c:1_0_DebugIdentity"
-
-    with self.assertRaisesRegexp(ValueError, "Invalid prefix"):
-      debug_data.parse_debug_node_name(invalid_debug_node_name_1)
-
-  def testParseDebugNodeName_missingDebugOpIndex(self):
-    invalid_debug_node_name_1 = "__dbg_node1:0_DebugIdentity"
-
-    with self.assertRaisesRegexp(ValueError, "Invalid debug node name"):
-      debug_data.parse_debug_node_name(invalid_debug_node_name_1)
-
-  def testParseDebugNodeName_invalidWatchedTensorName(self):
-    invalid_debug_node_name_1 = "__dbg_node1_0_DebugIdentity"
-
-    with self.assertRaisesRegexp(ValueError,
-                                 "Invalid tensor name in debug node name"):
-      debug_data.parse_debug_node_name(invalid_debug_node_name_1)
-
-
 class HasNanOrInfTest(test_util.TensorFlowTestCase):
 
   def setUp(self):
@@ -375,19 +304,5 @@ class DebugDumpDirTest(test_util.TensorFlowTestCase):
       fake.assert_has_calls(expected_calls, any_order=True)
 
 
-class GetNodeNameAndOutputSlotTest(test_util.TensorFlowTestCase):
-
-  def testParseTensorNameInputWorks(self):
-    self.assertEqual("a", debug_data.get_node_name("a:0"))
-    self.assertEqual(0, debug_data.get_output_slot("a:0"))
-
-    self.assertEqual("_b", debug_data.get_node_name("_b:1"))
-    self.assertEqual(1, debug_data.get_output_slot("_b:1"))
-
-  def testParseNodeNameInputWorks(self):
-    self.assertEqual("a", debug_data.get_node_name("a"))
-    self.assertEqual(0, debug_data.get_output_slot("a"))
-
-
 if __name__ == "__main__":
   googletest.main()
diff --git a/tensorflow/python/debug/lib/debug_gradients.py b/tensorflow/python/debug/lib/debug_gradients.py
index 5306391613620090b9e991317256b18629e2acf3..b01a58719cb45b3a42052e0f3522f39a7c5c63c5 100644
--- a/tensorflow/python/debug/lib/debug_gradients.py
+++ b/tensorflow/python/debug/lib/debug_gradients.py
@@ -24,6 +24,7 @@ import uuid
 import six
 
 from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.lib import debug_graphs
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.ops import variables
@@ -34,7 +35,7 @@ _gradient_debuggers = {}
 
 
 def _tensor_to_grad_debug_op_name(tensor, grad_debugger_uuid):
-  op_name, slot = debug_data.parse_node_or_tensor_name(tensor.name)
+  op_name, slot = debug_graphs.parse_node_or_tensor_name(tensor.name)
   return "%s_%d/%s%s" % (op_name, slot, _GRADIENT_DEBUG_TAG, grad_debugger_uuid)
 
 
@@ -407,7 +408,7 @@ def gradient_values_from_dump(grad_debugger, x_tensor, dump):
         (grad_debugger.graph, dump.python_graph))
 
   gradient_tensor = grad_debugger.gradient_tensor(x_tensor)
-  node_name, output_slot = debug_data.parse_node_or_tensor_name(
+  node_name, output_slot = debug_graphs.parse_node_or_tensor_name(
       gradient_tensor.name)
 
   try:
diff --git a/tensorflow/python/debug/lib/debug_graphs.py b/tensorflow/python/debug/lib/debug_graphs.py
new file mode 100644
index 0000000000000000000000000000000000000000..20e2a6acfec656b98ae833ae052d9662abcd48f9
--- /dev/null
+++ b/tensorflow/python/debug/lib/debug_graphs.py
@@ -0,0 +1,430 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Classes and methods for processing debugger-decorated graphs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six.moves import xrange  # pylint: disable=redefined-builtin
+
+from tensorflow.python.framework import op_def_registry
+
+
+def parse_node_or_tensor_name(name):
+  """Get the node name from a string that can be node or tensor name.
+
+  Args:
+    name: An input node name (e.g., "node_a") or tensor name (e.g.,
+      "node_a:0"), as a str.
+
+  Returns:
+    1) The node name, as a str. If the input name is a tensor name, i.e.,
+      consists of a colon, the final colon and the following output slot
+      will be stripped.
+    2) If the input name is a tensor name, the output slot, as an int. If
+      the input name is not a tensor name, None.
+  """
+
+  if ":" in name and not name.endswith(":"):
+    node_name = name[:name.rfind(":")]
+    output_slot = int(name[name.rfind(":") + 1:])
+
+    return node_name, output_slot
+  else:
+    return name, None
+
+
+def get_node_name(element_name):
+  node_name, _ = parse_node_or_tensor_name(element_name)
+  return node_name
+
+
+def get_output_slot(element_name):
+  """Get the output slot number from the name of a graph element.
+
+  If element_name is a node name without output slot at the end, 0 will be
+  assumed.
+
+  Args:
+    element_name: (`str`) name of the graph element in question.
+
+  Returns:
+    (`int`) output slot number.
+  """
+  _, output_slot = parse_node_or_tensor_name(element_name)
+  return output_slot if output_slot is not None else 0
+
+
+def is_copy_node(node_name):
+  """Determine whether a node name is that of a debug Copy node.
+
+  Such nodes are inserted by TensorFlow core upon request in
+  RunOptions.debug_options.debug_tensor_watch_opts.
+
+  Args:
+    node_name: Name of the node.
+
+  Returns:
+    A bool indicating whether the input argument is the name of a debug Copy
+    node.
+  """
+  return node_name.startswith("__copy_")
+
+
+def is_debug_node(node_name):
+  """Determine whether a node name is that of a debug node.
+
+  Such nodes are inserted by TensorFlow core upon request in
+  RunOptions.debug_options.debug_tensor_watch_opts.
+
+  Args:
+    node_name: Name of the node.
+
+  Returns:
+    A bool indicating whether the input argument is the name of a debug node.
+  """
+  return node_name.startswith("__dbg_")
+
+
+def parse_debug_node_name(node_name):
+  """Parse the name of a debug node.
+
+  Args:
+    node_name: Name of the debug node.
+
+  Returns:
+    1. Name of the watched node, as a str.
+    2. Output slot index of the watched tensor, as an int.
+    3. Index of the debug node, as an int.
+    4. Name of the debug op, as a str, e.g, "DebugIdentity".
+
+  Raises:
+    ValueError: If the input node name is not a valid debug node name.
+  """
+  prefix = "__dbg_"
+
+  name = node_name
+  if not name.startswith(prefix):
+    raise ValueError("Invalid prefix in debug node name: '%s'" % node_name)
+
+  name = name[len(prefix):]
+
+  if name.count("_") < 2:
+    raise ValueError("Invalid debug node name: '%s'" % node_name)
+
+  debug_op = name[name.rindex("_") + 1:]
+  name = name[:name.rindex("_")]
+
+  debug_op_index = int(name[name.rindex("_") + 1:])
+  name = name[:name.rindex("_")]
+
+  if name.count(":") != 1:
+    raise ValueError("Invalid tensor name in debug node name: '%s'" % node_name)
+
+  watched_node_name = name[:name.index(":")]
+  watched_output_slot = int(name[name.index(":") + 1:])
+
+  return watched_node_name, watched_output_slot, debug_op_index, debug_op
+
+
+class GraphTracingReachedDestination(Exception):
+  pass
+
+
+class DFSGraphTracer(object):
+  """Graph input tracer using depth-first search."""
+
+  def __init__(self,
+               input_lists,
+               skip_node_names=None,
+               destination_node_name=None):
+    """Constructor of _DFSGraphTracer.
+
+    Args:
+      input_lists: A list of dicts. Each dict is an adjacency (input) map from
+        the recipient node name as the key and the list of input node names
+        as the value.
+      skip_node_names: Optional: a list of node names to skip tracing.
+      destination_node_name: Optional: destination node name. If not `None`, it
+        should be the name of a destination not as a str and the graph tracing
+        will raise GraphTracingReachedDestination as soon as the node has been
+        reached.
+
+    Raises:
+      GraphTracingReachedDestination: if stop_at_node_name is not None and
+        the specified node is reached.
+    """
+
+    self._input_lists = input_lists
+    self._skip_node_names = skip_node_names
+
+    self._inputs = []
+    self._visited_nodes = []
+    self._depth_count = 0
+    self._depth_list = []
+
+    self._destination_node_name = destination_node_name
+
+  def trace(self, graph_element_name):
+    """Trace inputs.
+
+    Args:
+      graph_element_name: Name of the node or an output tensor of the node, as a
+        str.
+
+    Raises:
+      GraphTracingReachedDestination: if destination_node_name of this tracer
+        object is not None and the specified node is reached.
+    """
+    self._depth_count += 1
+
+    node_name = get_node_name(graph_element_name)
+    if node_name == self._destination_node_name:
+      raise GraphTracingReachedDestination()
+
+    if node_name in self._skip_node_names:
+      return
+    if node_name in self._visited_nodes:
+      return
+
+    self._visited_nodes.append(node_name)
+
+    for input_list in self._input_lists:
+      for inp in input_list[node_name]:
+        if get_node_name(inp) in self._visited_nodes:
+          continue
+        self._inputs.append(inp)
+        self._depth_list.append(self._depth_count)
+        self.trace(inp)
+
+    self._depth_count -= 1
+
+  def inputs(self):
+    return self._inputs
+
+  def depth_list(self):
+    return self._depth_list
+
+
+class DebugGraph(object):
+  """Represents a debugger-decorated graph."""
+
+  def __init__(self, debug_graph_def, device_name=None):
+    self._debug_graph_def = debug_graph_def
+
+    self._node_attributes = {}
+    self._node_inputs = {}
+    self._node_reversed_ref_inputs = {}
+    self._node_ctrl_inputs = {}
+    self._node_recipients = {}
+    self._node_ctrl_recipients = {}
+    self._node_devices = {}
+    self._node_op_types = {}
+    self._copy_send_nodes = []
+    self._ref_args = {}
+
+    self._device_name = device_name
+    if not self._device_name and debug_graph_def.node:
+      self._device_name = debug_graph_def.node[0].device
+
+    for node in debug_graph_def.node:
+      self._process_debug_graph_node(node)
+
+    self._prune_non_control_edges_of_debug_ops()
+    self._prune_control_edges_of_debug_ops()
+
+    self._populate_recipient_maps()
+
+  def _process_debug_graph_node(self, node):
+    """Process a node from the debug GraphDef.
+
+    Args:
+      node: (NodeDef) A partition-graph node to be processed.
+
+    Raises:
+      ValueError: If duplicate node names are encountered.
+    """
+
+    if is_debug_node(node.name):
+      # This is a debug node. Parse the node name and retrieve the
+      # information about debug watches on tensors. But do not include
+      # the node in the graph.
+      return
+
+    if node.name in self._node_inputs:
+      raise ValueError("Duplicate node name on device %s: '%s'" %
+                       (self._device_name, node.name))
+
+    self._node_attributes[node.name] = node.attr
+
+    self._node_inputs[node.name] = []
+    self._node_ctrl_inputs[node.name] = []
+    self._node_recipients[node.name] = []
+    self._node_ctrl_recipients[node.name] = []
+
+    if node.name not in self._node_devices:
+      self._node_devices[node.name] = set()
+    self._node_devices[node.name].add(node.device)
+    self._node_op_types[node.name] = node.op
+    self._ref_args[node.name] = self._get_ref_args(node)
+
+    for inp in node.input:
+      if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"):
+        self._copy_send_nodes.append(node.name)
+
+      if inp.startswith("^"):
+        cinp = inp[1:]
+        self._node_ctrl_inputs[node.name].append(cinp)
+      else:
+        self._node_inputs[node.name].append(inp)
+
+  def _get_ref_args(self, node):
+    """Determine whether an input of an op is ref-type.
+
+    Args:
+      node: A `NodeDef`.
+
+    Returns:
+      A list of the arg names (as strs) that are ref-type.
+    """
+    op_def = op_def_registry.get_registered_ops().get(node.op)
+    ref_args = []
+    if op_def:
+      for i, output_arg in enumerate(op_def.output_arg):
+        if output_arg.is_ref:
+          arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i))
+          ref_args.append(arg_name)
+    return ref_args
+
+  def _prune_non_control_edges_of_debug_ops(self):
+    """Prune (non-control) edges related to debug ops.
+
+    Prune the Copy ops and associated _Send ops inserted by the debugger out
+    from the non-control inputs and output recipients map. Replace the inputs
+    and recipients with original ones.
+    """
+    copy_nodes = []
+    for node in self._node_inputs:
+      if node in self._copy_send_nodes:
+        continue
+
+      if is_copy_node(node):
+        copy_nodes.append(node)
+
+      inputs = self._node_inputs[node]
+
+      for i in xrange(len(inputs)):
+        inp = inputs[i]
+        if is_copy_node(inp):
+          # Find the input to the Copy node, which should be the original
+          # input to the node.
+          orig_inp = self._node_inputs[inp][0]
+          inputs[i] = orig_inp
+
+    self._prune_nodes_from_input_and_recipient_maps(copy_nodes)
+    self._prune_nodes_from_input_and_recipient_maps(self._copy_send_nodes)
+
+  def _prune_control_edges_of_debug_ops(self):
+    """Prune control edges related to the debug ops."""
+    for node in self._node_ctrl_inputs:
+      ctrl_inputs = self._node_ctrl_inputs[node]
+      debug_op_inputs = []
+      for ctrl_inp in ctrl_inputs:
+        if is_debug_node(ctrl_inp):
+          debug_op_inputs.append(ctrl_inp)
+      for debug_op_inp in debug_op_inputs:
+        ctrl_inputs.remove(debug_op_inp)
+
+  def _populate_recipient_maps(self):
+    """Populate the map from node name to recipient(s) of its output(s).
+
+    This method also populates the input map based on reversed ref edges.
+    """
+    for node in self._node_inputs:
+      inputs = self._node_inputs[node]
+      for inp in inputs:
+        inp = get_node_name(inp)
+        if inp not in self._node_recipients:
+          self._node_recipients[inp] = []
+        self._node_recipients[inp].append(node)
+
+        if inp in self._ref_args:
+          if inp not in self._node_reversed_ref_inputs:
+            self._node_reversed_ref_inputs[inp] = []
+          self._node_reversed_ref_inputs[inp].append(node)
+
+    for node in self._node_ctrl_inputs:
+      ctrl_inputs = self._node_ctrl_inputs[node]
+      for ctrl_inp in ctrl_inputs:
+        if ctrl_inp in self._copy_send_nodes:
+          continue
+
+        if ctrl_inp not in self._node_ctrl_recipients:
+          self._node_ctrl_recipients[ctrl_inp] = []
+        self._node_ctrl_recipients[ctrl_inp].append(node)
+
+  def _prune_nodes_from_input_and_recipient_maps(self, nodes_to_prune):
+    """Prune nodes out of input and recipient maps.
+
+    Args:
+      nodes_to_prune: (`list` of `str`) Names of the nodes to be pruned.
+    """
+    for node in nodes_to_prune:
+      del self._node_inputs[node]
+      del self._node_ctrl_inputs[node]
+      del self._node_recipients[node]
+      del self._node_ctrl_recipients[node]
+
+  @property
+  def device_name(self):
+    return self._device_name
+
+  @property
+  def debug_graph_def(self):
+    """The debugger-decorated GraphDef."""
+    return self._debug_graph_def
+
+  @property
+  def node_devices(self):
+    return self._node_devices
+
+  @property
+  def node_op_types(self):
+    return self._node_op_types
+
+  @property
+  def node_attributes(self):
+    return self._node_attributes
+
+  @property
+  def node_inputs(self):
+    return self._node_inputs
+
+  @property
+  def node_ctrl_inputs(self):
+    return self._node_ctrl_inputs
+
+  @property
+  def node_reversed_ref_inputs(self):
+    return self._node_reversed_ref_inputs
+
+  @property
+  def node_recipients(self):
+    return self._node_recipients
+
+  @property
+  def node_ctrl_recipients(self):
+    return self._node_ctrl_recipients
diff --git a/tensorflow/python/debug/lib/debug_graphs_test.py b/tensorflow/python/debug/lib/debug_graphs_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..34257794f11f022583d394e65a56e7f9ced22c6e
--- /dev/null
+++ b/tensorflow/python/debug/lib/debug_graphs_test.py
@@ -0,0 +1,112 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tfdbg module debug_data."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.debug.lib import debug_graphs
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class ParseNodeOrTensorNameTest(test_util.TensorFlowTestCase):
+
+  def testParseNodeName(self):
+    node_name, slot = debug_graphs.parse_node_or_tensor_name(
+        "namespace1/node_1")
+
+    self.assertEqual("namespace1/node_1", node_name)
+    self.assertIsNone(slot)
+
+  def testParseTensorName(self):
+    node_name, slot = debug_graphs.parse_node_or_tensor_name(
+        "namespace1/node_2:3")
+
+    self.assertEqual("namespace1/node_2", node_name)
+    self.assertEqual(3, slot)
+
+
+class GetNodeNameAndOutputSlotTest(test_util.TensorFlowTestCase):
+
+  def testParseTensorNameInputWorks(self):
+    self.assertEqual("a", debug_graphs.get_node_name("a:0"))
+    self.assertEqual(0, debug_graphs.get_output_slot("a:0"))
+
+    self.assertEqual("_b", debug_graphs.get_node_name("_b:1"))
+    self.assertEqual(1, debug_graphs.get_output_slot("_b:1"))
+
+  def testParseNodeNameInputWorks(self):
+    self.assertEqual("a", debug_graphs.get_node_name("a"))
+    self.assertEqual(0, debug_graphs.get_output_slot("a"))
+
+
+class NodeNameChecksTest(test_util.TensorFlowTestCase):
+
+  def testIsCopyNode(self):
+    self.assertTrue(debug_graphs.is_copy_node("__copy_ns1/ns2/node3_0"))
+
+    self.assertFalse(debug_graphs.is_copy_node("copy_ns1/ns2/node3_0"))
+    self.assertFalse(debug_graphs.is_copy_node("_copy_ns1/ns2/node3_0"))
+    self.assertFalse(debug_graphs.is_copy_node("_copyns1/ns2/node3_0"))
+    self.assertFalse(debug_graphs.is_copy_node("__dbg_ns1/ns2/node3_0"))
+
+  def testIsDebugNode(self):
+    self.assertTrue(
+        debug_graphs.is_debug_node("__dbg_ns1/ns2/node3:0_0_DebugIdentity"))
+
+    self.assertFalse(
+        debug_graphs.is_debug_node("dbg_ns1/ns2/node3:0_0_DebugIdentity"))
+    self.assertFalse(
+        debug_graphs.is_debug_node("_dbg_ns1/ns2/node3:0_0_DebugIdentity"))
+    self.assertFalse(
+        debug_graphs.is_debug_node("_dbgns1/ns2/node3:0_0_DebugIdentity"))
+    self.assertFalse(debug_graphs.is_debug_node("__copy_ns1/ns2/node3_0"))
+
+
+class ParseDebugNodeNameTest(test_util.TensorFlowTestCase):
+
+  def testParseDebugNodeName_valid(self):
+    debug_node_name_1 = "__dbg_ns_a/ns_b/node_c:1_0_DebugIdentity"
+    (watched_node, watched_output_slot, debug_op_index,
+     debug_op) = debug_graphs.parse_debug_node_name(debug_node_name_1)
+
+    self.assertEqual("ns_a/ns_b/node_c", watched_node)
+    self.assertEqual(1, watched_output_slot)
+    self.assertEqual(0, debug_op_index)
+    self.assertEqual("DebugIdentity", debug_op)
+
+  def testParseDebugNodeName_invalidPrefix(self):
+    invalid_debug_node_name_1 = "__copy_ns_a/ns_b/node_c:1_0_DebugIdentity"
+
+    with self.assertRaisesRegexp(ValueError, "Invalid prefix"):
+      debug_graphs.parse_debug_node_name(invalid_debug_node_name_1)
+
+  def testParseDebugNodeName_missingDebugOpIndex(self):
+    invalid_debug_node_name_1 = "__dbg_node1:0_DebugIdentity"
+
+    with self.assertRaisesRegexp(ValueError, "Invalid debug node name"):
+      debug_graphs.parse_debug_node_name(invalid_debug_node_name_1)
+
+  def testParseDebugNodeName_invalidWatchedTensorName(self):
+    invalid_debug_node_name_1 = "__dbg_node1_0_DebugIdentity"
+
+    with self.assertRaisesRegexp(ValueError,
+                                 "Invalid tensor name in debug node name"):
+      debug_graphs.parse_debug_node_name(invalid_debug_node_name_1)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/debug/lib/grpc_debug_server.py b/tensorflow/python/debug/lib/grpc_debug_server.py
index 309fdb3bceda73a2b2a34cc329e3ae16baa0afec..64e4f00168151c8a37c06595932fe6402ecf3675 100644
--- a/tensorflow/python/debug/lib/grpc_debug_server.py
+++ b/tensorflow/python/debug/lib/grpc_debug_server.py
@@ -29,9 +29,10 @@ from six.moves import queue
 
 from tensorflow.core.debug import debug_service_pb2
 from tensorflow.core.framework import graph_pb2
-from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.lib import debug_graphs
 from tensorflow.python.debug.lib import debug_service_pb2_grpc
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import compat
 
 DebugWatch = collections.namedtuple("DebugWatch",
                                     ["node_name", "output_slot", "debug_op"])
@@ -219,7 +220,8 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
     """
 
     value = event.summary.value[0]
-    debugger_plugin_metadata = json.loads(value.metadata.plugin_data.content)
+    debugger_plugin_metadata = json.loads(
+        compat.as_text(value.metadata.plugin_data.content))
     device_name = debugger_plugin_metadata["device"]
     num_chunks = debugger_plugin_metadata["numChunks"]
     chunk_index = debugger_plugin_metadata["chunkIndex"]
@@ -294,10 +296,10 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
 
   def _process_graph_def(self, graph_def):
     for node_def in graph_def.node:
-      if (debug_data.is_debug_node(node_def.name) and
+      if (debug_graphs.is_debug_node(node_def.name) and
           node_def.attr["gated_grpc"].b):
         node_name, output_slot, _, debug_op = (
-            debug_data.parse_debug_node_name(node_def.name))
+            debug_graphs.parse_debug_node_name(node_def.name))
         self._gated_grpc_debug_watches.add(
             DebugWatch(node_name, output_slot, debug_op))
 
diff --git a/tensorflow/python/debug/lib/grpc_debug_test_server.py b/tensorflow/python/debug/lib/grpc_debug_test_server.py
index 5e3743d9d3015ffbe67cc5cde5e0e534e5564df9..2a87d861d25761a71dd098cb45533368fec14625 100644
--- a/tensorflow/python/debug/lib/grpc_debug_test_server.py
+++ b/tensorflow/python/debug/lib/grpc_debug_test_server.py
@@ -41,6 +41,7 @@ from tensorflow.python.debug.lib import grpc_debug_server
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import errors
 from tensorflow.python.ops import variables
+from tensorflow.python.util import compat
 
 
 def _get_dump_file_path(dump_root, device_name, debug_node_name):
@@ -198,7 +199,7 @@ class EventListenerTestStreamHandler(
     if not summary_metadata.plugin_data:
       raise ValueError("The value lacks plugin data.")
     try:
-      content = json.loads(summary_metadata.plugin_data.content)
+      content = json.loads(compat.as_text(summary_metadata.plugin_data.content))
     except ValueError as err:
       raise ValueError("Could not parse content into JSON: %r, %r" % (content,
                                                                       err))
diff --git a/tensorflow/python/debug/lib/session_debug_file_test.py b/tensorflow/python/debug/lib/session_debug_file_test.py
index 48f31771db8b6309883f3b9eac51ca51611d173f..aa5314dda590a6f7d8289e370e3aa04f3dfda1b8 100644
--- a/tensorflow/python/debug/lib/session_debug_file_test.py
+++ b/tensorflow/python/debug/lib/session_debug_file_test.py
@@ -34,7 +34,7 @@ from tensorflow.python.ops import variables
 from tensorflow.python.platform import googletest
 
 
-class SessionDebugTest(session_debug_testlib.SessionDebugTestBase):
+class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase):
 
   def _no_rewrite_session_config(self):
     rewriter_config = rewriter_config_pb2.RewriterConfig(
diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py
index 08b3e75e7c89249bbcf169de23b83c6b77a9e551..d4b9d06b54e2dc5c8440510f887b384e4e7ed49d 100644
--- a/tensorflow/python/debug/lib/session_debug_testlib.py
+++ b/tensorflow/python/debug/lib/session_debug_testlib.py
@@ -33,6 +33,7 @@ from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.core.util import event_pb2
 from tensorflow.python.client import session
 from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.lib import debug_graphs
 from tensorflow.python.debug.lib import debug_utils
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -242,7 +243,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
       v_copy_node_def = None
       for partition_graph in run_metadata.partition_graphs:
         for node_def in partition_graph.node:
-          if debug_data.is_copy_node(node_def.name):
+          if debug_graphs.is_copy_node(node_def.name):
             if node_def.name == "__copy_u_0":
               u_copy_node_def = node_def
             elif node_def.name == "__copy_v_0":
diff --git a/tensorflow/python/debug/lib/stepper.py b/tensorflow/python/debug/lib/stepper.py
index c814520b7e7cba7e3661cab7bd2d2faede8eed1d..1fa0b3dba2b547bf1d311e42e1005a8e501f9829 100644
--- a/tensorflow/python/debug/lib/stepper.py
+++ b/tensorflow/python/debug/lib/stepper.py
@@ -27,6 +27,7 @@ import six
 
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.lib import debug_graphs
 from tensorflow.python.debug.lib import debug_utils
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import session_ops
@@ -706,8 +707,8 @@ class NodeStepper(object):
       if ":" in element_name:
         debug_utils.add_debug_tensor_watch(
             run_options,
-            debug_data.get_node_name(element_name),
-            output_slot=debug_data.get_output_slot(element_name),
+            debug_graphs.get_node_name(element_name),
+            output_slot=debug_graphs.get_output_slot(element_name),
             debug_urls=["file://" + dump_path])
 
     return dump_path, run_options
@@ -961,5 +962,5 @@ class NodeStepper(object):
       The node associated with element in the graph.
     """
 
-    node_name, _ = debug_data.parse_node_or_tensor_name(element.name)
+    node_name, _ = debug_graphs.parse_node_or_tensor_name(element.name)
     return self._sess.graph.as_graph_element(node_name)
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index ecb0ca9dfd7cef3f4fda01270cdb1d3563345238..4bb810a6c8fb5e079157dd15f0a297d4a998247d 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -405,6 +405,7 @@ py_test(
     srcs_version = "PY2AND3",
     deps = [
         ":context",
+        ":execute",
         ":tensor",
         ":test",
         "//tensorflow/python:array_ops",
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 3a70eaeaa544574b81c58d3b93b3056f3476ace2..46872e617accc8b93b66e4b23765e4fbcc583770 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -20,6 +20,7 @@ from __future__ import print_function
 
 import threading
 
+from autograd import container_types
 from autograd import convenience_wrappers
 from autograd import core as ag_core
 
@@ -145,17 +146,17 @@ def _record_gradient(op_name, inputs, attrs, results, name):
   # It is imperative we make a copy of results here as otherwise we create a
   # dependency cycle in the captured function and this can delay garbage
   # collecting of the tensors arbitrarily.
-  result_copies = results[:]
+  results_size = len(results) if isinstance(results, (list, tuple)) else 1
 
-  def grad_fn(*outputs):
+  def grad_fn(*orig_outputs):
     """Generated gradient function."""
-    tensors = inputs + result_copies + list(outputs)
-    tensors = [ag_core.getval(x) for x in tensors]
+    tensors = inputs + list(orig_outputs)
+    tensors = container_types.make_sequence(tape.EagerList, *tensors)
     result = _magic_gradient_function(op_name, attrs, len(inputs),
                                       num_outputs, *(tensors))
     if _tracing:
       print("Gradient for", (name if name else op_name), "inputs", inputs,
-            "output_grads", outputs)
+            "output_grads", orig_outputs[results_size:], "gradients", result)
     return result
 
   results = tape.record_operation(results, inputs, [], grad_fn)
@@ -168,14 +169,12 @@ def _record_gradient(op_name, inputs, attrs, results, name):
 execute.record_gradient = _record_gradient
 
 
-def _ones(shape, dtype):
-  return array_ops.fill(shape, tensor.Tensor(1, dtype=dtype))
-
-
 def _aggregate_grads(gradients):
   """Aggregate gradients of the same tensor."""
   grad_lists = dict()
   for t, g in gradients:
+    if g is None:
+      continue
     if id(t) not in grad_lists:
       grad_lists[id(t)] = [(t, g)]
     else:
@@ -187,7 +186,7 @@ def _aggregate_grads(gradients):
       ret.append(g_list[0])
     else:
       # TODO(xpan): Aggregate IndexedSlices.
-      ret.append((g_list[0][0], math_ops.add_n(zip(*g_list)[1])))
+      ret.append((g_list[0][0], math_ops.add_n(list(zip(*g_list))[1])))
   return ret
 
 
@@ -222,7 +221,7 @@ def implicit_val_and_grad(f):
                        (end_node.progenitors, repr(start_node)))
     output_gradients = kwds.get("output_gradients", None)
     if output_gradients is None:
-      output_gradients = _ones(end_node.shape, end_node.dtype)
+      output_gradients = array_ops.ones_like(end_node.value)
     grad = ag_core.backward_pass(output_gradients, end_node, start_node)
     return end_node.value, _aggregate_grads(grad.gradients)
 
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 37007378a04f0784afce97c138d2a17a718daefd..b43790550969943b124f3565c0e540e0dc64cfbc 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -23,7 +23,9 @@ from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
 from tensorflow.python.eager import tape
 from tensorflow.python.eager import tensor
+from tensorflow.python.eager import tensor_node
 from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
@@ -74,12 +76,16 @@ class BackpropTest(test.TestCase):
 
       self.assertAllClose(grad.numpy(), tf_dense_grad.eval())
 
+  def testTensoVspaceNoneMutAdd(self):
+    t = tensor.Tensor(1.0)
+    self.assertEqual(tensor_node.TensorVSpace(t).mut_add(t, None).numpy(), 1.0)
+
   def testImplicitGradWithResourceVariable(self):
     x = resource_variable_ops.ResourceVariable(
         initial_value=tensor.Tensor(1.0), name='x')
 
     def fn():
-      tape.watch(x.handle)
+      tape.watch_variable(x)
       b = tensor.Tensor(2.0)
       c = math_ops.add(x.value(), b)
       return math_ops.add(c, tensor.Tensor(3.0))
@@ -155,6 +161,29 @@ class BackpropTest(test.TestCase):
     grad = backprop.gradients_function(second, [0])(f)[0]
     self.assertAllEqual([[0.0]], grad.numpy())
 
+  def testGradGrad(self):
+
+    def sq(x):
+      return x * x
+
+    def grad(x):
+      value = backprop.gradients_function(sq, [0])(x)[0]
+      return value
+
+    gradgrad = backprop.gradients_function(grad, [0])
+
+    self.assertAllEqual(gradgrad(tensor.Tensor(3.0))[0].numpy(), 2.0)
+
+  def testGradGradExp(self):
+
+    def grad(x):
+      value = backprop.gradients_function(math_ops.exp, [0])(x)[0]
+      return value
+
+    gradgrad = backprop.gradients_function(grad, [0])
+
+    self.assertAllEqual(gradgrad(tensor.Tensor(0.0))[0].numpy(), 1.0)
+
   def testGPU(self):
     if not context.context().num_gpus():
       self.skipTest('No GPUs found')
@@ -253,6 +282,15 @@ class BackpropTest(test.TestCase):
     self.assertEqual([dtypes.float32],
                      backprop.make_attr([pywrap_tensorflow.TF_ATTR_TYPE], [1]))
 
+  def testMulType(self):
+
+    def mul(x):
+      return math_ops._mul_dispatch(x, x)  # pylint: disable=protected-access
+
+    self.assertAllEqual(
+        backprop.gradients_function(mul)(constant_op.constant(3.0))[0].numpy(),
+        6.0)
+
   def testMakeAttrShape(self):
     for s in ([], None, [1, 2, 3], [None, None], [1, None, 3]):
       expected = tensor_shape.TensorShape(s).as_proto()
@@ -269,6 +307,20 @@ class BackpropTest(test.TestCase):
         [tensor_shape.TensorShape(s).as_proto() for s in shape_list],
         backprop.make_attr([pywrap_tensorflow.TF_ATTR_SHAPE], shape_list))
 
+  def testMultiValueConvertToTensor(self):
+    x = resource_variable_ops.ResourceVariable(
+        initial_value=array_ops.constant([1.0]), name='x')
+
+    def fn():
+      tape.watch_variable(x)
+      a = math_ops.add(x.value(), 1.0)
+      # Make sure convert_to_tensor works correctly with list of TensorNodes.
+      b = array_ops.stack([a, a], axis=0)
+      return math_ops.reduce_mean(b)
+
+    grad = backprop.implicit_grad(fn)()[0][1]
+    self.assertAllEqual([1.0], grad.numpy())
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 27ffdd981050ff73e33b0143562d30b0d53780c3..a5a93b7bbe09f68397ac17c0008935c86126360e 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -171,16 +171,6 @@ class Context(object):
     """Sets summary writer resource."""
     self._summary_writer_resource = resource
 
-  @property
-  def recording_summaries(self):
-    """Returns True if recording summaries is enabled in current thread.."""
-    return self._eager_context.recording_summaries
-
-  @recording_summaries.setter
-  def recording_summaries(self, val):
-    """Enables recording summaries is enabled in current thread.."""
-    self._eager_context.recording_summaries = val
-
   @property
   def device_name(self):
     """Returns the device name for the current thread."""
@@ -360,24 +350,6 @@ def device(name):
   return context().device(name)
 
 
-@contextlib.contextmanager
-def record_summaries():
-  """Context-manager to enable recording of summaries."""
-  ctx = context()
-  old = ctx.recording_summaries
-  ctx.recording_summaries = True
-  try:
-    yield
-  finally:
-    ctx.recording_summaries = old
-
-
-def should_record_summary():
-  """True if a summary should be recorded now."""
-  c = context()
-  return c.recording_summaries and c.summary_writer_resource is not None
-
-
 def run(main=None, argv=None):
   """Runs the program with an optional 'main' function and 'argv' list.
 
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index 7ae80aa156aef74a29bcc3c55cfb6c60e1a41c43..5de396f62c37585acca3dbc6117dc00e47fbc3e5 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -55,10 +55,6 @@ class TFETest(test_util.TensorFlowTestCase):
     ctx.summary_writer_resource = 'mock'
     self.assertEqual('mock', ctx.summary_writer_resource)
 
-    self.assertFalse(ctx.recording_summaries)
-    ctx.recording_summaries = True
-    self.assertTrue(ctx.recording_summaries)
-
     self.assertEqual('', ctx.device_name)
     self.assertEqual(ctx.device_name, ctx.device_spec.to_string())
     with ctx.device('GPU:0'):
@@ -95,8 +91,7 @@ class TFETest(test_util.TensorFlowTestCase):
       return [
           ctx.in_graph_mode(),
           ctx.in_eager_mode(), ctx.scope_name, ctx.summary_writer_resource,
-          ctx.recording_summaries, ctx.device_name,
-          ctx.num_gpus()
+          ctx.device_name, ctx.num_gpus()
       ]
 
     def get_values(ctx, values):
diff --git a/tensorflow/python/eager/custom_gradient.py b/tensorflow/python/eager/custom_gradient.py
index 39fd845efc3e27c7e80371296ddc2999a1282ddd..0c921bb023628be1cd9980589299cf841497397f 100644
--- a/tensorflow/python/eager/custom_gradient.py
+++ b/tensorflow/python/eager/custom_gradient.py
@@ -56,6 +56,14 @@ def custom_gradient(f):
                      if isinstance(x, (_tensor.Tensor, tf_ops.Tensor))
                      or ag_core.isnode(x)]
     result, grad_fn = f(*args, **kwargs)
+    result_size = len(result) if isinstance(result, (list, tuple)) else 1
+
+    # TODO(apassos): naive uses of custom_gradient will not get the correct
+    # second derivative this way if they capture any output tensors. Change the
+    # signature of custom_gradient.
+    def actual_grad_fn(*outputs):
+      outputs = outputs[result_size:]
+      return grad_fn(*outputs)
 
     flat_result = nest.flatten(result)
     flat_result = [ag_core.getval(x) for x in flat_result]
@@ -63,7 +71,7 @@ def custom_gradient(f):
         flat_result,
         input_tensors,
         [],
-        grad_fn)
+        actual_grad_fn)
     flat_result = list(flat_result)
     return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
 
diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py
index 2223e9833af7ab07461b0dd7bedfa21d1040b0dd..2b5a76ca121275e8cfca07eccc9aafbe2b9b093a 100644
--- a/tensorflow/python/eager/execute.py
+++ b/tensorflow/python/eager/execute.py
@@ -188,17 +188,12 @@ def args_to_matching_eager(l, default_dtype=None):
       dtype = t.dtype
       break
 
-  if dtype is None:
-    # TODO(josh11b): At the moment, I don't think this can fail, but at some
-    # point we likely should have some logic to prevent bad conversions.
-    dtype = default_dtype
-
   if dtype is None:
     # Infer a dtype based on the first value, and use that dtype for the
     # remaining values.
     ret = []
     for t in l:
-      ret.append(ops.convert_to_tensor(t, dtype))
+      ret.append(ops.convert_to_tensor(t, dtype, preferred_dtype=default_dtype))
       if dtype is None:
         dtype = ret[-1].dtype
   else:
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 60fd4957dae34684987de15a220cabd4f531bc4d..227520eea8a3413a861dc8a45de36cb1e171b038 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -276,9 +276,15 @@ class _GraphModeFunction(object):
           break
       else:  # Note: for-else here done on purpose
         watched_extra_inputs.append(t)
-    real_outputs = tape.record_operation(real_outputs,
-                                         (args + watched_extra_inputs),
-                                         side_outputs, self._backward_function)
+
+    def backward_function_wrapper(*outputs):
+      outputs = outputs[len(real_outputs):]
+      return self._backward_function(*outputs)
+    real_outputs = tape.record_operation(
+        real_outputs,
+        (args + watched_extra_inputs),
+        side_outputs,
+        backward_function_wrapper)
 
     return self._build_call_outputs(self._returns, real_outputs)
 
@@ -367,6 +373,13 @@ def _defun_internal(name, func, args, kwds):
   """Defines and returns graph-mode version of func."""
   with context.graph_mode():
     tmp_graph = ops.Graph()
+    # Copy the graph collections to ensure summaries and other things work. This
+    # lets the function access (but not mutate) collections of the containing
+    # graph, such as the global step and the summary writer collections.
+    curr_graph = ops.get_default_graph()
+    for collection in curr_graph.collections:
+      tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
+          collection)
     with tmp_graph.as_default():
       func_inputs = _get_defun_inputs(args)
 
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 18b722e7923c02fee148258d92e6bd8b3879a150..c15dde9e487b025180a3dbb5cd018981138c194f 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.framework import function as tf_function
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import clip_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
 
 
 class FunctionTest(test.TestCase):
@@ -52,6 +53,19 @@ class FunctionTest(test.TestCase):
     out = sq(t)
     self.assertAllEqual(out.numpy(), math_ops.matmul(t, t).numpy())
 
+  def testGraphModeWithGradients(self):
+    v = resource_variable_ops.ResourceVariable(1.0)
+
+    @function.defun
+    def step():
+      def inner():
+        tape.watch(v.handle)
+        return v * v
+
+      return backprop.implicit_grad(inner)()[0][1]
+
+    self.assertAllEqual(step().numpy(), 2.0)
+
   def testTensorConversionWithDefun(self):
 
     @function.defun
diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py
index 78ff2f677718483718cdbab4eb0e894a912e3543..ffd5b39b8597fd038477d5cb581f63c7557d7ca9 100644
--- a/tensorflow/python/eager/ops_test.py
+++ b/tensorflow/python/eager/ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
 import numpy as np
 
 from tensorflow.python.eager import context
+from tensorflow.python.eager import execute
 from tensorflow.python.eager import tensor
 from tensorflow.python.eager import test
 from tensorflow.python.framework import dtypes
@@ -272,7 +273,7 @@ class TargetTest(test_util.TensorFlowTestCase):
 
   def testInvalidInputDataType(self):
     # Fill requires the first input to be an int32 tensor.
-    with self.assertRaisesRegexp(ValueError, 'int64'):
+    with self.assertRaisesRegexp(errors.InvalidArgumentError, 'int64'):
       array_ops.fill(tensor.Tensor([2], dtype=dtypes.int64), tensor.Tensor(1))
 
   def testOutputOnHostMemory(self):
@@ -295,6 +296,19 @@ class TargetTest(test_util.TensorFlowTestCase):
     self.assertLess(x.numpy(), 6)
     self.assertGreaterEqual(x.numpy(), 5)
 
+  def testArgsToMatchingEagerDefault(self):
+    # Uses default
+    t, r = execute.args_to_matching_eager([[3, 4]], dtypes.int32)
+    self.assertEquals(t, dtypes.int32)
+    self.assertEquals(r[0].dtype, dtypes.int32)
+    t, r = execute.args_to_matching_eager([[3, 4]], dtypes.int64)
+    self.assertEquals(t, dtypes.int64)
+    self.assertEquals(r[0].dtype, dtypes.int64)
+    # Doesn't use default
+    t, r = execute.args_to_matching_eager([['string', 'arg']], dtypes.int32)
+    self.assertEquals(t, dtypes.string)
+    self.assertEquals(r[0].dtype, dtypes.string)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc
index c46a3d8db37ba4cdeb14c82a77f7d11022901d29..a526856794d69bbb6541747e2a8528623c365d2a 100644
--- a/tensorflow/python/eager/python_eager_op_gen.cc
+++ b/tensorflow/python/eager/python_eager_op_gen.cc
@@ -659,14 +659,25 @@ void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) {
 
 string GetEagerPythonOps(const OpList& ops,
                          const std::vector& hidden_ops,
-                         bool require_shapes) {
+                         bool require_shapes,
+                         const string& source_file_name = "") {
   string result;
   // Header
   // TODO(josh11b): Mention the library for which wrappers are being generated.
-  strings::StrAppend(&result, R"("""Python wrappers for TensorFlow ops.
+  strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
 
 This file is MACHINE GENERATED! Do not edit.
-"""
+)");
+
+  // Mention the original source file so someone tracing back through generated
+  // Python code will know where to look next.
+  if (!source_file_name.empty()) {
+    strings::StrAppend(&result, "Original C++ source file: ");
+    strings::StrAppend(&result, source_file_name);
+    strings::StrAppend(&result, "\n");
+  }
+
+  strings::StrAppend(&result, R"("""
 
 import collections as _collections
 
@@ -747,8 +758,10 @@ from tensorflow.python.framework import op_def_library as _op_def_library
 
 void PrintEagerPythonOps(const OpList& ops,
                          const std::vector& hidden_ops,
-                         bool require_shapes) {
-  printf("%s", GetEagerPythonOps(ops, hidden_ops, require_shapes).c_str());
+                         bool require_shapes, const string& source_file_name) {
+  printf("%s",
+         GetEagerPythonOps(ops, hidden_ops, require_shapes, source_file_name)
+             .c_str());
 }
 
 string GetEagerPythonWrappers(const char* op_list_buf, size_t op_list_len) {
diff --git a/tensorflow/python/eager/python_eager_op_gen.h b/tensorflow/python/eager/python_eager_op_gen.h
index 9a7ed28cf942695fa12c68d6c50ffa482520931b..250623850f2c04d5deb0924cc4043226e089d425 100644
--- a/tensorflow/python/eager/python_eager_op_gen.h
+++ b/tensorflow/python/eager/python_eager_op_gen.h
@@ -24,9 +24,12 @@ namespace tensorflow {
 
 // hidden_ops should be a list of Op names that should get a leading _
 // in the output. Prints the output to stdout.
+// Optional fourth argument is the name of the original C++ source file
+// where the ops' REGISTER_OP() calls reside.
 void PrintEagerPythonOps(const OpList& ops,
                          const std::vector& hidden_ops,
-                         bool require_shapes);
+                         bool require_shapes,
+                         const string& source_file_name = "");
 
 // Get the python wrappers for a list of ops in a OpList.
 // `op_list_buf` should be a pointer to a buffer containing
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index f2915eba59d9270c5b7b67e51112e946a599ef29..9cd29f630df0eba3bb0fc761ff159120763966ed 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -101,6 +101,9 @@ class NoneVSpace(ag_core.VSpace):
   def __init__(self, _):
     self.size = 0
 
+  def zeros(self):
+    return 0
+
 
 ag_core.register_vspace(NoneVSpace, type(None))
 
@@ -148,6 +151,15 @@ def watch(tensor):
   return tensor
 
 
+def watch_variable(resource_variable):
+  """Marks this ResourceVariable to be watched by all tapes in the stack.
+
+  Args:
+    resource_variable: A ResourceVariable to be watched.
+  """
+  watch(resource_variable.handle)  # py-lint: disable=protected-access
+
+
 def pop_tape():
   """Pops the top tape in the stack, if any."""
   if _tape_stack.stack:
@@ -197,39 +209,43 @@ class _EagerSequenceVSpace(container_types.SequenceVSpace):
     return True
 
 
-class _EagerList(list):
-  """Type used to bypass SequenceVSpace."""
+class EagerList(list):
+  """Type used to bypass SequenceVSpace.
+
+  SequenceVSpace has a very strict equality check which does not match
+  tensorflow semantics.
+  """
 
   def __init__(self, value):
-    super(_EagerList, self).__init__(value)
+    super(EagerList, self).__init__(value)
     for v in value:
       assert not ag_core.isnode(v)
 
-ag_core.register_vspace(_EagerSequenceVSpace, _EagerList)
-ag_core.register_node(_EagerSequenceNode, _EagerList)
+ag_core.register_vspace(_EagerSequenceVSpace, EagerList)
+ag_core.register_node(_EagerSequenceNode, EagerList)
 
 
 @ag_core.primitive
 def _record_operation(output_tensors, input_tensors, side_outputs,
                       backward_function):
   del input_tensors, side_outputs, backward_function
-  return _EagerList(output_tensors)
+  return EagerList(output_tensors)
 
 
 def record_operation(o, i, s, b):
   """Primitive to trigger autograd tracing on outputs from inputs."""
-  inputs = container_types.make_sequence(_EagerList, *i)
+  inputs = container_types.make_sequence(EagerList, *i)
   return _record_operation(o, inputs, s, b)
 
 
 def _record_operation_vjp(g, ans, vs, gvs, output_tensors, input_tensors,
                           side_outputs, backward_function):
   """Gradient for _record_operation."""
-  del ans, vs, gvs, output_tensors, input_tensors
+  del vs, gvs, input_tensors, output_tensors
   backward_args = tuple(g) + tuple(side_outputs)
-  if ag_core.isnode(backward_args):
-    backward_args = list(backward_args)
+  backward_args = container_types.make_sequence(
+      EagerList, *(tuple(ans) + backward_args))
   tensors = nest.flatten(backward_function(*backward_args))
-  return _EagerList([ag_core.getval(t) for t in tensors])
+  return container_types.make_sequence(EagerList, *tensors)
 
 _record_operation.defvjp(_record_operation_vjp, argnum=1)
diff --git a/tensorflow/python/eager/tensor.py b/tensorflow/python/eager/tensor.py
index 69269d1975f54a488006f6ed2824c21b0299295b..17e594d5c348dcd43d446cacad58be0a1221e490 100644
--- a/tensorflow/python/eager/tensor.py
+++ b/tensorflow/python/eager/tensor.py
@@ -51,3 +51,6 @@ class LazyZero(object):
 
   def numpy(self):
     return np.zeros(self.shape, self.dtype)
+
+  def _shape_tuple(self):
+    return self.shape
diff --git a/tensorflow/python/eager/tensor_node.py b/tensorflow/python/eager/tensor_node.py
index 8200761d03149c7ce7ffb67b197c4e88219e1e05..331bf7eef8eaf311b04b1af9047d2d67c4e94ad6 100644
--- a/tensorflow/python/eager/tensor_node.py
+++ b/tensorflow/python/eager/tensor_node.py
@@ -250,10 +250,18 @@ class TensorVSpace(ag_core.VSpace):
     if isinstance(value, ops.IndexedSlices):
       self.shape = tensor_shape.TensorShape(value.dense_shape.numpy())
       self.dtype = value.values.dtype
+      self.size = self.shape.num_elements()
     else:
-      self.shape = value.shape
+      self.shape = value._shape_tuple()  # pylint: disable=protected-access
+      if self.shape is None or None in self.shape:
+        # TODO(apassos) we currently don't check the size so this is fine, but
+        # presumably there should be a better way of doing this.
+        self.size = 1
+      else:
+        self.size = 1
+        for s in self.shape:
+          self.size *= s
       self.dtype = value.dtype
-    self.size = self.shape.num_elements()
     # TODO(apassos) put gradients on the same device as ops.
 
   def __eq__(self, other):
@@ -292,6 +300,10 @@ class TensorVSpace(ag_core.VSpace):
       x = _indexed_slices_to_tensor(x)
     if isinstance(y, ops.IndexedSlices):
       y = _indexed_slices_to_tensor(y)
+    if x is None:
+      return y
+    if y is None:
+      return x
     return math_ops.add(x, y)
 
 
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py
index 8d0f639ddcb604f17a3c662a8bf29320ca4351b7..bd8e653b976d7006c9642196f98913d3ed5a6483 100644
--- a/tensorflow/python/eager/tensor_test.py
+++ b/tensorflow/python/eager/tensor_test.py
@@ -77,8 +77,8 @@ class TFETensorTest(test_util.TensorFlowTestCase):
   def testMultiLineTensorStr(self):
     t = tensor.Tensor(np.eye(3))
     tensor_str = str(t)
-    self.assertIn("shape=%s, dtype=%s, " % (t.shape, t.dtype.name), tensor_str)
-    self.assertIn("numpy=\n%s" % t.numpy(), tensor_str)
+    self.assertIn("shape=%s, dtype=%s" % (t.shape, t.dtype.name), tensor_str)
+    self.assertIn(str(t.numpy()), tensor_str)
 
   def testMultiLineTensorRepr(self):
     t = tensor.Tensor(np.eye(3))
@@ -95,7 +95,7 @@ class TFETensorTest(test_util.TensorFlowTestCase):
     np.set_printoptions(threshold=2, edgeitems=1)
 
     t = tensor.Tensor(np.arange(10, dtype=np.int32))
-    self.assertIn("numpy=[0 ..., 9]", str(t))
+    self.assertIn("[0 ..., 9]", str(t))
     self.assertIn("[0, ..., 9]", repr(t))
 
     # Clean up: reset to previous printoptions.
@@ -103,7 +103,7 @@ class TFETensorTest(test_util.TensorFlowTestCase):
 
   def testZeroDimTensorStr(self):
     t = tensor.Tensor(42)
-    self.assertIn("shape=(), dtype=int32, numpy=42", str(t))
+    self.assertIn("42, shape=(), dtype=int32", str(t))
 
   def testZeroDimTensorRepr(self):
     t = tensor.Tensor(42)
@@ -113,7 +113,7 @@ class TFETensorTest(test_util.TensorFlowTestCase):
 
   def testZeroSizeTensorStr(self):
     t = tensor.Tensor(np.zeros(0, dtype=np.float32))
-    self.assertIn("shape=(0,), dtype=float32, numpy=[]", str(t))
+    self.assertIn("[], shape=(0,), dtype=float32", str(t))
 
   def testZeroSizeTensorRepr(self):
     t = tensor.Tensor(np.zeros(0, dtype=np.float32))
@@ -127,8 +127,8 @@ class TFETensorTest(test_util.TensorFlowTestCase):
     t = tensor.Tensor(42)
     # Force change dtype to a numpy-unprintable type.
     t._dtype = dtypes.resource
-    self.assertIn("numpy=", str(t))
-    self.assertIn("numpy=", repr(t))
+    self.assertIn("", str(t))
+    self.assertIn("", repr(t))
 
   def testStringTensor(self):
     t_np_orig = np.array([[b"a", b"ab"], [b"abc", b"abcd"]])
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 83eeeb35b674bbe61bc8bae4e2f499691f127445..167f9b105430480381279a9f8fba4850741a88c7 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -148,6 +148,7 @@ py_test(
     name = "dnn_test",
     size = "medium",
     srcs = ["canned/dnn_test.py"],
+    shard_count = 4,
     srcs_version = "PY2AND3",
     tags = ["no_pip"],
     deps = [
@@ -201,7 +202,7 @@ py_test(
     name = "dnn_linear_combined_test",
     size = "medium",
     srcs = ["canned/dnn_linear_combined_test.py"],
-    shard_count = 4,
+    shard_count = 8,
     srcs_version = "PY2AND3",
     tags = ["no_pip"],
     deps = [
@@ -552,11 +553,9 @@ py_test(
     name = "linear_test",
     size = "medium",
     srcs = ["canned/linear_test.py"],
+    shard_count = 4,
     srcs_version = "PY2AND3",
-    tags = [
-        "no_pip",
-        "noasan",  # times out b/63680444
-    ],
+    tags = ["no_pip"],
     deps = [
         ":linear",
         ":linear_testing_utils",
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index d2c5772483b718bcc6a0be00f10c2e0593e262d4..80d109d927ab406ce7bff00d6515d31d918cf2e9 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -200,8 +200,11 @@ def _check_labels(labels, expected_labels_dimension):
         dim1 = static_shape[1]
         if (dim1 is not None) and (dim1 != expected_labels_dimension):
           raise ValueError(
-              'labels shape must be [batch_size, labels_dimension], got %s.' %
-              (static_shape,))
+              'Mismatched label shape. '
+              'Classifier configured with n_classes=%s.  Received %s. '
+              'Suggested Fix: check your n_classes argument to the estimator '
+              'and/or the shape of your label.' %
+              (expected_labels_dimension, dim1))
       assert_dimension = check_ops.assert_equal(
           expected_labels_dimension, labels_shape[1], message=err_msg)
       with ops.control_dependencies([assert_dimension]):
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index 23678013c66a12b063447533f81976b570cd83c9..fa3d5b44eb6e37bdf3367e40302529ccf12d0436 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -139,7 +139,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
     features = {'x': np.array(((42.,),))}
 
     # Static shape.
-    with self.assertRaisesRegexp(ValueError, 'labels shape'):
+    with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):
       head.create_loss(
           features=features,
           mode=model_fn.ModeKeys.EVAL,
@@ -889,7 +889,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
     logits_2x1 = np.array(((45.,), (41.,),))
 
     # Static shape.
-    with self.assertRaisesRegexp(ValueError, 'labels shape'):
+    with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):
       head.create_loss(
           features={'x': np.array(((42.,),))},
           mode=model_fn.ModeKeys.EVAL,
@@ -1692,7 +1692,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
     values_1d = np.array(((43.,), (44.,),))
 
     # Static shape.
-    with self.assertRaisesRegexp(ValueError, 'labels shape'):
+    with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):
       head.create_loss(
           features={'x': values_1d},
           mode=model_fn.ModeKeys.EVAL,
@@ -1737,7 +1737,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
     values_1d = np.array(((43.,), (44.,),))
 
     # Static shape.
-    with self.assertRaisesRegexp(ValueError, 'labels shape'):
+    with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):
       head.create_loss(
           features={'x': values_1d},
           mode=model_fn.ModeKeys.TRAIN,
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 1a4b0c5fc00e7000b1f9092d0ffd4ff5ec601d4c..cfa4be5c7d4b66943478456ed640d1a03f525e4c 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -131,7 +131,10 @@ class EstimatorSpec(
       train_op: Op for the training step.
       eval_metric_ops: Dict of metric results keyed by name. The values of the
         dict are the results of calling a metric function, namely a
-        `(metric_tensor, update_op)` tuple.
+        `(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated
+        without any impact on state (typically is a pure computation results
+        based on variables.). For example, it should not trigger the `update_op`
+        or requires any input fetching.
       export_outputs: Describes the output signatures to be exported to
         `SavedModel` and used during serving.
         A dict `{name: output}` where:
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index 2ba51ec9eb0e0ed762b0f44546309bd13cd54efa..e242a60aabf0baeb9eac19797214b0edae6ce869 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -19,10 +19,14 @@ from __future__ import division
 from __future__ import print_function
 
 import copy
+import json
+import os
 
 import six
 
 from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import server_lib
 
 
 _USE_DEFAULT = object()
@@ -44,6 +48,56 @@ _SAVE_CKPT_ERR = (
     '`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set.'
 )
 
+_TF_CONFIG_ENV = 'TF_CONFIG'
+_TASK_ENV_KEY = 'task'
+_TASK_TYPE_KEY = 'type'
+_TASK_ID_KEY = 'index'
+_CLUSTER_KEY = 'cluster'
+_LOCAL_MASTER = ''
+_GRPC_SCHEME = 'grpc://'
+
+
+def _get_master(cluster_spec, task_type, task_id):
+  """Returns the appropriate string for the TensorFlow master."""
+  if not cluster_spec:
+    return _LOCAL_MASTER
+
+  jobs = cluster_spec.jobs
+  # Lookup the master in cluster_spec using task_type and task_id,
+  # if possible.
+  if task_type not in jobs:
+    raise ValueError(
+        '%s is not a valid task_type in the cluster_spec:\n'
+        '%s\n\n'
+        'Note that these values may be coming from the TF_CONFIG environment '
+        'variable.' % (task_type, cluster_spec))
+  addresses = cluster_spec.job_tasks(task_type)
+  if not 0 <= task_id < len(addresses):
+    raise ValueError(
+        '%d is not a valid task_id for task_type %s in the cluster_spec:\n'
+        '%s\n\n'
+        'Note that these values may be coming from the TF_CONFIG environment '
+        'variable.' % (task_id, task_type, cluster_spec))
+  return _GRPC_SCHEME + addresses[task_id]
+
+
+def _count_ps(cluster_spec):
+  """Counts the number of parameter servers in cluster_spec."""
+  if not cluster_spec:
+    return 0
+
+  return len(cluster_spec.as_dict().get(TaskType.PS, []))
+
+
+def _count_worker(cluster_spec):
+  """Counts the number of workers (including chief) in cluster_spec."""
+  if not cluster_spec:
+    raise RuntimeError(
+        'Internal error: `_count_worker` does not expect empty cluster_spec.')
+
+  return (len(cluster_spec.as_dict().get(TaskType.WORKER, [])) +
+          len(cluster_spec.as_dict().get(TaskType.CHIEF, [])))
+
 
 def _validate_save_ckpt_with_replaced_keys(new_copy, replaced_keys):
   """Validates the save ckpt properties."""
@@ -103,6 +157,8 @@ class TaskType(object):
   MASTER = 'master'
   PS = 'ps'
   WORKER = 'worker'
+  CHIEF = 'chief'
+  EVALUATOR = 'evaluator'
 
 
 class RunConfig(object):
@@ -120,6 +176,95 @@ class RunConfig(object):
                log_step_count_steps=100):
     """Constructs a RunConfig.
 
+    All distributed training related properties `cluster_spec`, `is_chief`,
+    `master` , `num_worker_replicas`, `num_ps_replicas`, `task_id`, and
+    `task_type` are set based on the `TF_CONFIG` environment variable, if the
+    pertinent information is present. The `TF_CONFIG` environment variable is a
+    JSON object with attributes: `cluster` and `task`.
+
+    `cluster` is a JSON serialized version of `ClusterSpec`'s Python dict from
+    `server_lib.py`, mapping task types (usually one of the `TaskType` enums) to
+    a list of task addresses.
+
+    `task` has two attributes: `type` and `index`, where `type` can be any of
+    the task types in `cluster`. ` When `TF_CONFIG` contains said information,
+    the following properties are set on this class:
+
+    * `cluster_spec` is parsed from `TF_CONFIG['cluster']`. Defaults to {}. If
+      present, must have one and only one node in the `chief` attribute of
+      `cluster_spec`.
+    * `task_type` is set to `TF_CONFIG['task']['type']`. Must set if
+      `cluster_spec` is present; must be `worker` (the default value) if
+      `cluster_spec` is not set.
+    * `task_id` is set to `TF_CONFIG['task']['index']`. Must set if
+      `cluster_spec` is present; must be 0 (the default value) if
+      `cluster_spec` is not set.
+    * `master` is determined by looking up `task_type` and `task_id` in the
+      `cluster_spec`. Defaults to ''.
+    * `num_ps_replicas` is set by counting the number of nodes listed
+      in the `ps` attribute of `cluster_spec`. Defaults to 0.
+    * `num_worker_replicas` is set by counting the number of nodes listed
+      in the `worker` and `chief` attributes of `cluster_spec`. Defaults to 1.
+    * `is_chief` is determined based on `task_type` and `cluster`.
+
+    There is a special node with `task_type` as `evaluator`, which is not part
+    of the (training) `cluster_spec`. It handles the distributed evaluation job.
+
+    Example of non-chief node:
+    ```
+      cluster = {'chief': ['host0:2222'],
+                 'ps': ['host1:2222', 'host2:2222'],
+                 'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
+      os.environ['TF_CONFIG'] = json.dumps(
+          {'cluster': cluster,
+           'task': {'type': 'worker', 'index': 1}})
+      config = ClusterConfig()
+      assert config.master == 'host4:2222'
+      assert config.task_id == 1
+      assert config.num_ps_replicas == 2
+      assert config.num_worker_replicas == 4
+      assert config.cluster_spec == server_lib.ClusterSpec(cluster)
+      assert config.task_type == 'worker'
+      assert not config.is_chief
+    ```
+
+    Example of chief node:
+    ```
+      cluster = {'chief': ['host0:2222'],
+                 'ps': ['host1:2222', 'host2:2222'],
+                 'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
+      os.environ['TF_CONFIG'] = json.dumps(
+          {'cluster': cluster,
+           'task': {'type': 'chief', 'index': 0}})
+      config = ClusterConfig()
+      assert config.master == 'host0:2222'
+      assert config.task_id == 0
+      assert config.num_ps_replicas == 2
+      assert config.num_worker_replicas == 4
+      assert config.cluster_spec == server_lib.ClusterSpec(cluster)
+      assert config.task_type == 'chief'
+      assert config.is_chief
+    ```
+
+    Example of evaluator node (evaluator is not part of training cluster):
+    ```
+      cluster = {'chief': ['host0:2222'],
+                 'ps': ['host1:2222', 'host2:2222'],
+                 'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
+      os.environ['TF_CONFIG'] = json.dumps(
+          {'cluster': cluster,
+           'task': {'type': 'evaluator', 'index': 0}})
+      config = ClusterConfig()
+      assert config.master == ''
+      assert config.evaluator_master == ''
+      assert config.task_id == 0
+      assert config.num_ps_replicas == 0
+      assert config.num_worker_replicas == 0
+      assert config.cluster_spec == {}
+      assert config.task_type == 'evaluator'
+      assert not config.is_chief
+    ```
+
     N.B.: If `save_checkpoints_steps` or `save_checkpoints_secs` is set,
     `keep_checkpoint_max` might need to be adjusted accordingly, especially in
     distributed training. For example, setting `save_checkpoints_secs` as 60
@@ -137,9 +282,10 @@ class RunConfig(object):
       save_checkpoints_steps: Save checkpoints every this many steps. Can not be
           specified with `save_checkpoints_secs`.
       save_checkpoints_secs: Save checkpoints every this many seconds. Can not
-          be specified with `save_checkpoints_steps`. Defaults to 600 seconds.
-          If both `save_checkpoints_steps` and `save_checkpoints_secs` are None,
-          then checkpoints are disabled.
+          be specified with `save_checkpoints_steps`. Defaults to 600 seconds if
+          both `save_checkpoints_steps` and `save_checkpoints_secs` are not set
+          in constructor.  If both `save_checkpoints_steps` and
+          `save_checkpoints_secs` are None, then checkpoints are disabled.
       session_config: a ConfigProto used to set session parameters, or None.
       keep_checkpoint_max: The maximum number of recent checkpoint files to
         keep. As new files are created, older files are deleted. If None or 0,
@@ -181,9 +327,79 @@ class RunConfig(object):
         keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
         log_step_count_steps=log_step_count_steps)
 
+    self._init_distributed_setting_from_environment_var()
+
+  def _init_distributed_setting_from_environment_var(self):
+    """Initialize distributed properties based on environment variable."""
+
+    tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV) or '{}')
+    if tf_config:
+      logging.info('TF_CONFIG environment variable: %s', tf_config)
+
+    self._cluster_spec = server_lib.ClusterSpec(tf_config.get(_CLUSTER_KEY, {}))
+    task_env = tf_config.get(_TASK_ENV_KEY, {})
+
+    if self._cluster_spec:
+      # Distributed mode.
+      if TaskType.CHIEF not in self._cluster_spec.jobs:
+        raise ValueError(
+            'If "cluster" is set in TF_CONFIG, it must have one "chief" node.')
+      if len(self._cluster_spec.job_tasks(TaskType.CHIEF)) > 1:
+        raise ValueError(
+            'The "cluster" in TF_CONFIG must have only one "chief" node.')
+
+      self._task_type = task_env.get(_TASK_TYPE_KEY, None)
+      task_id = task_env.get(_TASK_ID_KEY, None)
+
+      if not self._task_type:
+        raise ValueError(
+            'If "cluster" is set in TF_CONFIG, task type must be set.')
+      if task_id is None:
+        raise ValueError(
+            'If "cluster" is set in TF_CONFIG, task index must be set.')
+
+      self._task_id = int(task_id)
+
+      # Check the task id bounds. Upper bound is not necessary as
+      # - for evaluator, there is no upper bound.
+      # - for non-evaluator, task id is upper bounded by the number of jobs in
+      # cluster spec, which will be checked later (when retrieving the `master`)
+      if self._task_id < 0:
+        raise ValueError('Task index must be non-negative number.')
+
+      if self._task_type != TaskType.EVALUATOR:
+        self._master = _get_master(
+            self._cluster_spec, self._task_type, self._task_id)
+        self._num_ps_replicas = _count_ps(self._cluster_spec)
+        self._num_worker_replicas = _count_worker(self._cluster_spec)
+      else:
+        # Evaluator is not part of the training cluster.
+        self._cluster_spec = server_lib.ClusterSpec({})
+        self._master = _LOCAL_MASTER
+        self._num_ps_replicas = 0
+        self._num_worker_replicas = 0
+
+      self._is_chief = self._task_type == TaskType.CHIEF
+    else:
+      # Local mode.
+      self._task_type = task_env.get(_TASK_TYPE_KEY, TaskType.WORKER)
+      self._task_id = int(task_env.get(_TASK_ID_KEY, 0))
+
+      if self._task_type != TaskType.WORKER:
+        raise ValueError(
+            'If "cluster" is not set in TF_CONFIG, task type must be WORKER.')
+      if self._task_id != 0:
+        raise ValueError(
+            'If "cluster" is not set in TF_CONFIG, task index must be 0.')
+
+      self._master = ''
+      self._is_chief = True
+      self._num_ps_replicas = 0
+      self._num_worker_replicas = 1
+
   @property
   def cluster_spec(self):
-    return None
+    return self._cluster_spec
 
   @property
   def evaluation_master(self):
@@ -191,27 +407,27 @@ class RunConfig(object):
 
   @property
   def is_chief(self):
-    return True
+    return self._is_chief
 
   @property
   def master(self):
-    return ''
+    return self._master
 
   @property
   def num_ps_replicas(self):
-    return 0
+    return self._num_ps_replicas
 
   @property
   def num_worker_replicas(self):
-    return 1
+    return self._num_worker_replicas
 
   @property
   def task_id(self):
-    return 0
+    return self._task_id
 
   @property
   def task_type(self):
-    return TaskType.WORKER
+    return self._task_type
 
   @property
   def tf_random_seed(self):
diff --git a/tensorflow/python/estimator/run_config_test.py b/tensorflow/python/estimator/run_config_test.py
index 4a09417630aa6daeef8de283c5d44104dc4f1839..cd135a34680ed3b06e81b3f99fe8b5097e2b0517 100644
--- a/tensorflow/python/estimator/run_config_test.py
+++ b/tensorflow/python/estimator/run_config_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import json
+
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.estimator import run_config as run_config_lib
 from tensorflow.python.platform import test
@@ -36,6 +38,22 @@ _SESSION_CONFIG_ERR = 'session_config must be instance of ConfigProto'
 _KEEP_CKPT_MAX_ERR = 'keep_checkpoint_max should be >= 0'
 _KEEP_CKPT_HOURS_ERR = 'keep_checkpoint_every_n_hours should be > 0'
 _TF_RANDOM_SEED_ERR = 'tf_random_seed must be integer'
+_ONE_CHIEF_ERR = 'The "cluster" in TF_CONFIG must have only one "chief" node.'
+_MISSING_CHIEF_ERR = 'If "cluster" is set .* it must have one "chief" node'
+_MISSING_TASK_TYPE_ERR = 'If "cluster" is set .* task type must be set'
+_MISSING_TASK_ID_ERR = 'If "cluster" is set .* task index must be set'
+_INVALID_TASK_INDEX_ERR = 'is not a valid task_id'
+_NEGATIVE_TASK_INDEX_ERR = 'Task index must be non-negative number.'
+_INVALID_TASK_TYPE_ERR = 'is not a valid task_type'
+_INVALID_TASK_TYPE_FOR_LOCAL_ERR = (
+    'If "cluster" is not set in TF_CONFIG, task type must be WORKER.')
+_INVALID_TASK_INDEX_FOR_LOCAL_ERR = (
+    'If "cluster" is not set in TF_CONFIG, task index must be 0.')
+
+
+def _create_run_config_with_cluster_spec(tf_config, **kwargs):
+  with test.mock.patch.dict('os.environ', {'TF_CONFIG': json.dumps(tf_config)}):
+    return run_config_lib.RunConfig(**kwargs)
 
 
 class RunConfigTest(test.TestCase):
@@ -189,6 +207,283 @@ class RunConfigTest(test.TestCase):
       run_config_lib.RunConfig(tf_random_seed=1.0)
 
 
+class RunConfigDistributedSettingTest(test.TestCase):
+
+  def _assert_distributed_properties(self, run_config,
+                                     expected_cluster_spec,
+                                     expected_task_type,
+                                     expected_task_id,
+                                     expected_master,
+                                     expected_evaluation_master,
+                                     expected_is_chief,
+                                     expected_num_worker_replicas,
+                                     expected_num_ps_replicas):
+    self.assertEqual(expected_cluster_spec, run_config.cluster_spec.as_dict())
+    self.assertEqual(expected_task_type, run_config.task_type)
+    self.assertEqual(expected_task_id, run_config.task_id)
+    self.assertEqual(expected_master, run_config.master)
+    self.assertEqual(expected_evaluation_master, run_config.evaluation_master)
+    self.assertEqual(expected_is_chief, run_config.is_chief)
+    self.assertEqual(expected_num_worker_replicas,
+                     run_config.num_worker_replicas)
+    self.assertEqual(expected_num_ps_replicas, run_config.num_ps_replicas)
+
+  def test_default_values(self):
+    self._assert_distributed_properties(
+        run_config=run_config_lib.RunConfig(),
+        expected_cluster_spec={},
+        expected_task_type=run_config_lib.TaskType.WORKER,
+        expected_task_id=0,
+        expected_master='',
+        expected_evaluation_master='',
+        expected_is_chief=True,
+        expected_num_worker_replicas=1,
+        expected_num_ps_replicas=0)
+
+  def test_tf_config_for_local(self):
+    tf_config = {
+        'task': {
+            'type': run_config_lib.TaskType.WORKER,
+            'index': 0
+        }
+    }
+    self._assert_distributed_properties(
+        run_config=_create_run_config_with_cluster_spec(tf_config),
+        expected_cluster_spec={},
+        expected_task_type=run_config_lib.TaskType.WORKER,
+        expected_task_id=0,
+        expected_master='',
+        expected_evaluation_master='',
+        expected_is_chief=True,
+        expected_num_worker_replicas=1,
+        expected_num_ps_replicas=0)
+
+  def test_invalid_task_type_for_local(self):
+    tf_config = {
+        'task': {
+            'type': run_config_lib.TaskType.CHIEF,
+            'index': 0
+        }
+    }
+    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE_FOR_LOCAL_ERR):
+      _create_run_config_with_cluster_spec(tf_config)
+
+  def test_invalid_task_index_for_local(self):
+    tf_config = {
+        'task': {
+            'type': run_config_lib.TaskType.WORKER,
+            'index': 1
+        }
+    }
+    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_INDEX_FOR_LOCAL_ERR):
+      _create_run_config_with_cluster_spec(tf_config)
+
+  def test_chief_tf_config(self):
+    tf_config = {
+        'cluster': {
+            run_config_lib.TaskType.CHIEF: ['host0:0'],
+            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+        },
+        'task': {
+            'type': run_config_lib.TaskType.CHIEF,
+            'index': 0
+        }
+    }
+    self._assert_distributed_properties(
+        run_config=_create_run_config_with_cluster_spec(tf_config),
+        expected_cluster_spec=tf_config['cluster'],
+        expected_task_type=run_config_lib.TaskType.CHIEF,
+        expected_task_id=0,
+        expected_master='grpc://host0:0',
+        expected_evaluation_master='',
+        expected_is_chief=True,
+        expected_num_worker_replicas=4,
+        expected_num_ps_replicas=2)
+
+  def test_fail_with_multiple_chief_nodes(self):
+    tf_config = {
+        'cluster': {
+            run_config_lib.TaskType.CHIEF: ['host0:0', 'host:6:6'],
+            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+        },
+    }
+    with self.assertRaisesRegexp(ValueError, _ONE_CHIEF_ERR):
+      _create_run_config_with_cluster_spec(tf_config)
+
+  def test_fail_with_missing_chief_node(self):
+    tf_config = {
+        'cluster': {
+            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+        },
+    }
+    with self.assertRaisesRegexp(ValueError, _MISSING_CHIEF_ERR):
+      _create_run_config_with_cluster_spec(tf_config)
+
+  def test_single_chief_node(self):
+    tf_config = {
+        'cluster': {
+            run_config_lib.TaskType.CHIEF: ['host0:0'],
+        },
+        'task': {
+            'type': run_config_lib.TaskType.CHIEF,
+            'index': 0
+        }
+    }
+    self._assert_distributed_properties(
+        run_config=_create_run_config_with_cluster_spec(tf_config),
+        expected_cluster_spec=tf_config['cluster'],
+        expected_task_type=run_config_lib.TaskType.CHIEF,
+        expected_task_id=0,
+        expected_master='grpc://host0:0',
+        expected_evaluation_master='',
+        expected_is_chief=True,
+        expected_num_worker_replicas=1,
+        expected_num_ps_replicas=0)
+
+  def test_fail_with_missing_task_type_for_distributed(self):
+    tf_config = {
+        'cluster': {
+            run_config_lib.TaskType.CHIEF: ['host3:3']
+        },
+    }
+    with self.assertRaisesRegexp(ValueError, _MISSING_TASK_TYPE_ERR):
+      _create_run_config_with_cluster_spec(tf_config)
+
+  def test_fail_with_missing_task_index_for_distributed(self):
+    tf_config = {
+        'cluster': {
+            run_config_lib.TaskType.CHIEF: ['host3:3']
+        },
+        'task': {
+            'type': run_config_lib.TaskType.CHIEF,
+        }
+    }
+    with self.assertRaisesRegexp(ValueError, _MISSING_TASK_ID_ERR):
+      _create_run_config_with_cluster_spec(tf_config)
+
+  def test_fail_with_index_is_too_large(self):
+    tf_config = {
+        'cluster': {
+            run_config_lib.TaskType.CHIEF: ['host3:3']
+        },
+        'task': {
+            'type': run_config_lib.TaskType.CHIEF,
+            'index': 1
+        }
+    }
+    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_INDEX_ERR):
+      _create_run_config_with_cluster_spec(tf_config)
+
+  def test_fail_with_invalid_task_index(self):
+    tf_config = {
+        'cluster': {
+            run_config_lib.TaskType.CHIEF: ['host3:3']
+        },
+        'task': {
+            'type': run_config_lib.TaskType.CHIEF,
+            'index': -1
+        }
+    }
+    with self.assertRaisesRegexp(ValueError, _NEGATIVE_TASK_INDEX_ERR):
+      _create_run_config_with_cluster_spec(tf_config)
+
+  def test_fail_with_invalid_task_type(self):
+    tf_config = {
+        'cluster': {
+            run_config_lib.TaskType.CHIEF: ['host3:3']
+        },
+        'task': {
+            'type': run_config_lib.TaskType.WORKER,
+            'index': 0
+        }
+    }
+    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE_ERR):
+      _create_run_config_with_cluster_spec(tf_config)
+
+  def test_worker_tf_config(self):
+    tf_config = {
+        'cluster': {
+            run_config_lib.TaskType.CHIEF: ['host0:0'],
+            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+        },
+        'task': {
+            'type': run_config_lib.TaskType.WORKER,
+            'index': 1
+        }
+    }
+    self._assert_distributed_properties(
+        run_config=_create_run_config_with_cluster_spec(tf_config),
+        expected_cluster_spec=tf_config['cluster'],
+        expected_task_type=run_config_lib.TaskType.WORKER,
+        expected_task_id=1,
+        expected_master='grpc://host4:4',
+        expected_evaluation_master='',
+        expected_is_chief=False,
+        expected_num_worker_replicas=4,
+        expected_num_ps_replicas=2)
+
+  def test_ps_tf_config(self):
+    tf_config = {
+        'cluster': {
+            run_config_lib.TaskType.CHIEF: ['host0:0'],
+            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+        },
+        'task': {
+            'type': run_config_lib.TaskType.PS,
+            'index': 0
+        }
+    }
+    self._assert_distributed_properties(
+        run_config=_create_run_config_with_cluster_spec(tf_config),
+        expected_cluster_spec=tf_config['cluster'],
+        expected_task_type=run_config_lib.TaskType.PS,
+        expected_task_id=0,
+        expected_master='grpc://host1:1',
+        expected_evaluation_master='',
+        expected_is_chief=False,
+        expected_num_worker_replicas=4,
+        expected_num_ps_replicas=2)
+
+  def test_evaluator_tf_config(self):
+    tf_config = {
+        'cluster': {
+            run_config_lib.TaskType.CHIEF: ['host0:0'],
+            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+        },
+        'task': {
+            'type': run_config_lib.TaskType.EVALUATOR,
+            'index': 12
+        }
+    }
+    self._assert_distributed_properties(
+        run_config=_create_run_config_with_cluster_spec(tf_config),
+        expected_cluster_spec={},
+        expected_task_type=run_config_lib.TaskType.EVALUATOR,
+        expected_task_id=12,
+        expected_master='',
+        expected_evaluation_master='',
+        expected_is_chief=False,  # evaluator is never chief.
+        expected_num_worker_replicas=0,  # evaluator is not in training cluster.
+        expected_num_ps_replicas=0)
+
+  def test_fail_with_invalid_task_index_for_evaluator(self):
+    tf_config = {
+        'cluster': {
+            run_config_lib.TaskType.CHIEF: ['host3:3']
+        },
+        'task': {
+            'type': run_config_lib.TaskType.EVALUATOR,
+            'index': -1
+        }
+    }
+    with self.assertRaisesRegexp(ValueError, _NEGATIVE_TASK_INDEX_ERR):
+      _create_run_config_with_cluster_spec(tf_config)
+
+
 class RunConfigSaveCheckpointsTest(test.TestCase):
 
   def test_save_checkpoint(self):
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index a8434d0c991bbc7c692e0fb37d750f44a6a9a2c0..965b35bc4c8b59e7fcde8133f31184154edb7533 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -2474,6 +2474,9 @@ class _IndicatorColumn(_DenseColumn,
           sp_ids=id_tensor,
           sp_values=weight_tensor,
           vocab_size=int(self._variable_shape[-1]))
+      # Remove (?, -1) index
+      weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
+                                                weighted_column.dense_shape)
       return sparse_ops.sparse_tensor_to_dense(weighted_column)
 
     dense_id_tensor = sparse_ops.sparse_tensor_to_dense(
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 5138f31e981717f835175f86ae01b1abb15b5b26..926e78acee713c0d823922ef898fbdf3dccba2e5 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -3213,13 +3213,39 @@ class IndicatorColumnTest(test.TestCase):
     weights = fc.weighted_categorical_column(ids, 'weights')
     indicator = fc.indicator_column(weights)
     features = {
-      'ids': constant_op.constant(['c', 'b', 'a'], shape=(1, 3)),
-      'weights': constant_op.constant([2., 4., 6.], shape=(1, 3))
+        'ids': constant_op.constant([['c', 'b', 'a']]),
+        'weights': constant_op.constant([[2., 4., 6.]])
     }
     indicator_tensor = _transform_features(features, [indicator])[indicator]
     with _initialized_session():
       self.assertAllEqual([[6., 4., 2.]], indicator_tensor.eval())
 
+  def test_transform_with_missing_value_in_weighted_column(self):
+    # Github issue 12583
+    ids = fc.categorical_column_with_vocabulary_list(
+        key='ids', vocabulary_list=('a', 'b', 'c'))
+    weights = fc.weighted_categorical_column(ids, 'weights')
+    indicator = fc.indicator_column(weights)
+    features = {
+        'ids': constant_op.constant([['c', 'b', 'unknown']]),
+        'weights': constant_op.constant([[2., 4., 6.]])
+    }
+    indicator_tensor = _transform_features(features, [indicator])[indicator]
+    with _initialized_session():
+      self.assertAllEqual([[0., 4., 2.]], indicator_tensor.eval())
+
+  def test_transform_with_missing_value_in_categorical_column(self):
+    # Github issue 12583
+    ids = fc.categorical_column_with_vocabulary_list(
+        key='ids', vocabulary_list=('a', 'b', 'c'))
+    indicator = fc.indicator_column(ids)
+    features = {
+        'ids': constant_op.constant([['c', 'b', 'unknown']]),
+    }
+    indicator_tensor = _transform_features(features, [indicator])[indicator]
+    with _initialized_session():
+      self.assertAllEqual([[0., 1., 1.]], indicator_tensor.eval())
+
   def test_linear_model(self):
     animal = fc.indicator_column(
         fc.categorical_column_with_identity('animal', num_buckets=4))
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 2f35f0e04b6c69f2dabd713f59518c5d3ce3f937..7a866ee6e8a6ec9521295cf123975db49cf0d172 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -26,7 +26,9 @@ import hashlib
 
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import op_def_pb2
+from tensorflow.python import pywrap_tensorflow as c_api
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import graph_to_function_def
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
@@ -290,6 +292,7 @@ class _DefinedFunction(object):
     self._shape_func = shape_func
     self._extra_kwargs = kwargs
     self._definition = None  # Constructed lazily.
+    self._c_func = None  # Constructed with definition.
     self._sub_functions = dict()  # Constructed with definition.
 
     self._args = []
@@ -396,6 +399,22 @@ class _DefinedFunction(object):
     if self._func.__doc__:
       self._definition.signature.description = self._func.__doc__
 
+    # pylint: disable=protected-access
+    if temp_graph._c_graph:
+      with errors.raise_exception_on_not_ok_status() as status:
+        output_names = ([compat.as_bytes(x) for x in self._out_names]
+                        if self._out_names else [])
+        self._c_func = c_api.TF_GraphToFunction_wrapper(
+            temp_graph._c_graph,
+            self._func_name,
+            None,  # opers
+            [t._as_tf_output() for t in inputs],
+            [t._as_tf_output() for t in outputs],
+            output_names,
+            None,  # opts
+            status)
+    # pylint: enable=protected-access
+
   def _create_hash_str(self, input_arg, output_arg, node_def):
     """Creates an 8-character string unique to this input.
 
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 589db9ef4dc1389d7822ba812ae8fb8f78e40456..40205ddf0532a666ba61d3b6c2642ead7365bd84 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import function
 from tensorflow.python.framework import graph_to_function_def
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import clip_ops
 from tensorflow.python.ops import control_flow_ops
@@ -63,7 +64,51 @@ def _OptimizerOptions():
                 do_constant_folding=cfold)))
 
 
-class FunctionTest(test.TestCase):
+class FunctionTestMethods(object):
+  """Test methods for verifying Function support.
+
+  These test methods are used as mix-ins in two test cases: with
+  and without C API support.
+  """
+
+  def testIdentity(self):
+
+    @function.Defun(dtypes.float32, func_name="MyIdentity")
+    def MyIdentityFunc(a):
+      return a
+
+    with ops.Graph().as_default():
+      call = MyIdentityFunc([18.0])
+      self.assertEqual("MyIdentity", call.op.name)
+      with session.Session() as sess:
+        self.assertAllEqual([18.0], sess.run(call))
+
+  def testIdentityOutputName(self):
+
+    @function.Defun(
+        dtypes.float32, func_name="MyIdentity", out_names=["my_result_name"])
+    def MyIdentityFunc(a):
+      return a
+
+    with ops.Graph().as_default():
+      call = MyIdentityFunc([18.0])
+      self.assertEqual("MyIdentity", call.op.name)
+      with session.Session() as sess:
+        self.assertAllEqual([18.0], sess.run(call))
+
+  def testTooManyOutputNames(self):
+
+    @function.Defun(
+        dtypes.float32, func_name="MyIdentity",
+        out_names=["my_result1", "my_result2"])
+    def MyIdentityFunc(a):
+      return a
+
+    with ops.Graph().as_default():
+      with self.assertRaisesRegexp(
+          ValueError, (r"Length of out_names \(2\) does not match number of "
+                       r"outputs \(1\): my_result1, my_result2")):
+        MyIdentityFunc([18.0])
 
   def testDefineFunction2Args(self):
 
@@ -77,6 +122,35 @@ class FunctionTest(test.TestCase):
       with session.Session() as sess:
         self.assertAllEqual([5.0], sess.run(call))
 
+  def testValueErrorOnFunctionWithNoOutput(self):
+    # TODO(iga): Remove this restriction and this test
+
+    @function.Defun(dtypes.float32, dtypes.float32)
+    def APlus2B(a, b):
+      print(a + b * 2)  # Create some ops to have nodes in the body
+                        # Using 'print' to make lint happy
+
+    with ops.Graph().as_default():
+      with self.assertRaisesRegexp(ValueError,
+                                   "Function can not return None"):
+        APlus2B([1.0], [2.0])
+
+  def testDefineFunction2ArgsOutputName(self):
+
+    @function.Defun(
+        dtypes.float32,
+        dtypes.float32,
+        func_name="APlus2B",
+        out_names=["my_result_name"])
+    def APlus2B(a, b):
+      return a + b * 2
+
+    with ops.Graph().as_default():
+      call = APlus2B([1.0], [2.0])
+      self.assertEqual("APlus2B", call.op.name)
+      with session.Session() as sess:
+        self.assertAllEqual([5.0], sess.run(call))
+
   def testDefineFunctionDuplicateOutputs(self):
 
     @function.Defun(dtypes.float32, func_name="Duplicate")
@@ -137,6 +211,7 @@ class FunctionTest(test.TestCase):
       out, = sess.run(dx, feed)
     self.assertAllClose(1 - np.square(np.tanh(inp)), out)
 
+  @test_util.disable_c_api   # Function gradients don't work with C API
   def testCustomGradient(self):
     dtype = dtypes.float32
 
@@ -169,6 +244,7 @@ class FunctionTest(test.TestCase):
         out, = sess.run(dlogits, {logits: x, labels: y})
       self.assertAllClose(out, np.exp(prob - y))
 
+  @test_util.disable_c_api   # Function gradients don't work with C API
   def testCustomGradientError(self):
     dtype = dtypes.float32
 
@@ -194,6 +270,7 @@ class FunctionTest(test.TestCase):
           "SymGrad expects to return 1.*but get 2.*instead"):
         _ = sess.run(dinp, {inp: x})
 
+  @test_util.disable_c_api   # Function gradients don't work with C API
   def testSymGradShape(self):
     g = ops.Graph()
     with g.as_default():
@@ -209,6 +286,7 @@ class FunctionTest(test.TestCase):
       self.assertEqual(x.get_shape(), dx.get_shape())
       self.assertEqual(y.get_shape(), dy.get_shape())
 
+  @test_util.disable_c_api   # Function gradients don't work with C API
   def testSymGradAttr(self):
 
     @function.Defun(noinline=True)
@@ -312,6 +390,7 @@ class FunctionTest(test.TestCase):
                                    "assertion failed.*-3"):
         self.assertAllEqual(Foo(constant_op.constant(-3.0)).eval(), 6.0)
 
+  @test_util.disable_c_api   # Op._add_control_inputs doesn't work with C API
   def testAssertWrapper(self):
 
     @function.Defun(dtypes.float32)
@@ -326,6 +405,7 @@ class FunctionTest(test.TestCase):
                                    "assertion"):
         _ = MyFn(100.0).eval()
 
+  @test_util.disable_c_api   # Op._add_control_inputs doesn't work with C API
   def testWhileLoopCallsFunc(self):
     with self.test_session(use_gpu=True) as sess:
 
@@ -345,6 +425,7 @@ class FunctionTest(test.TestCase):
       ans = sess.run(loop)
       self.assertAllClose(ans, 131072.)
 
+  @test_util.disable_c_api   # Op._add_control_inputs doesn't work with C API
   def testControlFlowStrictness(self):
     """Inlined functions must not execute in a untaken control flow branch."""
 
@@ -607,6 +688,7 @@ class FunctionTest(test.TestCase):
       self.assertAllClose(vals[0], vals[1])
       self.assertAllClose(vals[2], vals[3])
 
+  @test_util.disable_c_api   # Function Declaration doesn't work with C API
   def testDeclare(self):
     foo = function.Declare("Foo", [("x", dtypes.float32)], [("y",
                                                              dtypes.float32)])
@@ -626,6 +708,7 @@ class FunctionTest(test.TestCase):
       expected = rand * rand + 1.0
       self.assertAllClose(expected, y.eval(feed_dict={x: rand}))
 
+  @test_util.disable_c_api   # Function Declaration doesn't work with C API
   def testDeclareUsedInDefun(self):
     foo = function.Declare("Foo", [("x", dtypes.float32)], [("y",
                                                              dtypes.float32)])
@@ -649,6 +732,7 @@ class FunctionTest(test.TestCase):
       expected = rand * rand + 1.0
       self.assertAllClose(expected, y.eval(feed_dict={x: rand}))
 
+  @test_util.disable_c_api   # Function Declaration doesn't work with C API
   def testDeclareTypeMistake(self):
     foo = function.Declare("Foo", [("x", dtypes.float32)], [("y",
                                                              dtypes.float32)])
@@ -861,6 +945,32 @@ class FunctionTest(test.TestCase):
     self.assertEqual(len(f.signature.input_arg), 3)
 
 
+class FunctionTest(FunctionTestMethods, test.TestCase):
+  """Test case that invokes test methods with _USE_C_API=False."""
+
+  def setUp(self):
+    self.prev_use_c_api = ops._USE_C_API
+    ops._USE_C_API = False
+    super(FunctionTest, self).setUp()
+
+  def tearDown(self):
+    ops._USE_C_API = self.prev_use_c_api
+    super(FunctionTest, self).tearDown()
+
+
+class FunctionWithCApiTest(FunctionTestMethods, test.TestCase):
+  """Test case that invokes test methods with _USE_C_API=True."""
+
+  def setUp(self):
+    self.prev_use_c_api = ops._USE_C_API
+    ops._USE_C_API = True
+    super(FunctionWithCApiTest, self).setUp()
+
+  def tearDown(self):
+    ops._USE_C_API = self.prev_use_c_api
+    super(FunctionWithCApiTest, self).tearDown()
+
+
 class FunctionsFromProtos(test.TestCase):
 
   def expectFunctionsEqual(self, func, grad_func=None, new_func=None):
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index aa373600669c16c4e2c0b2078705d3f9a2fc0823..76424ef579bd8853020f8070430fc0def51d6fde 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -784,6 +784,7 @@ class OpDefLibrary(object):
                               if arg.is_ref]
       with _MaybeColocateWith(must_colocate_inputs):
         # Add Op to graph
+        inputs = [ag_core.getval(x) for x in inputs]
         op = g.create_op(op_type_name, inputs, output_types, name=scope,
                          input_types=input_types, attrs=attr_protos,
                          op_def=op_def)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 72d9bd115ebf2143e6ff5076de1b753963e3554e..b197e96886e4f3d34f1e99fa746156a6c3305ee4 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -49,6 +49,7 @@ from tensorflow.python.framework import versions
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import compat
 from tensorflow.python.util import decorator_utils
+from tensorflow.python.util import nest
 from tensorflow.python.util import tf_contextlib
 
 # Temporary global switch determining if we should enable the work-in-progress
@@ -372,6 +373,12 @@ class Tensor(_TensorLike):
     else:
       return None
 
+  def _shape_tuple(self):
+    shape = self._shape_as_list()
+    if shape is None:
+      return None
+    return tuple(shape)
+
   def get_shape(self):
     """Alias of Tensor.shape."""
     return self.shape
@@ -598,6 +605,13 @@ def _maybe_modify_numpy_dtype_determination(np_array):
   return np_array
 
 
+def _has_string(value):
+  if isinstance(value, compat.bytes_or_text_types): return True
+  if isinstance(value, collections.Sequence) and value:
+    return _has_string(value[0])
+  return False
+
+
 # TODO(agarwal): rename to TensorHandle.
 class EagerTensor(Tensor):
   """A TensorFlow Eager Tensor."""
@@ -619,6 +633,8 @@ class EagerTensor(Tensor):
     # https://www.tensorflow.org/code/tensorflow/python/framework/constant_op.py
     self._id = uid()
     if not isinstance(value, np.ndarray):
+      if dtype is None and _has_string(value):
+        dtype = dtypes.string
       npt = None if dtype is None else dtype.as_numpy_dtype
       try:
         value = np.array(value, dtype=npt)
@@ -706,12 +722,12 @@ class EagerTensor(Tensor):
     return numpy_text
 
   def __str__(self):
-    return "tfe.Tensor(shape=%s, dtype=%s, numpy=%s)" % (self.shape,
-                                                         self.dtype.name,
-                                                         self._numpy_text())
+    return "tf.Tensor(%s, shape=%s, dtype=%s)" % (self._numpy_text(),
+                                                  self.shape,
+                                                  self.dtype.name)
 
   def __repr__(self):
-    return "" % (
+    return "" % (
         self._id, self.shape, self.dtype.name, self._numpy_text(is_repr=True))
 
   @staticmethod
@@ -919,6 +935,8 @@ def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
 _tensor_conversion_func_registry = {
     0: [(Tensor, _TensorTensorConversionFunction)]
 }
+_tensor_conversion_func_cache = {}
+_tensor_conversion_func_lock = threading.Lock()
 register_dense_tensor_like_type(Tensor)
 
 
@@ -975,6 +993,10 @@ def convert_to_tensor(value, dtype=None, name=None, preferred_dtype=None):
       as_ref=False)
 
 
+def _error_prefix(name):
+  return "" if name is None else "%s: " % name
+
+
 def internal_convert_to_tensor(value,
                                dtype=None,
                                name=None,
@@ -1011,51 +1033,77 @@ def internal_convert_to_tensor(value,
     RuntimeError: If a registered conversion function returns an invalid value.
 
   """
-  error_prefix = "" if name is None else "%s: " % name
+  # Note we check the type of the object unwrapped from an autograd node, if
+  # tracing gradients, to ensure the same behavior happens with and without
+  # tracing.
+  unwrapped = ag_core.getval(value)
+
+  if context.in_eager_mode():
+    # Fast path for EagerTensors that don't need any conversion.
+    if isinstance(unwrapped, EagerTensor):
+      # Note that we don't check that value's dtype matches the dtype
+      # argument.  We exepct that the C runtime will do that checking
+      # when we execute the kernel.
+      return value
+    values = nest.flatten(value)
+    if (len(values) > 1 and
+        any(isinstance(ag_core.getval(v), EagerTensor) for v in values)):
+      raise TypeError("Cannot convert to a eager tensor.")
+
   if dtype is not None:
     dtype = dtypes.as_dtype(dtype)
-  for _, funcs_at_priority in sorted(_tensor_conversion_func_registry.items()):
-    for base_type, conversion_func in funcs_at_priority:
-      if isinstance(value, base_type):
-        # If dtype is None but preferred_dtype is not None, we try to
-        # cast to preferred_dtype first.
+  unwrapped_type = type(unwrapped)
+  conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None)
+  if conversion_func_list is None:
+    with _tensor_conversion_func_lock:
+      conversion_func_list = []
+      for _, funcs_at_priority in sorted(
+          _tensor_conversion_func_registry.items()):
+        for base_type, conversion_func in funcs_at_priority:
+          if isinstance(unwrapped, base_type):
+            conversion_func_list.append((base_type, conversion_func))
+      _tensor_conversion_func_cache[unwrapped_type] = conversion_func_list
+
+  for base_type, conversion_func in conversion_func_list:
+    # If dtype is None but preferred_dtype is not None, we try to
+    # cast to preferred_dtype first.
+    ret = None
+    if dtype is None and preferred_dtype is not None:
+      try:
+        ret = conversion_func(
+            value, dtype=preferred_dtype, name=name, as_ref=as_ref)
+      except (TypeError, ValueError):
+        # Could not coerce the conversion to use the preferred dtype.
         ret = None
-        if dtype is None and preferred_dtype is not None:
-          try:
-            ret = conversion_func(
-                value, dtype=preferred_dtype, name=name, as_ref=as_ref)
-          except (TypeError, ValueError):
-            # Could not coerce the conversion to use the preferred dtype.
-            ret = None
-
-          if ret is not None and ret is not NotImplemented:
-            if (ret.dtype.base_dtype !=
-                dtypes.as_dtype(preferred_dtype).base_dtype):
-              raise TypeError("convert_to_tensor did not convert to "
-                              "the preferred dtype: %s vs %s " %
-                              (ret.dtype.base_dtype,
-                               dtypes.as_dtype(preferred_dtype).base_dtype))
-
-        if ret is None:
-          ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
-
-        if ret is NotImplemented:
-          continue
-
-        if not isinstance(ag_core.getval(ret), Tensor):
-          raise RuntimeError(
-              "%sConversion function %r for type %s returned non-Tensor: %r" %
-              (error_prefix, conversion_func, base_type, ret))
-        if dtype and not dtype.is_compatible_with(ret.dtype):
-          raise RuntimeError(
-              "%sConversion function %r for type %s returned incompatible "
-              "dtype: requested = %s, actual = %s" %
-              (error_prefix, conversion_func, base_type, dtype.name,
-               ret.dtype.name))
-        return ret
+
+      if ret is not None and ret is not NotImplemented:
+        if (ret.dtype.base_dtype !=
+            dtypes.as_dtype(preferred_dtype).base_dtype):
+          raise TypeError("convert_to_tensor did not convert to "
+                          "the preferred dtype: %s vs %s " %
+                          (ret.dtype.base_dtype,
+                           dtypes.as_dtype(preferred_dtype).base_dtype))
+
+    if ret is None:
+      ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
+
+    if ret is NotImplemented:
+      continue
+
+    if not isinstance(ag_core.getval(ret), Tensor):
+      raise RuntimeError(
+          "%sConversion function %r for type %s returned non-Tensor: %r" %
+          (_error_prefix(name), conversion_func, base_type, ret))
+    if dtype and not dtype.is_compatible_with(ret.dtype):
+      raise RuntimeError(
+          "%sConversion function %r for type %s returned incompatible "
+          "dtype: requested = %s, actual = %s" %
+          (_error_prefix(name), conversion_func, base_type, dtype.name,
+           ret.dtype.name))
+    return ret
   raise TypeError("%sCannot convert %r with type %s to Tensor: "
-                  "no conversion function registered." % (error_prefix, value,
-                                                          type(value)))
+                  "no conversion function registered." %
+                  (_error_prefix(name), value, unwrapped_type))
 
 
 def internal_convert_n_to_tensor(values,
@@ -1305,19 +1353,22 @@ def register_tensor_conversion_function(base_type,
     TypeError: If the arguments do not have the appropriate type.
 
   """
-  if not (isinstance(base_type, type) or
-          (isinstance(base_type, tuple) and
-           all(isinstance(x, type) for x in base_type))):
-    raise TypeError("base_type must be a type or a tuple of types.")
-  if not callable(conversion_func):
-    raise TypeError("conversion_func must be callable.")
+  global _tensor_conversion_func_cache
+  with _tensor_conversion_func_lock:
+    if not (isinstance(base_type, type) or
+            (isinstance(base_type, tuple) and
+             all(isinstance(x, type) for x in base_type))):
+      raise TypeError("base_type must be a type or a tuple of types.")
+    if not callable(conversion_func):
+      raise TypeError("conversion_func must be callable.")
 
-  try:
-    funcs_at_priority = _tensor_conversion_func_registry[priority]
-  except KeyError:
-    funcs_at_priority = []
-    _tensor_conversion_func_registry[priority] = funcs_at_priority
-  funcs_at_priority.append((base_type, conversion_func))
+    try:
+      funcs_at_priority = _tensor_conversion_func_registry[priority]
+    except KeyError:
+      funcs_at_priority = []
+      _tensor_conversion_func_registry[priority] = funcs_at_priority
+    funcs_at_priority.append((base_type, conversion_func))
+    _tensor_conversion_func_cache = {}
 
 
 class IndexedSlices(_TensorLike):
@@ -2905,6 +2956,14 @@ class Graph(object):
     if self._graph_def_versions.min_consumer < 12:
       self._graph_def_versions.min_consumer = 12
     self._functions[name] = function
+    if self._c_graph:
+      # pylint: disable=protected-access
+      assert function._c_func, (
+          "Cannot add function created without C API support to graph "
+          "created with C API support")
+      with errors.raise_exception_on_not_ok_status() as status:
+        c_api.TF_GraphAddFunction(self._c_graph, function._c_func, status)
+      # pylint: enable=protected-access
 
   @property
   def building_function(self):
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 61cc1ff31da8c23f036c64e087ff8b163521ccee..dc036598cb94c7125b4bfb686a3d90c0f56acf41 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -291,6 +291,11 @@ class OperationTest(test_util.TensorFlowTestCase):
       self.assertAllEqual((4, 1), tensor.get_shape().as_list())
       self.assertAllEqual(values, tensor.eval())
 
+  def testShapeTuple(self):
+    with self.test_session():
+      c = constant_op.constant(1)
+      self.assertEqual(c._shape_tuple(), ())  # pylint: disable=protected-access
+
   def testConvertToTensorEager(self):
     with context.eager_mode():
       t = ops.EagerTensor(1)
@@ -394,7 +399,7 @@ class OperationTest(test_util.TensorFlowTestCase):
       self.assertIsInstance(x, dtypes.DType)
     self.assertEqual([dtypes.string, dtypes.double], l)
 
-  # TODO(skyewm): test adding cycles, other error cases
+  # TODO(nolivia): test all error cases
   @test_util.enable_c_api
   def testAddControlInput(self):
     with ops.Graph().as_default():
@@ -403,6 +408,22 @@ class OperationTest(test_util.TensorFlowTestCase):
     y._add_control_input(x)  # pylint: disable=protected-access
     self.assertEqual(y.control_inputs, [x])
 
+  @test_util.enable_c_api
+  def testControlInputCycle(self):
+    graph = ops.Graph()
+    with graph.as_default():
+      z = constant_op.constant(0)
+      x = constant_op.constant(1)
+      y = constant_op.constant(2)
+      y.op._add_control_input(z.op)  # pylint: disable=protected-access
+      y.op._add_control_input(x.op)  # pylint: disable=protected-access
+      x.op._add_control_input(y.op)  # pylint: disable=protected-access
+    with self.test_session(graph=graph) as sess:
+      with self.assertRaisesRegexp(
+          errors.InvalidArgumentError,
+          "Graph is invalid, contains a cycle with 2 nodes"):
+        sess.run(x)
+
 
 class CreateOpTest(test_util.TensorFlowTestCase):
 
diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc
index 83665422885ba95393bd4d9ffd4084ad6d580b08..f681daa7e46474c9478cf9c52098158bfb357862 100644
--- a/tensorflow/python/framework/python_op_gen_main.cc
+++ b/tensorflow/python/framework/python_op_gen_main.cc
@@ -24,6 +24,7 @@ limitations under the License.
 #include "tensorflow/core/framework/op_def.pb.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/lib/strings/scanner.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/init_main.h"
@@ -80,7 +81,29 @@ Status ParseOpListCommandLine(const char* arg, std::vector* op_list) {
   return Status::OK();
 }
 
-void PrintAllPythonOps(const std::vector& op_list, bool require_shapes,
+// Use the name of the current executable to infer the C++ source file
+// where the REGISTER_OP() call for the operator can be found.
+// Returns the name of the file.
+// Returns an empty string if the current executable's name does not
+// follow a known pattern.
+string InferSourceFileName(const char* argv_zero) {
+  StringPiece command_str = io::Basename(argv_zero);
+
+  // For built-in ops, the Bazel build creates a separate executable
+  // with the name gen__ops_py_wrappers_cc containing the
+  // operators defined in _ops.cc
+  const char* kExecPrefix = "gen_";
+  const char* kExecSuffix = "_py_wrappers_cc";
+  if (command_str.Consume(kExecPrefix) && command_str.ends_with(kExecSuffix)) {
+    command_str.remove_suffix(strlen(kExecSuffix));
+    return strings::StrCat(command_str, ".cc");
+  } else {
+    return string("");
+  }
+}
+
+void PrintAllPythonOps(const std::vector& op_list,
+                       const string& source_file_name, bool require_shapes,
                        bool op_list_is_whitelist) {
   OpList ops;
   OpRegistry::Global()->Export(false, &ops);
@@ -93,9 +116,9 @@ void PrintAllPythonOps(const std::vector& op_list, bool require_shapes,
         *pruned_ops.mutable_op()->Add() = op_def;
       }
     }
-    PrintEagerPythonOps(pruned_ops, {}, require_shapes);
+    PrintEagerPythonOps(pruned_ops, {}, require_shapes, source_file_name);
   } else {
-    PrintEagerPythonOps(ops, op_list, require_shapes);
+    PrintEagerPythonOps(ops, op_list, require_shapes, source_file_name);
   }
 }
 
@@ -105,20 +128,26 @@ void PrintAllPythonOps(const std::vector& op_list, bool require_shapes,
 int main(int argc, char* argv[]) {
   tensorflow::port::InitMain(argv[0], &argc, &argv);
 
+  tensorflow::string source_file_name =
+      tensorflow::InferSourceFileName(argv[0]);
+
   // Usage:
   //   gen_main [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
   if (argc == 2) {
-    tensorflow::PrintAllPythonOps({}, {}, tensorflow::string(argv[1]) == "1");
+    tensorflow::PrintAllPythonOps({}, source_file_name,
+                                  tensorflow::string(argv[1]) == "1",
+                                  false /* op_list_is_whitelist */);
   } else if (argc == 3) {
     std::vector hidden_ops;
     TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &hidden_ops));
-    tensorflow::PrintAllPythonOps(hidden_ops,
+    tensorflow::PrintAllPythonOps(hidden_ops, source_file_name,
                                   tensorflow::string(argv[2]) == "1",
                                   false /* op_list_is_whitelist */);
   } else if (argc == 4) {
     std::vector op_list;
     TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &op_list));
-    tensorflow::PrintAllPythonOps(op_list, tensorflow::string(argv[2]) == "1",
+    tensorflow::PrintAllPythonOps(op_list, source_file_name,
+                                  tensorflow::string(argv[2]) == "1",
                                   tensorflow::string(argv[3]) == "1");
   } else {
     return -1;
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 3bdc72140284db83d82fee3d800be28ea88ac4d0..3e13b825f853e3bb287e25b33fa23ff72ac909d9 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from autograd import core as ag_core
 import numpy as np
 import six
 
@@ -235,8 +236,8 @@ def _FilterTuple(v):
 def _FilterInt(v):
   if isinstance(v, (list, tuple)):
     return _FirstNotNone([_FilterInt(x) for x in v])
-  return None if isinstance(v, compat.integral_types) else _NotNone(v)
-
+  return None if isinstance(v, (compat.integral_types,
+                                tensor_shape.Dimension)) else _NotNone(v)
 
 def _FilterFloat(v):
   if isinstance(v, (list, tuple)):
@@ -605,7 +606,7 @@ def ShapeEquals(tensor_proto, shape):
 
 def _ConstantValue(tensor, partial):
   # TODO(touts): Support Variables?
-  if not isinstance(tensor, ops.Tensor):
+  if not isinstance(ag_core.getval(tensor), ops.Tensor):
     raise TypeError("tensor is not a Tensor")
   if tensor.op.type == "Const":
     return MakeNdarray(tensor.op.get_attr("value"))
@@ -688,6 +689,22 @@ def _ConstantValue(tensor, partial):
       return np.full(fill_shape.as_list(), fill_value, dtype=fill_value.dtype)
     else:
       return None
+  elif tensor.op.type == "Equal":
+    value1 = constant_value(tensor.op.inputs[0])
+    if value1 is None:
+      return None
+    value2 = constant_value(tensor.op.inputs[1])
+    if value2 is None:
+      return None
+    return np.equal(value1, value2)
+  elif tensor.op.type == "NotEqual":
+    value1 = constant_value(tensor.op.inputs[0])
+    if value1 is None:
+      return None
+    value2 = constant_value(tensor.op.inputs[1])
+    if value2 is None:
+      return None
+    return np.not_equal(value1, value2)
   else:
     return None
 
@@ -719,7 +736,7 @@ def constant_value(tensor, partial=False):  # pylint: disable=invalid-name
   Raises:
     TypeError: if tensor is not an ops.Tensor.
   """
-  if isinstance(tensor, ops.EagerTensor):
+  if isinstance(ag_core.getval(tensor), ops.EagerTensor):
     return tensor.numpy()
   ret = _ConstantValue(tensor, partial)
   if ret is not None:
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index 3a6dce11f47744415513ef7653c72fb1a965b90f..f66af3adc650d8c78a41648015d2b14a17376499 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -314,6 +314,17 @@ class TensorUtilTest(test.TestCase):
                   shape=[3, 4],
                   dtype=dtype)))
 
+  def testIntMixedWithDimension(self):
+    # Github issue: 11974
+    dtype = dtypes.int32
+    nptype = np.int32
+    t = tensor_util.make_tensor_proto(
+        [10, tensor_shape.Dimension(20), 30], dtype=dtype)
+    self.assertEquals(dtype, t.dtype)
+    a = tensor_util.MakeNdarray(t)
+    self.assertEquals(nptype, a.dtype)
+    self.assertAllClose(np.array([10, 20, 30], dtype=nptype), a)
+
   def testLong(self):
     t = tensor_util.make_tensor_proto(10, dtype=dtypes.int64)
     self.assertProtoEquals("""
@@ -800,6 +811,36 @@ class ConstantValueTest(test.TestCase):
     self.assertAllClose(input_, c_val[0])
     self.assertIsNone(c_val[1])
 
+  def testEqual(self):
+    # Scalar inputs.
+    tf_val = math_ops.equal(constant_op.constant(1), constant_op.constant(1))
+    self.assertEqual(tensor_util.constant_value(tf_val), True)
+
+    tf_val = math_ops.equal(constant_op.constant(1), constant_op.constant(0))
+    self.assertEqual(tensor_util.constant_value(tf_val), False)
+
+    # Shaped inputs with broadcast semantics.
+    tf_val = math_ops.equal(constant_op.constant([[0, 1]]),
+                            constant_op.constant([[0], [1]]))
+    c_val = tensor_util.constant_value(tf_val)
+    self.assertAllEqual(c_val, [[True, False], [False, True]])
+
+  def testNotEqual(self):
+    # Scalar inputs.
+    tf_val = math_ops.not_equal(constant_op.constant(1),
+                                constant_op.constant(1))
+    self.assertEqual(tensor_util.constant_value(tf_val), False)
+
+    tf_val = math_ops.not_equal(constant_op.constant(1),
+                                constant_op.constant(0))
+    self.assertEqual(tensor_util.constant_value(tf_val), True)
+
+    # Shaped inputs with broadcast semantics.
+    tf_val = math_ops.not_equal(constant_op.constant([[0, 1]]),
+                                constant_op.constant([[0], [1]]))
+    c_val = tensor_util.constant_value(tf_val)
+    self.assertAllEqual(c_val, [[False, True], [True, False]])
+
 
 class ConstantValueAsShapeTest(test.TestCase):
 
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 8a24d10c593775ce6c588aee15604f237980ebe4..9cf222a63ab337b6f00dfc96ac873f9d56e8d160 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -53,6 +53,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
 from tensorflow.python.framework import versions
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.platform import googletest
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import server_lib
@@ -64,7 +65,7 @@ def gpu_device_name():
   """Returns the name of a GPU device if available or the empty string."""
   for x in device_lib.list_local_devices():
     if x.device_type == "GPU" or x.device_type == "SYCL":
-      return x.name
+      return compat.as_str(x.name)
   return ""
 
 
@@ -275,7 +276,8 @@ def enable_c_api(fn):
 
 
 def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
-                                 use_gpu=False, force_gpu=False):
+                                 use_gpu=False, force_gpu=False,
+                                 reset_test=True):
   """Runs the test in both graph and eager modes.
 
   Args:
@@ -285,6 +287,7 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
       session.
     use_gpu: If True, attempt to run as many ops as possible on GPU.
     force_gpu: If True, pin all ops to `/device:GPU:0`.
+    reset_test: If True, tearDown and SetUp the test case again.
 
   Returns:
     Returns a decorator that will run the decorated test function
@@ -295,11 +298,17 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
 
   def decorator(f):
     """Test method decorator."""
-    def decorated(self):
+    def decorated(self, **kwargs):
       """Decorated the test method."""
       with context.graph_mode():
         with self.test_session(graph, config, use_gpu, force_gpu):
-          f(self)
+          f(self, **kwargs)
+
+      if reset_test:
+        # This decorator runs the wrapped test twice.
+        # Reset the test environment between runs.
+        self.tearDown()
+        self.setUp()
 
       def run_eager_mode():
         if force_gpu:
@@ -310,17 +319,15 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
             f(self)
         elif use_gpu:
           # TODO(xpan): Support softplacement and gpu by default when available.
-          f(self)
+          f(self, **kwargs)
         else:
           with context.device("/device:CPU:0"):
-            f(self)
+            f(self, **kwargs)
 
+      eager_graph = graph or ops.Graph()
       with context.eager_mode():
-        if graph is None:
+        with eager_graph.as_default():
           run_eager_mode()
-        else:
-          with graph.as_default():
-            run_eager_mode()
 
     return decorated
   return decorator
@@ -385,6 +392,7 @@ class TensorFlowTestCase(googletest.TestCase):
     self._cached_session = None
 
   def setUp(self):
+    logging.info("SET UP: %s" % str(self))
     self._ClearCachedSession()
     random.seed(random_seed.DEFAULT_GRAPH_SEED)
     np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
@@ -399,6 +407,7 @@ class TensorFlowTestCase(googletest.TestCase):
     ops.get_default_graph().seed = random_seed.DEFAULT_GRAPH_SEED
 
   def tearDown(self):
+    logging.info("TEAR DOWN: %s" % str(self))
     for thread in self._threads:
       self.assertFalse(thread.is_alive(), "A checkedThread did not terminate")
 
@@ -490,6 +499,9 @@ class TensorFlowTestCase(googletest.TestCase):
   def _eval_helper(self, tensors):
     if isinstance(tensors, ops.EagerTensor):
       return tensors.numpy()
+    if isinstance(tensors, resource_variable_ops.ResourceVariable):
+      return tensors.read_value().numpy()
+
     if isinstance(tensors, tuple):
       return tuple([self._eval_helper(t) for t in tensors])
     elif isinstance(tensors, list):
@@ -586,6 +598,8 @@ class TensorFlowTestCase(googletest.TestCase):
       config.graph_options.optimizer_options.opt_level = -1
       config.graph_options.rewrite_options.constant_folding = (
           rewriter_config_pb2.RewriterConfig.OFF)
+      config.graph_options.rewrite_options.arithmetic_optimization = (
+          rewriter_config_pb2.RewriterConfig.OFF)
       return config
 
     if graph is None:
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 5dbaf76edb660f67a78b78aa6de11d74770f2665..bda9502cd115eaca0df06e80ee716726d91f13c4 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -22,8 +22,10 @@ from tensorflow.core.protobuf import config_pb2
 from tensorflow.core.protobuf import rewriter_config_pb2
 from tensorflow.python.client import session
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import random_seed
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
 from tensorflow.python.ops import nn
 from tensorflow.python.ops import random_ops
 from tensorflow.python.platform import test
@@ -51,9 +53,7 @@ def max_pool_2x2(x):
 
 
 # Taken from tensorflow/examples/tutorials/mnist/mnist_deep.py
-def two_layer_model():
-  random_seed.set_random_seed(0)
-  x = random_ops.truncated_normal([1, 784], seed=0)
+def two_layer_model(x):
   x_image = array_ops.reshape(x, [-1, 28, 28, 1])
   w_conv1 = weight([5, 5, 1, 32])
   b_conv1 = bias([32])
@@ -66,24 +66,39 @@ def two_layer_model():
   return h_pool2
 
 
+def loop():
+  random_seed.set_random_seed(0)
+  x1 = random_ops.truncated_normal([1, 784], seed=0)
+  x2 = random_ops.truncated_normal([1, 784], seed=0)
+  x3 = random_ops.truncated_normal([1, 784], seed=0)
+  x4 = random_ops.truncated_normal([1, 784], seed=0)
+  elems = (x1, x2, x3, x4)
+  outputs = functional_ops.map_fn(two_layer_model, elems, dtype=dtypes.float32)
+  return outputs
+
+
+def get_config():
+  rewrite_options = rewriter_config_pb2.RewriterConfig(
+      optimize_tensor_layout=True)
+  graph_options = config_pb2.GraphOptions(
+      rewrite_options=rewrite_options, build_cost_model=1)
+  config = config_pb2.ConfigProto(graph_options=graph_options)
+  return config
+
+
 class LayoutOptimizerTest(test.TestCase):
   """Tests the Grappler layout optimizer."""
 
   def testTwoConvLayers(self):
     if test.is_gpu_available(cuda_only=True):
-      output = two_layer_model()
+      random_seed.set_random_seed(0)
+      x = random_ops.truncated_normal([1, 784], seed=0)
+      output = two_layer_model(x)
 
       with session.Session() as sess:
         output_val_ref = sess.run(output)
 
-      rewrite_options = rewriter_config_pb2.RewriterConfig(
-          optimize_tensor_layout=True)
-      graph_options = config_pb2.GraphOptions(
-          rewrite_options=rewrite_options,
-          build_cost_model=1)
-      config = config_pb2.ConfigProto(graph_options=graph_options)
-
-      with session.Session(config=config) as sess:
+      with session.Session(config=get_config()) as sess:
         metadata = config_pb2.RunMetadata()
         output_val = sess.run(output, run_metadata=metadata)
 
@@ -105,6 +120,19 @@ class LayoutOptimizerTest(test.TestCase):
 
       self.assertAllClose(output_val_ref, output_val, atol=1e-3)
 
+  def testLoop(self):
+    if test.is_gpu_available(cuda_only=True):
+      output = loop()
+
+      with session.Session() as sess:
+        output_val_ref = sess.run(output)
+
+      with session.Session(config=get_config()) as sess:
+        metadata = config_pb2.RunMetadata()
+        output_val = sess.run(output, run_metadata=metadata)
+
+      self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py
index 78f819e8c4f5fda522ed50402ba655901c53e705..46f03952fa912786f207bdd0cfa21c2b0f5d63f2 100644
--- a/tensorflow/python/grappler/memory_optimizer_test.py
+++ b/tensorflow/python/grappler/memory_optimizer_test.py
@@ -50,6 +50,7 @@ class MemoryOptimizerSwapTest(test.TestCase):
 
     rewriter_config = rewriter_config_pb2.RewriterConfig(
         disable_model_pruning=True,
+        constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
         memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
     graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
 
@@ -72,6 +73,7 @@ class MemoryOptimizerSwapTest(test.TestCase):
 
     rewriter_config = rewriter_config_pb2.RewriterConfig(
         disable_model_pruning=True,
+        constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
         memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
     graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
 
diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py
index 522d1baf2926f9ce5e2b3795f50a188473b84244..1a348ede9bc4631a930d0ad14c089bc0df44e713 100644
--- a/tensorflow/python/grappler/tf_optimizer_test.py
+++ b/tensorflow/python/grappler/tf_optimizer_test.py
@@ -36,6 +36,7 @@ class PyWrapOptimizeGraphTest(test.TestCase):
     c = math_ops.add_n([a, b], name='c')
     d = math_ops.add_n([b, c], name='d')
     train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+    # Being a train_op will make 'd' to be added as a fetch node.
     train_op.append(d)
     mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
 
@@ -44,9 +45,8 @@ class PyWrapOptimizeGraphTest(test.TestCase):
 
     graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
 
-    self.assertEqual(len(graph.node), 4)
-    self.assertItemsEqual([node.name
-                           for node in graph.node], ['a', 'b', 'c', 'd'])
+    self.assertEqual(len(graph.node), 3)
+    self.assertItemsEqual([node.name for node in graph.node], ['b', 'c', 'd'])
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..a7daab8335a9af1adbab047cf6b88a407a0e93f4
--- /dev/null
+++ b/tensorflow/python/keras/BUILD
@@ -0,0 +1,694 @@
+# Description:
+#   Contains the Keras API (internal TensorFlow version).
+
+licenses(["notice"])  # Apache 2.0
+
+package(default_visibility = ["//visibility:public"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+    name = "keras",
+    srcs = [
+        "__init__.py",
+        "_impl/keras/__init__.py",
+        "_impl/keras/activations.py",
+        "_impl/keras/applications/__init__.py",
+        "_impl/keras/applications/imagenet_utils.py",
+        "_impl/keras/applications/inception_v3.py",
+        "_impl/keras/applications/mobilenet.py",
+        "_impl/keras/applications/resnet50.py",
+        "_impl/keras/applications/vgg16.py",
+        "_impl/keras/applications/vgg19.py",
+        "_impl/keras/applications/xception.py",
+        "_impl/keras/backend.py",
+        "_impl/keras/callbacks.py",
+        "_impl/keras/constraints.py",
+        "_impl/keras/datasets/__init__.py",
+        "_impl/keras/datasets/boston_housing.py",
+        "_impl/keras/datasets/cifar.py",
+        "_impl/keras/datasets/cifar10.py",
+        "_impl/keras/datasets/cifar100.py",
+        "_impl/keras/datasets/imdb.py",
+        "_impl/keras/datasets/mnist.py",
+        "_impl/keras/datasets/reuters.py",
+        "_impl/keras/engine/__init__.py",
+        "_impl/keras/engine/topology.py",
+        "_impl/keras/engine/training.py",
+        "_impl/keras/initializers.py",
+        "_impl/keras/layers/__init__.py",
+        "_impl/keras/layers/advanced_activations.py",
+        "_impl/keras/layers/convolutional.py",
+        "_impl/keras/layers/convolutional_recurrent.py",
+        "_impl/keras/layers/core.py",
+        "_impl/keras/layers/embeddings.py",
+        "_impl/keras/layers/local.py",
+        "_impl/keras/layers/merge.py",
+        "_impl/keras/layers/noise.py",
+        "_impl/keras/layers/normalization.py",
+        "_impl/keras/layers/pooling.py",
+        "_impl/keras/layers/recurrent.py",
+        "_impl/keras/layers/serialization.py",
+        "_impl/keras/layers/wrappers.py",
+        "_impl/keras/losses.py",
+        "_impl/keras/metrics.py",
+        "_impl/keras/models.py",
+        "_impl/keras/optimizers.py",
+        "_impl/keras/preprocessing/__init__.py",
+        "_impl/keras/preprocessing/image.py",
+        "_impl/keras/preprocessing/sequence.py",
+        "_impl/keras/preprocessing/text.py",
+        "_impl/keras/regularizers.py",
+        "_impl/keras/testing_utils.py",
+        "_impl/keras/utils/__init__.py",
+        "_impl/keras/utils/conv_utils.py",
+        "_impl/keras/utils/data_utils.py",
+        "_impl/keras/utils/generic_utils.py",
+        "_impl/keras/utils/io_utils.py",
+        "_impl/keras/utils/layer_utils.py",
+        "_impl/keras/utils/np_utils.py",
+        "_impl/keras/utils/vis_utils.py",
+        "_impl/keras/wrappers/__init__.py",
+        "_impl/keras/wrappers/scikit_learn.py",
+        "activations/__init__.py",
+        "applications/__init__.py",
+        "applications/inception_v3/__init__.py",
+        "applications/mobilenet/__init__.py",
+        "applications/resnet50/__init__.py",
+        "applications/vgg16/__init__.py",
+        "applications/vgg19/__init__.py",
+        "applications/xception/__init__.py",
+        "backend/__init__.py",
+        "callbacks/__init__.py",
+        "constraints/__init__.py",
+        "datasets/__init__.py",
+        "datasets/boston_housing/__init__.py",
+        "datasets/cifar10/__init__.py",
+        "datasets/cifar100/__init__.py",
+        "datasets/imdb/__init__.py",
+        "datasets/mnist/__init__.py",
+        "datasets/reuters/__init__.py",
+        "initializers/__init__.py",
+        "layers/__init__.py",
+        "losses/__init__.py",
+        "metrics/__init__.py",
+        "models/__init__.py",
+        "optimizers/__init__.py",
+        "preprocessing/__init__.py",
+        "preprocessing/image/__init__.py",
+        "preprocessing/sequence/__init__.py",
+        "preprocessing/text/__init__.py",
+        "regularizers/__init__.py",
+        "utils/__init__.py",
+        "wrappers/__init__.py",
+        "wrappers/scikit_learn/__init__.py",
+    ],
+    srcs_version = "PY2AND3",
+    visibility = ["//visibility:public"],
+    deps = [
+        "//tensorflow/core:protos_all_py",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:check_ops",
+        "//tensorflow/python:client",
+        "//tensorflow/python:clip_ops",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:ctc_ops",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:functional_ops",
+        "//tensorflow/python:gradients",
+        "//tensorflow/python:image_ops",
+        "//tensorflow/python:init_ops",
+        "//tensorflow/python:layers",
+        "//tensorflow/python:layers_base",
+        "//tensorflow/python:logging_ops",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:nn",
+        "//tensorflow/python:platform",
+        "//tensorflow/python:random_ops",
+        "//tensorflow/python:sparse_ops",
+        "//tensorflow/python:sparse_tensor",
+        "//tensorflow/python:state_ops",
+        "//tensorflow/python:summary",
+        "//tensorflow/python:tensor_array_grad",
+        "//tensorflow/python:tensor_array_ops",
+        "//tensorflow/python:tensor_shape",
+        "//tensorflow/python:training",
+        "//tensorflow/python:util",
+        "//tensorflow/python:variable_scope",
+        "//tensorflow/python:variables",
+        "@six_archive//:six",
+    ],
+)
+
+py_test(
+    name = "integration_test",
+    size = "medium",
+    srcs = ["_impl/keras/integration_test.py"],
+    srcs_version = "PY2AND3",
+    tags = ["notsan"],
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:layers",
+        "//tensorflow/python:nn",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "activations_test",
+    size = "small",
+    srcs = ["_impl/keras/activations_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "constraints_test",
+    size = "small",
+    srcs = ["_impl/keras/constraints_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "initializers_test",
+    size = "small",
+    srcs = ["_impl/keras/initializers_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:init_ops",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "regularizers_test",
+    size = "small",
+    srcs = ["_impl/keras/regularizers_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_test(
+    name = "optimizers_test",
+    size = "medium",
+    srcs = ["_impl/keras/optimizers_test.py"],
+    srcs_version = "PY2AND3",
+    tags = ["notsan"],
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:training",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "losses_test",
+    size = "small",
+    srcs = ["_impl/keras/losses_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "metrics_test",
+    size = "small",
+    srcs = ["_impl/keras/metrics_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "inception_v3_test",
+    size = "medium",
+    srcs = ["_impl/keras/applications/inception_v3_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "mobilenet_test",
+    size = "medium",
+    srcs = ["_impl/keras/applications/mobilenet_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "resnet50_test",
+    size = "small",
+    srcs = ["_impl/keras/applications/resnet50_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_test(
+    name = "vgg16_test",
+    size = "small",
+    srcs = ["_impl/keras/applications/vgg16_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_test(
+    name = "vgg19_test",
+    size = "small",
+    srcs = ["_impl/keras/applications/vgg19_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_test(
+    name = "xception_test",
+    size = "medium",
+    srcs = ["_impl/keras/applications/xception_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "advanced_activations_test",
+    size = "small",
+    srcs = ["_impl/keras/layers/advanced_activations_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_test(
+    name = "convolutional_recurrent_test",
+    size = "medium",
+    srcs = ["_impl/keras/layers/convolutional_recurrent_test.py"],
+    shard_count = 2,
+    srcs_version = "PY2AND3",
+    tags = ["noasan"],  # times out b/63678675
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "convolutional_test",
+    size = "medium",
+    srcs = ["_impl/keras/layers/convolutional_test.py"],
+    srcs_version = "PY2AND3",
+    tags = [
+        "manual",
+        "noasan",  # times out b/63678675
+        "notsan",
+    ],
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "pooling_test",
+    size = "small",
+    srcs = ["_impl/keras/layers/pooling_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_test(
+    name = "core_test",
+    size = "small",
+    srcs = ["_impl/keras/layers/core_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "embeddings_test",
+    size = "small",
+    srcs = ["_impl/keras/layers/embeddings_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_test(
+    name = "local_test",
+    size = "medium",
+    srcs = ["_impl/keras/layers/local_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "merge_test",
+    size = "small",
+    srcs = ["_impl/keras/layers/merge_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "noise_test",
+    size = "small",
+    srcs = ["_impl/keras/layers/noise_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_test(
+    name = "normalization_test",
+    size = "small",
+    srcs = ["_impl/keras/layers/normalization_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "simplernn_test",
+    size = "medium",
+    srcs = ["_impl/keras/layers/simplernn_test.py"],
+    srcs_version = "PY2AND3",
+    tags = ["notsan"],
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "gru_test",
+    size = "medium",
+    srcs = ["_impl/keras/layers/gru_test.py"],
+    srcs_version = "PY2AND3",
+    tags = ["notsan"],  # http://b/62136390
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "lstm_test",
+    size = "medium",
+    srcs = ["_impl/keras/layers/lstm_test.py"],
+    srcs_version = "PY2AND3",
+    tags = [
+        "noasan",  # times out b/63678675
+        "notsan",  # http://b/62189182
+    ],
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "serialization_test",
+    size = "small",
+    srcs = ["_impl/keras/layers/serialization_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_test(
+    name = "wrappers_test",
+    size = "small",
+    srcs = ["_impl/keras/layers/wrappers_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "scikit_learn_test",
+    size = "small",
+    srcs = ["_impl/keras/wrappers/scikit_learn_test.py"],
+    srcs_version = "PY2AND3",
+    tags = ["notsan"],
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "data_utils_test",
+    size = "small",
+    srcs = ["_impl/keras/utils/data_utils_test.py"],
+    srcs_version = "PY2AND3",
+    tags = [
+        "noasan",  # times out
+        "notsan",
+    ],
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "generic_utils_test",
+    size = "small",
+    srcs = ["_impl/keras/utils/generic_utils_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_test(
+    name = "io_utils_test",
+    size = "small",
+    srcs = ["_impl/keras/utils/io_utils_test.py"],
+    srcs_version = "PY2AND3",
+    tags = ["notsan"],
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "imagenet_utils_test",
+    size = "small",
+    srcs = ["_impl/keras/applications/imagenet_utils_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "image_test",
+    size = "medium",
+    srcs = ["_impl/keras/preprocessing/image_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "sequence_test",
+    size = "small",
+    srcs = ["_impl/keras/preprocessing/sequence_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "text_test",
+    size = "small",
+    srcs = ["_impl/keras/preprocessing/text_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "callbacks_test",
+    size = "medium",
+    srcs = ["_impl/keras/callbacks_test.py"],
+    srcs_version = "PY2AND3",
+    tags = ["notsan"],
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "training_test",
+    size = "medium",
+    srcs = ["_impl/keras/engine/training_test.py"],
+    srcs_version = "PY2AND3",
+    tags = ["notsan"],
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "topology_test",
+    size = "small",
+    srcs = ["_impl/keras/engine/topology_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:dtypes",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "models_test",
+    size = "small",
+    srcs = ["_impl/keras/models_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:training",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "backend_test",
+    size = "small",
+    srcs = ["_impl/keras/backend_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:util",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_library(
+    name = "testing_utils",
+    srcs = [
+        "_impl/keras/testing_utils.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:util",
+        "//third_party/py/numpy",
+    ],
+)
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/python/keras/README.md b/tensorflow/python/keras/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..db2556fe422c179737178f622a53d69d57282b8e
--- /dev/null
+++ b/tensorflow/python/keras/README.md
@@ -0,0 +1,6 @@
+Keras is an object-oriented API for defining and training neural networks.
+
+This module contains a pure-TensorFlow implementation of the Keras API,
+allowing for deep integration with TensorFlow functionality.
+
+See [keras.io](https://keras.io) for complete documentation and user guides.
diff --git a/tensorflow/python/keras/__init__.py b/tensorflow/python/keras/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..962c7678dd285c009975c4df9d858be043a4bd2f
--- /dev/null
+++ b/tensorflow/python/keras/__init__.py
@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implementation of the Keras API meant to be a high-level API for TensorFlow.
+
+Detailed documentation and user guides are available at
+[keras.io](https://keras.io).
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=wildcard-import
+from tensorflow.python.keras import activations
+from tensorflow.python.keras import applications
+from tensorflow.python.keras import backend
+from tensorflow.python.keras import callbacks
+from tensorflow.python.keras import constraints
+from tensorflow.python.keras import datasets
+from tensorflow.python.keras import initializers
+from tensorflow.python.keras import layers
+from tensorflow.python.keras import losses
+from tensorflow.python.keras import metrics
+from tensorflow.python.keras import models
+from tensorflow.python.keras import optimizers
+from tensorflow.python.keras import preprocessing
+from tensorflow.python.keras import regularizers
+from tensorflow.python.keras import utils
+from tensorflow.python.keras import wrappers
+from tensorflow.python.keras._impl.keras import __version__
+from tensorflow.python.keras.layers import Input
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/_impl/keras/__init__.py b/tensorflow/python/keras/_impl/keras/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1aa4415a19c9358c721c90c83b7eb07e604a243
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/__init__.py
@@ -0,0 +1,40 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Keras API.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras import activations
+from tensorflow.python.keras._impl.keras import applications
+from tensorflow.python.keras._impl.keras import backend
+from tensorflow.python.keras._impl.keras import callbacks
+from tensorflow.python.keras._impl.keras import constraints
+from tensorflow.python.keras._impl.keras import datasets
+from tensorflow.python.keras._impl.keras import engine
+from tensorflow.python.keras._impl.keras import initializers
+from tensorflow.python.keras._impl.keras import layers
+from tensorflow.python.keras._impl.keras import losses
+from tensorflow.python.keras._impl.keras import metrics
+from tensorflow.python.keras._impl.keras import models
+from tensorflow.python.keras._impl.keras import optimizers
+from tensorflow.python.keras._impl.keras import preprocessing
+from tensorflow.python.keras._impl.keras import regularizers
+from tensorflow.python.keras._impl.keras import utils
+from tensorflow.python.keras._impl.keras import wrappers
+from tensorflow.python.keras._impl.keras.layers import Input
+
+__version__ = '2.0.8-tf'
diff --git a/tensorflow/contrib/keras/python/keras/activations.py b/tensorflow/python/keras/_impl/keras/activations.py
similarity index 93%
rename from tensorflow/contrib/keras/python/keras/activations.py
rename to tensorflow/python/keras/_impl/keras/activations.py
index 7f04234e018676ac036d5f56bd712bc5a21ef6d5..4e35b79869f5ec1005bf5dfd8cac985942a18837 100644
--- a/tensorflow/contrib/keras/python/keras/activations.py
+++ b/tensorflow/python/keras/_impl/keras/activations.py
@@ -20,9 +20,9 @@ from __future__ import print_function
 
 import six
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.engine import Layer
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
 from tensorflow.python.platform import tf_logging as logging
 
 
diff --git a/tensorflow/contrib/keras/python/keras/activations_test.py b/tensorflow/python/keras/_impl/keras/activations_test.py
similarity index 99%
rename from tensorflow/contrib/keras/python/keras/activations_test.py
rename to tensorflow/python/keras/_impl/keras/activations_test.py
index 8efa464b03647b22452904bacdd47402e919f1f3..fb0bb5f1269d112e3f268ce211a2ddeb24b417bf 100644
--- a/tensorflow/contrib/keras/python/keras/activations_test.py
+++ b/tensorflow/python/keras/_impl/keras/activations_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/applications/__init__.py b/tensorflow/python/keras/_impl/keras/applications/__init__.py
similarity index 64%
rename from tensorflow/contrib/keras/python/keras/applications/__init__.py
rename to tensorflow/python/keras/_impl/keras/applications/__init__.py
index 9139df30a6e8db86cef752f7739f8bd047dc16a7..f78bbdc148145591cc16e3231bd9d2b7c06d208b 100644
--- a/tensorflow/contrib/keras/python/keras/applications/__init__.py
+++ b/tensorflow/python/keras/_impl/keras/applications/__init__.py
@@ -18,9 +18,9 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras.applications.inception_v3 import InceptionV3
-from tensorflow.contrib.keras.python.keras.applications.mobilenet import MobileNet
-from tensorflow.contrib.keras.python.keras.applications.resnet50 import ResNet50
-from tensorflow.contrib.keras.python.keras.applications.vgg16 import VGG16
-from tensorflow.contrib.keras.python.keras.applications.vgg19 import VGG19
-from tensorflow.contrib.keras.python.keras.applications.xception import Xception
+from tensorflow.python.keras._impl.keras.applications.inception_v3 import InceptionV3
+from tensorflow.python.keras._impl.keras.applications.mobilenet import MobileNet
+from tensorflow.python.keras._impl.keras.applications.resnet50 import ResNet50
+from tensorflow.python.keras._impl.keras.applications.vgg16 import VGG16
+from tensorflow.python.keras._impl.keras.applications.vgg19 import VGG19
+from tensorflow.python.keras._impl.keras.applications.xception import Xception
diff --git a/tensorflow/contrib/keras/python/keras/applications/imagenet_utils.py b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py
similarity index 63%
rename from tensorflow/contrib/keras/python/keras/applications/imagenet_utils.py
rename to tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py
index a64021ae499f70eb521c557a780c478d8d07c43c..43628341cb522090caf6eb996a5f0c9b44488424 100644
--- a/tensorflow/contrib/keras/python/keras/applications/imagenet_utils.py
+++ b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py
@@ -20,8 +20,9 @@ from __future__ import print_function
 
 import json
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.platform import tf_logging as logging
 
 
 CLASS_INDEX = None
@@ -43,19 +44,25 @@ def preprocess_input(x, data_format=None):
   assert data_format in {'channels_last', 'channels_first'}
 
   if data_format == 'channels_first':
-    # 'RGB'->'BGR'
-    x = x[:, ::-1, :, :]
-    # Zero-center by mean pixel
-    x[:, 0, :, :] -= 103.939
-    x[:, 1, :, :] -= 116.779
-    x[:, 2, :, :] -= 123.68
+    if x.ndim == 3:
+      # 'RGB'->'BGR'
+      x = x[::-1, ...]
+      # Zero-center by mean pixel
+      x[0, :, :] -= 103.939
+      x[1, :, :] -= 116.779
+      x[2, :, :] -= 123.68
+    else:
+      x = x[:, ::-1, ...]
+      x[:, 0, :, :] -= 103.939
+      x[:, 1, :, :] -= 116.779
+      x[:, 2, :, :] -= 123.68
   else:
     # 'RGB'->'BGR'
-    x = x[:, :, :, ::-1]
+    x = x[..., ::-1]
     # Zero-center by mean pixel
-    x[:, :, :, 0] -= 103.939
-    x[:, :, :, 1] -= 116.779
-    x[:, :, :, 2] -= 123.68
+    x[..., 0] -= 103.939
+    x[..., 1] -= 116.779
+    x[..., 2] -= 123.68
   return x
 
 
@@ -94,8 +101,12 @@ def decode_predictions(preds, top=5):
   return results
 
 
-def _obtain_input_shape(input_shape, default_size, min_size, data_format,
-                        include_top):
+def _obtain_input_shape(input_shape,
+                        default_size,
+                        min_size,
+                        data_format,
+                        require_flatten,
+                        weights=None):
   """Internal utility to compute/validate an ImageNet model's input shape.
 
   Arguments:
@@ -104,8 +115,11 @@ def _obtain_input_shape(input_shape, default_size, min_size, data_format,
       default_size: default input width/height for the model.
       min_size: minimum input width/height accepted by the model.
       data_format: image data format to use.
-      include_top: whether the model is expected to
+      require_flatten: whether the model is expected to
           be linked to a classifier via a Flatten layer.
+      weights: one of `None` (random initialization)
+          or 'imagenet' (pre-training on ImageNet).
+          If weights='imagenet' input channels must be equal to 3.
 
   Returns:
       An integer shape tuple (may include None entries).
@@ -113,43 +127,67 @@ def _obtain_input_shape(input_shape, default_size, min_size, data_format,
   Raises:
       ValueError: in case of invalid argument values.
   """
-  if data_format == 'channels_first':
-    default_shape = (3, default_size, default_size)
+  if weights != 'imagenet' and input_shape and len(input_shape) == 3:
+    if data_format == 'channels_first':
+      if input_shape[0] not in {1, 3}:
+        logging.warning('This model usually expects 1 or 3 input channels. '
+                        'However, it was passed an input_shape with ' +
+                        str(input_shape[0]) + ' input channels.')
+      default_shape = (input_shape[0], default_size, default_size)
+    else:
+      if input_shape[-1] not in {1, 3}:
+        logging.warning('This model usually expects 1 or 3 input channels. '
+                        'However, it was passed an input_shape with ' +
+                        str(input_shape[-1]) + ' input channels.')
+      default_shape = (default_size, default_size, input_shape[-1])
   else:
-    default_shape = (default_size, default_size, 3)
-  if include_top:
+    if data_format == 'channels_first':
+      default_shape = (3, default_size, default_size)
+    else:
+      default_shape = (default_size, default_size, 3)
+  if weights == 'imagenet' and require_flatten:
     if input_shape is not None:
       if input_shape != default_shape:
-        raise ValueError('When setting`include_top=True`, '
+        raise ValueError('When setting`include_top=True` '
+                         'and loading `imagenet` weights, '
                          '`input_shape` should be ' + str(default_shape) + '.')
-    input_shape = default_shape
-  else:
+    return default_shape
+  if input_shape:
     if data_format == 'channels_first':
       if input_shape is not None:
         if len(input_shape) != 3:
           raise ValueError('`input_shape` must be a tuple of three integers.')
-        if input_shape[0] != 3:
+        if input_shape[0] != 3 and weights == 'imagenet':
           raise ValueError('The input must have 3 channels; got '
                            '`input_shape=' + str(input_shape) + '`')
         if ((input_shape[1] is not None and input_shape[1] < min_size) or
             (input_shape[2] is not None and input_shape[2] < min_size)):
           raise ValueError('Input size must be at least ' + str(min_size) + 'x'
-                           + str(min_size) + ', got '
+                           + str(min_size) + '; got '
                            '`input_shape=' + str(input_shape) + '`')
-      else:
-        input_shape = (3, None, None)
     else:
       if input_shape is not None:
         if len(input_shape) != 3:
           raise ValueError('`input_shape` must be a tuple of three integers.')
-        if input_shape[-1] != 3:
+        if input_shape[-1] != 3 and weights == 'imagenet':
           raise ValueError('The input must have 3 channels; got '
                            '`input_shape=' + str(input_shape) + '`')
         if ((input_shape[0] is not None and input_shape[0] < min_size) or
             (input_shape[1] is not None and input_shape[1] < min_size)):
           raise ValueError('Input size must be at least ' + str(min_size) + 'x'
-                           + str(min_size) + ', got '
+                           + str(min_size) + '; got '
                            '`input_shape=' + str(input_shape) + '`')
+  else:
+    if require_flatten:
+      input_shape = default_shape
+    else:
+      if data_format == 'channels_first':
+        input_shape = (3, None, None)
       else:
         input_shape = (None, None, 3)
+  if require_flatten:
+    if None in input_shape:
+      raise ValueError('If `include_top` is True, '
+                       'you should specify a static `input_shape`. '
+                       'Got `input_shape=' + str(input_shape) + '`')
   return input_shape
diff --git a/tensorflow/contrib/keras/python/keras/applications/imagenet_utils_test.py b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils_test.py
similarity index 72%
rename from tensorflow/contrib/keras/python/keras/applications/imagenet_utils_test.py
rename to tensorflow/python/keras/_impl/keras/applications/imagenet_utils_test.py
index 378c06d30d894f963a1f14b15f0f6b880bc58577..517ba91219fc0ec0b61ccd673b420021a0db483d 100644
--- a/tensorflow/contrib/keras/python/keras/applications/imagenet_utils_test.py
+++ b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils_test.py
@@ -20,23 +20,33 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
 class ImageNetUtilsTest(test.TestCase):
 
   def test_preprocess_input(self):
-    x = np.random.uniform(0, 255, (2, 3, 2, 3))
+    # Test batch of images
+    x = np.random.uniform(0, 255, (2, 10, 10, 3))
     self.assertEqual(
         keras.applications.imagenet_utils.preprocess_input(x).shape, x.shape)
-
     out1 = keras.applications.imagenet_utils.preprocess_input(
         x, 'channels_last')
     out2 = keras.applications.imagenet_utils.preprocess_input(
         np.transpose(x, (0, 3, 1, 2)), 'channels_first')
     self.assertAllClose(out1, out2.transpose(0, 2, 3, 1))
 
+    # Test single image
+    x = np.random.uniform(0, 255, (10, 10, 3))
+    self.assertEqual(
+        keras.applications.imagenet_utils.preprocess_input(x).shape, x.shape)
+    out1 = keras.applications.imagenet_utils.preprocess_input(
+        x, 'channels_last')
+    out2 = keras.applications.imagenet_utils.preprocess_input(
+        np.transpose(x, (2, 0, 1)), 'channels_first')
+    self.assertAllClose(out1, out2.transpose(1, 2, 0))
+
   def test_obtain_input_shape(self):
     # input_shape and default_size are not identical.
     with self.assertRaises(ValueError):
@@ -45,7 +55,8 @@ class ImageNetUtilsTest(test.TestCase):
           default_size=299,
           min_size=139,
           data_format='channels_last',
-          include_top=True)
+          require_flatten=True,
+          weights='imagenet')
 
     # Test invalid use cases
     for data_format in ['channels_last', 'channels_first']:
@@ -61,7 +72,7 @@ class ImageNetUtilsTest(test.TestCase):
             default_size=None,
             min_size=139,
             data_format=data_format,
-            include_top=False)
+            require_flatten=False)
 
       # shape is 1D.
       shape = (100,)
@@ -75,7 +86,7 @@ class ImageNetUtilsTest(test.TestCase):
             default_size=None,
             min_size=139,
             data_format=data_format,
-            include_top=False)
+            require_flatten=False)
 
       # the number of channels is 5 not 3.
       shape = (100, 100)
@@ -89,43 +100,60 @@ class ImageNetUtilsTest(test.TestCase):
             default_size=None,
             min_size=139,
             data_format=data_format,
-            include_top=False)
+            require_flatten=False)
+
+      # require_flatten=True with dynamic input shape.
+      with self.assertRaises(ValueError):
+        keras.applications.imagenet_utils._obtain_input_shape(
+            input_shape=None,
+            default_size=None,
+            min_size=139,
+            data_format='channels_first',
+            require_flatten=True)
+
+    assert keras.applications.imagenet_utils._obtain_input_shape(
+        input_shape=(3, 200, 200),
+        default_size=None,
+        min_size=139,
+        data_format='channels_first',
+        require_flatten=True) == (3, 200, 200)
 
     assert keras.applications.imagenet_utils._obtain_input_shape(
         input_shape=None,
         default_size=None,
         min_size=139,
         data_format='channels_last',
-        include_top=False) == (None, None, 3)
+        require_flatten=False) == (None, None, 3)
 
     assert keras.applications.imagenet_utils._obtain_input_shape(
         input_shape=None,
         default_size=None,
         min_size=139,
         data_format='channels_first',
-        include_top=False) == (3, None, None)
+        require_flatten=False) == (3, None, None)
 
     assert keras.applications.imagenet_utils._obtain_input_shape(
         input_shape=None,
         default_size=None,
         min_size=139,
         data_format='channels_last',
-        include_top=False) == (None, None, 3)
+        require_flatten=False) == (None, None, 3)
 
     assert keras.applications.imagenet_utils._obtain_input_shape(
         input_shape=(150, 150, 3),
         default_size=None,
         min_size=139,
         data_format='channels_last',
-        include_top=False) == (150, 150, 3)
+        require_flatten=False) == (150, 150, 3)
 
     assert keras.applications.imagenet_utils._obtain_input_shape(
         input_shape=(3, None, None),
         default_size=None,
         min_size=139,
         data_format='channels_first',
-        include_top=False) == (3, None, None)
+        require_flatten=False) == (3, None, None)
 
 
 if __name__ == '__main__':
   test.main()
+
diff --git a/tensorflow/contrib/keras/python/keras/applications/inception_v3.py b/tensorflow/python/keras/_impl/keras/applications/inception_v3.py
similarity index 91%
rename from tensorflow/contrib/keras/python/keras/applications/inception_v3.py
rename to tensorflow/python/keras/_impl/keras/applications/inception_v3.py
index f77e4a83416c3f6918527c98c6280566492514af..edb4c60f8a58553a355245558b30d815000b3e11 100644
--- a/tensorflow/contrib/keras/python/keras/applications/inception_v3.py
+++ b/tensorflow/python/keras/_impl/keras/applications/inception_v3.py
@@ -29,22 +29,22 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras import layers
-from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import decode_predictions  # pylint: disable=unused-import
-from tensorflow.contrib.keras.python.keras.engine.topology import get_source_inputs
-from tensorflow.contrib.keras.python.keras.layers import Activation
-from tensorflow.contrib.keras.python.keras.layers import AveragePooling2D
-from tensorflow.contrib.keras.python.keras.layers import BatchNormalization
-from tensorflow.contrib.keras.python.keras.layers import Conv2D
-from tensorflow.contrib.keras.python.keras.layers import Dense
-from tensorflow.contrib.keras.python.keras.layers import GlobalAveragePooling2D
-from tensorflow.contrib.keras.python.keras.layers import GlobalMaxPooling2D
-from tensorflow.contrib.keras.python.keras.layers import Input
-from tensorflow.contrib.keras.python.keras.layers import MaxPooling2D
-from tensorflow.contrib.keras.python.keras.models import Model
-from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import layers
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions  # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.layers import Activation
+from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import BatchNormalization
+from tensorflow.python.keras._impl.keras.layers import Conv2D
+from tensorflow.python.keras._impl.keras.layers import Dense
+from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import Input
+from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
+from tensorflow.python.keras._impl.keras.models import Model
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
 
 
 WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.5/inception_v3_weights_tf_dim_ordering_tf_kernels.h5'
@@ -125,7 +125,7 @@ def InceptionV3(include_top=True,
           if `include_top` is False (otherwise the input shape
           has to be `(299, 299, 3)` (with `channels_last` data format)
           or `(3, 299, 299)` (with `channels_first` data format).
-          It should have exactly 3 inputs channels,
+          It should have exactly 3 input channels,
           and width and height should be no smaller than 139.
           E.g. `(150, 150, 3)` would be one valid value.
       pooling: Optional pooling mode for feature extraction
@@ -165,7 +165,8 @@ def InceptionV3(include_top=True,
       default_size=299,
       min_size=139,
       data_format=K.image_data_format(),
-      include_top=include_top)
+      require_flatten=False,
+      weights=weights)
 
   if input_tensor is None:
     img_input = Input(shape=input_shape)
diff --git a/tensorflow/contrib/keras/python/keras/applications/inception_v3_test.py b/tensorflow/python/keras/_impl/keras/applications/inception_v3_test.py
similarity index 97%
rename from tensorflow/contrib/keras/python/keras/applications/inception_v3_test.py
rename to tensorflow/python/keras/_impl/keras/applications/inception_v3_test.py
index 890df612ff55e28f72144a67143d9aa13aab3af3..20e11fa019134423cc7c0499e7507680e13cb86d 100644
--- a/tensorflow/contrib/keras/python/keras/applications/inception_v3_test.py
+++ b/tensorflow/python/keras/_impl/keras/applications/inception_v3_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/applications/mobilenet.py b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
similarity index 93%
rename from tensorflow/contrib/keras/python/keras/applications/mobilenet.py
rename to tensorflow/python/keras/_impl/keras/applications/mobilenet.py
index 37240234d37677e9ad676be192b606c4d4fa3d8b..9375e436f2d2e04a7ca5d8539c9df0207b6f173b 100644
--- a/tensorflow/contrib/keras/python/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
@@ -69,25 +69,25 @@ from __future__ import print_function
 
 import warnings
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras import constraints
-from tensorflow.contrib.keras.python.keras import initializers
-from tensorflow.contrib.keras.python.keras import regularizers
-from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import decode_predictions  # pylint: disable=unused-import
-from tensorflow.contrib.keras.python.keras.engine import InputSpec
-from tensorflow.contrib.keras.python.keras.engine.topology import get_source_inputs
-from tensorflow.contrib.keras.python.keras.layers import Activation
-from tensorflow.contrib.keras.python.keras.layers import BatchNormalization
-from tensorflow.contrib.keras.python.keras.layers import Conv2D
-from tensorflow.contrib.keras.python.keras.layers import Dropout
-from tensorflow.contrib.keras.python.keras.layers import GlobalAveragePooling2D
-from tensorflow.contrib.keras.python.keras.layers import GlobalMaxPooling2D
-from tensorflow.contrib.keras.python.keras.layers import Input
-from tensorflow.contrib.keras.python.keras.layers import Reshape
-from tensorflow.contrib.keras.python.keras.models import Model
-from tensorflow.contrib.keras.python.keras.utils import conv_utils
-from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import constraints
+from tensorflow.python.keras._impl.keras import initializers
+from tensorflow.python.keras._impl.keras import regularizers
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions  # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.engine import InputSpec
+from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.layers import Activation
+from tensorflow.python.keras._impl.keras.layers import BatchNormalization
+from tensorflow.python.keras._impl.keras.layers import Conv2D
+from tensorflow.python.keras._impl.keras.layers import Dropout
+from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import Input
+from tensorflow.python.keras._impl.keras.layers import Reshape
+from tensorflow.python.keras._impl.keras.models import Model
+from tensorflow.python.keras._impl.keras.utils import conv_utils
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
 
 BASE_WEIGHT_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.6/'
 
@@ -327,7 +327,7 @@ def MobileNet(input_shape=None,  # pylint: disable=invalid-name
           if `include_top` is False (otherwise the input shape
           has to be `(224, 224, 3)` (with `channels_last` data format)
           or (3, 224, 224) (with `channels_first` data format).
-          It should have exactly 3 inputs channels,
+          It should have exactly 3 input channels,
           and width and height should be no smaller than 32.
           E.g. `(200, 200, 3)` would be one valid value.
       alpha: controls the width of the network.
@@ -388,12 +388,26 @@ def MobileNet(input_shape=None,  # pylint: disable=invalid-name
                      'as true, `classes` should be 1000')
 
   # Determine proper input shape.
+  if input_shape is None:
+    default_size = 224
+  else:
+    if K.image_data_format() == 'channels_first':
+      rows = input_shape[1]
+      cols = input_shape[2]
+    else:
+      rows = input_shape[0]
+      cols = input_shape[1]
+    if rows == cols and rows in [128, 160, 192, 224]:
+      default_size = rows
+    else:
+      default_size = 224
   input_shape = _obtain_input_shape(
       input_shape,
-      default_size=224,
+      default_size=default_size,
       min_size=32,
       data_format=K.image_data_format(),
-      include_top=include_top or weights)
+      require_flatten=include_top,
+      weights=weights)
   if K.image_data_format() == 'channels_last':
     row_axis, col_axis = (0, 1)
   else:
diff --git a/tensorflow/contrib/keras/python/keras/applications/mobilenet_test.py b/tensorflow/python/keras/_impl/keras/applications/mobilenet_test.py
similarity index 55%
rename from tensorflow/contrib/keras/python/keras/applications/mobilenet_test.py
rename to tensorflow/python/keras/_impl/keras/applications/mobilenet_test.py
index d67964c02bda01dc27d0c0d31f39477b71ebf637..601d417e496b8230a2ad846eab204763ff5564b8 100644
--- a/tensorflow/contrib/keras/python/keras/applications/mobilenet_test.py
+++ b/tensorflow/python/keras/_impl/keras/applications/mobilenet_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
@@ -59,5 +59,43 @@ class MobileNetTest(test.TestCase):
     self.assertEqual(model.output_shape, (None, 1000))
     keras.backend.set_image_data_format('channels_last')
 
+  def test_mobilenet_variable_input_channels(self):
+    input_shape = (None, None, 1)
+    model = keras.applications.MobileNet(weights=None,
+                                         include_top=False,
+                                         input_shape=input_shape)
+    self.assertEqual(model.output_shape, (None, None, None, 1024))
+
+    input_shape = (None, None, 4)
+    model = keras.applications.MobileNet(weights=None,
+                                         include_top=False,
+                                         input_shape=input_shape)
+    self.assertEqual(model.output_shape, (None, None, None, 1024))
+
+  def test_mobilenet_image_size(self):
+    with self.test_session():
+      valid_image_sizes = [128, 160, 192, 224]
+      for size in valid_image_sizes:
+        keras.backend.set_image_data_format('channels_last')
+        input_shape = (size, size, 3)
+        model = keras.applications.MobileNet(input_shape=input_shape,
+                                             weights=None,
+                                             include_top=True)
+        self.assertEqual(model.input_shape, (None,) + input_shape)
+
+        keras.backend.set_image_data_format('channels_first')
+        input_shape = (3, size, size)
+        model = keras.applications.MobileNet(input_shape=input_shape,
+                                             weights=None,
+                                             include_top=True)
+        self.assertEqual(model.input_shape, (None,) + input_shape)
+
+      keras.backend.set_image_data_format('channels_last')
+      invalid_image_shape = (112, 112, 3)
+      with self.assertRaises(ValueError):
+        model = keras.applications.MobileNet(input_shape=invalid_image_shape,
+                                             weights='imagenet',
+                                             include_top=True)
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/contrib/keras/python/keras/applications/resnet50.py b/tensorflow/python/keras/_impl/keras/applications/resnet50.py
similarity index 85%
rename from tensorflow/contrib/keras/python/keras/applications/resnet50.py
rename to tensorflow/python/keras/_impl/keras/applications/resnet50.py
index 0de13c9592e57f7a346173245ee412c9242fc2a3..f0cff2d686f321e4ec86a85efc8a844576fc7fcf 100644
--- a/tensorflow/contrib/keras/python/keras/applications/resnet50.py
+++ b/tensorflow/python/keras/_impl/keras/applications/resnet50.py
@@ -26,25 +26,24 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras import layers
-from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import decode_predictions  # pylint: disable=unused-import
-from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import preprocess_input  # pylint: disable=unused-import
-from tensorflow.contrib.keras.python.keras.engine.topology import get_source_inputs
-from tensorflow.contrib.keras.python.keras.layers import Activation
-from tensorflow.contrib.keras.python.keras.layers import AveragePooling2D
-from tensorflow.contrib.keras.python.keras.layers import BatchNormalization
-from tensorflow.contrib.keras.python.keras.layers import Conv2D
-from tensorflow.contrib.keras.python.keras.layers import Dense
-from tensorflow.contrib.keras.python.keras.layers import Flatten
-from tensorflow.contrib.keras.python.keras.layers import GlobalAveragePooling2D
-from tensorflow.contrib.keras.python.keras.layers import GlobalMaxPooling2D
-from tensorflow.contrib.keras.python.keras.layers import Input
-from tensorflow.contrib.keras.python.keras.layers import MaxPooling2D
-from tensorflow.contrib.keras.python.keras.layers import ZeroPadding2D
-from tensorflow.contrib.keras.python.keras.models import Model
-from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import layers
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions  # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input  # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.layers import Activation
+from tensorflow.python.keras._impl.keras.layers import AveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import BatchNormalization
+from tensorflow.python.keras._impl.keras.layers import Conv2D
+from tensorflow.python.keras._impl.keras.layers import Dense
+from tensorflow.python.keras._impl.keras.layers import Flatten
+from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import Input
+from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
+from tensorflow.python.keras._impl.keras.models import Model
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
 
 
 WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5'
@@ -170,7 +169,7 @@ def ResNet50(include_top=True,
           if `include_top` is False (otherwise the input shape
           has to be `(224, 224, 3)` (with `channels_last` data format)
           or `(3, 224, 224)` (with `channels_first` data format).
-          It should have exactly 3 inputs channels,
+          It should have exactly 3 input channels,
           and width and height should be no smaller than 197.
           E.g. `(200, 200, 3)` would be one valid value.
       pooling: Optional pooling mode for feature extraction
@@ -210,7 +209,8 @@ def ResNet50(include_top=True,
       default_size=224,
       min_size=197,
       data_format=K.image_data_format(),
-      include_top=include_top)
+      require_flatten=include_top,
+      weights=weights)
 
   if input_tensor is None:
     img_input = Input(shape=input_shape)
@@ -222,8 +222,8 @@ def ResNet50(include_top=True,
   else:
     bn_axis = 1
 
-  x = ZeroPadding2D((3, 3))(img_input)
-  x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)
+  x = Conv2D(64, (7, 7),
+             strides=(2, 2), padding='same', name='conv1')(img_input)
   x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
   x = Activation('relu')(x)
   x = MaxPooling2D((3, 3), strides=(2, 2))(x)
diff --git a/tensorflow/contrib/keras/python/keras/applications/resnet50_test.py b/tensorflow/python/keras/_impl/keras/applications/resnet50_test.py
similarity index 97%
rename from tensorflow/contrib/keras/python/keras/applications/resnet50_test.py
rename to tensorflow/python/keras/_impl/keras/applications/resnet50_test.py
index 2b00170652a2f30664844811efa82600c88f04a5..07f9ffd73f55ee39351af71223e7919b08ca66e1 100644
--- a/tensorflow/contrib/keras/python/keras/applications/resnet50_test.py
+++ b/tensorflow/python/keras/_impl/keras/applications/resnet50_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/applications/vgg16.py b/tensorflow/python/keras/_impl/keras/applications/vgg16.py
similarity index 85%
rename from tensorflow/contrib/keras/python/keras/applications/vgg16.py
rename to tensorflow/python/keras/_impl/keras/applications/vgg16.py
index 89bbb040e6a29001d3f2a5ad6404f7abc3e7c9dc..485b486e9d826795d3499b978d119609050bd7de 100644
--- a/tensorflow/contrib/keras/python/keras/applications/vgg16.py
+++ b/tensorflow/python/keras/_impl/keras/applications/vgg16.py
@@ -25,21 +25,21 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import decode_predictions  # pylint: disable=unused-import
-from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import preprocess_input  # pylint: disable=unused-import
-from tensorflow.contrib.keras.python.keras.engine.topology import get_source_inputs
-from tensorflow.contrib.keras.python.keras.layers import Conv2D
-from tensorflow.contrib.keras.python.keras.layers import Dense
-from tensorflow.contrib.keras.python.keras.layers import Flatten
-from tensorflow.contrib.keras.python.keras.layers import GlobalAveragePooling2D
-from tensorflow.contrib.keras.python.keras.layers import GlobalMaxPooling2D
-from tensorflow.contrib.keras.python.keras.layers import Input
-from tensorflow.contrib.keras.python.keras.layers import MaxPooling2D
-from tensorflow.contrib.keras.python.keras.models import Model
-from tensorflow.contrib.keras.python.keras.utils import layer_utils
-from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions  # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input  # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.layers import Conv2D
+from tensorflow.python.keras._impl.keras.layers import Dense
+from tensorflow.python.keras._impl.keras.layers import Flatten
+from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import Input
+from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
+from tensorflow.python.keras._impl.keras.models import Model
+from tensorflow.python.keras._impl.keras.utils import layer_utils
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
 
 
 WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5'
@@ -76,7 +76,7 @@ def VGG16(include_top=True,
           if `include_top` is False (otherwise the input shape
           has to be `(224, 224, 3)` (with `channels_last` data format)
           or `(3, 224, 224)` (with `channels_first` data format).
-          It should have exactly 3 inputs channels,
+          It should have exactly 3 input channels,
           and width and height should be no smaller than 48.
           E.g. `(200, 200, 3)` would be one valid value.
       pooling: Optional pooling mode for feature extraction
@@ -115,7 +115,8 @@ def VGG16(include_top=True,
       default_size=224,
       min_size=48,
       data_format=K.image_data_format(),
-      include_top=include_top)
+      require_flatten=include_top,
+      weights=weights)
 
   if input_tensor is None:
     img_input = Input(shape=input_shape)
diff --git a/tensorflow/contrib/keras/python/keras/applications/vgg16_test.py b/tensorflow/python/keras/_impl/keras/applications/vgg16_test.py
similarity index 97%
rename from tensorflow/contrib/keras/python/keras/applications/vgg16_test.py
rename to tensorflow/python/keras/_impl/keras/applications/vgg16_test.py
index 4ba5dabd5ab9fe03ccd98bb762d1a1690726106f..e6eba83678def582c1a9fb477399790dbded8a15 100644
--- a/tensorflow/contrib/keras/python/keras/applications/vgg16_test.py
+++ b/tensorflow/python/keras/_impl/keras/applications/vgg16_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/applications/vgg19.py b/tensorflow/python/keras/_impl/keras/applications/vgg19.py
similarity index 85%
rename from tensorflow/contrib/keras/python/keras/applications/vgg19.py
rename to tensorflow/python/keras/_impl/keras/applications/vgg19.py
index 522a516ecfc95c3ad49cf94f5d42eed466adac1b..3af6417c8444453a9e9c3eef70097f520757f264 100644
--- a/tensorflow/contrib/keras/python/keras/applications/vgg19.py
+++ b/tensorflow/python/keras/_impl/keras/applications/vgg19.py
@@ -25,21 +25,21 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import decode_predictions  # pylint: disable=unused-import
-from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import preprocess_input  # pylint: disable=unused-import
-from tensorflow.contrib.keras.python.keras.engine.topology import get_source_inputs
-from tensorflow.contrib.keras.python.keras.layers import Conv2D
-from tensorflow.contrib.keras.python.keras.layers import Dense
-from tensorflow.contrib.keras.python.keras.layers import Flatten
-from tensorflow.contrib.keras.python.keras.layers import GlobalAveragePooling2D
-from tensorflow.contrib.keras.python.keras.layers import GlobalMaxPooling2D
-from tensorflow.contrib.keras.python.keras.layers import Input
-from tensorflow.contrib.keras.python.keras.layers import MaxPooling2D
-from tensorflow.contrib.keras.python.keras.models import Model
-from tensorflow.contrib.keras.python.keras.utils import layer_utils
-from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions  # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input  # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.layers import Conv2D
+from tensorflow.python.keras._impl.keras.layers import Dense
+from tensorflow.python.keras._impl.keras.layers import Flatten
+from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import Input
+from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
+from tensorflow.python.keras._impl.keras.models import Model
+from tensorflow.python.keras._impl.keras.utils import layer_utils
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
 
 
 WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg19_weights_tf_dim_ordering_tf_kernels.h5'
@@ -76,7 +76,7 @@ def VGG19(include_top=True,
           if `include_top` is False (otherwise the input shape
           has to be `(224, 224, 3)` (with `channels_last` data format)
           or `(3, 224, 224)` (with `channels_first` data format).
-          It should have exactly 3 inputs channels,
+          It should have exactly 3 input channels,
           and width and height should be no smaller than 48.
           E.g. `(200, 200, 3)` would be one valid value.
       pooling: Optional pooling mode for feature extraction
@@ -115,7 +115,8 @@ def VGG19(include_top=True,
       default_size=224,
       min_size=48,
       data_format=K.image_data_format(),
-      include_top=include_top)
+      require_flatten=include_top,
+      weights=weights)
 
   if input_tensor is None:
     img_input = Input(shape=input_shape)
diff --git a/tensorflow/contrib/keras/python/keras/applications/vgg19_test.py b/tensorflow/python/keras/_impl/keras/applications/vgg19_test.py
similarity index 97%
rename from tensorflow/contrib/keras/python/keras/applications/vgg19_test.py
rename to tensorflow/python/keras/_impl/keras/applications/vgg19_test.py
index 604d4bb2d8b181704caf811a654f31c51dd81d28..25100a2993f8a650b9ec441bf0c2c528f13364a4 100644
--- a/tensorflow/contrib/keras/python/keras/applications/vgg19_test.py
+++ b/tensorflow/python/keras/_impl/keras/applications/vgg19_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/applications/xception.py b/tensorflow/python/keras/_impl/keras/applications/xception.py
similarity index 89%
rename from tensorflow/contrib/keras/python/keras/applications/xception.py
rename to tensorflow/python/keras/_impl/keras/applications/xception.py
index 49fb6008f6e57d518006b93ff90a4caffd6c394b..6e521daa2d3b9eae33faecde6057e9fcc3222edc 100644
--- a/tensorflow/contrib/keras/python/keras/applications/xception.py
+++ b/tensorflow/python/keras/_impl/keras/applications/xception.py
@@ -36,22 +36,22 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras import layers
-from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import _obtain_input_shape
-from tensorflow.contrib.keras.python.keras.applications.imagenet_utils import decode_predictions  # pylint: disable=unused-import
-from tensorflow.contrib.keras.python.keras.engine.topology import get_source_inputs
-from tensorflow.contrib.keras.python.keras.layers import Activation
-from tensorflow.contrib.keras.python.keras.layers import BatchNormalization
-from tensorflow.contrib.keras.python.keras.layers import Conv2D
-from tensorflow.contrib.keras.python.keras.layers import Dense
-from tensorflow.contrib.keras.python.keras.layers import GlobalAveragePooling2D
-from tensorflow.contrib.keras.python.keras.layers import GlobalMaxPooling2D
-from tensorflow.contrib.keras.python.keras.layers import Input
-from tensorflow.contrib.keras.python.keras.layers import MaxPooling2D
-from tensorflow.contrib.keras.python.keras.layers import SeparableConv2D
-from tensorflow.contrib.keras.python.keras.models import Model
-from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import layers
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
+from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions  # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.layers import Activation
+from tensorflow.python.keras._impl.keras.layers import BatchNormalization
+from tensorflow.python.keras._impl.keras.layers import Conv2D
+from tensorflow.python.keras._impl.keras.layers import Dense
+from tensorflow.python.keras._impl.keras.layers import GlobalAveragePooling2D
+from tensorflow.python.keras._impl.keras.layers import GlobalMaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import Input
+from tensorflow.python.keras._impl.keras.layers import MaxPooling2D
+from tensorflow.python.keras._impl.keras.layers import SeparableConv2D
+from tensorflow.python.keras._impl.keras.models import Model
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
 from tensorflow.python.platform import tf_logging as logging
 
 
@@ -86,7 +86,7 @@ def Xception(include_top=True,
       input_shape: optional shape tuple, only to be specified
           if `include_top` is False (otherwise the input shape
           has to be `(299, 299, 3)`.
-          It should have exactly 3 inputs channels,
+          It should have exactly 3 input channels,
           and width and height should be no smaller than 71.
           E.g. `(150, 150, 3)` would be one valid value.
       pooling: Optional pooling mode for feature extraction
@@ -147,7 +147,8 @@ def Xception(include_top=True,
       default_size=299,
       min_size=71,
       data_format=K.image_data_format(),
-      include_top=include_top)
+      require_flatten=False,
+      weights=weights)
 
   if input_tensor is None:
     img_input = Input(shape=input_shape)
diff --git a/tensorflow/contrib/keras/python/keras/applications/xception_test.py b/tensorflow/python/keras/_impl/keras/applications/xception_test.py
similarity index 97%
rename from tensorflow/contrib/keras/python/keras/applications/xception_test.py
rename to tensorflow/python/keras/_impl/keras/applications/xception_test.py
index a941514c3e8f340462f44a820a8db5655bae3be6..7ebdc30010aa48362046b3c0c281fe1f2be64a84 100644
--- a/tensorflow/contrib/keras/python/keras/applications/xception_test.py
+++ b/tensorflow/python/keras/_impl/keras/applications/xception_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py
similarity index 95%
rename from tensorflow/contrib/keras/python/keras/backend.py
rename to tensorflow/python/keras/_impl/keras/backend.py
index 99570797af727835aeea58f1ae9d949d5d996eed..76704d5d3d15dfd58b1d2e5d6068c16e4504572d 100644
--- a/tensorflow/contrib/keras/python/keras/backend.py
+++ b/tensorflow/python/keras/_impl/keras/backend.py
@@ -266,7 +266,7 @@ def get_uid(prefix=''):
   graph = ops.get_default_graph()
   if graph not in tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS:
     tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS[graph] = collections.defaultdict(
-      int)
+        int)
   layer_name_uids = tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS[graph]
   layer_name_uids[prefix] += 1
   return layer_name_uids[prefix]
@@ -490,13 +490,15 @@ def to_dense(tensor):
 name_scope = ops.name_scope
 
 
-def variable(value, dtype=None, name=None):
+def variable(value, dtype=None, name=None, constraint=None):
   """Instantiates a variable and returns it.
 
   Arguments:
       value: Numpy array, initial value of the tensor.
       dtype: Tensor type.
       name: Optional name string for the tensor.
+      constraint: Optional projection function to be
+          applied to the variable after an optimizer update.
 
   Returns:
       A variable instance (with Keras metadata included).
@@ -523,10 +525,18 @@ def variable(value, dtype=None, name=None):
         sparse_coo.col, 1)), 1)
     v = sparse_tensor.SparseTensor(
         indices=indices, values=sparse_coo.data, dense_shape=sparse_coo.shape)
+    v._keras_shape = sparse_coo.shape
     v._uses_learning_phase = False
     return v
   v = variables_module.Variable(
-      value, dtype=_convert_string_dtype(dtype), name=name)
+      value,
+      dtype=_convert_string_dtype(dtype),
+      name=name,
+      constraint=constraint)
+  if isinstance(value, np.ndarray):
+    v._keras_shape = value.shape
+  elif hasattr(value, 'get_shape'):
+    v._keras_shape = int_shape(value)
   v._uses_learning_phase = False
   return v
 
@@ -562,6 +572,57 @@ def constant(value, dtype=None, shape=None, name=None):
   return constant_op.constant(value, dtype=dtype, shape=shape, name=name)
 
 
+def is_keras_tensor(x):
+  """Returns whether `x` is a Keras tensor.
+
+  A "Keras tensor" is a tensor that was returned by a Keras layer,
+  (`Layer` class) or by `Input`.
+
+  Arguments:
+      x: A candidate tensor.
+
+  Returns:
+      A boolean: Whether the argument is a Keras tensor.
+
+  Raises:
+      ValueError: In case `x` is not a symbolic tensor.
+
+  Examples:
+  ```python
+      >>> from keras import backend as K
+      >>> from keras.layers import Input, Dense
+      >>> np_var = numpy.array([1, 2])
+      >>> K.is_keras_tensor(np_var) # A numpy array is not a symbolic tensor.
+      ValueError
+      >>> k_var = tf.placeholder('float32', shape=(1,1))
+      >>> K.is_keras_tensor(k_var) # A variable indirectly created outside of
+      keras is not a Keras tensor.
+      False
+      >>> keras_var = K.variable(np_var)
+      >>> K.is_keras_tensor(keras_var)  # A variable created with the keras
+      backend is not a Keras tensor.
+      False
+      >>> keras_placeholder = K.placeholder(shape=(2, 4, 5))
+      >>> K.is_keras_tensor(keras_placeholder)  # A placeholder is not a Keras
+      tensor.
+      False
+      >>> keras_input = Input([10])
+      >>> K.is_keras_tensor(keras_input) # An Input is a Keras tensor.
+      True
+      >>> keras_layer_output = Dense(10)(keras_input)
+      >>> K.is_keras_tensor(keras_layer_output) # Any Keras layer output is a
+      Keras tensor.
+      True
+  ```
+  """
+  if not isinstance(x, (ops.Tensor,
+                        variables_module.Variable,
+                        sparse_tensor.SparseTensor)):
+    raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) +
+                     '`. Expected a symbolic tensor instance.')
+  return hasattr(x, '_keras_history')
+
+
 def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
   """Instantiates a placeholder tensor and returns it.
 
@@ -599,6 +660,21 @@ def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
   return x
 
 
+def is_placeholder(x):
+  """Returns whether `x` is a placeholder.
+
+  Arguments:
+      x: A candidate placeholder.
+
+  Returns:
+      Boolean.
+  """
+  try:
+    return x.op.type == 'Placeholder'
+  except AttributeError:
+    return False
+
+
 def shape(x):
   """Returns the symbolic shape of a tensor or variable.
 
@@ -609,7 +685,8 @@ def shape(x):
       A symbolic shape (which is itself a tensor).
 
   Examples:
-  ```
+
+  ```python
       # TensorFlow example
       >>> from keras import backend as K
       >>> tf_session = K.get_session()
@@ -651,9 +728,8 @@ def int_shape(x):
       (2, 2)
   ```
   """
-  shape = x.get_shape()
   try:
-    return tuple(shape.as_list())
+    return tuple(x.get_shape().as_list())
   except ValueError:
     return None
 
@@ -759,7 +835,6 @@ def zeros(shape, dtype=None, name=None):
   """
   if dtype is None:
     dtype = floatx()
-  shape = tuple(map(int, shape))
   tf_dtype = _convert_string_dtype(dtype)
   return variable(
       init_ops.constant_initializer(0., dtype=tf_dtype)(shape), dtype, name)
@@ -788,7 +863,6 @@ def ones(shape, dtype=None, name=None):
   """
   if dtype is None:
     dtype = floatx()
-  shape = tuple(map(int, shape))
   tf_dtype = _convert_string_dtype(dtype)
   return variable(
       init_ops.constant_initializer(1., dtype=tf_dtype)(shape), dtype, name)
@@ -908,7 +982,6 @@ def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):
   """
   if dtype is None:
     dtype = floatx()
-  shape = tuple(map(int, shape))
   tf_dtype = _convert_string_dtype(dtype)
   if seed is None:
     # ensure that randomness is conditioned by the Numpy RNG
@@ -946,7 +1019,6 @@ def random_normal_variable(shape, mean, scale, dtype=None, name=None,
   """
   if dtype is None:
     dtype = floatx()
-  shape = tuple(map(int, shape))
   tf_dtype = _convert_string_dtype(dtype)
   if seed is None:
     # ensure that randomness is conditioned by the Numpy RNG
@@ -1276,28 +1348,6 @@ def gather(reference, indices):
 # ELEMENT-WISE OPERATIONS
 
 
-def _normalize_axis(axis, ndim):
-  """Converts negative axes to positive values.
-
-  Arguments:
-      axis: Integer axis (possibly negative).
-      ndim: Rank of the tensor considered.
-
-  Returns:
-      Positive integer axis.
-  """
-  if isinstance(axis, tuple):
-    axis = list(axis)
-  if isinstance(axis, list):
-    for i, a in enumerate(axis):
-      if a is not None and a < 0:
-        axis[i] = a % ndim
-  else:
-    if axis is not None and axis < 0:
-      axis %= ndim
-  return axis
-
-
 def max(x, axis=None, keepdims=False):
   """Maximum value in a tensor.
 
@@ -1312,7 +1362,6 @@ def max(x, axis=None, keepdims=False):
   Returns:
       A tensor with maximum values of `x`.
   """
-  axis = _normalize_axis(axis, ndim(x))
   return math_ops.reduce_max(x, axis=axis, keep_dims=keepdims)
 
 
@@ -1330,7 +1379,6 @@ def min(x, axis=None, keepdims=False):
   Returns:
       A tensor with miminum values of `x`.
   """
-  axis = _normalize_axis(axis, ndim(x))
   return math_ops.reduce_min(x, axis=axis, keep_dims=keepdims)
 
 
@@ -1348,7 +1396,6 @@ def sum(x, axis=None, keepdims=False):
   Returns:
       A tensor with sum of `x`.
   """
-  axis = _normalize_axis(axis, ndim(x))
   return math_ops.reduce_sum(x, axis=axis, keep_dims=keepdims)
 
 
@@ -1366,7 +1413,6 @@ def prod(x, axis=None, keepdims=False):
   Returns:
       A tensor with the product of elements of `x`.
   """
-  axis = _normalize_axis(axis, ndim(x))
   return math_ops.reduce_prod(x, axis=axis, keep_dims=keepdims)
 
 
@@ -1380,7 +1426,6 @@ def cumsum(x, axis=0):
   Returns:
       A tensor of the cumulative sum of values of `x` along `axis`.
   """
-  axis = _normalize_axis(axis, ndim(x))
   return math_ops.cumsum(x, axis=axis)
 
 
@@ -1394,7 +1439,6 @@ def cumprod(x, axis=0):
   Returns:
       A tensor of the cumulative product of values of `x` along `axis`.
   """
-  axis = _normalize_axis(axis, ndim(x))
   return math_ops.cumprod(x, axis=axis)
 
 
@@ -1412,7 +1456,6 @@ def var(x, axis=None, keepdims=False):
   Returns:
       A tensor with the variance of elements of `x`.
   """
-  axis = _normalize_axis(axis, ndim(x))
   if x.dtype.base_dtype == dtypes_module.bool:
     x = math_ops.cast(x, floatx())
   m = math_ops.reduce_mean(x, axis=axis, keep_dims=True)
@@ -1452,7 +1495,6 @@ def mean(x, axis=None, keepdims=False):
   Returns:
       A tensor with the mean of elements of `x`.
   """
-  axis = _normalize_axis(axis, ndim(x))
   if x.dtype.base_dtype == dtypes_module.bool:
     x = math_ops.cast(x, floatx())
   return math_ops.reduce_mean(x, axis=axis, keep_dims=keepdims)
@@ -1469,7 +1511,6 @@ def any(x, axis=None, keepdims=False):
   Returns:
       A uint8 tensor (0s and 1s).
   """
-  axis = _normalize_axis(axis, ndim(x))
   x = math_ops.cast(x, dtypes_module.bool)
   return math_ops.reduce_any(x, axis=axis, keep_dims=keepdims)
 
@@ -1485,7 +1526,6 @@ def all(x, axis=None, keepdims=False):
   Returns:
       A uint8 tensor (0s and 1s).
   """
-  axis = _normalize_axis(axis, ndim(x))
   x = math_ops.cast(x, dtypes_module.bool)
   return math_ops.reduce_all(x, axis=axis, keep_dims=keepdims)
 
@@ -1500,7 +1540,6 @@ def argmax(x, axis=-1):
   Returns:
       A tensor.
   """
-  axis = _normalize_axis(axis, ndim(x))
   return math_ops.argmax(x, axis)
 
 
@@ -1514,7 +1553,6 @@ def argmin(x, axis=-1):
   Returns:
       A tensor.
   """
-  axis = _normalize_axis(axis, ndim(x))
   return math_ops.argmin(x, axis)
 
 
@@ -1599,7 +1637,6 @@ def logsumexp(x, axis=None, keepdims=False):
   Returns:
       The reduced tensor.
   """
-  axis = _normalize_axis(axis, ndim(x))
   return math_ops.reduce_logsumexp(x, axis=axis, keep_dims=keepdims)
 
 
@@ -1992,24 +2029,45 @@ def repeat_elements(x, rep, axis):
       rep: Python integer, number of times to repeat.
       axis: Axis along which to repeat.
 
-  Raises:
-      ValueError: In case `x.shape[axis]` is undefined.
-
   Returns:
       A tensor.
   """
   x_shape = x.get_shape().as_list()
-  if x_shape[axis] is None:
-    raise ValueError('Axis ' + str(axis) + ' of input tensor '
-                     'should have a defined dimension, but is None. '
-                     'Full tensor shape: ' + str(tuple(x_shape)) + '. '
-                     'Typically you need to pass a fully-defined '
-                     '`input_shape` argument to your first layer.')
-  # slices along the repeat axis
-  splits = array_ops.split(value=x, num_or_size_splits=x_shape[axis], axis=axis)
-  # repeat each slice the given number of reps
-  x_rep = [s for s in splits for _ in range(rep)]
-  return concatenate(x_rep, axis)
+  # For static axis
+  if x_shape[axis] is not None:
+    # slices along the repeat axis
+    splits = array_ops.split(value=x,
+                             num_or_size_splits=x_shape[axis],
+                             axis=axis)
+    # repeat each slice the given number of reps
+    x_rep = [s for s in splits for _ in range(rep)]
+    return concatenate(x_rep, axis)
+
+  # Here we use tf.tile to mimic behavior of np.repeat so that
+  # we can handle dynamic shapes (that include None).
+  # To do that, we need an auxiliary axis to repeat elements along
+  # it and then merge them along the desired axis.
+
+  # Repeating
+  auxiliary_axis = axis + 1
+  x_shape = array_ops.shape(x)
+  x_rep = array_ops.expand_dims(x, axis=auxiliary_axis)
+  reps = np.ones(len(x.get_shape()) + 1)
+  reps[auxiliary_axis] = rep
+  x_rep = array_ops.tile(x_rep, reps)
+
+  # Merging
+  reps = np.delete(reps, auxiliary_axis)
+  reps[axis] = rep
+  reps = array_ops.constant(reps, dtype='int32')
+  x_shape *= reps
+  x_rep = array_ops.reshape(x_rep, x_shape)
+
+  # Fix shape representation
+  x_shape = x.get_shape().as_list()
+  x_rep.set_shape(x_shape)
+  x_rep._keras_shape = tuple(x_shape)
+  return x_rep
 
 
 def repeat(x, n):
@@ -2303,7 +2361,7 @@ def set_value(x, value):
       value: Value to set the tensor to, as a Numpy array
           (of the same shape).
   """
-  value = np.asarray(value)
+  value = np.asarray(value, dtype=dtype(x))
   tf_dtype = _convert_string_dtype(x.dtype.name.split('_')[0])
   if hasattr(x, '_assign_placeholder'):
     assign_placeholder = x._assign_placeholder
@@ -2327,7 +2385,7 @@ def batch_set_value(tuples):
     assign_ops = []
     feed_dict = {}
     for x, value in tuples:
-      value = np.asarray(value)
+      value = np.asarray(value, dtype=dtype(x))
       tf_dtype = _convert_string_dtype(x.dtype.name.split('_')[0])
       if hasattr(x, '_assign_placeholder'):
         assign_placeholder = x._assign_placeholder
@@ -2457,11 +2515,16 @@ def stop_gradient(variables):
   """Returns `variables` but with zero gradient w.r.t. every other variable.
 
   Arguments:
-      variables: List of variables.
+      variables: Tensor or list of tensors to consider constant with respect
+        to any other variable.
+
 
   Returns:
-      The same list of variables.
+      A single tensor or a list of tensors (depending on the passed argument)
+      that has no gradient with respect to any other variable.
   """
+  if isinstance(variables, (list, tuple)):
+    return map(array_ops.stop_gradient, variables)
   return array_ops.stop_gradient(variables)
 
 
@@ -2874,14 +2937,14 @@ def softsign(x):
   return nn.softsign(x)
 
 
-def categorical_crossentropy(output, target, from_logits=False):
+def categorical_crossentropy(target, output, from_logits=False):
   """Categorical crossentropy between an output tensor and a target tensor.
 
   Arguments:
+      target: A tensor of the same shape as `output`.
       output: A tensor resulting from a softmax
           (unless `from_logits` is True, in which
           case `output` is expected to be the logits).
-      target: A tensor of the same shape as `output`.
       from_logits: Boolean, whether `output` is the
           result of a softmax, or is a tensor of logits.
 
@@ -2895,8 +2958,8 @@ def categorical_crossentropy(output, target, from_logits=False):
     output /= math_ops.reduce_sum(
         output, axis=len(output.get_shape()) - 1, keep_dims=True)
     # manual computation of crossentropy
-    epsilon = _to_tensor(_EPSILON, output.dtype.base_dtype)
-    output = clip_ops.clip_by_value(output, epsilon, 1. - epsilon)
+    epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype)
+    output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
     return -math_ops.reduce_sum(
         target * math_ops.log(output),
         axis=len(output.get_shape()) - 1)
@@ -2904,14 +2967,14 @@ def categorical_crossentropy(output, target, from_logits=False):
     return nn.softmax_cross_entropy_with_logits(labels=target, logits=output)
 
 
-def sparse_categorical_crossentropy(output, target, from_logits=False):
+def sparse_categorical_crossentropy(target, output, from_logits=False):
   """Categorical crossentropy with integer targets.
 
   Arguments:
+      target: An integer tensor.
       output: A tensor resulting from a softmax
           (unless `from_logits` is True, in which
           case `output` is expected to be the logits).
-      target: An integer tensor.
       from_logits: Boolean, whether `output` is the
           result of a softmax, or is a tensor of logits.
 
@@ -2921,8 +2984,8 @@ def sparse_categorical_crossentropy(output, target, from_logits=False):
   # Note: nn.sparse_softmax_cross_entropy_with_logits
   # expects logits, Keras expects probabilities.
   if not from_logits:
-    epsilon = _to_tensor(_EPSILON, output.dtype.base_dtype)
-    output = clip_ops.clip_by_value(output, epsilon, 1 - epsilon)
+    epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype)
+    output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
     output = math_ops.log(output)
 
   output_shape = output.get_shape()
@@ -2937,12 +3000,12 @@ def sparse_categorical_crossentropy(output, target, from_logits=False):
     return res
 
 
-def binary_crossentropy(output, target, from_logits=False):
+def binary_crossentropy(target, output, from_logits=False):
   """Binary crossentropy between an output tensor and a target tensor.
 
   Arguments:
-      output: A tensor.
       target: A tensor with the same shape as `output`.
+      output: A tensor.
       from_logits: Whether `output` is expected to be a logits tensor.
           By default, we consider that `output`
           encodes a probability distribution.
@@ -2954,8 +3017,8 @@ def binary_crossentropy(output, target, from_logits=False):
   # expects logits, Keras expects probabilities.
   if not from_logits:
     # transform back to logits
-    epsilon = _to_tensor(_EPSILON, output.dtype.base_dtype)
-    output = clip_ops.clip_by_value(output, epsilon, 1 - epsilon)
+    epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype)
+    output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
     output = math_ops.log(output / (1 - output))
   return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
 
@@ -3026,7 +3089,7 @@ def dropout(x, level, noise_shape=None, seed=None):
   return nn.dropout(x * 1., retain_prob, noise_shape, seed=seed)
 
 
-def l2_normalize(x, axis):
+def l2_normalize(x, axis=None):
   """Normalizes a tensor wrt the L2 norm alongside the specified axis.
 
   Arguments:
@@ -3036,8 +3099,6 @@ def l2_normalize(x, axis):
   Returns:
       A tensor.
   """
-  if axis < 0:
-    axis %= len(x.get_shape())
   return nn.l2_normalize(x, dim=axis)
 
 
@@ -3807,7 +3868,7 @@ def ctc_label_dense_to_sparse(labels, label_lengths):
       label_lengths: length of the labels.
 
   Returns:
-      A sparse tensor representation of the lablels.
+      A sparse tensor representation of the labels.
   """
   label_shape = array_ops.shape(labels)
   num_batches_tns = array_ops.stack([label_shape[0]])
diff --git a/tensorflow/contrib/keras/python/keras/backend_test.py b/tensorflow/python/keras/_impl/keras/backend_test.py
similarity index 97%
rename from tensorflow/contrib/keras/python/keras/backend_test.py
rename to tensorflow/python/keras/_impl/keras/backend_test.py
index 69dcf3f094e3604e85a91a871e9bccdc5d286a5b..d914490f7e42aa9dc67af44afde160572fcb8642 100644
--- a/tensorflow/contrib/keras/python/keras/backend_test.py
+++ b/tensorflow/python/keras/_impl/keras/backend_test.py
@@ -13,7 +13,6 @@
 # limitations under the License.
 # ==============================================================================
 """Tests for Keras backend."""
-
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
@@ -21,8 +20,8 @@ from __future__ import print_function
 import numpy as np
 import scipy.sparse
 
-from tensorflow.contrib.keras.python import keras
 from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 from tensorflow.python.util import tf_inspect
 
@@ -139,6 +138,33 @@ class BackendUtilsTest(test.TestCase):
       y_val = f([1])[0]
       self.assertAllClose(y_val, 1)
 
+  def test_is_keras_tensor(self):
+    x = keras.backend.variable(1)
+    self.assertEqual(keras.backend.is_keras_tensor(x), False)
+    x = keras.Input(shape=(1,))
+    self.assertEqual(keras.backend.is_keras_tensor(x), True)
+    with self.assertRaises(ValueError):
+      keras.backend.is_keras_tensor(0)
+
+  def test_is_placeholder(self):
+    x = keras.backend.placeholder(shape=(1,))
+    self.assertEqual(keras.backend.is_placeholder(x), True)
+    # Test with TF placeholder
+    x = keras.backend.array_ops.placeholder(dtype='float32', shape=(1,))
+    self.assertEqual(keras.backend.is_placeholder(x), True)
+    x = keras.backend.variable(1)
+    self.assertEqual(keras.backend.is_placeholder(x), False)
+
+  def test_stop_gradient(self):
+    x = keras.backend.variable(1)
+    y = keras.backend.stop_gradient(x)
+    self.assertEqual(y.op.name[:12], 'StopGradient')
+
+    xs = [keras.backend.variable(1) for _ in range(3)]
+    ys = keras.backend.stop_gradient(xs)
+    for y in ys:
+      self.assertEqual(y.op.name[:12], 'StopGradient')
+
 
 class BackendVariableTest(test.TestCase):
 
@@ -408,10 +434,10 @@ class BackendShapeOpsTest(test.TestCase):
     y = keras.backend.repeat_elements(x, 3, axis=1)
     self.assertEqual(y.get_shape().as_list(), [1, 9, 2])
 
-    # Invalid use:
-    with self.assertRaises(ValueError):
-      x = keras.backend.placeholder(shape=(2, None, 2))
-      keras.backend.repeat_elements(x, 3, axis=1)
+    # Use with a dynamic axis:
+    x = keras.backend.placeholder(shape=(2, None, 2))
+    y = keras.backend.repeat_elements(x, 3, axis=1)
+    self.assertEqual(y.get_shape().as_list(), [2, None, 2])
 
   def test_repeat(self):
     x = keras.backend.variable(np.ones((1, 3)))
diff --git a/tensorflow/contrib/keras/python/keras/callbacks.py b/tensorflow/python/keras/_impl/keras/callbacks.py
similarity index 94%
rename from tensorflow/contrib/keras/python/keras/callbacks.py
rename to tensorflow/python/keras/_impl/keras/callbacks.py
index 06a5f4ad8f6aac41ae37dc377704913da29cf377..eb678c4d1d9fe2ed9367417b9134756768d86b37 100644
--- a/tensorflow/contrib/keras/python/keras/callbacks.py
+++ b/tensorflow/python/keras/_impl/keras/callbacks.py
@@ -29,13 +29,11 @@ import time
 import numpy as np
 import six
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import Progbar
-from tensorflow.contrib.tensorboard.plugins import projector
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
 from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.summary import summary as tf_summary
-from tensorflow.python.training import saver as saver_lib
 
 
 # pylint: disable=g-import-not-at-top
@@ -397,7 +395,7 @@ class ModelCheckpoint(Callback):
 
     if mode not in ['auto', 'min', 'max']:
       logging.warning('ModelCheckpoint mode %s is unknown, '
-                      'fallback to auto mode.' % (mode))
+                      'fallback to auto mode.' % mode)
       mode = 'auto'
 
     if mode == 'min':
@@ -618,7 +616,7 @@ class TensorBoard(Callback):
   If you have installed TensorFlow with pip, you should be able
   to launch TensorBoard from the command line:
 
-  ```
+  ```sh
   tensorboard --logdir=/full_path_to_your_logs
   ```
 
@@ -660,10 +658,7 @@ class TensorBoard(Callback):
                batch_size=32,
                write_graph=True,
                write_grads=False,
-               write_images=False,
-               embeddings_freq=0,
-               embeddings_layer_names=None,
-               embeddings_metadata=None):
+               write_images=False):
     super(TensorBoard, self).__init__()
     self.log_dir = log_dir
     self.histogram_freq = histogram_freq
@@ -671,9 +666,6 @@ class TensorBoard(Callback):
     self.write_graph = write_graph
     self.write_grads = write_grads
     self.write_images = write_images
-    self.embeddings_freq = embeddings_freq
-    self.embeddings_layer_names = embeddings_layer_names
-    self.embeddings_metadata = embeddings_metadata or {}
     self.batch_size = batch_size
 
   def set_model(self, model):
@@ -686,6 +678,12 @@ class TensorBoard(Callback):
           tf_summary.histogram(mapped_weight_name, weight)
           if self.write_grads:
             grads = model.optimizer.get_gradients(model.total_loss, weight)
+
+            def is_indexed_slices(grad):
+              return type(grad).__name__ == 'IndexedSlices'
+
+            grads = [grad.values if is_indexed_slices(grad) else grad
+                     for grad in grads]
             tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
           if self.write_images:
             w_img = array_ops.squeeze(weight)
@@ -722,48 +720,12 @@ class TensorBoard(Callback):
     else:
       self.writer = tf_summary.FileWriter(self.log_dir)
 
-    if self.embeddings_freq:
-      embeddings_layer_names = self.embeddings_layer_names
-
-      if not embeddings_layer_names:
-        embeddings_layer_names = [
-            layer.name for layer in self.model.layers
-            if type(layer).__name__ == 'Embedding'
-        ]
-
-      embeddings = {
-          layer.name: layer.weights[0]
-          for layer in self.model.layers if layer.name in embeddings_layer_names
-      }
-
-      self.saver = saver_lib.Saver(list(embeddings.values()))
-
-      embeddings_metadata = {}
-
-      if not isinstance(self.embeddings_metadata, str):
-        embeddings_metadata = self.embeddings_metadata
-      else:
-        embeddings_metadata = {
-            layer_name: self.embeddings_metadata
-            for layer_name in embeddings.keys()
-        }
-
-      config = projector.ProjectorConfig()
-      self.embeddings_ckpt_path = os.path.join(self.log_dir,
-                                               'keras_embedding.ckpt')
-
-      for layer_name, tensor in embeddings.items():
-        embedding = config.embeddings.add()
-        embedding.tensor_name = tensor.name
-
-        if layer_name in embeddings_metadata:
-          embedding.metadata_path = embeddings_metadata[layer_name]
-
-      projector.visualize_embeddings(self.writer, config)
-
   def on_epoch_end(self, epoch, logs=None):
     logs = logs or {}
 
+    if not self.validation_data and self.histogram_freq:
+      raise ValueError('If printing histograms, validation_data must be '
+                       'provided, and cannot be a generator.')
     if self.validation_data and self.histogram_freq:
       if epoch % self.histogram_freq == 0:
 
@@ -784,17 +746,17 @@ class TensorBoard(Callback):
           batch_val.append(val_data[1][i:i + step])
           batch_val.append(val_data[2][i:i + step])
           if self.model.uses_learning_phase:
-            batch_val.append(val_data[3])
+            # do not slice the learning phase
+            batch_val = [x[i:i + step] for x in val_data[:-1]]
+            batch_val.append(val_data[-1])
+          else:
+            batch_val = [x[i:i + step] for x in val_data]
           feed_dict = dict(zip(tensors, batch_val))
           result = self.sess.run([self.merged], feed_dict=feed_dict)
           summary_str = result[0]
           self.writer.add_summary(summary_str, epoch)
           i += self.batch_size
 
-    if self.embeddings_freq and self.embeddings_ckpt_path:
-      if epoch % self.embeddings_freq == 0:
-        self.saver.save(self.sess, self.embeddings_ckpt_path, epoch)
-
     for name, value in logs.items():
       if name in ['batch', 'size']:
         continue
@@ -978,6 +940,10 @@ class CSVLogger(Callback):
       else:
         return k
 
+    if self.model.stop_training:
+      # We set NA so that csv parsers do not fail for this last epoch.
+      logs = dict([(k, logs[k]) if k in logs else (k, 'NA') for k in self.keys])
+
     if not self.writer:
       self.keys = sorted(logs.keys())
 
diff --git a/tensorflow/contrib/keras/python/keras/callbacks_test.py b/tensorflow/python/keras/_impl/keras/callbacks_test.py
similarity index 70%
rename from tensorflow/contrib/keras/python/keras/callbacks_test.py
rename to tensorflow/python/keras/_impl/keras/callbacks_test.py
index d8c5c0337f67ecd07dfe34992e5c0d2742312f60..d9d7fb5a9fb767a93019217ba16321c72f2a47ad 100644
--- a/tensorflow/contrib/keras/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/_impl/keras/callbacks_test.py
@@ -26,8 +26,8 @@ import shutil
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 
 try:
@@ -347,7 +347,6 @@ class KerasCallbacksTest(test.TestCase):
         return model
 
       model = make_model()
-
       # This should reduce the LR after the first epoch (due to high epsilon).
       cbks = [
           keras.callbacks.ReduceLROnPlateau(
@@ -365,28 +364,10 @@ class KerasCallbacksTest(test.TestCase):
           callbacks=cbks,
           epochs=5,
           verbose=0)
-      assert np.allclose(
+      self.assertAllClose(
           float(keras.backend.get_value(model.optimizer.lr)),
           0.01,
-          atol=keras.backend.epsilon())
-
-      model = make_model()
-      cbks = [
-          keras.callbacks.ReduceLROnPlateau(
-              monitor='val_loss', factor=0.1, epsilon=0, patience=1, cooldown=5)
-      ]
-      model.fit(
-          x_train,
-          y_train,
-          batch_size=BATCH_SIZE,
-          validation_data=(x_test, y_test),
-          callbacks=cbks,
-          epochs=5,
-          verbose=0)
-      assert np.allclose(
-          float(keras.backend.get_value(model.optimizer.lr)),
-          0.1,
-          atol=keras.backend.epsilon())
+          atol=1e-4)
 
   def test_CSVLogger(self):
     with self.test_session():
@@ -465,6 +446,61 @@ class KerasCallbacksTest(test.TestCase):
 
       os.remove(filepath)
 
+  def test_stop_training_csv(self):
+    # Test that using the CSVLogger callback with the TerminateOnNaN callback
+    # does not result in invalid CSVs.
+    np.random.seed(1337)
+    tmpdir = self.get_temp_dir()
+    self.addCleanup(shutil.rmtree, tmpdir)
+
+    with self.test_session():
+      fp = os.path.join(tmpdir, 'test.csv')
+      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+          train_samples=TRAIN_SAMPLES,
+          test_samples=TEST_SAMPLES,
+          input_shape=(INPUT_DIM,),
+          num_classes=NUM_CLASSES)
+
+      y_test = keras.utils.to_categorical(y_test)
+      y_train = keras.utils.to_categorical(y_train)
+      cbks = [keras.callbacks.TerminateOnNaN(), keras.callbacks.CSVLogger(fp)]
+      model = keras.models.Sequential()
+      for _ in range(5):
+        model.add(keras.layers.Dense(2, input_dim=INPUT_DIM, activation='relu'))
+      model.add(keras.layers.Dense(NUM_CLASSES, activation='linear'))
+      model.compile(loss='mean_squared_error',
+                    optimizer='rmsprop')
+
+      def data_generator():
+        i = 0
+        max_batch_index = len(x_train) // BATCH_SIZE
+        tot = 0
+        while 1:
+          if tot > 3 * len(x_train):
+            yield (np.ones([BATCH_SIZE, INPUT_DIM]) * np.nan,
+                   np.ones([BATCH_SIZE, NUM_CLASSES]) * np.nan)
+          else:
+            yield (x_train[i * BATCH_SIZE: (i + 1) * BATCH_SIZE],
+                   y_train[i * BATCH_SIZE: (i + 1) * BATCH_SIZE])
+          i += 1
+          tot += 1
+          i %= max_batch_index
+
+      history = model.fit_generator(data_generator(),
+                                    len(x_train) // BATCH_SIZE,
+                                    validation_data=(x_test, y_test),
+                                    callbacks=cbks,
+                                    epochs=20)
+      loss = history.history['loss']
+      assert len(loss) > 1
+      assert loss[-1] == np.inf or np.isnan(loss[-1])
+
+      values = []
+      with open(fp) as f:
+        for x in csv.reader(f):
+          values.append(x)
+      assert 'nan' in values[-1], 'The last epoch was not logged.'
+
   def test_TerminateOnNaN(self):
     np.random.seed(1337)
     (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
@@ -538,8 +574,7 @@ class KerasCallbacksTest(test.TestCase):
 
       tsb = keras.callbacks.TensorBoard(
           log_dir=temp_dir, histogram_freq=1, write_images=True,
-          write_grads=True, embeddings_freq=1,
-          embeddings_layer_names=['dense_1'], batch_size=5)
+          write_grads=True, batch_size=5)
       cbks = [tsb]
 
       # fit with validation data
@@ -593,6 +628,146 @@ class KerasCallbacksTest(test.TestCase):
           data_generator(True), len(x_train), epochs=2, callbacks=cbks)
       assert os.path.exists(temp_dir)
 
+  def test_TensorBoard_histogram_freq_must_have_validation_data(self):
+    np.random.seed(1337)
+    tmpdir = self.get_temp_dir()
+    self.addCleanup(shutil.rmtree, tmpdir)
+
+    with self.test_session():
+      filepath = os.path.join(tmpdir, 'logs')
+
+      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+          train_samples=TRAIN_SAMPLES,
+          test_samples=TEST_SAMPLES,
+          input_shape=(INPUT_DIM,),
+          num_classes=NUM_CLASSES)
+      y_test = keras.utils.to_categorical(y_test)
+      y_train = keras.utils.to_categorical(y_train)
+
+      def data_generator(train):
+        if train:
+          max_batch_index = len(x_train) // BATCH_SIZE
+        else:
+          max_batch_index = len(x_test) // BATCH_SIZE
+        i = 0
+        while 1:
+          if train:
+            # simulate multi-input/output models
+            yield (x_train[i * BATCH_SIZE: (i + 1) * BATCH_SIZE],
+                   y_train[i * BATCH_SIZE: (i + 1) * BATCH_SIZE])
+          else:
+            yield (x_test[i * BATCH_SIZE: (i + 1) * BATCH_SIZE],
+                   y_test[i * BATCH_SIZE: (i + 1) * BATCH_SIZE])
+          i += 1
+          i %= max_batch_index
+
+      inp = keras.Input((INPUT_DIM,))
+      hidden = keras.layers.Dense(2, activation='relu')(inp)
+      hidden = keras.layers.Dropout(0.1)(hidden)
+      output = keras.layers.Dense(NUM_CLASSES, activation='softmax')(hidden)
+      model = keras.models.Model(inputs=inp, outputs=output)
+      model.compile(loss='categorical_crossentropy',
+                    optimizer='sgd',
+                    metrics=['accuracy'])
+
+      # we must generate new callbacks for each test, as they aren't stateless
+      def callbacks_factory(histogram_freq):
+        return [keras.callbacks.TensorBoard(
+            log_dir=filepath,
+            histogram_freq=histogram_freq,
+            write_images=True, write_grads=True,
+            batch_size=5)]
+
+      # fit w/o validation data should raise ValueError if histogram_freq > 0
+      with self.assertRaises(ValueError):
+        model.fit(x_train, y_train, batch_size=BATCH_SIZE,
+                  callbacks=callbacks_factory(histogram_freq=1), epochs=3)
+
+      # fit generator without validation data should raise ValueError if
+      # histogram_freq > 0
+      with self.assertRaises(ValueError):
+        model.fit_generator(data_generator(True), len(x_train), epochs=2,
+                            callbacks=callbacks_factory(histogram_freq=1))
+
+      # fit generator with validation data generator should raise ValueError if
+      # histogram_freq > 0
+      with self.assertRaises(ValueError):
+        model.fit_generator(data_generator(True), len(x_train), epochs=2,
+                            validation_data=data_generator(False),
+                            validation_steps=1,
+                            callbacks=callbacks_factory(histogram_freq=1))
+
+  def test_TensorBoard_multi_input_output(self):
+    np.random.seed(1337)
+    tmpdir = self.get_temp_dir()
+    self.addCleanup(shutil.rmtree, tmpdir)
+
+    with self.test_session():
+      filepath = os.path.join(tmpdir, 'logs')
+
+      (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+          train_samples=TRAIN_SAMPLES,
+          test_samples=TEST_SAMPLES,
+          input_shape=(INPUT_DIM,),
+          num_classes=NUM_CLASSES)
+      y_test = keras.utils.to_categorical(y_test)
+      y_train = keras.utils.to_categorical(y_train)
+
+      def data_generator(train):
+        if train:
+          max_batch_index = len(x_train) // BATCH_SIZE
+        else:
+          max_batch_index = len(x_test) // BATCH_SIZE
+        i = 0
+        while 1:
+          if train:
+            # simulate multi-input/output models
+            yield ([x_train[i * BATCH_SIZE: (i + 1) * BATCH_SIZE]] * 2,
+                   [y_train[i * BATCH_SIZE: (i + 1) * BATCH_SIZE]] * 2)
+          else:
+            yield ([x_test[i * BATCH_SIZE: (i + 1) * BATCH_SIZE]] * 2,
+                   [y_test[i * BATCH_SIZE: (i + 1) * BATCH_SIZE]] * 2)
+          i += 1
+          i %= max_batch_index
+
+      inp1 = keras.Input((INPUT_DIM,))
+      inp2 = keras.Input((INPUT_DIM,))
+      inp = keras.layers.add([inp1, inp2])
+      hidden = keras.layers.Dense(2, activation='relu')(inp)
+      hidden = keras.layers.Dropout(0.1)(hidden)
+      output1 = keras.layers.Dense(NUM_CLASSES, activation='softmax')(hidden)
+      output2 = keras.layers.Dense(NUM_CLASSES, activation='softmax')(hidden)
+      model = keras.models.Model([inp1, inp2], [output1, output2])
+      model.compile(loss='categorical_crossentropy',
+                    optimizer='sgd',
+                    metrics=['accuracy'])
+
+      # we must generate new callbacks for each test, as they aren't stateless
+      def callbacks_factory(histogram_freq):
+        return [keras.callbacks.TensorBoard(log_dir=filepath,
+                                            histogram_freq=histogram_freq,
+                                            write_images=True, write_grads=True,
+                                            batch_size=5)]
+
+      # fit without validation data
+      model.fit([x_train] * 2, [y_train] * 2, batch_size=BATCH_SIZE,
+                callbacks=callbacks_factory(histogram_freq=0), epochs=3)
+
+      # fit with validation data and accuracy
+      model.fit([x_train] * 2, [y_train] * 2, batch_size=BATCH_SIZE,
+                validation_data=([x_test] * 2, [y_test] * 2),
+                callbacks=callbacks_factory(histogram_freq=1), epochs=2)
+
+      # fit generator without validation data
+      model.fit_generator(data_generator(True), len(x_train), epochs=2,
+                          callbacks=callbacks_factory(histogram_freq=0))
+
+      # fit generator with validation data and accuracy
+      model.fit_generator(data_generator(True), len(x_train), epochs=2,
+                          validation_data=([x_test] * 2, [y_test] * 2),
+                          callbacks=callbacks_factory(histogram_freq=1))
+      assert os.path.isdir(filepath)
+
   def test_LambdaCallback(self):
     with self.test_session():
       np.random.seed(1337)
diff --git a/tensorflow/contrib/keras/python/keras/constraints.py b/tensorflow/python/keras/_impl/keras/constraints.py
similarity index 96%
rename from tensorflow/contrib/keras/python/keras/constraints.py
rename to tensorflow/python/keras/_impl/keras/constraints.py
index 0a59dd92c114f1dc431c0c644e5788e064fa9131..e58e3b0377b4b0fcad923095177c54d9c3ee1c0b 100644
--- a/tensorflow/contrib/keras/python/keras/constraints.py
+++ b/tensorflow/python/keras/_impl/keras/constraints.py
@@ -20,9 +20,9 @@ from __future__ import print_function
 
 import six
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
+from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object
 
 
 class Constraint(object):
diff --git a/tensorflow/contrib/keras/python/keras/constraints_test.py b/tensorflow/python/keras/_impl/keras/constraints_test.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/constraints_test.py
rename to tensorflow/python/keras/_impl/keras/constraints_test.py
index 36fbee7fd56093b7130f5a8a6f4d242721075909..87905693caa900a2cc565cef4bcea3fa30a4bc6c 100644
--- a/tensorflow/contrib/keras/python/keras/constraints_test.py
+++ b/tensorflow/python/keras/_impl/keras/constraints_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/datasets/__init__.py b/tensorflow/python/keras/_impl/keras/datasets/__init__.py
similarity index 68%
rename from tensorflow/contrib/keras/python/keras/datasets/__init__.py
rename to tensorflow/python/keras/_impl/keras/datasets/__init__.py
index fe8dee54db3f74f407805dd951432f233885b9e4..22afb6a55343ce1cba66785ebc792434060eda02 100644
--- a/tensorflow/contrib/keras/python/keras/datasets/__init__.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/__init__.py
@@ -18,10 +18,10 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras.datasets import boston_housing
-from tensorflow.contrib.keras.python.keras.datasets import cifar10
-from tensorflow.contrib.keras.python.keras.datasets import cifar100
-from tensorflow.contrib.keras.python.keras.datasets import imdb
-from tensorflow.contrib.keras.python.keras.datasets import mnist
-from tensorflow.contrib.keras.python.keras.datasets import reuters
+from tensorflow.python.keras._impl.keras.datasets import boston_housing
+from tensorflow.python.keras._impl.keras.datasets import cifar10
+from tensorflow.python.keras._impl.keras.datasets import cifar100
+from tensorflow.python.keras._impl.keras.datasets import imdb
+from tensorflow.python.keras._impl.keras.datasets import mnist
+from tensorflow.python.keras._impl.keras.datasets import reuters
 
diff --git a/tensorflow/contrib/keras/python/keras/datasets/boston_housing.py b/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py
similarity index 96%
rename from tensorflow/contrib/keras/python/keras/datasets/boston_housing.py
rename to tensorflow/python/keras/_impl/keras/datasets/boston_housing.py
index 36b20451ff670753c29ce9d68b72b1f2e8962a0d..e4f7fb9d2128d305ee7e26777c7627725001cf92 100644
--- a/tensorflow/contrib/keras/python/keras/datasets/boston_housing.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/boston_housing.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
 
 
 def load_data(path='boston_housing.npz', seed=113, test_split=0.2):
diff --git a/tensorflow/contrib/keras/python/keras/datasets/cifar.py b/tensorflow/python/keras/_impl/keras/datasets/cifar.py
similarity index 100%
rename from tensorflow/contrib/keras/python/keras/datasets/cifar.py
rename to tensorflow/python/keras/_impl/keras/datasets/cifar.py
diff --git a/tensorflow/contrib/keras/python/keras/datasets/cifar10.py b/tensorflow/python/keras/_impl/keras/datasets/cifar10.py
similarity index 89%
rename from tensorflow/contrib/keras/python/keras/datasets/cifar10.py
rename to tensorflow/python/keras/_impl/keras/datasets/cifar10.py
index 11618b8552b7d939adfcbb70653e9592277e98e4..672249ff20f37e701e276ab3c2489de4630867be 100644
--- a/tensorflow/contrib/keras/python/keras/datasets/cifar10.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/cifar10.py
@@ -22,9 +22,9 @@ import os
 
 import numpy as np
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.datasets.cifar import load_batch
-from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.datasets.cifar import load_batch
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
 
 
 def load_data():
diff --git a/tensorflow/contrib/keras/python/keras/datasets/cifar100.py b/tensorflow/python/keras/_impl/keras/datasets/cifar100.py
similarity index 89%
rename from tensorflow/contrib/keras/python/keras/datasets/cifar100.py
rename to tensorflow/python/keras/_impl/keras/datasets/cifar100.py
index eba3ee641506870f58c574ed3eccff0883e17af0..1be7483d27332cb89fbc02e2f4a502de7200e828 100644
--- a/tensorflow/contrib/keras/python/keras/datasets/cifar100.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/cifar100.py
@@ -22,9 +22,9 @@ import os
 
 import numpy as np
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.datasets.cifar import load_batch
-from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.datasets.cifar import load_batch
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
 
 
 def load_data(label_mode='fine'):
diff --git a/tensorflow/contrib/keras/python/keras/datasets/imdb.py b/tensorflow/python/keras/_impl/keras/datasets/imdb.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/datasets/imdb.py
rename to tensorflow/python/keras/_impl/keras/datasets/imdb.py
index 04ab154f9f3accf2ee7c8c6b6263024303a36d88..0db9d61f6d58448fb33851623991a0587d1db84e 100644
--- a/tensorflow/contrib/keras/python/keras/datasets/imdb.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/imdb.py
@@ -23,7 +23,7 @@ import json
 import numpy as np
 from six.moves import zip  # pylint: disable=redefined-builtin
 
-from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
 
 
 def load_data(path='imdb.npz',
diff --git a/tensorflow/contrib/keras/python/keras/datasets/mnist.py b/tensorflow/python/keras/_impl/keras/datasets/mnist.py
similarity index 94%
rename from tensorflow/contrib/keras/python/keras/datasets/mnist.py
rename to tensorflow/python/keras/_impl/keras/datasets/mnist.py
index aaced003d0f33feebc35abeb7442bd9d8b397a35..02be5e2a407be89d93f3c20f6a01c476a35697bf 100644
--- a/tensorflow/contrib/keras/python/keras/datasets/mnist.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/mnist.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
 
 
 def load_data(path='mnist.npz'):
diff --git a/tensorflow/contrib/keras/python/keras/datasets/reuters.py b/tensorflow/python/keras/_impl/keras/datasets/reuters.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/datasets/reuters.py
rename to tensorflow/python/keras/_impl/keras/datasets/reuters.py
index 2904eb5bf6f45a0a845affd9e8af198f614c4fea..c36bac5cc7df157b8bbb1416ca3715a041586e27 100644
--- a/tensorflow/contrib/keras/python/keras/datasets/reuters.py
+++ b/tensorflow/python/keras/_impl/keras/datasets/reuters.py
@@ -24,7 +24,7 @@ import json
 import numpy as np
 from six.moves import zip  # pylint: disable=redefined-builtin
 
-from tensorflow.contrib.keras.python.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
 
 
 def load_data(path='reuters.npz',
diff --git a/tensorflow/contrib/keras/python/keras/engine/__init__.py b/tensorflow/python/keras/_impl/keras/engine/__init__.py
similarity index 67%
rename from tensorflow/contrib/keras/python/keras/engine/__init__.py
rename to tensorflow/python/keras/_impl/keras/engine/__init__.py
index 0a1dc3dd2de91658aedd558bdd7e86afa36a9389..31f624f9af65cac60b6466d4eb5753cbdee984c6 100644
--- a/tensorflow/contrib/keras/python/keras/engine/__init__.py
+++ b/tensorflow/python/keras/_impl/keras/engine/__init__.py
@@ -18,12 +18,12 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras.engine.topology import get_source_inputs
-from tensorflow.contrib.keras.python.keras.engine.topology import Input
-from tensorflow.contrib.keras.python.keras.engine.topology import InputLayer
-from tensorflow.contrib.keras.python.keras.engine.topology import InputSpec
-from tensorflow.contrib.keras.python.keras.engine.topology import Layer
-from tensorflow.contrib.keras.python.keras.engine.training import Model
+from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs
+from tensorflow.python.keras._impl.keras.engine.topology import Input
+from tensorflow.python.keras._impl.keras.engine.topology import InputLayer
+from tensorflow.python.keras._impl.keras.engine.topology import InputSpec
+from tensorflow.python.keras._impl.keras.engine.topology import Layer
+from tensorflow.python.keras._impl.keras.engine.training import Model
 
 
 # Note: topology.Node is an internal class,
diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/python/keras/_impl/keras/engine/topology.py
similarity index 93%
rename from tensorflow/contrib/keras/python/keras/engine/topology.py
rename to tensorflow/python/keras/_impl/keras/engine/topology.py
index 8f69dbf49c0d505ae4c7ae9379f954659e502465..b6d341f7c997952316415eaa43788db39272206f 100644
--- a/tensorflow/contrib/keras/python/keras/engine/topology.py
+++ b/tensorflow/python/keras/_impl/keras/engine/topology.py
@@ -26,11 +26,11 @@ import os
 import numpy as np
 from six.moves import zip  # pylint: disable=redefined-builtin
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.utils import conv_utils
-from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
-from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summary as print_layer_summary
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.utils import conv_utils
+from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite
+from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary as print_layer_summary
 from tensorflow.python.layers import base as tf_base_layers
 from tensorflow.python.platform import tf_logging as logging
 
@@ -89,7 +89,6 @@ class Layer(tf_base_layers.Layer):
       non_trainable_weights: List of variables.
       weights: The concatenation of the lists trainable_weights and
           non_trainable_weights (in this order).
-      constraints: Dict mapping weights to constraints.
 
   # Methods
       call(x, mask=None): Where the layer's logic lives.
@@ -647,7 +646,6 @@ class Network(tf_base_layers.Network, Layer):
       outbound_nodes: list of nodes
       trainable_weights (list of variables)
       non_trainable_weights (list of variables)
-      constraints (list of tuples (weight, constraint))
 
   # Methods
       summary
@@ -843,6 +841,8 @@ class Network(tf_base_layers.Network, Layer):
       layer, node_index, tensor_index = self._input_coordinates[i]
       node_key = tf_base_layers._make_node_key(layer.name,
                                                node_index)
+      if node_key not in self._network_nodes:
+        continue
       new_node_index = node_conversion_map[node_key]
       model_inputs.append([layer.name, new_node_index, tensor_index])
     config['input_layers'] = model_inputs
@@ -851,6 +851,8 @@ class Network(tf_base_layers.Network, Layer):
       layer, node_index, tensor_index = self._output_coordinates[i]
       node_key = tf_base_layers._make_node_key(layer.name,
                                                node_index)
+      if node_key not in self._network_nodes:
+        continue
       new_node_index = node_conversion_map[node_key]
       model_outputs.append([layer.name, new_node_index, tensor_index])
     config['output_layers'] = model_outputs
@@ -872,10 +874,61 @@ class Network(tf_base_layers.Network, Layer):
     Raises:
         ValueError: In case of improperly formatted config dict.
     """
-    # layer instances created during
+    # Layer instances created during
     # the graph reconstruction process
     created_layers = {}
 
+    # Dictionary mapping layer instances to
+    # node data that specifies a layer call.
+    # It acts as a queue that maintains any unprocessed
+    # layer call until it becomes possible to process it
+    # (i.e. until the input tensors to the call all exist).
+    unprocessed_nodes = {}
+
+    def add_unprocessed_node(layer, node_data):
+      if layer not in unprocessed_nodes:
+        unprocessed_nodes[layer] = [node_data]
+      else:
+        unprocessed_nodes[layer].append(node_data)
+
+    def process_node(layer, node_data):
+      """Deserialize a node.
+
+      Arguments:
+          layer: layer instance.
+          node_data: node config dict.
+
+      Raises:
+          ValueError: In case of improperly formatted `node_data` dict.
+      """
+      input_tensors = []
+      for input_data in node_data:
+        inbound_layer_name = input_data[0]
+        inbound_node_index = input_data[1]
+        inbound_tensor_index = input_data[2]
+        if len(input_data) == 3:
+          kwargs = {}
+        elif len(input_data) == 4:
+          kwargs = input_data[3]
+        else:
+          raise ValueError('Improperly formatted model config.')
+        if inbound_layer_name not in created_layers:
+          add_unprocessed_node(layer, node_data)
+          return
+        inbound_layer = created_layers[inbound_layer_name]
+        if len(inbound_layer.inbound_nodes) <= inbound_node_index:
+          add_unprocessed_node(layer, node_data)
+          return
+        inbound_node = inbound_layer.inbound_nodes[inbound_node_index]
+        input_tensors.append(inbound_node.output_tensors[inbound_tensor_index])
+      # Call layer on its inputs, thus creating the node
+      # and building the layer if needed.
+      if input_tensors:
+        if len(input_tensors) == 1:
+          layer(input_tensors[0], **kwargs)
+        else:
+          layer(input_tensors, **kwargs)
+
     def process_layer(layer_data):
       """Deserialize a layer, then call it on appropriate inputs.
 
@@ -888,40 +941,33 @@ class Network(tf_base_layers.Network, Layer):
       layer_name = layer_data['name']
 
       # Instantiate layer.
-      from tensorflow.contrib.keras.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
+      from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
+
       layer = deserialize_layer(layer_data, custom_objects=custom_objects)
       created_layers[layer_name] = layer
 
       # Gather layer inputs.
       inbound_nodes_data = layer_data['inbound_nodes']
       for node_data in inbound_nodes_data:
-        input_tensors = []
-        for input_data in node_data:
-          inbound_layer_name = input_data[0]
-          inbound_node_index = input_data[1]
-          inbound_tensor_index = input_data[2]
-          if len(input_data) == 3:
-            kwargs = {}
-          elif len(input_data) == 4:
-            kwargs = input_data[3]
-          else:
-            raise ValueError('Improperly formatted model config.')
-          if inbound_layer_name not in created_layers:
-            raise ValueError('Missing layer: ' + inbound_layer_name)
-          inbound_layer = created_layers[inbound_layer_name]
-          inbound_node = inbound_layer.inbound_nodes[inbound_node_index]
-          input_tensors.append(
-              inbound_node.output_tensors[inbound_tensor_index])
-        # Call layer on its inputs, thus creating the node
-        # and building the layer if needed.
-        if input_tensors:
-          if len(input_tensors) == 1:
-            layer(input_tensors[0], **kwargs)
-          else:
-            layer(input_tensors, **kwargs)
+        # We don't process nodes (i.e. make layer calls)
+        # on the fly because the inbound node may not yet exist,
+        # in case of layer shared at different topological depths
+        # (e.g. a model such as A(B(A(B(x)))))
+        add_unprocessed_node(layer, node_data)
 
+    # First, we create all layers and enqueue nodes to be processed
     for layer_data in config['layers']:
       process_layer(layer_data)
+    # Then we process nodes in order of layer depth.
+    # Nodes that cannot yet be processed (if the inbound node
+    # does not yet exist) are re-enqueued, and the process
+    # is repeated until all nodes are processed.
+    while unprocessed_nodes:
+      for layer_data in config['layers']:
+        layer = created_layers[layer_data['name']]
+        if layer in unprocessed_nodes:
+          for node_data in unprocessed_nodes.pop(layer):
+            process_node(layer, node_data)
 
     name = config.get('name')
     input_tensors = []
@@ -976,7 +1022,7 @@ class Network(tf_base_layers.Network, Layer):
     model = load_model('my_model.h5')
     ```
     """
-    from tensorflow.contrib.keras.python.keras.models import save_model  # pylint: disable=g-import-not-at-top
+    from tensorflow.python.keras._impl.keras.models import save_model  # pylint: disable=g-import-not-at-top
     save_model(self, filepath, overwrite, include_optimizer)
 
   def save_weights(self, filepath, overwrite=True):
@@ -1054,7 +1100,7 @@ class Network(tf_base_layers.Network, Layer):
     Returns:
         Model config with Keras version information added.
     """
-    from tensorflow.contrib.keras.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
+    from tensorflow.python.keras._impl.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
 
     config = self.get_config()
     model_config = {
@@ -1201,7 +1247,7 @@ def _to_list(x):
 
 
 def save_weights_to_hdf5_group(f, layers):
-  from tensorflow.contrib.keras.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
+  from tensorflow.python.keras._impl.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
 
   f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in layers]
   f.attrs['backend'] = K.backend().encode('utf8')
diff --git a/tensorflow/contrib/keras/python/keras/engine/topology_test.py b/tensorflow/python/keras/_impl/keras/engine/topology_test.py
similarity index 94%
rename from tensorflow/contrib/keras/python/keras/engine/topology_test.py
rename to tensorflow/python/keras/_impl/keras/engine/topology_test.py
index fa099515abc961955ece189eb266dace14949b54..e5ec01ed716d617a01a5320b809b45fa37250182 100644
--- a/tensorflow/contrib/keras/python/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/topology_test.py
@@ -23,8 +23,8 @@ import shutil
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
 from tensorflow.python.framework import dtypes
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import test
 
@@ -637,5 +637,53 @@ class TopologyConstructionTest(test.TestCase):
     _ = keras.engine.topology.preprocess_weights_for_loading(
         model, model.weights, original_keras_version='1')
 
+  def test_layer_sharing_at_heterogenous_depth(self):
+    with self.test_session():
+      x_val = np.random.random((10, 5))
+
+      x = keras.Input(shape=(5,))
+      a = keras.layers.Dense(5, name='A')
+      b = keras.layers.Dense(5, name='B')
+      output = a(b(a(b(x))))
+      m = keras.models.Model(x, output)
+
+      output_val = m.predict(x_val)
+
+      config = m.get_config()
+      weights = m.get_weights()
+
+      m2 = keras.models.Model.from_config(config)
+      m2.set_weights(weights)
+
+      output_val_2 = m2.predict(x_val)
+      self.assertAllClose(output_val, output_val_2, atol=1e-6)
+
+  def test_layer_sharing_at_heterogenous_depth_with_concat(self):
+    with self.test_session():
+      input_shape = (16, 9, 3)
+      input_layer = keras.Input(shape=input_shape)
+
+      a = keras.layers.Dense(3, name='dense_A')
+      b = keras.layers.Dense(3, name='dense_B')
+      c = keras.layers.Dense(3, name='dense_C')
+
+      x1 = b(a(input_layer))
+      x2 = a(c(input_layer))
+      output = keras.layers.concatenate([x1, x2])
+
+      m = keras.models.Model(inputs=input_layer, outputs=output)
+
+      x_val = np.random.random((10, 16, 9, 3))
+      output_val = m.predict(x_val)
+
+      config = m.get_config()
+      weights = m.get_weights()
+
+      m2 = keras.models.Model.from_config(config)
+      m2.set_weights(weights)
+
+      output_val_2 = m2.predict(x_val)
+      self.assertAllClose(output_val, output_val_2, atol=1e-6)
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/contrib/keras/python/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py
similarity index 74%
rename from tensorflow/contrib/keras/python/keras/engine/training.py
rename to tensorflow/python/keras/_impl/keras/engine/training.py
index fabfa537d838a9538487c24b2b32f9268a80b22c..0b04c17ad7007602e5c1d3b7241953952ad63aaf 100644
--- a/tensorflow/contrib/keras/python/keras/engine/training.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training.py
@@ -22,18 +22,17 @@ from __future__ import print_function
 import copy
 
 import numpy as np
-import six
-
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras import callbacks as cbks
-from tensorflow.contrib.keras.python.keras import losses
-from tensorflow.contrib.keras.python.keras import metrics as metrics_module
-from tensorflow.contrib.keras.python.keras import optimizers
-from tensorflow.contrib.keras.python.keras.engine.topology import Container
-from tensorflow.contrib.keras.python.keras.utils.data_utils import GeneratorEnqueuer
-from tensorflow.contrib.keras.python.keras.utils.data_utils import OrderedEnqueuer
-from tensorflow.contrib.keras.python.keras.utils.data_utils import Sequence
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import Progbar
+
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import callbacks as cbks
+from tensorflow.python.keras._impl.keras import losses
+from tensorflow.python.keras._impl.keras import metrics as metrics_module
+from tensorflow.python.keras._impl.keras import optimizers
+from tensorflow.python.keras._impl.keras.engine.topology import Container
+from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer
+from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer
+from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence
+from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
 from tensorflow.python.platform import tf_logging as logging
 
 
@@ -65,6 +64,9 @@ def _standardize_input_data(data,
       ValueError: in case of improperly formatted user-provided data.
   """
   if not names:
+    if data is not None and hasattr(data, '__len__') and len(data):
+      raise ValueError('Error when checking model ' + exception_prefix + ': '
+                       'expected no data, but got:', data)
     return []
   if data is None:
     return [None for _ in range(len(names))]
@@ -83,7 +85,7 @@ def _standardize_input_data(data,
             ': the list of Numpy arrays '
             'that you are passing to your model '
             'is not the size the model expected. '
-            'Expected to see ' + str(len(names)) + ' arrays but instead got '
+            'Expected to see ' + str(len(names)) + ' array(s), but instead got '
             'the following list of ' + str(len(data)) + ' arrays: ' +
             str(data)[:200] + '...')
       else:
@@ -441,6 +443,7 @@ def _weighted_masked_objective(fn):
     # score_array has ndim >= 2
     score_array = fn(y_true, y_pred)
     if mask is not None:
+      # Cast the mask to floatX to avoid float64 upcasting in theano
       mask = K.cast(mask, K.floatx())
       # mask should have the same shape as score_array
       score_array *= mask
@@ -461,47 +464,6 @@ def _weighted_masked_objective(fn):
   return weighted
 
 
-def _masked_objective(fn):
-  """Adds support for masking to an objective function.
-
-  It transforms an objective function `fn(y_true, y_pred)`
-  into a cost-masked objective function
-  `fn(y_true, y_pred, mask)`.
-
-  Arguments:
-      fn: The objective function to wrap,
-          with signature `fn(y_true, y_pred)`.
-
-  Returns:
-      A function with signature `fn(y_true, y_pred, mask)`.
-  """
-
-  def masked(y_true, y_pred, mask=None):
-    """Wrapper function.
-
-    Arguments:
-        y_true: `y_true` argument of `fn`.
-        y_pred: `y_pred` argument of `fn`.
-        mask: Mask tensor.
-
-    Returns:
-        Scalar tensor.
-    """
-    # score_array has ndim >= 2
-    score_array = fn(y_true, y_pred)
-    if mask is not None:
-      mask = K.cast(mask, K.floatx())
-      # mask should have the same shape as score_array
-      score_array *= mask
-      #  the loss per batch should be proportional
-      #  to the number of unmasked samples.
-      score_array /= K.mean(mask)
-
-    return K.mean(score_array)
-
-  return masked
-
-
 def _standardize_weights(y,
                          sample_weight=None,
                          class_weight=None,
@@ -607,19 +569,21 @@ class Model(Container):
               metrics=None,
               loss_weights=None,
               sample_weight_mode=None,
+              weighted_metrics=None,
+              target_tensors=None,
               **kwargs):
     """Configures the model for training.
 
     Arguments:
-        optimizer: str (name of optimizer) or optimizer object.
+        optimizer: String (name of optimizer) or optimizer object.
             See [optimizers](/optimizers).
-        loss: str (name of objective function) or objective function.
+        loss: String (name of objective function) or objective function.
             See [losses](/losses).
             If the model has multiple outputs, you can use a different loss
             on each output by passing a dictionary or a list of losses.
             The loss value that will be minimized by the model
             will then be the sum of all individual losses.
-        metrics: list of metrics to be evaluated by the model
+        metrics: List of metrics to be evaluated by the model
             during training and testing.
             Typically you will use `metrics=['accuracy']`.
             To specify different metrics for different outputs of a
@@ -634,18 +598,29 @@ class Model(Container):
             If a list, it is expected to have a 1:1 mapping
             to the model's outputs. If a tensor, it is expected to map
             output names (strings) to scalar coefficients.
-        sample_weight_mode: if you need to do timestep-wise
+        sample_weight_mode: If you need to do timestep-wise
             sample weighting (2D weights), set this to `"temporal"`.
             `None` defaults to sample-wise weights (1D).
             If the model has multiple outputs, you can use a different
             `sample_weight_mode` on each output by passing a
             dictionary or a list of modes.
-        **kwargs: Additional arguments passed to `tf.Session.run`.
+        weighted_metrics: List of metrics to be evaluated and weighted
+            by sample_weight or class_weight during training and testing.
+        target_tensors: By default, Keras will create placeholders for the
+            model's target, which will be fed with the target data during
+            training. If instead you would like to use your own
+            target tensors (in turn, Keras will not expect external
+            Numpy data for these targets at training time), you
+            can specify them via the `target_tensors` argument. It can be
+            a single tensor (for a single-output model), a list of tensors,
+            or a dict mapping output names to target tensors.
+        **kwargs: When using the Theano/CNTK backends, these arguments
+            are passed into K.function. When using the TensorFlow backend,
+            these arguments are passed into `tf.Session.run`.
 
     Raises:
         ValueError: In case of invalid arguments for
             `optimizer`, `loss`, `metrics` or `sample_weight_mode`.
-        RuntimeError: In case of ill-formulated optimization problem.
     """
     loss = loss or {}
     self.optimizer = optimizers.get(optimizer)
@@ -682,19 +657,16 @@ class Model(Container):
       loss_functions = [loss_function for _ in range(len(self.outputs))]
     self.loss_functions = loss_functions
     weighted_losses = [_weighted_masked_objective(fn) for fn in loss_functions]
-    skip_indices = []
+    skip_target_indices = []
+    skip_target_weighing_indices = []
     self._feed_outputs = []
     self._feed_output_names = []
     self._feed_output_shapes = []
     self._feed_loss_fns = []
     for i in range(len(weighted_losses)):
       if weighted_losses[i] is None:
-        skip_indices.append(i)
-      else:
-        self._feed_outputs.append(self.outputs[i])
-        self._feed_output_names.append(self.output_names[i])
-        self._feed_output_shapes.append(self.internal_output_shapes[i])
-        self._feed_loss_fns.append(self.loss_functions[i])
+        skip_target_indices.append(i)
+        skip_target_weighing_indices.append(i)
 
     # Prepare output masks.
     masks = self.compute_mask(self.inputs, mask=None)
@@ -728,6 +700,57 @@ class Model(Container):
       raise TypeError('Could not interpret loss_weights argument: ' +
                       str(loss_weights) + ' - expected a list of dicts.')
 
+    # Prepare targets of model.
+    self.targets = []
+    self._feed_targets = []
+    if target_tensors is not None:
+      if isinstance(target_tensors, list):
+        if len(target_tensors) != len(self.outputs):
+          raise ValueError('When passing a list as `target_tensors`, '
+                           'it should have one entry per model outputs. '
+                           'The model has ' + str(len(self.outputs)) +
+                           ' outputs, but you passed target_tensors=' +
+                           str(target_tensors))
+      elif isinstance(target_tensors, dict):
+        for name in target_tensors:
+          if name not in self.output_names:
+            raise ValueError('Unknown entry in `target_tensors` '
+                             'dictionary: "' + name + '". '
+                             'Only expected the following keys: ' +
+                             str(self.output_names))
+        target_tensors_ = []
+        for name in self.output_names:
+          target_tensors_.append(target_tensors.get(name, None))
+        target_tensors = target_tensors_
+      else:
+        raise TypeError('Expected `target_tensors` to be '
+                        'a list or dict, but got:', target_tensors)
+    for i in range(len(self.outputs)):
+      if i in skip_target_indices:
+        self.targets.append(None)
+      else:
+        shape = self.internal_output_shapes[i]
+        name = self.output_names[i]
+        if target_tensors is not None:
+          target = target_tensors[i]
+        else:
+          target = None
+        if target is None or K.is_placeholder(target):
+          if target is None:
+            target = K.placeholder(
+                ndim=len(shape),
+                name=name + '_target',
+                sparse=K.is_sparse(self.outputs[i]),
+                dtype=K.dtype(self.outputs[i]))
+          self._feed_targets.append(target)
+          self._feed_outputs.append(self.outputs[i])
+          self._feed_output_names.append(name)
+          self._feed_output_shapes.append(shape)
+          self._feed_loss_fns.append(self.loss_functions[i])
+        else:
+          skip_target_weighing_indices.append(i)
+        self.targets.append(target)
+
     # Prepare sample weights.
     sample_weights = []
     sample_weight_modes = []
@@ -739,7 +762,7 @@ class Model(Container):
                            'Only expected the following keys: ' +
                            str(self.output_names))
       for i, name in enumerate(self.output_names):
-        if i in skip_indices:
+        if i in skip_target_weighing_indices:
           weight = None
           sample_weight_modes.append(None)
         else:
@@ -762,7 +785,7 @@ class Model(Container):
                          ' outputs, but you passed '
                          'sample_weight_mode=' + str(sample_weight_mode))
       for i in range(len(self.output_names)):
-        if i in skip_indices:
+        if i in skip_target_weighing_indices:
           weight = None
           sample_weight_modes.append(None)
         else:
@@ -777,7 +800,7 @@ class Model(Container):
         sample_weights.append(weight)
     else:
       for i, name in enumerate(self.output_names):
-        if i in skip_indices:
+        if i in skip_target_weighing_indices:
           sample_weight_modes.append(None)
           sample_weights.append(None)
         else:
@@ -792,111 +815,112 @@ class Model(Container):
     self.sample_weight_modes = sample_weight_modes
     self._feed_sample_weight_modes = []
     for i in range(len(self.outputs)):
-      if i not in skip_indices:
+      if i not in skip_target_weighing_indices:
         self._feed_sample_weight_modes.append(self.sample_weight_modes[i])
 
-    # Prepare targets of model.
-    self.targets = []
-    self._feed_targets = []
-    for i in range(len(self.outputs)):
-      if i in skip_indices:
-        self.targets.append(None)
-      else:
-        shape = self.internal_output_shapes[i]
-        name = self.output_names[i]
-        target = K.placeholder(
-            ndim=len(shape),
-            name=name + '_target',
-            sparse=K.is_sparse(self.outputs[i]),
-            dtype=K.dtype(self.outputs[i]))
-        self.targets.append(target)
-        self._feed_targets.append(target)
-
     # Prepare metrics.
     self.metrics = metrics
+    self.weighted_metrics = weighted_metrics
     self.metrics_names = ['loss']
     self.metrics_tensors = []
 
     # Compute total loss.
     total_loss = None
-    for i in range(len(self.outputs)):
-      if i in skip_indices:
-        continue
-      y_true = self.targets[i]
-      y_pred = self.outputs[i]
-      weighted_loss = weighted_losses[i]
-      sample_weight = sample_weights[i]
-      mask = masks[i]
-      loss_weight = loss_weights_list[i]
-      output_loss = weighted_loss(y_true, y_pred, sample_weight, mask)
-      if len(self.outputs) > 1:
-        self.metrics_tensors.append(output_loss)
-        self.metrics_names.append(self.output_names[i] + '_loss')
+    with K.name_scope('loss'):
+      for i in range(len(self.outputs)):
+        if i in skip_target_indices:
+          continue
+        y_true = self.targets[i]
+        y_pred = self.outputs[i]
+        weighted_loss = weighted_losses[i]
+        sample_weight = sample_weights[i]
+        mask = masks[i]
+        loss_weight = loss_weights_list[i]
+        with K.name_scope(self.output_names[i] + '_loss'):
+          output_loss = weighted_loss(y_true, y_pred, sample_weight, mask)
+        if len(self.outputs) > 1:
+          self.metrics_tensors.append(output_loss)
+          self.metrics_names.append(self.output_names[i] + '_loss')
+        if total_loss is None:
+          total_loss = loss_weight * output_loss
+        else:
+          total_loss += loss_weight * output_loss
       if total_loss is None:
-        total_loss = loss_weight * output_loss
-      else:
-        total_loss += loss_weight * output_loss
-    if total_loss is None:
-      if not self.losses:
-        raise RuntimeError('The model cannot be compiled '
+        if not self.losses:
+          raise ValueError('The model cannot be compiled '
                            'because it has no loss to optimize.')
-      else:
-        total_loss = 0.
+        else:
+          total_loss = 0.
 
-    # Add regularization penalties
-    # and other layer-specific losses.
-    for loss_tensor in self.losses:
-      total_loss += loss_tensor
+      # Add regularization penalties
+      # and other layer-specific losses.
+      for loss_tensor in self.losses:
+        total_loss += loss_tensor
 
     # List of same size as output_names.
     # contains tuples (metrics for output, names of metrics).
     nested_metrics = _collect_metrics(metrics, self.output_names)
+    nested_weighted_metrics = _collect_metrics(weighted_metrics,
+                                               self.output_names)
 
-    def append_metric(layer_num, metric_name, metric_tensor):
+    def append_metric(layer_index, metric_name, metric_tensor):
       """Helper function used in loop below."""
       if len(self.output_names) > 1:
-        metric_name = self._output_layers[layer_num].name + '_' + metric_name
+        metric_name = self.output_names[layer_index] + '_' + metric_name
       self.metrics_names.append(metric_name)
       self.metrics_tensors.append(metric_tensor)
 
-    for i in range(len(self.outputs)):
-      if i in skip_indices:
-        continue
-      y_true = self.targets[i]
-      y_pred = self.outputs[i]
-      output_metrics = nested_metrics[i]
-      for metric in output_metrics:
-        if metric == 'accuracy' or metric == 'acc':
-          # custom handling of accuracy
-          # (because of class mode duality)
-          output_shape = self.internal_output_shapes[i]
-          acc_fn = None
-          if (output_shape[-1] == 1 or
-              self.loss_functions[i] == losses.binary_crossentropy):
-            # case: binary accuracy
-            acc_fn = metrics_module.binary_accuracy
-          elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
-            # case: categorical accuracy with sparse targets
-            acc_fn = metrics_module.sparse_categorical_accuracy
-          else:
-            acc_fn = metrics_module.categorical_accuracy
+    with K.name_scope('metrics'):
+      for i in range(len(self.outputs)):
+        if i in skip_target_indices:
+          continue
 
-          masked_fn = _masked_objective(acc_fn)
-          append_metric(i, 'acc', masked_fn(y_true, y_pred, mask=masks[i]))
-        else:
-          metric_fn = metrics_module.get(metric)
-          masked_metric_fn = _masked_objective(metric_fn)
-          metric_result = masked_metric_fn(y_true, y_pred, mask=masks[i])
-          metric_result = {metric_fn.__name__: metric_result}
-          for name, tensor in six.iteritems(metric_result):
-            append_metric(i, name, tensor)
+        y_true = self.targets[i]
+        y_pred = self.outputs[i]
+        weights = sample_weights[i]
+        output_metrics = nested_metrics[i]
+        output_weighted_metrics = nested_weighted_metrics[i]
+
+        def handle_metrics(metrics, weights=None):
+          metric_name_prefix = 'weighted_' if weights is not None else ''
+
+          for metric in metrics:
+            if metric == 'accuracy' or metric == 'acc':
+              # custom handling of accuracy
+              # (because of class mode duality)
+              output_shape = self.internal_output_shapes[i]
+              if (output_shape[-1] == 1 or
+                  self.loss_functions[i] == losses.binary_crossentropy):
+                # case: binary accuracy
+                acc_fn = metrics_module.binary_accuracy
+              elif self.loss_functions[
+                  i] == losses.sparse_categorical_crossentropy:
+                # case: categorical accuracy with sparse targets
+                acc_fn = metrics_module.sparse_categorical_accuracy
+              else:
+                acc_fn = metrics_module.categorical_accuracy
+
+              weighted_metric_fn = _weighted_masked_objective(acc_fn)
+              metric_name = metric_name_prefix + 'acc'
+            else:
+              metric_fn = metrics_module.get(metric)
+              weighted_metric_fn = _weighted_masked_objective(metric_fn)
+              metric_name = metric_name_prefix + metric_fn.__name__
+
+            with K.name_scope(metric_name):
+              metric_result = weighted_metric_fn(
+                  y_true, y_pred, weights=weights, mask=masks[i])
+            append_metric(i, metric_name, metric_result)
+
+        handle_metrics(output_metrics)
+        handle_metrics(output_weighted_metrics, weights=weights)
 
     # Prepare gradient updates and state updates.
     self.total_loss = total_loss
     self.sample_weights = sample_weights
     self._feed_sample_weights = []
     for i in range(len(self.sample_weights)):
-      if i not in skip_indices:
+      if i not in skip_target_weighing_indices:
         self._feed_sample_weights.append(sample_weights[i])
 
     # Functions for train, test and predict will
@@ -917,30 +941,30 @@ class Model(Container):
       raise RuntimeError('You must compile your model before using it.')
     if self.train_function is None:
       inputs = (self._feed_inputs +
-                self._feed_targets + self._feed_sample_weights)
+                self._feed_targets +
+                self._feed_sample_weights)
       if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
         inputs += [K.learning_phase()]
 
-      constraints = {}
-      for w in self._collected_trainable_weights:
-        if hasattr(w, 'constraint') and w.constraint is not None:
-          constraints[w] = w.constraint
-      training_updates = self.optimizer.get_updates(
-          self._collected_trainable_weights, constraints, self.total_loss)
-      updates = self.updates + training_updates
-      # Gets loss and metrics. Updates weights at each call.
-      self.train_function = K.function(
-          inputs, [self.total_loss] + self.metrics_tensors,
-          updates=updates,
-          name='train_function',
-          **self._function_kwargs)
+      with K.name_scope('training'):
+        with K.name_scope(self.optimizer.__class__.__name__):
+          training_updates = self.optimizer.get_updates(
+              params=self._collected_trainable_weights, loss=self.total_loss)
+        updates = self.updates + training_updates
+        # Gets loss and metrics. Updates weights at each call.
+        self.train_function = K.function(
+            inputs, [self.total_loss] + self.metrics_tensors,
+            updates=updates,
+            name='train_function',
+            **self._function_kwargs)
 
   def _make_test_function(self):
     if not hasattr(self, 'test_function'):
       raise RuntimeError('You must compile your model before using it.')
     if self.test_function is None:
       inputs = (self._feed_inputs +
-                self._feed_targets + self._feed_sample_weights)
+                self._feed_targets +
+                self._feed_sample_weights)
       if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
         inputs += [K.learning_phase()]
       # Return loss and metrics, no gradient updates.
@@ -969,11 +993,54 @@ class Model(Container):
           name='predict_function',
           **kwargs)
 
+  def _check_num_samples(self,
+                         ins,
+                         batch_size=None,
+                         steps=None,
+                         steps_name='steps'):
+    """Determine the number of samples provided for training and evaluation.
+
+    The number of samples is not defined when running with `steps`,
+    in which case the number of samples is set to `None`.
+
+    Arguments:
+        ins: List of tensors to be fed to the Keras function.
+        batch_size: Integer batch size or `None` if not defined.
+        steps: Total number of steps (batches of samples)
+            before declaring `_predict_loop` finished.
+            Ignored with the default value of `None`.
+        steps_name: The public API's parameter name for `steps`.
+
+    Raises:
+        ValueError: when `steps` is `None` and the attribute `ins.shape`
+        does not exist. Also raises ValueError when `steps` is not `None`
+        and `batch_size` is not `None` because they are mutually
+        exclusive.
+
+    Returns:
+        When steps is `None`, returns the number of samples to be
+        processed based on the size of the first dimension of the
+        first input numpy array. When steps is not `None` and
+        `batch_size` is `None`, returns `None`.
+    """
+    if steps is not None:
+      num_samples = None
+      if batch_size is not None:
+        raise ValueError('If ' + steps_name +
+                         ' is set, the `batch_size` must be None.')
+    elif ins and hasattr(ins[0], 'shape'):
+      num_samples = ins[0].shape[0]
+    else:
+      raise ValueError('Either the input data should have '
+                       'a defined shape, or ' + steps_name +
+                       ' should be specified.')
+    return num_samples
+
   def _fit_loop(self,
                 f,
                 ins,
                 out_labels=None,
-                batch_size=32,
+                batch_size=None,
                 epochs=100,
                 verbose=1,
                 callbacks=None,
@@ -981,55 +1048,70 @@ class Model(Container):
                 val_ins=None,
                 shuffle=True,
                 callback_metrics=None,
-                initial_epoch=0):
+                initial_epoch=0,
+                steps_per_epoch=None,
+                validation_steps=None):
     """Abstract fit function for `f(ins)`.
 
     Assume that f returns a list, labeled by out_labels.
 
     Arguments:
         f: Keras function returning a list of tensors
-        ins: list of tensors to be fed to `f`
-        out_labels: list of strings, display names of
+        ins: List of tensors to be fed to `f`
+        out_labels: List of strings, display names of
             the outputs of `f`
-        batch_size: integer batch size
-        epochs: number of times to iterate over the data
-        verbose: verbosity mode, 0, 1 or 2
-        callbacks: list of callbacks to be called during training
+        batch_size: Integer batch size or None if unknown.
+        epochs: Number of times to iterate over the data
+        verbose: Verbosity mode, 0, 1 or 2
+        callbacks: List of callbacks to be called during training
         val_f: Keras function to call for validation
-        val_ins: list of tensors to be fed to `val_f`
-        shuffle: whether to shuffle the data at the beginning of each epoch
-        callback_metrics: list of strings, the display names of the metrics
+        val_ins: List of tensors to be fed to `val_f`
+        shuffle: Whether to shuffle the data at the beginning of each epoch
+        callback_metrics: List of strings, the display names of the metrics
             passed to the callbacks. They should be the
             concatenation of list the display names of the outputs of
              `f` and the list of display names of the outputs of `f_val`.
-        initial_epoch: epoch at which to start training
+        initial_epoch: Epoch at which to start training
             (useful for resuming a previous training run)
+        steps_per_epoch: Total number of steps (batches of samples)
+            before declaring one epoch finished and starting the
+            next epoch. Ignored with the default value of `None`.
+        validation_steps: Number of steps to run validation for (only if doing
+          validation from data tensors). Ignored with default value of `None`.
 
     Returns:
         `History` object.
+
+    Raises:
+      ValueError: In case of invalid argument values.
     """
     do_validation = False
     if val_f and val_ins:
       do_validation = True
-      if verbose:
+      if (verbose and ins and
+          hasattr(ins[0], 'shape') and hasattr(val_ins[0], 'shape')):
         print('Train on %d samples, validate on %d samples' %
               (ins[0].shape[0], val_ins[0].shape[0]))
+    if validation_steps:
+      if steps_per_epoch is None:
+        raise ValueError('Can only use `validation_steps` when doing step-wise '
+                         'training, i.e. `steps_per_epoch` must be set.')
+      do_validation = True
 
-    if ins and hasattr(ins[0], 'shape'):
-      num_train_samples = ins[0].shape[0]
-    else:
-      # May happen if we are running `fit` without Numpy input data,
-      # i.e. if all inputs to the models are data tensors
-      # instead of placeholders.
-      # In that case we will run `fit` over a single batch.
-      num_train_samples = batch_size
-      verbose = 2
-    index_array = np.arange(num_train_samples)
+    num_train_samples = self._check_num_samples(
+        ins, batch_size, steps_per_epoch, 'steps_per_epoch')
+
+    if num_train_samples is not None:
+      index_array = np.arange(num_train_samples)
 
     self.history = cbks.History()
     callbacks = [cbks.BaseLogger()] + (callbacks or []) + [self.history]
     if verbose:
-      callbacks += [cbks.ProgbarLogger()]
+      if steps_per_epoch is not None:
+        count_mode = 'steps'
+      else:
+        count_mode = 'samples'
+      callbacks += [cbks.ProgbarLogger(count_mode)]
     callbacks = cbks.CallbackList(callbacks)
     out_labels = out_labels or []
 
@@ -1044,6 +1126,7 @@ class Model(Container):
     callbacks.set_params({
         'batch_size': batch_size,
         'epochs': epochs,
+        'steps': steps_per_epoch,
         'samples': num_train_samples,
         'verbose': verbose,
         'do_validation': do_validation,
@@ -1056,55 +1139,85 @@ class Model(Container):
 
     for epoch in range(initial_epoch, epochs):
       callbacks.on_epoch_begin(epoch)
-      if shuffle == 'batch':
-        index_array = _batch_shuffle(index_array, batch_size)
-      elif shuffle:
-        np.random.shuffle(index_array)
-
-      batches = _make_batches(num_train_samples, batch_size)
       epoch_logs = {}
-      for batch_index, (batch_start, batch_end) in enumerate(batches):
-        batch_ids = index_array[batch_start:batch_end]
-        try:
-          if isinstance(ins[-1], float):
-            # Do not slice the training phase flag.
-            ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
-          else:
-            ins_batch = _slice_arrays(ins, batch_ids)
-        except TypeError:
-          raise TypeError('TypeError while preparing batch. '
-                          'If using HDF5 input data, '
-                          'pass shuffle="batch".')
-        batch_logs = {}
-        batch_logs['batch'] = batch_index
-        batch_logs['size'] = len(batch_ids)
-        callbacks.on_batch_begin(batch_index, batch_logs)
-        outs = f(ins_batch)
-        if not isinstance(outs, list):
-          outs = [outs]
-        for l, o in zip(out_labels, outs):
-          batch_logs[l] = o
+      if steps_per_epoch is not None:
+        for step_index in range(steps_per_epoch):
+          batch_logs = {}
+          batch_logs['batch'] = step_index
+          batch_logs['size'] = 1
+          callbacks.on_batch_begin(step_index, batch_logs)
+          outs = f(ins)
 
-        callbacks.on_batch_end(batch_index, batch_logs)
-        if callback_model.stop_training:
-          break
+          if not isinstance(outs, list):
+            outs = [outs]
+          for l, o in zip(out_labels, outs):
+            batch_logs[l] = o
 
-        if batch_index == len(batches) - 1:  # Last batch.
-          if do_validation:
-            val_outs = self._test_loop(
-                val_f, val_ins, batch_size=batch_size, verbose=0)
-            if not isinstance(val_outs, list):
-              val_outs = [val_outs]
-            # Same labels assumed.
-            for l, o in zip(out_labels, val_outs):
-              epoch_logs['val_' + l] = o
+          callbacks.on_batch_end(step_index, batch_logs)
+          if callback_model.stop_training:
+            break
+
+        if do_validation:
+          val_outs = self._test_loop(
+              val_f,
+              val_ins,
+              batch_size=batch_size,
+              steps=validation_steps,
+              verbose=0)
+          if not isinstance(val_outs, list):
+            val_outs = [val_outs]
+          # Same labels assumed.
+          for l, o in zip(out_labels, val_outs):
+            epoch_logs['val_' + l] = o
+      else:
+        if shuffle == 'batch':
+          index_array = _batch_shuffle(index_array, batch_size)
+        elif shuffle:
+          np.random.shuffle(index_array)
+
+        batches = _make_batches(num_train_samples, batch_size)
+        for batch_index, (batch_start, batch_end) in enumerate(batches):
+          batch_ids = index_array[batch_start:batch_end]
+          try:
+            if isinstance(ins[-1], float):
+              # Do not slice the training phase flag.
+              ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
+            else:
+              ins_batch = _slice_arrays(ins, batch_ids)
+          except TypeError:
+            raise TypeError('TypeError while preparing batch. '
+                            'If using HDF5 input data, '
+                            'pass shuffle="batch".')
+          batch_logs = {}
+          batch_logs['batch'] = batch_index
+          batch_logs['size'] = len(batch_ids)
+          callbacks.on_batch_begin(batch_index, batch_logs)
+          outs = f(ins_batch)
+          if not isinstance(outs, list):
+            outs = [outs]
+          for l, o in zip(out_labels, outs):
+            batch_logs[l] = o
+
+          callbacks.on_batch_end(batch_index, batch_logs)
+          if callback_model.stop_training:
+            break
+
+          if batch_index == len(batches) - 1:  # Last batch.
+            if do_validation:
+              val_outs = self._test_loop(
+                  val_f, val_ins, batch_size=batch_size, verbose=0)
+              if not isinstance(val_outs, list):
+                val_outs = [val_outs]
+              # Same labels assumed.
+              for l, o in zip(out_labels, val_outs):
+                epoch_logs['val_' + l] = o
       callbacks.on_epoch_end(epoch, epoch_logs)
       if callback_model.stop_training:
         break
     callbacks.on_train_end()
     return self.history
 
-  def _predict_loop(self, f, ins, batch_size=32, verbose=0):
+  def _predict_loop(self, f, ins, batch_size=32, verbose=0, steps=None):
     """Abstract method to loop over some data in batches.
 
     Arguments:
@@ -1112,58 +1225,85 @@ class Model(Container):
         ins: list of tensors to be fed to `f`.
         batch_size: integer batch size.
         verbose: verbosity mode.
+        steps: Total number of steps (batches of samples)
+            before declaring `_predict_loop` finished.
+            Ignored with the default value of `None`.
 
     Returns:
         Array of predictions (if the model has a single output)
         or list of arrays of predictions
         (if the model has multiple outputs).
     """
-    if ins and hasattr(ins[0], 'shape'):
-      samples = ins[0].shape[0]
-    else:
-      # May happen if we are running `predict` without Numpy input data,
-      # i.e. if all inputs to the models are data tensors
-      # instead of placeholders.
-      # In that case we will run `predict` over a single batch.
-      samples = batch_size
-      verbose = 2
-    outs = []
+    num_samples = self._check_num_samples(ins, batch_size, steps, 'steps')
     if verbose == 1:
-      progbar = Progbar(target=samples)
-    batches = _make_batches(samples, batch_size)
-    index_array = np.arange(samples)
-    for batch_index, (batch_start, batch_end) in enumerate(batches):
-      batch_ids = index_array[batch_start:batch_end]
-      if ins and isinstance(ins[-1], float):
-        # Do not slice the training phase flag.
-        ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
+      if steps is not None:
+        progbar = Progbar(target=steps)
       else:
-        ins_batch = _slice_arrays(ins, batch_ids)
-
-      batch_outs = f(ins_batch)
-      if not isinstance(batch_outs, list):
-        batch_outs = [batch_outs]
-      if batch_index == 0:
-        for batch_out in batch_outs:
-          shape = (samples,) + batch_out.shape[1:]
-          outs.append(np.zeros(shape, dtype=batch_out.dtype))
-
-      for i, batch_out in enumerate(batch_outs):
-        outs[i][batch_start:batch_end] = batch_out
-      if verbose == 1:
-        progbar.update(batch_end)
-    if len(outs) == 1:
-      return outs[0]
-    return outs
+        progbar = Progbar(target=num_samples)
+    if steps is not None:
+      # Step-based predictions.
+      # Since we do not know how many samples
+      # we will see, we cannot pre-allocate
+      # the returned Numpy arrays.
+      # Instead, we store one array per batch seen
+      # and concatenate them upon returning.
+      unconcatenated_outs = []
+      for step in range(steps):
+        batch_outs = f(ins)
+        if not isinstance(batch_outs, list):
+          batch_outs = [batch_outs]
+        if step == 0:
+          for batch_out in batch_outs:
+            unconcatenated_outs.append([])
+        for i, batch_out in enumerate(batch_outs):
+          unconcatenated_outs[i].append(batch_out)
+        if verbose == 1:
+          progbar.update(step)
+      if len(unconcatenated_outs) == 1:
+        return np.concatenate(unconcatenated_outs[0], axis=0)
+      return [
+          np.concatenate(unconcatenated_outs[i], axis=0)
+          for i in range(len(unconcatenated_outs))
+      ]
+    else:
+      # Sample-based predictions.
+      outs = []
+      batches = _make_batches(num_samples, batch_size)
+      index_array = np.arange(num_samples)
+      for batch_index, (batch_start, batch_end) in enumerate(batches):
+        batch_ids = index_array[batch_start:batch_end]
+        if ins and isinstance(ins[-1], float):
+          # Do not slice the training phase flag.
+          ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
+        else:
+          ins_batch = _slice_arrays(ins, batch_ids)
+        batch_outs = f(ins_batch)
+        if not isinstance(batch_outs, list):
+          batch_outs = [batch_outs]
+        if batch_index == 0:
+          # Pre-allocate the results arrays.
+          for batch_out in batch_outs:
+            shape = (num_samples,) + batch_out.shape[1:]
+            outs.append(np.zeros(shape, dtype=batch_out.dtype))
+        for i, batch_out in enumerate(batch_outs):
+          outs[i][batch_start:batch_end] = batch_out
+        if verbose == 1:
+          progbar.update(batch_end)
+      if len(outs) == 1:
+        return outs[0]
+      return outs
 
-  def _test_loop(self, f, ins, batch_size=32, verbose=0):
+  def _test_loop(self, f, ins, batch_size=None, verbose=0, steps=None):
     """Abstract method to loop over some data in batches.
 
     Arguments:
         f: Keras function returning a list of tensors.
         ins: list of tensors to be fed to `f`.
-        batch_size: integer batch size.
+        batch_size: integer batch size or `None`.
         verbose: verbosity mode.
+        steps: Total number of steps (batches of samples)
+            before declaring predictions finished.
+            Ignored with the default value of `None`.
 
     Returns:
         Scalar loss (if the model has a single output and no metrics)
@@ -1171,45 +1311,56 @@ class Model(Container):
         and/or metrics). The attribute `model.metrics_names` will give you
         the display labels for the scalar outputs.
     """
-    if ins and hasattr(ins[0], 'shape'):
-      samples = ins[0].shape[0]
-    else:
-      # May happen if we are running `evaluate` without Numpy input data,
-      # i.e. if all inputs to the models are data tensors
-      # instead of placeholders.
-      # In that case we will run `evaluate` over a single batch.
-      samples = batch_size
-      verbose = 2
-
+    num_samples = self._check_num_samples(ins, batch_size, steps, 'steps')
     outs = []
-    if verbose == 1:
-      progbar = Progbar(target=samples)
-    batches = _make_batches(samples, batch_size)
-    index_array = np.arange(samples)
-    for batch_index, (batch_start, batch_end) in enumerate(batches):
-      batch_ids = index_array[batch_start:batch_end]
-      if isinstance(ins[-1], float):
-        # Do not slice the training phase flag.
-        ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
-      else:
-        ins_batch = _slice_arrays(ins, batch_ids)
-
-      batch_outs = f(ins_batch)
-      if isinstance(batch_outs, list):
-        if batch_index == 0:
-          for batch_out in enumerate(batch_outs):
+    if steps is not None:
+      if verbose == 1:
+        progbar = Progbar(target=steps)
+      for step in range(steps):
+        batch_outs = f(ins)
+        if isinstance(batch_outs, list):
+          if step == 0:
+            for _ in enumerate(batch_outs):
+              outs.append(0.)
+          for i, batch_out in enumerate(batch_outs):
+            outs[i] += batch_out
+        else:
+          if step == 0:
             outs.append(0.)
-        for i, batch_out in enumerate(batch_outs):
-          outs[i] += batch_out * len(batch_ids)
-      else:
-        if batch_index == 0:
-          outs.append(0.)
-        outs[0] += batch_outs * len(batch_ids)
-
+          outs[0] += batch_outs
+        if verbose == 1:
+          progbar.update(step)
+      for i in range(len(outs)):
+        outs[i] /= steps
+    else:
       if verbose == 1:
-        progbar.update(batch_end)
-    for i in range(len(outs)):
-      outs[i] /= samples
+        progbar = Progbar(target=num_samples)
+      batches = _make_batches(num_samples, batch_size)
+      index_array = np.arange(num_samples)
+      for batch_index, (batch_start, batch_end) in enumerate(batches):
+        batch_ids = index_array[batch_start:batch_end]
+        if isinstance(ins[-1], float):
+          # Do not slice the training phase flag.
+          ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
+        else:
+          ins_batch = _slice_arrays(ins, batch_ids)
+
+        batch_outs = f(ins_batch)
+        if isinstance(batch_outs, list):
+          if batch_index == 0:
+            for batch_out in enumerate(batch_outs):
+              outs.append(0.)
+          for i, batch_out in enumerate(batch_outs):
+            outs[i] += batch_out * len(batch_ids)
+        else:
+          if batch_index == 0:
+            outs.append(0.)
+          outs[0] += batch_outs * len(batch_ids)
+
+        if verbose == 1:
+          progbar.update(batch_end)
+      for i in range(len(outs)):
+        outs[i] /= num_samples
     if len(outs) == 1:
       return outs[0]
     return outs
@@ -1285,7 +1436,7 @@ class Model(Container):
   def fit(self,
           x=None,
           y=None,
-          batch_size=32,
+          batch_size=None,
           epochs=1,
           verbose=1,
           callbacks=None,
@@ -1294,7 +1445,9 @@ class Model(Container):
           shuffle=True,
           class_weight=None,
           sample_weight=None,
-          initial_epoch=0):
+          initial_epoch=0,
+          steps_per_epoch=None,
+          validation_steps=None):
     """Trains the model for a fixed number of epochs (iterations on a dataset).
 
     Arguments:
@@ -1308,42 +1461,54 @@ class Model(Container):
             If all outputs in the model are named,
             you can also pass a dictionary
             mapping output names to Numpy arrays.
-        batch_size: integer. Number of samples per gradient update.
-        epochs: integer, the number of times to iterate
+        batch_size: Integer or `None`.
+            Number of samples per gradient update.
+            If unspecified, it will default to 32.
+        epochs: Integer, the number of times to iterate
             over the training data arrays.
         verbose: 0, 1, or 2. Verbosity mode.
             0 = silent, 1 = verbose, 2 = one log line per epoch.
-        callbacks: list of callbacks to be called during training.
+        callbacks: List of callbacks to be called during training.
             See [callbacks](/callbacks).
-        validation_split: float between 0 and 1:
+        validation_split: Float between 0 and 1:
             fraction of the training data to be used as validation data.
             The model will set apart this fraction of the training data,
             will not train on it, and will evaluate
             the loss and any model metrics
             on this data at the end of each epoch.
-        validation_data: data on which to evaluate
+        validation_data: Data on which to evaluate
             the loss and any model metrics
             at the end of each epoch. The model will not
             be trained on this data.
             This could be a tuple (x_val, y_val)
             or a tuple (x_val, y_val, val_sample_weights).
-        shuffle: boolean, whether to shuffle the training data
-            before each epoch.
-        class_weight: optional dictionary mapping
+        shuffle: Boolean, whether to shuffle the training data
+            before each epoch. Has no effect when `steps_per_epoch`
+            is not `None`.
+        class_weight: Optional dictionary mapping
             class indices (integers) to
             a weight (float) to apply to the model's loss for the samples
             from this class during training.
             This can be useful to tell the model to "pay more attention" to
             samples from an under-represented class.
-        sample_weight: optional array of the same length as x, containing
+        sample_weight: Optional array of the same length as x, containing
             weights to apply to the model's loss for each sample.
             In the case of temporal data, you can pass a 2D array
             with shape (samples, sequence_length),
             to apply a different weight to every timestep of every sample.
             In this case you should make sure to specify
             sample_weight_mode="temporal" in compile().
-        initial_epoch: epoch at which to start training
+        initial_epoch: Epoch at which to start training
             (useful for resuming a previous training run)
+        steps_per_epoch: Total number of steps (batches of samples)
+            before declaring one epoch finished and starting the
+            next epoch. When training with Input Tensors such as
+            TensorFlow data tensors, the default `None` is equal to
+            the number of unique samples in your dataset divided by
+            the batch size, or 1 if that cannot be determined.
+        validation_steps: Only relevant if `steps_per_epoch`
+            is specified. Total number of steps (batches of samples)
+            to validate before stopping.
 
     Returns:
         A `History` instance. Its `history` attribute contains
@@ -1353,7 +1518,13 @@ class Model(Container):
         ValueError: In case of mismatch between the provided input data
             and what the model expects.
     """
-
+    # Backwards compatibility
+    if batch_size is None and steps_per_epoch is None:
+      batch_size = 32
+    if x is None and y is None and steps_per_epoch is None:
+      raise ValueError('If fitting from data tensors, '
+                       'you should specify the `steps_per_epoch` '
+                       'argument.')
     # Validate user data.
     x, y, sample_weights = self._standardize_user_data(
         x,
@@ -1362,7 +1533,10 @@ class Model(Container):
         class_weight=class_weight,
         check_batch_axis=False,
         batch_size=batch_size)
+
     # Prepare validation data.
+    do_validation = False
+    val_ins = []
     if validation_data:
       do_validation = True
       if len(validation_data) == 2:
@@ -1383,8 +1557,6 @@ class Model(Container):
           sample_weight=val_sample_weight,
           check_batch_axis=False,
           batch_size=batch_size)
-      self._make_test_function()
-      val_f = self.test_function
       if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
         val_ins = val_x + val_y + val_sample_weights + [0.]
       else:
@@ -1400,16 +1572,15 @@ class Model(Container):
       y, val_y = (_slice_arrays(y, 0, split_at), _slice_arrays(y, split_at))
       sample_weights, val_sample_weights = (_slice_arrays(
           sample_weights, 0, split_at), _slice_arrays(sample_weights, split_at))
-      self._make_test_function()
-      val_f = self.test_function
       if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
         val_ins = val_x + val_y + val_sample_weights + [0.]
       else:
         val_ins = val_x + val_y + val_sample_weights
-    else:
-      do_validation = False
-      val_f = None
-      val_ins = None
+
+    elif validation_steps:
+      do_validation = True
+      if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
+        val_ins = [0.]
 
     # Prepare input arrays and training function.
     if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
@@ -1423,10 +1594,13 @@ class Model(Container):
     out_labels = self._get_deduped_metrics_names()
 
     if do_validation:
+      self._make_test_function()
+      val_f = self.test_function
       callback_metrics = copy.copy(out_labels) + [
           'val_' + n for n in out_labels
       ]
     else:
+      val_f = None
       callback_metrics = copy.copy(out_labels)
 
     # Delegate logic to `_fit_loop`.
@@ -1442,9 +1616,17 @@ class Model(Container):
         val_ins=val_ins,
         shuffle=shuffle,
         callback_metrics=callback_metrics,
-        initial_epoch=initial_epoch)
-
-  def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None):
+        initial_epoch=initial_epoch,
+        steps_per_epoch=steps_per_epoch,
+        validation_steps=validation_steps)
+
+  def evaluate(self,
+               x,
+               y,
+               batch_size=None,
+               verbose=1,
+               sample_weight=None,
+               steps=None):
     """Returns the loss value & metrics values for the model in test mode.
 
     Computation is done in batches.
@@ -1460,17 +1642,30 @@ class Model(Container):
             If all outputs in the model are named,
             you can also pass a dictionary
             mapping output names to Numpy arrays.
-        batch_size: integer. Number of samples per gradient update.
-        verbose: verbosity mode, 0 or 1.
+        batch_size: Integer. If unspecified, it will default to 32.
+        verbose: Verbosity mode, 0 or 1.
         sample_weight: Array of weights to weight the contribution
             of different samples to the loss and metrics.
+        steps: Total number of steps (batches of samples)
+            before declaring the evaluation round finished.
+            Ignored with the default value of `None`.
 
     Returns:
         Scalar test loss (if the model has a single output and no metrics)
         or list of scalars (if the model has multiple outputs
         and/or metrics). The attribute `model.metrics_names` will give you
         the display labels for the scalar outputs.
+
+    Raises:
+      ValueError: In case of invalid argument values.
     """
+    # Backwards compatibility.
+    if batch_size is None and steps is None:
+      batch_size = 32
+    if x is None and y is None and steps is None:
+      raise ValueError('If evaluating from data tensors, '
+                       'you should specify the `steps` '
+                       'argument.')
     # Validate user data.
     x, y, sample_weights = self._standardize_user_data(
         x,
@@ -1485,18 +1680,22 @@ class Model(Container):
       ins = x + y + sample_weights
     self._make_test_function()
     f = self.test_function
-    return self._test_loop(f, ins, batch_size=batch_size, verbose=verbose)
+    return self._test_loop(
+        f, ins, batch_size=batch_size, verbose=verbose, steps=steps)
 
-  def predict(self, x, batch_size=32, verbose=0):
+  def predict(self, x, batch_size=None, verbose=0, steps=None):
     """Generates output predictions for the input samples.
 
     Computation is done in batches.
 
     Arguments:
-        x: the input data, as a Numpy array
+        x: The input data, as a Numpy array
             (or list of Numpy arrays if the model has multiple outputs).
-        batch_size: integer.
-        verbose: verbosity mode, 0 or 1.
+        batch_size: Integer. If unspecified, it will default to 32.
+        verbose: Verbosity mode, 0 or 1.
+        steps: Total number of steps (batches of samples)
+            before declaring the prediction round finished.
+            Ignored with the default value of `None`.
 
     Returns:
         Numpy array(s) of predictions.
@@ -1507,6 +1706,13 @@ class Model(Container):
             or in case a stateful model receives a number of samples
             that is not a multiple of the batch size.
     """
+    # Backwards compatibility.
+    if batch_size is None and steps is None:
+      batch_size = 32
+    if x is None and steps is None:
+      raise ValueError('If predicting from data tensors, '
+                       'you should specify the `steps` '
+                       'argument.')
     # Validate user data.
     x = _standardize_input_data(
         x,
@@ -1529,7 +1735,8 @@ class Model(Container):
       ins = x
     self._make_predict_function()
     f = self.predict_function
-    return self._predict_loop(f, ins, batch_size=batch_size, verbose=verbose)
+    return self._predict_loop(
+        f, ins, batch_size=batch_size, verbose=verbose, steps=steps)
 
   def train_on_batch(self, x, y, sample_weight=None, class_weight=None):
     """Runs a single gradient update on a single batch of data.
@@ -1545,14 +1752,14 @@ class Model(Container):
             If all outputs in the model are named,
             you can also pass a dictionary
             mapping output names to Numpy arrays.
-        sample_weight: optional array of the same length as x, containing
+        sample_weight: Optional array of the same length as x, containing
             weights to apply to the model's loss for each sample.
             In the case of temporal data, you can pass a 2D array
             with shape (samples, sequence_length),
             to apply a different weight to every timestep of every sample.
             In this case you should make sure to specify
             sample_weight_mode="temporal" in compile().
-        class_weight: optional dictionary mapping
+        class_weight: Optional dictionary mapping
             class indices (integers) to
             a weight (float) to apply to the model's loss for the samples
             from this class during training.
@@ -1596,7 +1803,7 @@ class Model(Container):
             If all outputs in the model are named,
             you can also pass a dictionary
             mapping output names to Numpy arrays.
-        sample_weight: optional array of the same length as x, containing
+        sample_weight: Optional array of the same length as x, containing
             weights to apply to the model's loss for each sample.
             In the case of temporal data, you can pass a 2D array
             with shape (samples, sequence_length),
@@ -1655,6 +1862,7 @@ class Model(Container):
                     max_queue_size=10,
                     workers=1,
                     use_multiprocessing=False,
+                    shuffle=True,
                     initial_epoch=0,
                     **kwargs):
     """Fits the model on data yielded batch-by-batch by a Python generator.
@@ -1668,7 +1876,7 @@ class Model(Container):
     using `use_multiprocessing=True`.
 
     Arguments:
-        generator: a generator or an instance of Sequence (keras.utils.Sequence)
+        generator: A generator or an instance of Sequence (keras.utils.Sequence)
                 object in order to avoid duplicate data
                 when using multiprocessing.
             The output of the generator must be either
@@ -1683,29 +1891,32 @@ class Model(Container):
             finished and starting the next epoch. It should typically
             be equal to the number of unique samples if your dataset
             divided by the batch size.
-        epochs: integer, total number of iterations on the data.
-        verbose: verbosity mode, 0, 1, or 2.
-        callbacks: list of callbacks to be called during training.
-        validation_data: this can be either
+        epochs: Integer, total number of iterations on the data.
+        verbose: Verbosity mode, 0, 1, or 2.
+        callbacks: List of callbacks to be called during training.
+        validation_data: This can be either
             - a generator for the validation data
             - a tuple (inputs, targets)
             - a tuple (inputs, targets, sample_weights).
         validation_steps: Only relevant if `validation_data`
             is a generator. Total number of steps (batches of samples)
             to yield from `generator` before stopping.
-        class_weight: dictionary mapping class indices to a weight
+        class_weight: Dictionary mapping class indices to a weight
             for the class.
-        max_queue_size: maximum size for the generator queue
-        workers: maximum number of processes to spin up
+        max_queue_size: Maximum size for the generator queue
+        workers: Maximum number of processes to spin up
             when using process based threading
-        use_multiprocessing: if True, use process based threading.
+        use_multiprocessing: If True, use process based threading.
             Note that because
             this implementation relies on multiprocessing,
             you should not pass
             non picklable arguments to the generator
             as they can't be passed
             easily to children processes.
-        initial_epoch: epoch at which to start training
+        shuffle: Whether to shuffle the data at the beginning of each
+            epoch. Only used with instances of `Sequence` (
+            keras.utils.Sequence).
+        initial_epoch: Epoch at which to start training
             (useful for resuming a previous training run)
         **kwargs: support for legacy arguments.
 
@@ -1733,7 +1944,7 @@ class Model(Container):
         ValueError: In case the generator yields
             data in an invalid format.
     """
-    # Legacy support
+     # Legacy support
     if 'max_q_size' in kwargs:
       max_queue_size = kwargs.pop('max_q_size')
       logging.warning('The argument `max_q_size` has been renamed '
@@ -1810,15 +2021,16 @@ class Model(Container):
     is_sequence = isinstance(generator, Sequence)
     if not is_sequence and use_multiprocessing and workers > 1:
       logging.warning(
-          'Using a generator with `use_multiprocessing=True`'
-          ' may duplicate your data.Please consider using '
-          'the `keras.utils.Sequence` class.')
+          logging.warning('Using a generator with `use_multiprocessing=True`'
+                          ' and multiple workers may duplicate your data.'
+                          ' Please consider using the`keras.utils.Sequence'
+                          ' class.'))
     enqueuer = None
 
     try:
       if is_sequence:
         enqueuer = OrderedEnqueuer(
-            generator, use_multiprocessing=use_multiprocessing)
+            generator, use_multiprocessing=use_multiprocessing, shuffle=shuffle)
       else:
         enqueuer = GeneratorEnqueuer(
             generator,
@@ -1899,6 +2111,9 @@ class Model(Container):
             for l, o in zip(out_labels, val_outs):
               epoch_logs['val_' + l] = o
 
+          if callback_model.stop_training:
+            break
+
         callbacks.on_epoch_end(epoch, epoch_logs)
         epoch += 1
         if callback_model.stop_training:
@@ -1975,9 +2190,10 @@ class Model(Container):
     is_sequence = isinstance(generator, Sequence)
     if not is_sequence and use_multiprocessing and workers > 1:
       logging.warning(
-          'Using a generator with `use_multiprocessing=True`'
-          ' may duplicate your data.Please consider using '
-          'the `keras.utils.Sequence` class.')
+          logging.warning('Using a generator with `use_multiprocessing=True`'
+                          ' and multiple workers may duplicate your data.'
+                          ' Please consider using the`keras.utils.Sequence'
+                          ' class.'))
     enqueuer = None
 
     try:
@@ -2086,8 +2302,6 @@ class Model(Container):
       logging.warning('The argument `pickle_safe` has been renamed '
                       '`use_multiprocessing`. '
                       'Update your method calls accordingly.')
-    if kwargs:
-      raise ValueError('Unrecognized keyword arguments: ' + str(kwargs))
 
     self._make_predict_function()
 
@@ -2097,9 +2311,10 @@ class Model(Container):
     is_sequence = isinstance(generator, Sequence)
     if not is_sequence and use_multiprocessing and workers > 1:
       logging.warning(
-          'Using a generator with `use_multiprocessing=True`'
-          ' may duplicate your data.Please consider using '
-          'the `keras.utils.Sequence` class.')
+          logging.warning('Using a generator with `use_multiprocessing=True`'
+                          ' and multiple workers may duplicate your data.'
+                          ' Please consider using the`keras.utils.Sequence'
+                          ' class.'))
     enqueuer = None
 
     try:
diff --git a/tensorflow/contrib/keras/python/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py
similarity index 67%
rename from tensorflow/contrib/keras/python/keras/engine/training_test.py
rename to tensorflow/python/keras/_impl/keras/engine/training_test.py
index ad6812ddaf9a0b37c2acc30825f8fd380317e6ab..bc9ad6693e540585751b12fdaf63007078637547 100644
--- a/tensorflow/contrib/keras/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py
@@ -20,9 +20,9 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
-from tensorflow.contrib.keras.python.keras.engine.training import _weighted_masked_objective
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
+from tensorflow.python.keras._impl.keras.engine.training import _weighted_masked_objective
 from tensorflow.python.platform import test
 
 
@@ -305,7 +305,7 @@ class TrainingTest(test.TestCase):
                       optimizer='rmsprop',
                       metrics=set(0))
 
-      with self.assertRaises(RuntimeError):
+      with self.assertRaises(ValueError):
         model.compile(loss=None,
                       optimizer='rmsprop')
 
@@ -981,5 +981,443 @@ class TestTrainingUtils(test.TestCase):
     keras.engine.training._slice_arrays(input_a, stop=2)
 
 
+class TestTrainingWithDataTensors(test.TestCase):
+
+  def test_model_with_input_feed_tensor(self):
+    """We test building a model with a TF variable as input.
+
+    We should be able to call fit, evaluate, predict,
+    by only passing them data for the placeholder inputs
+    in the model.
+    """
+    with self.test_session():
+      input_a_np = np.random.random((10, 3))
+      input_b_np = np.random.random((10, 3))
+
+      output_a_np = np.random.random((10, 4))
+      output_b_np = np.random.random((10, 3))
+
+      a = keras.Input(
+          tensor=keras.backend.variables_module.Variable(input_a_np,
+                                                         dtype='float32'))
+      b = keras.Input(shape=(3,), name='input_b')
+
+      a_2 = keras.layers.Dense(4, name='dense_1')(a)
+      dp = keras.layers.Dropout(0.5, name='dropout')
+      b_2 = dp(b)
+
+      model = keras.models.Model([a, b], [a_2, b_2])
+      model.summary()
+
+      optimizer = 'rmsprop'
+      loss = 'mse'
+      loss_weights = [1., 0.5]
+      model.compile(optimizer, loss, metrics=['mean_squared_error'],
+                    loss_weights=loss_weights,
+                    sample_weight_mode=None)
+
+      # test train_on_batch
+      out = model.train_on_batch(input_b_np,
+                                 [output_a_np, output_b_np])
+      out = model.train_on_batch({'input_b': input_b_np},
+                                 [output_a_np, output_b_np])
+      out = model.test_on_batch({'input_b': input_b_np},
+                                [output_a_np, output_b_np])
+      out = model.predict_on_batch({'input_b': input_b_np})
+
+      # test fit
+      out = model.fit({'input_b': input_b_np},
+                      [output_a_np, output_b_np], epochs=1, batch_size=10)
+      out = model.fit(input_b_np,
+                      [output_a_np, output_b_np], epochs=1, batch_size=10)
+
+      # test evaluate
+      out = model.evaluate({'input_b': input_b_np},
+                           [output_a_np, output_b_np], batch_size=10)
+      out = model.evaluate(input_b_np,
+                           [output_a_np, output_b_np], batch_size=10)
+
+      # test predict
+      out = model.predict({'input_b': input_b_np}, batch_size=10)
+      out = model.predict(input_b_np, batch_size=10)
+      self.assertEqual(len(out), 2)
+
+      # Now test a model with a single input
+      # i.e. we don't pass any data to fit the model.
+      a = keras.Input(
+          tensor=keras.backend.variables_module.Variable(input_a_np,
+                                                         dtype='float32'))
+      a_2 = keras.layers.Dense(4, name='dense_1')(a)
+      a_2 = keras.layers.Dropout(0.5, name='dropout')(a_2)
+      model = keras.models.Model(a, a_2)
+      model.summary()
+
+      optimizer = 'rmsprop'
+      loss = 'mse'
+      model.compile(optimizer, loss, metrics=['mean_squared_error'])
+
+      # test train_on_batch
+      out = model.train_on_batch(None,
+                                 output_a_np)
+      out = model.train_on_batch(None,
+                                 output_a_np)
+      out = model.test_on_batch(None,
+                                output_a_np)
+      out = model.predict_on_batch(None)
+      out = model.train_on_batch([],
+                                 output_a_np)
+      out = model.train_on_batch({},
+                                 output_a_np)
+
+      # test fit
+      out = model.fit(None,
+                      output_a_np, epochs=1, batch_size=10)
+      out = model.fit(None,
+                      output_a_np, epochs=1, batch_size=10)
+
+      # test evaluate
+      out = model.evaluate(None,
+                           output_a_np, batch_size=10)
+      out = model.evaluate(None,
+                           output_a_np, batch_size=10)
+
+      # test predict
+      out = model.predict(None, steps=3)
+      out = model.predict(None, steps=3)
+      self.assertEqual(out.shape, (10 * 3, 4))
+
+      # Same, without learning phase
+      # i.e. we don't pass any data to fit the model.
+      a = keras.Input(
+          tensor=keras.backend.variables_module.Variable(input_a_np,
+                                                         dtype='float32'))
+      a_2 = keras.layers.Dense(4, name='dense_1')(a)
+      model = keras.models.Model(a, a_2)
+      model.summary()
+
+      optimizer = 'rmsprop'
+      loss = 'mse'
+      model.compile(optimizer, loss, metrics=['mean_squared_error'])
+
+      # test train_on_batch
+      out = model.train_on_batch(None,
+                                 output_a_np)
+      out = model.train_on_batch(None,
+                                 output_a_np)
+      out = model.test_on_batch(None,
+                                output_a_np)
+      out = model.predict_on_batch(None)
+      out = model.train_on_batch([],
+                                 output_a_np)
+      out = model.train_on_batch({},
+                                 output_a_np)
+
+      # test fit
+      out = model.fit(None,
+                      output_a_np, epochs=1, batch_size=10)
+      out = model.fit(None,
+                      output_a_np, epochs=1, batch_size=10)
+
+      # test evaluate
+      out = model.evaluate(None,
+                           output_a_np, batch_size=10)
+      out = model.evaluate(None,
+                           output_a_np, batch_size=10)
+
+      # test predict
+      out = model.predict(None, steps=3)
+      out = model.predict(None, steps=3)
+      self.assertEqual(out.shape, (10 * 3, 4))
+
+  def test_model_with_partial_loss(self):
+    with self.test_session():
+      a = keras.Input(shape=(3,), name='input_a')
+      a_2 = keras.layers.Dense(4, name='dense_1')(a)
+      dp = keras.layers.Dropout(0.5, name='dropout')
+      a_3 = dp(a_2)
+      model = keras.models.Model(a, [a_2, a_3])
+
+      optimizer = 'rmsprop'
+      loss = {'dropout': 'mse'}
+      model.compile(optimizer, loss, metrics=['mae'])
+
+      input_a_np = np.random.random((10, 3))
+      output_a_np = np.random.random((10, 4))
+
+      # test train_on_batch
+      _ = model.train_on_batch(input_a_np, output_a_np)
+      _ = model.test_on_batch(input_a_np, output_a_np)
+      # fit
+      _ = model.fit(input_a_np, [output_a_np])
+      # evaluate
+      _ = model.evaluate(input_a_np, [output_a_np])
+
+      # Same without dropout.
+      a = keras.Input(shape=(3,), name='input_a')
+      a_2 = keras.layers.Dense(4, name='dense_1')(a)
+      a_3 = keras.layers.Dense(4, name='dense_2')(a_2)
+      model = keras.models.Model(a, [a_2, a_3])
+
+      optimizer = 'rmsprop'
+      loss = {'dense_2': 'mse'}
+      model.compile(optimizer, loss, metrics={'dense_1': 'mae'})
+
+      # test train_on_batch
+      _ = model.train_on_batch(input_a_np, output_a_np)
+      _ = model.test_on_batch(input_a_np, output_a_np)
+      # fit
+      _ = model.fit(input_a_np, [output_a_np])
+      # evaluate
+      _ = model.evaluate(input_a_np, [output_a_np])
+
+  def test_model_with_external_loss(self):
+    with self.test_session():
+      # None loss, only regularization loss.
+      a = keras.Input(shape=(3,), name='input_a')
+      a_2 = keras.layers.Dense(4, name='dense_1',
+                               kernel_regularizer='l1',
+                               bias_regularizer='l2')(a)
+      dp = keras.layers.Dropout(0.5, name='dropout')
+      a_3 = dp(a_2)
+
+      model = keras.models.Model(a, [a_2, a_3])
+
+      optimizer = 'rmsprop'
+      loss = None
+      model.compile(optimizer, loss, metrics=['mae'])
+
+      input_a_np = np.random.random((10, 3))
+
+      # test train_on_batch
+      out = model.train_on_batch(input_a_np, None)
+      out = model.test_on_batch(input_a_np, None)
+      # fit
+      out = model.fit(input_a_np, None)
+      # evaluate
+      out = model.evaluate(input_a_np, None)
+
+      # No dropout, external loss.
+      a = keras.Input(shape=(3,), name='input_a')
+      a_2 = keras.layers.Dense(4, name='dense_1')(a)
+      a_3 = keras.layers.Dense(4, name='dense_2')(a)
+
+      model = keras.models.Model(a, [a_2, a_3])
+      model.add_loss(keras.backend.mean(a_3 + a_2))
+
+      optimizer = 'rmsprop'
+      loss = None
+      model.compile(optimizer, loss, metrics=['mae'])
+
+      # test train_on_batch
+      out = model.train_on_batch(input_a_np, None)
+      out = model.test_on_batch(input_a_np, None)
+      # fit
+      out = model.fit(input_a_np, None)
+      # evaluate
+      out = model.evaluate(input_a_np, None)
+
+      # Test model with no external data at all.
+      a = keras.Input(
+          tensor=keras.backend.variables_module.Variable(input_a_np,
+                                                         dtype='float32'))
+      a_2 = keras.layers.Dense(4, name='dense_1')(a)
+      a_2 = keras.layers.Dropout(0.5, name='dropout')(a_2)
+      model = keras.models.Model(a, a_2)
+      model.add_loss(keras.backend.mean(a_2))
+
+      model.compile(optimizer='rmsprop',
+                    loss=None,
+                    metrics=['mean_squared_error'])
+
+      # test train_on_batch
+      out = model.train_on_batch(None, None)
+      out = model.test_on_batch(None, None)
+      out = model.predict_on_batch(None)
+
+      # test fit
+      with self.assertRaises(ValueError):
+        out = model.fit(None, None, epochs=1, batch_size=10)
+      out = model.fit(None, None, epochs=1, steps_per_epoch=1)
+
+      # test fit with validation data
+      with self.assertRaises(ValueError):
+        out = model.fit(None, None, epochs=1,
+                        steps_per_epoch=None,
+                        validation_steps=2)
+      out = model.fit(None, None, epochs=1,
+                      steps_per_epoch=2,
+                      validation_steps=2)
+
+      # test evaluate
+      with self.assertRaises(ValueError):
+        out = model.evaluate(None, None, batch_size=10)
+      out = model.evaluate(None, None, steps=3)
+
+      # test predict
+      with self.assertRaises(ValueError):
+        out = model.predict(None, batch_size=10)
+      out = model.predict(None, steps=3)
+      self.assertEqual(out.shape, (10 * 3, 4))
+
+      # Test multi-output model with no external data at all.
+      a = keras.Input(
+          tensor=keras.backend.variables_module.Variable(input_a_np,
+                                                         dtype='float32'))
+      a_1 = keras.layers.Dense(4, name='dense_1')(a)
+      a_2 = keras.layers.Dropout(0.5, name='dropout')(a_1)
+      model = keras.models.Model(a, [a_1, a_2])
+      model.add_loss(keras.backend.mean(a_2))
+
+      model.compile(optimizer='rmsprop',
+                    loss=None,
+                    metrics=['mean_squared_error'])
+
+      # test train_on_batch
+      out = model.train_on_batch(None, None)
+      out = model.test_on_batch(None, None)
+      out = model.predict_on_batch(None)
+
+      # test fit
+      with self.assertRaises(ValueError):
+        out = model.fit(None, None, epochs=1, batch_size=10)
+      out = model.fit(None, None, epochs=1, steps_per_epoch=1)
+
+      # test fit with validation data
+      out = model.fit(None, None, epochs=1,
+                      steps_per_epoch=2,
+                      validation_steps=2)
+
+      # test evaluate
+      with self.assertRaises(ValueError):
+        out = model.evaluate(None, None, batch_size=10)
+      out = model.evaluate(None, None, steps=3)
+
+      # test predict
+      with self.assertRaises(ValueError):
+        out = model.predict(None, batch_size=10, verbose=1)
+      out = model.predict(None, steps=3)
+      self.assertEqual(len(out), 2)
+      self.assertEqual(out[0].shape, (10 * 3, 4))
+      self.assertEqual(out[1].shape, (10 * 3, 4))
+
+  def test_target_tensors(self):
+    with self.test_session():
+      # single-output, as list
+      model = keras.models.Sequential()
+      model.add(keras.layers.Dense(4, input_shape=(4,), name='dense'))
+      input_val = np.random.random((10, 4))
+      target_val = np.random.random((10, 4))
+      target = keras.backend.variable(target_val)
+      model.compile(optimizer='rmsprop', loss='mse', target_tensors=[target])
+      model.train_on_batch(input_val, None)
+
+      # single-output, as dict
+      model.compile(optimizer='rmsprop', loss='mse',
+                    target_tensors={'dense': target})
+      model.train_on_batch(input_val, None)
+
+      # test invalid arguments
+      with self.assertRaises(TypeError):
+        model.compile(optimizer='rmsprop', loss='mse',
+                      target_tensors=set())
+      with self.assertRaises(ValueError):
+        model.compile(optimizer='rmsprop', loss='mse',
+                      target_tensors=[target, target])
+      with self.assertRaises(ValueError):
+        model.compile(optimizer='rmsprop', loss='mse',
+                      target_tensors={'dense2': None})
+      with self.assertRaises(ValueError):
+        model.compile(optimizer='rmsprop', loss='mse',
+                      target_tensors=[target])
+        model.train_on_batch(input_val, target_val)
+
+      # multi-output, as list
+      input_val = np.random.random((10, 4))
+      target_val_a = np.random.random((10, 4))
+      target_val_b = np.random.random((10, 4))
+      target_a = keras.backend.variable(target_val_a)
+      target_b = keras.backend.variable(target_val_b)
+
+      inputs = keras.layers.Input(shape=(4,))
+      output_a = keras.layers.Dense(4, name='dense_a')(inputs)
+      output_b = keras.layers.Dense(4, name='dense_b')(inputs)
+      model = keras.models.Model(inputs, [output_a, output_b])
+      model.compile(optimizer='rmsprop', loss='mse',
+                    target_tensors=[target_a, target_b])
+      model.train_on_batch(input_val, None)
+
+      # multi-output, as dict
+      model.compile(optimizer='rmsprop', loss='mse',
+                    target_tensors={'dense_a': target_a,
+                                    'dense_b': target_b})
+      model.train_on_batch(input_val, None)
+
+      # test with sample weights
+      model.compile(optimizer='rmsprop', loss='mse',
+                    target_tensors=[target_a, target_b])
+      model.train_on_batch(input_val, None,
+                           sample_weight={'dense_a': np.random.random((10,))})
+
+  def test_model_custom_target_tensors(self):
+    with self.test_session():
+      a = keras.Input(shape=(3,), name='input_a')
+      b = keras.Input(shape=(3,), name='input_b')
+
+      a_2 = keras.layers.Dense(4, name='dense_1')(a)
+      dp = keras.layers.Dropout(0.5, name='dropout')
+      b_2 = dp(b)
+
+      y = keras.backend.placeholder([10, 4], name='y')
+      y1 = keras.backend.placeholder([10, 3], name='y1')
+      y2 = keras.backend.placeholder([7, 5], name='y2')
+      model = keras.models.Model([a, b], [a_2, b_2])
+
+      optimizer = 'rmsprop'
+      loss = 'mse'
+      loss_weights = [1., 0.5]
+
+      # test list of target tensors
+      with self.assertRaises(ValueError):
+        model.compile(optimizer, loss, metrics=[], loss_weights=loss_weights,
+                      sample_weight_mode=None, target_tensors=[y, y1, y2])
+      model.compile(optimizer, loss, metrics=[], loss_weights=loss_weights,
+                    sample_weight_mode=None, target_tensors=[y, y1])
+      input_a_np = np.random.random((10, 3))
+      input_b_np = np.random.random((10, 3))
+
+      output_a_np = np.random.random((10, 4))
+      output_b_np = np.random.random((10, 3))
+
+      _ = model.train_on_batch([input_a_np, input_b_np],
+                               [output_a_np, output_b_np],
+                               {y: np.random.random((10, 4)),
+                                y1: np.random.random((10, 3))})
+      # test dictionary of target_tensors
+      with self.assertRaises(ValueError):
+        model.compile(optimizer, loss,
+                      metrics=[],
+                      loss_weights=loss_weights,
+                      sample_weight_mode=None,
+                      target_tensors={'does_not_exist': y2})
+      # test dictionary of target_tensors
+      model.compile(optimizer, loss,
+                    metrics=[],
+                    loss_weights=loss_weights,
+                    sample_weight_mode=None,
+                    target_tensors={'dense_1': y, 'dropout': y1})
+      _ = model.train_on_batch([input_a_np, input_b_np],
+                               [output_a_np, output_b_np],
+                               {y: np.random.random((10, 4)),
+                                y1: np.random.random((10, 3))})
+
+      # test with custom TF placeholder as target
+      pl_target_a = keras.backend.array_ops.placeholder('float32',
+                                                        shape=(None, 4))
+      model.compile(optimizer='rmsprop', loss='mse',
+                    target_tensors={'dense_1': pl_target_a})
+      model.train_on_batch([input_a_np, input_b_np],
+                           [output_a_np, output_b_np])
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/contrib/keras/python/keras/initializers.py b/tensorflow/python/keras/_impl/keras/initializers.py
similarity index 96%
rename from tensorflow/contrib/keras/python/keras/initializers.py
rename to tensorflow/python/keras/_impl/keras/initializers.py
index ae76c079f307ecbff1e85b7bfa40a0f8193def37..8752faa534a3d6094ce530e490571ff939f86dbb 100644
--- a/tensorflow/contrib/keras/python/keras/initializers.py
+++ b/tensorflow/python/keras/_impl/keras/initializers.py
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import six
 
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
+from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object
 from tensorflow.python.ops.init_ops import Constant
 from tensorflow.python.ops.init_ops import Identity
 from tensorflow.python.ops.init_ops import Initializer  # pylint: disable=unused-import
diff --git a/tensorflow/contrib/keras/python/keras/initializers_test.py b/tensorflow/python/keras/_impl/keras/initializers_test.py
similarity index 99%
rename from tensorflow/contrib/keras/python/keras/initializers_test.py
rename to tensorflow/python/keras/_impl/keras/initializers_test.py
index f39d2bfd525fbd37055a85ffdf583629a79dcae1..7b4e6b4d5b115bc788469bf1afe2a43f8dd86f04 100644
--- a/tensorflow/contrib/keras/python/keras/initializers_test.py
+++ b/tensorflow/python/keras/_impl/keras/initializers_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.ops import init_ops
 from tensorflow.python.platform import test
 
diff --git a/tensorflow/contrib/keras/python/keras/integration_test.py b/tensorflow/python/keras/_impl/keras/integration_test.py
similarity index 99%
rename from tensorflow/contrib/keras/python/keras/integration_test.py
rename to tensorflow/python/keras/_impl/keras/integration_test.py
index 5c42ffcfbd583ae4593aaa465b83ab48b53abdbd..d7d20e5698afa1428dfb786f1c9b82298f250045 100644
--- a/tensorflow/contrib/keras/python/keras/integration_test.py
+++ b/tensorflow/python/keras/_impl/keras/integration_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.layers import base as tf_base_layers
 from tensorflow.python.layers import core as tf_core_layers
 from tensorflow.python.ops import nn
diff --git a/tensorflow/python/keras/_impl/keras/layers/__init__.py b/tensorflow/python/keras/_impl/keras/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..81b2faf106925d974749af3149c5b40d10d49e99
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/layers/__init__.py
@@ -0,0 +1,40 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras layers module.
+"""
+# pylint: disable=wildcard-import
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.engine import Input
+from tensorflow.python.keras._impl.keras.engine import InputLayer
+from tensorflow.python.keras._impl.keras.engine import InputSpec
+from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.layers.advanced_activations import *
+from tensorflow.python.keras._impl.keras.layers.convolutional import *
+from tensorflow.python.keras._impl.keras.layers.convolutional_recurrent import *
+from tensorflow.python.keras._impl.keras.layers.core import *
+from tensorflow.python.keras._impl.keras.layers.embeddings import *
+from tensorflow.python.keras._impl.keras.layers.local import *
+from tensorflow.python.keras._impl.keras.layers.merge import *
+from tensorflow.python.keras._impl.keras.layers.noise import *
+from tensorflow.python.keras._impl.keras.layers.normalization import *
+from tensorflow.python.keras._impl.keras.layers.pooling import *
+from tensorflow.python.keras._impl.keras.layers.recurrent import *
+from tensorflow.python.keras._impl.keras.layers.serialization import deserialize
+from tensorflow.python.keras._impl.keras.layers.serialization import serialize
+from tensorflow.python.keras._impl.keras.layers.wrappers import *
+
diff --git a/tensorflow/contrib/keras/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
similarity index 94%
rename from tensorflow/contrib/keras/python/keras/layers/advanced_activations.py
rename to tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
index 55f17ac4e29eee361e29be0c4cd6ee6d33bc5d22..1cb881a13f348fedc55ee48518a54b852d680876 100644
--- a/tensorflow/contrib/keras/python/keras/layers/advanced_activations.py
+++ b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
@@ -19,13 +19,13 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras import constraints
-from tensorflow.contrib.keras.python.keras import initializers
-from tensorflow.contrib.keras.python.keras import regularizers
-from tensorflow.contrib.keras.python.keras.engine import InputSpec
-from tensorflow.contrib.keras.python.keras.engine import Layer
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import constraints
+from tensorflow.python.keras._impl.keras import initializers
+from tensorflow.python.keras._impl.keras import regularizers
+from tensorflow.python.keras._impl.keras.engine import InputSpec
+from tensorflow.python.keras._impl.keras.engine import Layer
 
 
 class LeakyReLU(Layer):
diff --git a/tensorflow/contrib/keras/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py
similarity index 94%
rename from tensorflow/contrib/keras/python/keras/layers/advanced_activations_test.py
rename to tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py
index 1be56123d8b897a8522b7baa77d7e2ea87db9f0e..91efab30edf99901b25dc0085b7d49e70d1b6d6d 100644
--- a/tensorflow/contrib/keras/python/keras/layers/advanced_activations_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/advanced_activations_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional.py b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/layers/convolutional.py
rename to tensorflow/python/keras/_impl/keras/layers/convolutional.py
index 9174f6df160e20aa7e2765c73537c30eebb9c4f0..ce96bc66f7cc932bae84f746276cbed98961c127 100644
--- a/tensorflow/contrib/keras/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
@@ -19,24 +19,24 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras import activations
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras import constraints
-from tensorflow.contrib.keras.python.keras import initializers
-from tensorflow.contrib.keras.python.keras import regularizers
-from tensorflow.contrib.keras.python.keras.engine import InputSpec
-from tensorflow.contrib.keras.python.keras.engine import Layer
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import activations
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import constraints
+from tensorflow.python.keras._impl.keras import initializers
+from tensorflow.python.keras._impl.keras import regularizers
+from tensorflow.python.keras._impl.keras.engine import InputSpec
+from tensorflow.python.keras._impl.keras.engine import Layer
 # imports for backwards namespace compatibility
 # pylint: disable=unused-import
-from tensorflow.contrib.keras.python.keras.layers.pooling import AveragePooling1D
-from tensorflow.contrib.keras.python.keras.layers.pooling import AveragePooling2D
-from tensorflow.contrib.keras.python.keras.layers.pooling import AveragePooling3D
-from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPooling1D
-from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPooling2D
-from tensorflow.contrib.keras.python.keras.layers.pooling import MaxPooling3D
+from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D
+from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling2D
+from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling3D
+from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling1D
+from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling2D
+from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling3D
 # pylint: enable=unused-import
-from tensorflow.contrib.keras.python.keras.utils import conv_utils
-from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras.utils import conv_utils
 from tensorflow.python.layers import convolutional as tf_convolutional_layers
 
 
@@ -763,7 +763,7 @@ class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer):
       depthwise_regularizer: Regularizer function applied to
           the depthwise kernel matrix.
       pointwise_regularizer: Regularizer function applied to
-          the depthwise kernel matrix.
+          the pointwise kernel matrix.
       bias_regularizer: Regularizer function applied to the bias vector.
       activity_regularizer: Regularizer function applied to
           the output of the layer (its "activation")..
@@ -1473,14 +1473,14 @@ class Cropping3D(Layer):
   spatial or spatio-temporal).
 
   Arguments:
-      cropping: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
+      cropping: int, or tuple of 23ints, or tuple of 3 tuples of 2 ints.
           - If int: the same symmetric cropping
-              is applied to width and height.
-          - If tuple of 2 ints:
+              is applied to depth, height, and width.
+          - If tuple of 3 ints:
               interpreted as two different
-              symmetric cropping values for height and width:
+              symmetric cropping values for depth, height, and width:
               `(symmetric_dim1_crop, symmetric_dim2_crop, symmetric_dim3_crop)`.
-          - If tuple of 2 tuples of 2 ints:
+          - If tuple of 3 tuples of 2 ints:
               interpreted as
               `((left_dim1_crop, right_dim1_crop), (left_dim2_crop,
                 right_dim2_crop), (left_dim3_crop, right_dim3_crop))`
diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
similarity index 91%
rename from tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py
rename to tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
index 9ab2e72bf1b5ae130cdbcd35d6a07b6b24537abc..74757532e17af39cbd1b30cac39c730f3f450eb0 100644
--- a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
@@ -20,15 +20,15 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python.keras import activations
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras import constraints
-from tensorflow.contrib.keras.python.keras import initializers
-from tensorflow.contrib.keras.python.keras import regularizers
-from tensorflow.contrib.keras.python.keras.engine import InputSpec
-from tensorflow.contrib.keras.python.keras.layers.recurrent import Recurrent
-from tensorflow.contrib.keras.python.keras.utils import conv_utils
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import activations
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import constraints
+from tensorflow.python.keras._impl.keras import initializers
+from tensorflow.python.keras._impl.keras import regularizers
+from tensorflow.python.keras._impl.keras.engine import InputSpec
+from tensorflow.python.keras._impl.keras.layers.recurrent import Recurrent
+from tensorflow.python.keras._impl.keras.utils import conv_utils
 
 
 class ConvRecurrent2D(Recurrent):
@@ -151,18 +151,30 @@ class ConvRecurrent2D(Recurrent):
         dilation=self.dilation_rate[1])
     if self.return_sequences:
       if self.data_format == 'channels_first':
-        return tensor_shape.TensorShape(
-            [input_shape[0], input_shape[1], self.filters, rows, cols])
+        output_shape = [input_shape[0], input_shape[1],
+                        self.filters, rows, cols]
       elif self.data_format == 'channels_last':
-        return tensor_shape.TensorShape(
-            [input_shape[0], input_shape[1], rows, cols, self.filters])
+        output_shape = [input_shape[0], input_shape[1],
+                        rows, cols, self.filters]
     else:
       if self.data_format == 'channels_first':
-        return tensor_shape.TensorShape(
-            [input_shape[0], self.filters, rows, cols])
+        output_shape = [input_shape[0], self.filters, rows, cols]
       elif self.data_format == 'channels_last':
-        return tensor_shape.TensorShape(
-            [input_shape[0], rows, cols, self.filters])
+        output_shape = [input_shape[0], rows, cols, self.filters]
+
+    if self.return_state:
+      if self.data_format == 'channels_first':
+        output_shapes = [output_shape] + [(input_shape[0],
+                                           self.filters,
+                                           rows,
+                                           cols) for _ in range(2)]
+      elif self.data_format == 'channels_last':
+        output_shapes = [output_shape] + [(input_shape[0],
+                                           rows,
+                                           cols,
+                                           self.filters) for _ in range(2)]
+      return [tensor_shape.TensorShape(shape) for shape in output_shapes]
+    return tensor_shape.TensorShape(output_shape)
 
   def get_config(self):
     config = {
@@ -447,7 +459,6 @@ class ConvLSTM2D(ConvRecurrent2D):
     if not self.stateful:
       raise RuntimeError('Layer must be stateful.')
     input_shape = self.input_spec[0].shape
-    output_shape = self._compute_output_shape(input_shape)
 
     if not input_shape[0]:
       raise ValueError('If a RNN is stateful, a complete '
@@ -455,20 +466,24 @@ class ConvLSTM2D(ConvRecurrent2D):
                        '(including batch size). '
                        'Got input shape: ' + str(input_shape))
 
+    if self.return_state:
+      output_shape = tuple(self._compute_output_shape(input_shape)[0].as_list())
+    else:
+      output_shape = tuple(self._compute_output_shape(input_shape).as_list())
     if self.return_sequences:
-      out_row, out_col, out_filter = output_shape[2:]
+      output_shape = (input_shape[0],) + output_shape[2:]
     else:
-      out_row, out_col, out_filter = output_shape[1:]
+      output_shape = (input_shape[0],) + output_shape[1:]
 
     if hasattr(self, 'states'):
       K.set_value(self.states[0],
-                  np.zeros((input_shape[0], out_row, out_col, out_filter)))
+                  np.zeros(output_shape))
       K.set_value(self.states[1],
-                  np.zeros((input_shape[0], out_row, out_col, out_filter)))
+                  np.zeros(output_shape))
     else:
       self.states = [
-          K.zeros((input_shape[0], out_row, out_col, out_filter)),
-          K.zeros((input_shape[0], out_row, out_col, out_filter))
+          K.zeros(output_shape),
+          K.zeros(output_shape)
       ]
 
   def get_constants(self, inputs, training=None):
diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py
similarity index 85%
rename from tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent_test.py
rename to tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py
index 06b2be6b68060844ef8096bbcd0b0d078d2827b6..60137bdd724676af2c89bb7531cf4ea4e529b2a1 100644
--- a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 
 
@@ -47,8 +47,27 @@ class ConvLSTMTest(test.TestCase):
                                 input_channel)
 
       for return_sequences in [True, False]:
-        # test for output shape:
         with self.test_session():
+          # test for return state:
+          x = keras.Input(batch_shape=inputs.shape)
+          kwargs = {'data_format': data_format,
+                    'return_sequences': return_sequences,
+                    'return_state': True,
+                    'stateful': True,
+                    'filters': filters,
+                    'kernel_size': (num_row, num_col),
+                    'padding': 'valid'}
+          layer = keras.layers.ConvLSTM2D(**kwargs)
+          layer.build(inputs.shape)
+          outputs = layer(x)
+          _, states = outputs[0], outputs[1:]
+          self.assertEqual(len(states), 2)
+          model = keras.models.Model(x, states[0])
+          state = model.predict(inputs)
+          self.assertAllClose(
+              keras.backend.eval(layer.states[0]), state, atol=1e-4)
+
+          # test for output shape:
           testing_utils.layer_test(
               keras.layers.ConvLSTM2D,
               kwargs={'data_format': data_format,
diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py
similarity index 99%
rename from tensorflow/contrib/keras/python/keras/layers/convolutional_test.py
rename to tensorflow/python/keras/_impl/keras/layers/convolutional_test.py
index 00a7fbf8fb93ab5ac39c01ca4194e3c069b10249..be7da6f2b409aa57e3f1328441f0e37ede924c11 100644
--- a/tensorflow/contrib/keras/python/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/layers/core.py b/tensorflow/python/keras/_impl/keras/layers/core.py
similarity index 96%
rename from tensorflow/contrib/keras/python/keras/layers/core.py
rename to tensorflow/python/keras/_impl/keras/layers/core.py
index c3df1c85d7273b9cd97d9719ab68be10f87cf8ca..e7b87a09aa23351e245e35ae6c51f72c81e536b2 100644
--- a/tensorflow/contrib/keras/python/keras/layers/core.py
+++ b/tensorflow/python/keras/_impl/keras/layers/core.py
@@ -23,18 +23,18 @@ import types as python_types
 
 import numpy as np
 
-from tensorflow.contrib.keras.python.keras import activations
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras import constraints
-from tensorflow.contrib.keras.python.keras import initializers
-from tensorflow.contrib.keras.python.keras import regularizers
-from tensorflow.contrib.keras.python.keras.engine import InputSpec
-from tensorflow.contrib.keras.python.keras.engine import Layer
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_dump
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_load
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import has_arg
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import activations
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import constraints
+from tensorflow.python.keras._impl.keras import initializers
+from tensorflow.python.keras._impl.keras import regularizers
+from tensorflow.python.keras._impl.keras.engine import InputSpec
+from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
+from tensorflow.python.keras._impl.keras.utils.generic_utils import func_dump
+from tensorflow.python.keras._impl.keras.utils.generic_utils import func_load
+from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
 from tensorflow.python.layers import core as tf_core_layers
 
 
@@ -77,7 +77,7 @@ class Masking(Layer):
   def call(self, inputs):
     boolean_mask = K.any(
         K.not_equal(inputs, self.mask_value), axis=-1, keepdims=True)
-    return inputs * K.cast(boolean_mask, K.floatx())
+    return inputs * K.cast(boolean_mask, inputs.dtype)
 
   def get_config(self):
     config = {'mask_value': self.mask_value}
@@ -107,7 +107,10 @@ class Dropout(tf_core_layers.Dropout, Layer):
     self.supports_masking = True
     # Inheritance call order:
     # 1) tf.layers.Dropout, 2) keras.layers.Layer, 3) tf.layers.Layer
-    super(Dropout, self).__init__(rate=rate, noise_shape=noise_shape, seed=seed, **kwargs)
+    super(Dropout, self).__init__(rate=rate,
+                                  noise_shape=noise_shape,
+                                  seed=seed,
+                                  **kwargs)
 
   def call(self, inputs, training=None):
     if training is None:
@@ -349,7 +352,7 @@ class Reshape(Layer):
         The new output shape with a -1 replaced with its computed value.
 
         Raises a ValueError if the total array size of the output_shape is
-        different then the input_shape, or more then one unknown dimension
+        different then the input_shape, or more than one unknown dimension
         is specified.
 
     Raises:
diff --git a/tensorflow/contrib/keras/python/keras/layers/core_test.py b/tensorflow/python/keras/_impl/keras/layers/core_test.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/layers/core_test.py
rename to tensorflow/python/keras/_impl/keras/layers/core_test.py
index 818c55afe479ae2093a6d337e25dc26156479d24..5b15895c4111fb7f69d3e187065010adb5bed534 100644
--- a/tensorflow/contrib/keras/python/keras/layers/core_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/core_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/layers/embeddings.py b/tensorflow/python/keras/_impl/keras/layers/embeddings.py
similarity index 95%
rename from tensorflow/contrib/keras/python/keras/layers/embeddings.py
rename to tensorflow/python/keras/_impl/keras/layers/embeddings.py
index 9f617fd3e425ae7eb03a5f92f0ac850a5f3e3cb0..65d63550774830ea13e26c9d493ffa04978179d2 100644
--- a/tensorflow/contrib/keras/python/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/_impl/keras/layers/embeddings.py
@@ -18,12 +18,12 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras import constraints
-from tensorflow.contrib.keras.python.keras import initializers
-from tensorflow.contrib.keras.python.keras import regularizers
-from tensorflow.contrib.keras.python.keras.engine import Layer
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import constraints
+from tensorflow.python.keras._impl.keras import initializers
+from tensorflow.python.keras._impl.keras import regularizers
+from tensorflow.python.keras._impl.keras.engine import Layer
 
 
 class Embedding(Layer):
diff --git a/tensorflow/contrib/keras/python/keras/layers/embeddings_test.py b/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py
similarity index 95%
rename from tensorflow/contrib/keras/python/keras/layers/embeddings_test.py
rename to tensorflow/python/keras/_impl/keras/layers/embeddings_test.py
index 5d6d386862bf9921a6f0f8b58e494ee04e1643ed..1712111b877cf1fee4353c5542f33a973a26de95 100644
--- a/tensorflow/contrib/keras/python/keras/layers/embeddings_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/layers/gru_test.py b/tensorflow/python/keras/_impl/keras/layers/gru_test.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/layers/gru_test.py
rename to tensorflow/python/keras/_impl/keras/layers/gru_test.py
index 9af32904801cffcfef4d1e4e401e9f299c27685c..03f0736161e6d1ce91b1efab8cfddef71e0360d3 100644
--- a/tensorflow/contrib/keras/python/keras/layers/gru_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/gru_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/layers/local.py b/tensorflow/python/keras/_impl/keras/layers/local.py
similarity index 97%
rename from tensorflow/contrib/keras/python/keras/layers/local.py
rename to tensorflow/python/keras/_impl/keras/layers/local.py
index 31a29cdaf467bde6ad6076f398c98987d67435b6..040fe40c57a53d47418b040ffa4770265664c838 100644
--- a/tensorflow/contrib/keras/python/keras/layers/local.py
+++ b/tensorflow/python/keras/_impl/keras/layers/local.py
@@ -14,20 +14,19 @@
 # ==============================================================================
 """Locally-connected layers.
 """
-
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras import activations
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras import constraints
-from tensorflow.contrib.keras.python.keras import initializers
-from tensorflow.contrib.keras.python.keras import regularizers
-from tensorflow.contrib.keras.python.keras.engine import InputSpec
-from tensorflow.contrib.keras.python.keras.engine import Layer
-from tensorflow.contrib.keras.python.keras.utils import conv_utils
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import activations
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import constraints
+from tensorflow.python.keras._impl.keras import initializers
+from tensorflow.python.keras._impl.keras import regularizers
+from tensorflow.python.keras._impl.keras.engine import InputSpec
+from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.utils import conv_utils
 
 
 class LocallyConnected1D(Layer):
diff --git a/tensorflow/contrib/keras/python/keras/layers/local_test.py b/tensorflow/python/keras/_impl/keras/layers/local_test.py
similarity index 97%
rename from tensorflow/contrib/keras/python/keras/layers/local_test.py
rename to tensorflow/python/keras/_impl/keras/layers/local_test.py
index 6da20d8f83f690b8f63e52b4034fece4729a8adb..a815a0fadc8215c00f3db4749e323f96e44b66f3 100644
--- a/tensorflow/contrib/keras/python/keras/layers/local_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/local_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/layers/lstm_test.py b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py
similarity index 91%
rename from tensorflow/contrib/keras/python/keras/layers/lstm_test.py
rename to tensorflow/python/keras/_impl/keras/layers/lstm_test.py
index 7858b0e6b54265e3125dfba794d68eac3fa3d460..94049d4066a576256dd4ef12c85abc78cdfdb93c 100644
--- a/tensorflow/contrib/keras/python/keras/layers/lstm_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 
 
@@ -337,6 +337,33 @@ class LSTMLayerTest(test.TestCase):
       inputs = np.random.random((num_samples, timesteps, embedding_dim))
       outputs = model.predict(inputs)
 
+  def test_initial_states_as_other_inputs(self):
+    timesteps = 3
+    embedding_dim = 4
+    units = 3
+    num_samples = 2
+    num_states = 2
+    layer_class = keras.layers.LSTM
+
+    with self.test_session():
+      # Test with Keras tensor
+      main_inputs = keras.Input((timesteps, embedding_dim))
+      initial_state = [keras.Input((units,)) for _ in range(num_states)]
+      inputs = [main_inputs] + initial_state
+
+      layer = layer_class(units)
+      output = layer(inputs)
+      assert initial_state[0] in layer.inbound_nodes[0].input_tensors
+
+      model = keras.models.Model(inputs, output)
+      model.compile(loss='categorical_crossentropy', optimizer='adam')
+
+      main_inputs = np.random.random((num_samples, timesteps, embedding_dim))
+      initial_state = [np.random.random((num_samples, units))
+                       for _ in range(num_states)]
+      targets = np.random.random((num_samples, units))
+      model.train_on_batch([main_inputs] + initial_state, targets)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/contrib/keras/python/keras/layers/merge.py b/tensorflow/python/keras/_impl/keras/layers/merge.py
similarity index 91%
rename from tensorflow/contrib/keras/python/keras/layers/merge.py
rename to tensorflow/python/keras/_impl/keras/layers/merge.py
index 64d0c40e615d59e176f857a2f3b3b5300ebefa7b..b6391dba2514cc8ece54159204079175cf71c19d 100644
--- a/tensorflow/contrib/keras/python/keras/layers/merge.py
+++ b/tensorflow/python/keras/_impl/keras/layers/merge.py
@@ -20,9 +20,9 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.engine.topology import Layer
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.engine.topology import Layer
 
 
 class _Merge(Layer):
@@ -223,6 +223,37 @@ class Add(_Merge):
     return output
 
 
+class Subtract(_Merge):
+  """Layer that subtracts two inputs.
+
+  It takes as input a list of tensors of size 2,
+  both of the same shape, and returns a single tensor, (inputs[0] - inputs[1]),
+  also of the same shape.
+
+  Examples:
+
+  ```python
+      import keras
+
+      input1 = keras.layers.Input(shape=(16,))
+      x1 = keras.layers.Dense(8, activation='relu')(input1)
+      input2 = keras.layers.Input(shape=(32,))
+      x2 = keras.layers.Dense(8, activation='relu')(input2)
+      # Equivalent to subtracted = keras.layers.subtract([x1, x2])
+      subtracted = keras.layers.Subtract()([x1, x2])
+
+      out = keras.layers.Dense(4)(subtracted)
+      model = keras.models.Model(inputs=[input1, input2], outputs=out)
+  ```
+  """
+
+  def _merge_function(self, inputs):
+    if len(inputs) != 2:
+      raise ValueError('`Subtract` layer should be called '
+                       'on exactly 2 inputs. Received: %s' % inputs)
+    return inputs[0] - inputs[1]
+
+
 class Multiply(_Merge):
   """Layer that multiplies (element-wise) a list of inputs.
 
@@ -486,6 +517,34 @@ def add(inputs, **kwargs):
   return Add(**kwargs)(inputs)
 
 
+def subtract(inputs, **kwargs):
+  """Functional interface to the `Subtract` layer.
+
+  Arguments:
+      inputs: A list of input tensors (exactly 2).
+      **kwargs: Standard layer keyword arguments.
+
+  Returns:
+      A tensor, the difference of the inputs.
+
+  Examples:
+
+  ```python
+      import keras
+
+      input1 = keras.layers.Input(shape=(16,))
+      x1 = keras.layers.Dense(8, activation='relu')(input1)
+      input2 = keras.layers.Input(shape=(32,))
+      x2 = keras.layers.Dense(8, activation='relu')(input2)
+      subtracted = keras.layers.subtract([x1, x2])
+
+      out = keras.layers.Dense(4)(subtracted)
+      model = keras.models.Model(inputs=[input1, input2], outputs=out)
+  ```
+  """
+  return Subtract(**kwargs)(inputs)
+
+
 def multiply(inputs, **kwargs):
   """Functional interface to the `Multiply` layer.
 
diff --git a/tensorflow/contrib/keras/python/keras/layers/merge_test.py b/tensorflow/python/keras/_impl/keras/layers/merge_test.py
similarity index 93%
rename from tensorflow/contrib/keras/python/keras/layers/merge_test.py
rename to tensorflow/python/keras/_impl/keras/layers/merge_test.py
index 4a365c2c44eda1cae9b31e39cb4b7ab6caee2175..ea76337317936e48985f61b74fe174d84d9db065 100644
--- a/tensorflow/contrib/keras/python/keras/layers/merge_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/merge_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import test
 
@@ -194,6 +194,20 @@ class MergeLayersTest(test.TestCase):
       dot = keras.layers.Dot(1)
       dot._compute_output_shape(1)
 
+  def test_merge_subtract(self):
+    i1 = keras.layers.Input(shape=(4, 5))
+    i2 = keras.layers.Input(shape=(4, 5))
+    y = keras.layers.subtract([i1, i2])
+    self.assertEqual(y.get_shape().as_list(), [None, 4, 5])
+
+    # Test invalid use cases
+    i1 = keras.layers.Input(shape=(4, 5))
+    i2 = keras.layers.Input(shape=(3, 5))
+    with self.assertRaises(ValueError):
+      keras.layers.subtract([i1, i2])
+    with self.assertRaises(ValueError):
+      keras.layers.subtract([i1, i1, i1])
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/contrib/keras/python/keras/layers/noise.py b/tensorflow/python/keras/_impl/keras/layers/noise.py
similarity index 97%
rename from tensorflow/contrib/keras/python/keras/layers/noise.py
rename to tensorflow/python/keras/_impl/keras/layers/noise.py
index e3cfa1f711cd6f4f62e230db5984da2c1e833440..9caa8b7024aa31118802a5bac5edac756dccc0f9 100644
--- a/tensorflow/contrib/keras/python/keras/layers/noise.py
+++ b/tensorflow/python/keras/_impl/keras/layers/noise.py
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.engine import Layer
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.engine import Layer
 
 
 class GaussianNoise(Layer):
diff --git a/tensorflow/contrib/keras/python/keras/layers/noise_test.py b/tensorflow/python/keras/_impl/keras/layers/noise_test.py
similarity index 93%
rename from tensorflow/contrib/keras/python/keras/layers/noise_test.py
rename to tensorflow/python/keras/_impl/keras/layers/noise_test.py
index 8fb1339c2ef5134ce936338c2e6baaadae43142a..f9b4d9cd090ffec1a5acd9118ea6a65798bd72a6 100644
--- a/tensorflow/contrib/keras/python/keras/layers/noise_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/noise_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/layers/normalization.py b/tensorflow/python/keras/_impl/keras/layers/normalization.py
similarity index 94%
rename from tensorflow/contrib/keras/python/keras/layers/normalization.py
rename to tensorflow/python/keras/_impl/keras/layers/normalization.py
index 7b98fe9e850c5976b2b98786d9ade89dbd4b225a..965ef70e6e6cb488aa4832462da4a2cb43e964a6 100644
--- a/tensorflow/contrib/keras/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/_impl/keras/layers/normalization.py
@@ -18,11 +18,11 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras import constraints
-from tensorflow.contrib.keras.python.keras import initializers
-from tensorflow.contrib.keras.python.keras import regularizers
-from tensorflow.contrib.keras.python.keras.engine import Layer
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import constraints
+from tensorflow.python.keras._impl.keras import initializers
+from tensorflow.python.keras._impl.keras import regularizers
+from tensorflow.python.keras._impl.keras.engine import Layer
 from tensorflow.python.layers import normalization as tf_normalization_layers
 
 
diff --git a/tensorflow/contrib/keras/python/keras/layers/normalization_test.py b/tensorflow/python/keras/_impl/keras/layers/normalization_test.py
similarity index 97%
rename from tensorflow/contrib/keras/python/keras/layers/normalization_test.py
rename to tensorflow/python/keras/_impl/keras/layers/normalization_test.py
index eaeafb0c6299e92ce4f3d540a405b5a657a9990f..39a90e597089b30d110f26f074eba5d6895e52df 100644
--- a/tensorflow/contrib/keras/python/keras/layers/normalization_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/normalization_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/layers/pooling.py b/tensorflow/python/keras/_impl/keras/layers/pooling.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/layers/pooling.py
rename to tensorflow/python/keras/_impl/keras/layers/pooling.py
index 704f05e494e9d7109f234e5e73cb08937cdc7f9e..e773e396796d1d69cc5699f882384ee4b24bdbf1 100644
--- a/tensorflow/contrib/keras/python/keras/layers/pooling.py
+++ b/tensorflow/python/keras/_impl/keras/layers/pooling.py
@@ -18,11 +18,11 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.engine import InputSpec
-from tensorflow.contrib.keras.python.keras.engine import Layer
-from tensorflow.contrib.keras.python.keras.utils import conv_utils
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.engine import InputSpec
+from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.utils import conv_utils
 from tensorflow.python.layers import pooling as tf_pooling_layers
 
 
diff --git a/tensorflow/contrib/keras/python/keras/layers/pooling_test.py b/tensorflow/python/keras/_impl/keras/layers/pooling_test.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/layers/pooling_test.py
rename to tensorflow/python/keras/_impl/keras/layers/pooling_test.py
index d8a6a1673bb32a2e4a9328ad889df0ca4eaba799..ec0a5ae560f49ee39ecffb64f4ac65d3e800024c 100644
--- a/tensorflow/contrib/keras/python/keras/layers/pooling_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/pooling_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/layers/recurrent.py
rename to tensorflow/python/keras/_impl/keras/layers/recurrent.py
index 592e5f5e3aae74d6baf5d0c985f5a8a81c7de4c8..f0f5e564959463b428e5acf520a255ceb22a7c17 100644
--- a/tensorflow/contrib/keras/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
@@ -21,14 +21,14 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python.keras import activations
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras import constraints
-from tensorflow.contrib.keras.python.keras import initializers
-from tensorflow.contrib.keras.python.keras import regularizers
-from tensorflow.contrib.keras.python.keras.engine import InputSpec
-from tensorflow.contrib.keras.python.keras.engine import Layer
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import activations
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import constraints
+from tensorflow.python.keras._impl.keras import initializers
+from tensorflow.python.keras._impl.keras import regularizers
+from tensorflow.python.keras._impl.keras.engine import InputSpec
+from tensorflow.python.keras._impl.keras.engine import Layer
 
 
 # pylint: disable=access-member-before-definition
@@ -48,7 +48,7 @@ def _time_distributed_dense(x,
       x: input tensor.
       w: weight matrix.
       b: optional bias vector.
-      dropout: wether to apply dropout (same dropout mask
+      dropout: whether to apply dropout (same dropout mask
           for every temporal slice of the input).
       input_dim: integer; optional dimensionality of the input.
       output_dim: integer; optional dimensionality of the output.
@@ -279,6 +279,12 @@ class Recurrent(Layer):
     return inputs
 
   def __call__(self, inputs, initial_state=None, **kwargs):
+    if (isinstance(inputs, (list, tuple)) and
+        len(inputs) > 1
+        and initial_state is None):
+      initial_state = inputs[1:]
+      inputs = inputs[0]
+
     # If `initial_state` is specified,
     # and if it a Keras tensor,
     # then add it to the inputs and temporarily
diff --git a/tensorflow/contrib/keras/python/keras/layers/serialization.py b/tensorflow/python/keras/_impl/keras/layers/serialization.py
similarity index 59%
rename from tensorflow/contrib/keras/python/keras/layers/serialization.py
rename to tensorflow/python/keras/_impl/keras/layers/serialization.py
index f9c21a3e671e907c45287ebcc6a5ae3a85c4fad1..928feaadbf3554fdeec61527d730c475f25c0e5a 100644
--- a/tensorflow/contrib/keras/python/keras/layers/serialization.py
+++ b/tensorflow/python/keras/_impl/keras/layers/serialization.py
@@ -20,21 +20,21 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras.engine import Input
-from tensorflow.contrib.keras.python.keras.engine import InputLayer
-from tensorflow.contrib.keras.python.keras.layers.advanced_activations import *
-from tensorflow.contrib.keras.python.keras.layers.convolutional import *
-from tensorflow.contrib.keras.python.keras.layers.convolutional_recurrent import *
-from tensorflow.contrib.keras.python.keras.layers.core import *
-from tensorflow.contrib.keras.python.keras.layers.embeddings import *
-from tensorflow.contrib.keras.python.keras.layers.local import *
-from tensorflow.contrib.keras.python.keras.layers.merge import *
-from tensorflow.contrib.keras.python.keras.layers.noise import *
-from tensorflow.contrib.keras.python.keras.layers.normalization import *
-from tensorflow.contrib.keras.python.keras.layers.pooling import *
-from tensorflow.contrib.keras.python.keras.layers.recurrent import *
-from tensorflow.contrib.keras.python.keras.layers.wrappers import *
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
+from tensorflow.python.keras._impl.keras.engine import Input
+from tensorflow.python.keras._impl.keras.engine import InputLayer
+from tensorflow.python.keras._impl.keras.layers.advanced_activations import *
+from tensorflow.python.keras._impl.keras.layers.convolutional import *
+from tensorflow.python.keras._impl.keras.layers.convolutional_recurrent import *
+from tensorflow.python.keras._impl.keras.layers.core import *
+from tensorflow.python.keras._impl.keras.layers.embeddings import *
+from tensorflow.python.keras._impl.keras.layers.local import *
+from tensorflow.python.keras._impl.keras.layers.merge import *
+from tensorflow.python.keras._impl.keras.layers.noise import *
+from tensorflow.python.keras._impl.keras.layers.normalization import *
+from tensorflow.python.keras._impl.keras.layers.pooling import *
+from tensorflow.python.keras._impl.keras.layers.recurrent import *
+from tensorflow.python.keras._impl.keras.layers.wrappers import *
+from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
 
 
 def serialize(layer):
@@ -52,7 +52,7 @@ def deserialize(config, custom_objects=None):
   Returns:
       Layer instance (may be Model, Sequential, Layer...)
   """
-  from tensorflow.contrib.keras.python.keras import models  # pylint: disable=g-import-not-at-top
+  from tensorflow.python.keras._impl.keras import models  # pylint: disable=g-import-not-at-top
   globs = globals()  # All layers.
   globs['Model'] = models.Model
   globs['Sequential'] = models.Sequential
diff --git a/tensorflow/contrib/keras/python/keras/layers/serialization_test.py b/tensorflow/python/keras/_impl/keras/layers/serialization_test.py
similarity index 96%
rename from tensorflow/contrib/keras/python/keras/layers/serialization_test.py
rename to tensorflow/python/keras/_impl/keras/layers/serialization_test.py
index fb2e506a4c33679d0d1589112b1df3dea77a4588..787160d1e71f570479144c5afd45cd41f38f0e91 100644
--- a/tensorflow/contrib/keras/python/keras/layers/serialization_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/serialization_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/layers/simplernn_test.py b/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/layers/simplernn_test.py
rename to tensorflow/python/keras/_impl/keras/layers/simplernn_test.py
index 3d67011352e34a013adea0148cd870a36c63261c..9833485236b68095402cc2921ba7050591d44a55 100644
--- a/tensorflow/contrib/keras/python/keras/layers/simplernn_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/layers/wrappers.py b/tensorflow/python/keras/_impl/keras/layers/wrappers.py
similarity index 90%
rename from tensorflow/contrib/keras/python/keras/layers/wrappers.py
rename to tensorflow/python/keras/_impl/keras/layers/wrappers.py
index aee02f432e76ef102cf130084c67d1bd29e8b2b5..79e144869eb3775340d2ea063295f76145a7d9fb 100644
--- a/tensorflow/contrib/keras/python/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/_impl/keras/layers/wrappers.py
@@ -21,11 +21,12 @@ from __future__ import print_function
 
 import copy
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.engine import InputSpec
-from tensorflow.contrib.keras.python.keras.engine import Layer
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import has_arg
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.engine import InputSpec
+from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
+from tensorflow.python.layers import base as tf_base_layers
 
 
 class Wrapper(Layer):
@@ -41,6 +42,10 @@ class Wrapper(Layer):
 
   def __init__(self, layer, **kwargs):
     self.layer = layer
+    # Tracks mapping of Wrapper inputs to inner layer inputs. Useful when
+    # the inner layer has update ops that depend on its inputs (as opposed
+    # to the inputs to the Wrapper layer).
+    self._input_map = {}
     super(Wrapper, self).__init__(**kwargs)
 
   def build(self, input_shape=None):
@@ -68,10 +73,17 @@ class Wrapper(Layer):
     return []
 
   def get_updates_for(self, inputs=None):
-    if inputs is None:
-      updates = self.layer.get_updates_for(None)
-      return updates + super(Wrapper, self).get_updates_for(None)
-    return super(Wrapper, self).get_updates_for(inputs)
+    # If the wrapper modifies the inputs, use the modified inputs to
+    # get the updates from the inner layer.
+    inner_inputs = inputs
+    if inputs is not None:
+      uid = tf_base_layers._object_list_uid(inputs)
+      if uid in self._input_map:
+        inner_inputs = self._input_map[uid]
+
+    updates = self.layer.get_updates_for(inner_inputs)
+    updates += super(Wrapper, self).get_updates_for(inputs)
+    return updates
 
   @property
   def losses(self):
@@ -107,7 +119,7 @@ class Wrapper(Layer):
 
   @classmethod
   def from_config(cls, config, custom_objects=None):
-    from tensorflow.contrib.keras.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
+    from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
     layer = deserialize_layer(
         config.pop('layer'), custom_objects=custom_objects)
     return cls(layer, **config)
@@ -213,8 +225,11 @@ class TimeDistributed(Wrapper):
       input_length = input_shape[1]
       if not input_length:
         input_length = K.shape(inputs)[1]
-      # Shape: (num_samples * timesteps, ...)
+      # Shape: (num_samples * timesteps, ...). And track the
+      # transformation in self._input_map.
+      input_uid = tf_base_layers._object_list_uid(inputs)
       inputs = K.reshape(inputs, (-1,) + input_shape[2:])
+      self._input_map[input_uid] = inputs
       # (num_samples * timesteps, ...)
       y = self.layer.call(inputs, **kwargs)
       if hasattr(y, '_uses_learning_phase'):
diff --git a/tensorflow/contrib/keras/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py
similarity index 82%
rename from tensorflow/contrib/keras/python/keras/layers/wrappers_test.py
rename to tensorflow/python/keras/_impl/keras/layers/wrappers_test.py
index 531fa76dd8d2be73dfa1b50dc30254d97e0d0158..a0951b8240dac5162161962456c34df4c2a16595 100644
--- a/tensorflow/contrib/keras/python/keras/layers/wrappers_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
@@ -101,14 +101,37 @@ class TimeDistributedTest(test.TestCase):
       self.assertEqual(len(model.losses), 1)
 
   def test_TimeDistributed_learning_phase(self):
-    # test layers that need learning_phase to be set
-    np.random.seed(1234)
-    x = keras.layers.Input(shape=(3, 2))
-    y = keras.layers.TimeDistributed(
-        keras.layers.Dropout(.999))(x, training=True)
-    model = keras.models.Model(x, y)
-    y = model.predict(np.random.random((10, 3, 2)))
-    self.assertAllClose(np.mean(y), 0., atol=1e-1, rtol=1e-1)
+    with self.test_session():
+      # test layers that need learning_phase to be set
+      np.random.seed(1234)
+      x = keras.layers.Input(shape=(3, 2))
+      y = keras.layers.TimeDistributed(
+          keras.layers.Dropout(.999))(x, training=True)
+      model = keras.models.Model(x, y)
+      y = model.predict(np.random.random((10, 3, 2)))
+      self.assertAllClose(np.mean(y), 0., atol=1e-1, rtol=1e-1)
+
+  def test_TimeDistributed_batchnorm(self):
+    with self.test_session():
+      # test that wrapped BN updates still work.
+      model = keras.models.Sequential()
+      model.add(keras.layers.TimeDistributed(
+          keras.layers.BatchNormalization(center=True, scale=True),
+          name='bn',
+          input_shape=(10, 2)))
+      model.compile(optimizer='rmsprop', loss='mse')
+      # Assert that mean and variance are 0 and 1.
+      td = model.layers[0]
+      self.assertAllClose(td.get_weights()[2], np.array([0, 0]))
+      assert np.array_equal(td.get_weights()[3], np.array([1, 1]))
+      # Train
+      model.train_on_batch(np.random.normal(loc=2, scale=2, size=(1, 10, 2)),
+                           np.broadcast_to(np.array([0, 1]), (1, 10, 2)))
+      # Assert that mean and variance changed.
+      assert not np.array_equal(td.get_weights()[2], np.array([0, 0]))
+      assert not np.array_equal(td.get_weights()[3], np.array([1, 1]))
+      # Verify input_map has one mapping from inputs to reshaped inputs.
+      self.assertEqual(len(td._input_map.keys()), 1)
 
 
 class BidirectionalTest(test.TestCase):
diff --git a/tensorflow/contrib/keras/python/keras/losses.py b/tensorflow/python/keras/_impl/keras/losses.py
similarity index 91%
rename from tensorflow/contrib/keras/python/keras/losses.py
rename to tensorflow/python/keras/_impl/keras/losses.py
index 777ec440ac3abdebf8f31289dda4348b6fba68ff..7c6b304622a3ec6995483bfafef1c865ce6520cc 100644
--- a/tensorflow/contrib/keras/python/keras/losses.py
+++ b/tensorflow/python/keras/_impl/keras/losses.py
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import six
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
 
 
 def mean_squared_error(y_true, y_pred):
@@ -67,15 +67,15 @@ def logcosh(y_true, y_pred):
 
 
 def categorical_crossentropy(y_true, y_pred):
-  return K.categorical_crossentropy(y_pred, y_true)
+  return K.categorical_crossentropy(y_true, y_pred)
 
 
 def sparse_categorical_crossentropy(y_true, y_pred):
-  return K.sparse_categorical_crossentropy(y_pred, y_true)
+  return K.sparse_categorical_crossentropy(y_true, y_pred)
 
 
 def binary_crossentropy(y_true, y_pred):
-  return K.mean(K.binary_crossentropy(y_pred, y_true), axis=-1)
+  return K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
 
 
 def kullback_leibler_divergence(y_true, y_pred):
diff --git a/tensorflow/contrib/keras/python/keras/losses_test.py b/tensorflow/python/keras/_impl/keras/losses_test.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/losses_test.py
rename to tensorflow/python/keras/_impl/keras/losses_test.py
index 6bdcc0b5ff3fa260c244816c88a3fb869d529ea4..b295356ec19c28af3ca80c81f3669bd6bec005b6 100644
--- a/tensorflow/contrib/keras/python/keras/losses_test.py
+++ b/tensorflow/python/keras/_impl/keras/losses_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/metrics.py b/tensorflow/python/keras/_impl/keras/metrics.py
similarity index 67%
rename from tensorflow/contrib/keras/python/keras/metrics.py
rename to tensorflow/python/keras/_impl/keras/metrics.py
index 999e9cb9d4d03c74da7f4775398afb2135980421..202048f26d2ad201b4762d3b2b32638f9d041e88 100644
--- a/tensorflow/contrib/keras/python/keras/metrics.py
+++ b/tensorflow/python/keras/_impl/keras/metrics.py
@@ -20,23 +20,23 @@ from __future__ import print_function
 
 import six
 
-from tensorflow.contrib.keras.python.keras import backend as K
+from tensorflow.python.keras._impl.keras import backend as K
 # pylint: disable=unused-import
-from tensorflow.contrib.keras.python.keras.losses import binary_crossentropy
-from tensorflow.contrib.keras.python.keras.losses import categorical_crossentropy
-from tensorflow.contrib.keras.python.keras.losses import cosine_proximity
-from tensorflow.contrib.keras.python.keras.losses import hinge
-from tensorflow.contrib.keras.python.keras.losses import kullback_leibler_divergence
-from tensorflow.contrib.keras.python.keras.losses import logcosh
-from tensorflow.contrib.keras.python.keras.losses import mean_absolute_error
-from tensorflow.contrib.keras.python.keras.losses import mean_absolute_percentage_error
-from tensorflow.contrib.keras.python.keras.losses import mean_squared_error
-from tensorflow.contrib.keras.python.keras.losses import mean_squared_logarithmic_error
-from tensorflow.contrib.keras.python.keras.losses import poisson
-from tensorflow.contrib.keras.python.keras.losses import sparse_categorical_crossentropy
-from tensorflow.contrib.keras.python.keras.losses import squared_hinge
+from tensorflow.python.keras._impl.keras.losses import binary_crossentropy
+from tensorflow.python.keras._impl.keras.losses import categorical_crossentropy
+from tensorflow.python.keras._impl.keras.losses import cosine_proximity
+from tensorflow.python.keras._impl.keras.losses import hinge
+from tensorflow.python.keras._impl.keras.losses import kullback_leibler_divergence
+from tensorflow.python.keras._impl.keras.losses import logcosh
+from tensorflow.python.keras._impl.keras.losses import mean_absolute_error
+from tensorflow.python.keras._impl.keras.losses import mean_absolute_percentage_error
+from tensorflow.python.keras._impl.keras.losses import mean_squared_error
+from tensorflow.python.keras._impl.keras.losses import mean_squared_logarithmic_error
+from tensorflow.python.keras._impl.keras.losses import poisson
+from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy
+from tensorflow.python.keras._impl.keras.losses import squared_hinge
 # pylint: disable=unused-import
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
+from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
 
 
 def binary_accuracy(y_true, y_pred):
diff --git a/tensorflow/contrib/keras/python/keras/metrics_test.py b/tensorflow/python/keras/_impl/keras/metrics_test.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/metrics_test.py
rename to tensorflow/python/keras/_impl/keras/metrics_test.py
index 84c6528174ddf27e44cb3dbc8d1f1d2fcab41890..f4792f3543cc5ca8e5e7ad03d9906bbfadd1fb04 100644
--- a/tensorflow/contrib/keras/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/_impl/keras/metrics_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/models.py b/tensorflow/python/keras/_impl/keras/models.py
similarity index 75%
rename from tensorflow/contrib/keras/python/keras/models.py
rename to tensorflow/python/keras/_impl/keras/models.py
index 1a0d95c7ff24f3ff658cce7ef86204c04820806b..9a4578b89b3d51512eddcb3b2dfa4f5489370824 100644
--- a/tensorflow/contrib/keras/python/keras/models.py
+++ b/tensorflow/python/keras/_impl/keras/models.py
@@ -15,7 +15,6 @@
 # pylint: disable=protected-access
 """Home of the Sequential model, and the `save_model`/`load_model` functions.
 """
-
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
@@ -26,16 +25,17 @@ import os
 
 import numpy as np
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras import layers as layer_module
-from tensorflow.contrib.keras.python.keras import optimizers
-from tensorflow.contrib.keras.python.keras.engine import topology
-from tensorflow.contrib.keras.python.keras.engine.topology import Input
-from tensorflow.contrib.keras.python.keras.engine.topology import Layer
-from tensorflow.contrib.keras.python.keras.engine.topology import TFBaseLayer
-from tensorflow.contrib.keras.python.keras.engine.training import Model
-from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
 from tensorflow.python.framework import ops
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import layers as layer_module
+from tensorflow.python.keras._impl.keras import optimizers
+from tensorflow.python.keras._impl.keras.engine import topology
+from tensorflow.python.keras._impl.keras.engine.topology import Input
+from tensorflow.python.keras._impl.keras.engine.topology import Layer
+from tensorflow.python.keras._impl.keras.engine.topology import TFBaseLayer
+from tensorflow.python.keras._impl.keras.engine.training import Model
+from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
+from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite
 from tensorflow.python.platform import tf_logging as logging
 
 
@@ -113,7 +113,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
 
     raise TypeError('Not JSON Serializable:', obj)
 
-  from tensorflow.contrib.keras.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
+  from tensorflow.python.keras._impl.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
 
   # If file exists and should not be overwritten.
   if not overwrite and os.path.isfile(filepath):
@@ -121,66 +121,65 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
     if not proceed:
       return
 
-  f = h5py.File(filepath, 'w')
-  f.attrs['keras_version'] = str(keras_version).encode('utf8')
-  f.attrs['backend'] = K.backend().encode('utf8')
-  f.attrs['model_config'] = json.dumps(
-      {
-          'class_name': model.__class__.__name__,
-          'config': model.get_config()
-      },
-      default=get_json_type).encode('utf8')
-
-  model_weights_group = f.create_group('model_weights')
-  model_layers = model.layers
-  topology.save_weights_to_hdf5_group(model_weights_group, model_layers)
-
-  if include_optimizer and hasattr(model, 'optimizer'):
-    if isinstance(model.optimizer, optimizers.TFOptimizer):
-      logging.warning(
-          'TensorFlow optimizers do not '
-          'make it possible to access '
-          'optimizer attributes or optimizer state '
-          'after instantiation. '
-          'As a result, we cannot save the optimizer '
-          'as part of the model save file.'
-          'You will have to compile your model again after loading it. '
-          'Prefer using a Keras optimizer instead '
-          '(see keras.io/optimizers).')
-    else:
-      f.attrs['training_config'] = json.dumps(
-          {
-              'optimizer_config': {
-                  'class_name': model.optimizer.__class__.__name__,
-                  'config': model.optimizer.get_config()
-              },
-              'loss': model.loss,
-              'metrics': model.metrics,
-              'sample_weight_mode': model.sample_weight_mode,
-              'loss_weights': model.loss_weights,
-          },
-          default=get_json_type).encode('utf8')
-
-      # Save optimizer weights.
-      symbolic_weights = getattr(model.optimizer, 'weights')
-      if symbolic_weights:
-        optimizer_weights_group = f.create_group('optimizer_weights')
-        weight_values = K.batch_get_value(symbolic_weights)
-        weight_names = []
-        for w, val in zip(symbolic_weights, weight_values):
-          name = str(w.name)
-          weight_names.append(name.encode('utf8'))
-        optimizer_weights_group.attrs['weight_names'] = weight_names
-        for name, val in zip(weight_names, weight_values):
-          param_dset = optimizer_weights_group.create_dataset(
-              name, val.shape, dtype=val.dtype)
-          if not val.shape:
-            # scalar
-            param_dset[()] = val
-          else:
-            param_dset[:] = val
-  f.flush()
-  f.close()
+  with h5py.File(filepath, mode='w') as f:
+    f.attrs['keras_version'] = str(keras_version).encode('utf8')
+    f.attrs['backend'] = K.backend().encode('utf8')
+    f.attrs['model_config'] = json.dumps(
+        {
+            'class_name': model.__class__.__name__,
+            'config': model.get_config()
+        },
+        default=get_json_type).encode('utf8')
+
+    model_weights_group = f.create_group('model_weights')
+    model_layers = model.layers
+    topology.save_weights_to_hdf5_group(model_weights_group, model_layers)
+
+    if include_optimizer and hasattr(model, 'optimizer'):
+      if isinstance(model.optimizer, optimizers.TFOptimizer):
+        logging.warning(
+            'TensorFlow optimizers do not '
+            'make it possible to access '
+            'optimizer attributes or optimizer state '
+            'after instantiation. '
+            'As a result, we cannot save the optimizer '
+            'as part of the model save file.'
+            'You will have to compile your model again after loading it. '
+            'Prefer using a Keras optimizer instead '
+            '(see keras.io/optimizers).')
+      else:
+        f.attrs['training_config'] = json.dumps(
+            {
+                'optimizer_config': {
+                    'class_name': model.optimizer.__class__.__name__,
+                    'config': model.optimizer.get_config()
+                },
+                'loss': model.loss,
+                'metrics': model.metrics,
+                'sample_weight_mode': model.sample_weight_mode,
+                'loss_weights': model.loss_weights,
+            },
+            default=get_json_type).encode('utf8')
+
+        # Save optimizer weights.
+        symbolic_weights = getattr(model.optimizer, 'weights')
+        if symbolic_weights:
+          optimizer_weights_group = f.create_group('optimizer_weights')
+          weight_values = K.batch_get_value(symbolic_weights)
+          weight_names = []
+          for w, val in zip(symbolic_weights, weight_values):
+            name = str(w.name)
+            weight_names.append(name.encode('utf8'))
+          optimizer_weights_group.attrs['weight_names'] = weight_names
+          for name, val in zip(weight_names, weight_values):
+            param_dset = optimizer_weights_group.create_dataset(
+                name, val.shape, dtype=val.dtype)
+            if not val.shape:
+              # scalar
+              param_dset[()] = val
+            else:
+              param_dset[:] = val
+    f.flush()
 
 
 def load_model(filepath, custom_objects=None, compile=True):  # pylint: disable=redefined-builtin
@@ -545,12 +544,12 @@ class Sequential(Model):
     Returns:
         A layer instance.
     """
-    if self.model is None:
+    if not self.built:
       self.build()
     return self.model.get_layer(name, index)
 
   def call(self, inputs, mask=None):
-    if self.model is None:
+    if not self.built:
       self.build()
     return self.model.call(inputs, mask)
 
@@ -586,7 +585,7 @@ class Sequential(Model):
 
   @property
   def uses_learning_phase(self):
-    if self.model is None:
+    if not self.built:
       self.build()
     return self.model.uses_learning_phase
 
@@ -622,35 +621,35 @@ class Sequential(Model):
 
   @property
   def updates(self):
-    if self.model is None:
+    if not self.built:
       self.build()
     return self.model.updates
 
   @property
   def state_updates(self):
-    if self.model is None:
+    if not self.built:
       self.build()
     return self.model.state_updates
 
   def get_updates_for(self, inputs):
-    if self.model is None:
+    if not self.built:
       self.build()
     return self.model.get_updates_for(inputs)
 
   @property
   def losses(self):
-    if self.model is None:
+    if not self.built:
       self.build()
     return self.model.losses
 
   def get_losses_for(self, inputs):
-    if self.model is None:
+    if not self.built:
       self.build()
     return self.model.get_losses_for(inputs)
 
   @property
   def regularizers(self):
-    if self.model is None:
+    if not self.built:
       self.build()
     return self.model.regularizers
 
@@ -661,7 +660,7 @@ class Sequential(Model):
         A flat list of Numpy arrays
         (one array per model weight).
     """
-    if self.model is None:
+    if not self.built:
       self.build()
     return self.model.get_weights()
 
@@ -673,7 +672,7 @@ class Sequential(Model):
             of Numpy arrays with shapes and types matching
             the output of `model.get_weights()`.
     """
-    if self.model is None:
+    if not self.built:
       self.build()
     self.model.set_weights(weights)
 
@@ -710,6 +709,7 @@ class Sequential(Model):
               loss,
               metrics=None,
               sample_weight_mode=None,
+              weighted_metrics=None,
               **kwargs):
     """Configures the learning process.
 
@@ -725,9 +725,9 @@ class Sequential(Model):
         sample_weight_mode: if you need to do timestep-wise
             sample weighting (2D weights), set this to "temporal".
             "None" defaults to sample-wise weights (1D).
-        **kwargs: for Theano backend, these are passed into K.function.
-            When using the Tensorflow backend, these are passed into
-            `tf.Session.run`.
+        weighted_metrics: list of metrics to be evaluated and weighted
+             by `sample_weight` or `class_weight` during training and testing.
+        **kwargs: These are passed into `tf.Session.run`.
 
     Example:
         ```python
@@ -747,12 +747,14 @@ class Sequential(Model):
         loss,
         metrics=metrics,
         sample_weight_mode=sample_weight_mode,
+        weighted_metrics=weighted_metrics,
         **kwargs)
     self.optimizer = self.model.optimizer
     self.loss = self.model.loss
     self.total_loss = self.model.total_loss
     self.loss_weights = self.model.loss_weights
     self.metrics = self.model.metrics
+    self.weighted_metrics = self.model.weighted_metrics
     self.metrics_tensors = self.model.metrics_tensors
     self.metrics_names = self.model.metrics_names
     self.sample_weight_mode = self.model.sample_weight_mode
@@ -818,7 +820,7 @@ class Sequential(Model):
     Raises:
         RuntimeError: if the model was never compiled.
     """
-    if self.model is None:
+    if not self.built:
       raise RuntimeError('The model needs to be compiled ' 'before being used.')
     return self.model.fit(
         x,
@@ -854,7 +856,7 @@ class Sequential(Model):
     Raises:
         RuntimeError: if the model was never compiled.
     """
-    if self.model is None:
+    if not self.built:
       raise RuntimeError('The model needs to be compiled ' 'before being used.')
     return self.model.evaluate(
         x,
@@ -876,7 +878,7 @@ class Sequential(Model):
     Returns:
         A Numpy array of predictions.
     """
-    if self.model is None:
+    if not self.built:
       self.build()
     return self.model.predict(x, batch_size=batch_size, verbose=verbose)
 
@@ -890,7 +892,7 @@ class Sequential(Model):
     Returns:
         A Numpy array of predictions.
     """
-    if self.model is None:
+    if not self.built:
       self.build()
     return self.model.predict_on_batch(x)
 
@@ -914,7 +916,7 @@ class Sequential(Model):
     Raises:
         RuntimeError: if the model was never compiled.
     """
-    if self.model is None:
+    if not self.built:
       raise RuntimeError('The model needs to be compiled ' 'before being used.')
     return self.model.train_on_batch(
         x, y, sample_weight=sample_weight, class_weight=class_weight)
@@ -937,7 +939,7 @@ class Sequential(Model):
     Raises:
         RuntimeError: if the model was never compiled.
     """
-    if self.model is None:
+    if not self.built:
       raise RuntimeError('The model needs to be compiled ' 'before being used.')
     return self.model.test_on_batch(x, y, sample_weight=sample_weight)
 
@@ -1083,7 +1085,7 @@ class Sequential(Model):
     if kwargs:
       raise ValueError('Unrecognized keyword arguments: ' + str(kwargs))
 
-    if self.model is None:
+    if not self.built:
       raise RuntimeError('The model needs to be compiled ' 'before being used.')
     return self.model.fit_generator(
         generator,
@@ -1149,7 +1151,7 @@ class Sequential(Model):
     if kwargs:
       raise ValueError('Unrecognized keyword arguments: ' + str(kwargs))
 
-    if self.model is None:
+    if not self.built:
       raise RuntimeError('The model needs to be compiled ' 'before being used.')
     return self.model.evaluate_generator(
         generator,
@@ -1205,7 +1207,7 @@ class Sequential(Model):
     if kwargs:
       raise ValueError('Unrecognized keyword arguments: ' + str(kwargs))
 
-    if self.model is None:
+    if not self.built:
       self.build()
     return self.model.predict_generator(
         generator,
@@ -1231,3 +1233,225 @@ class Sequential(Model):
       layer = layer_module.deserialize(conf, custom_objects=custom_objects)
       model.add(layer)
     return model
+
+
+def _clone_functional_model(model, input_tensors=None):
+  """Clone a functional `Model` instance.
+
+  Model cloning is similar to calling a model on new inputs,
+  except that it creates new layers (and thus new weights) instead
+  of sharing the weights of the existing layers.
+
+  Arguments:
+      model: Instance of `Model`.
+      input_tensors: optional list of input tensors
+          to build the model upon. If not provided,
+          placeholders will be created.
+
+  Returns:
+      An instance of `Model` reproducing the behavior
+      of the original model, on top of new inputs tensors,
+      using newly instantiated weights.
+
+  Raises:
+      ValueError: in case of invalid `model` argument value.
+  """
+  if not isinstance(model, Model):
+    raise ValueError('Expected `model` argument '
+                     'to be a `Model` instance, got ', model)
+  if isinstance(model, Sequential):
+    raise ValueError('Expected `model` argument '
+                     'to be a functional `Model` instance, '
+                     'got a `Sequential` instance instead:', model)
+
+  layer_map = {}  # Cache for created layers.
+  tensor_map = {}  # Map {reference_tensor: (corresponding_tensor, mask)}
+  if input_tensors is None:
+    # Create placeholders to build the model on top of.
+    input_layers = []
+    input_tensors = []
+    for layer in model._input_layers:
+      input_tensor = Input(
+          batch_shape=layer.batch_input_shape,
+          dtype=layer.dtype,
+          sparse=layer.sparse,
+          name=layer.name)
+      input_tensors.append(input_tensor)
+      # Cache newly created input layer.
+      newly_created_input_layer = input_tensor._keras_history[0]
+      layer_map[layer] = newly_created_input_layer
+    for original_input_layer, cloned_input_layer in zip(model._input_layers,
+                                                        input_layers):
+      layer_map[original_input_layer] = cloned_input_layer
+  else:
+    # Make sure that all input tensors come from a Keras layer.
+    # If tensor comes from an input layer: cache the input layer.
+    input_tensors = topology._to_list(input_tensors)
+    input_tensors_ = []
+    for i, x in enumerate(input_tensors):
+      if not K.is_keras_tensor(x):
+        name = model._input_layers[i].name
+        input_tensor = Input(tensor=x, name='input_wrapper_for_' + name)
+        input_tensors_.append(input_tensor)
+        # Cache newly created input layer.
+        original_input_layer = x._keras_history[0]
+        newly_created_input_layer = input_tensor._keras_history[0]
+        layer_map[original_input_layer] = newly_created_input_layer
+      else:
+        input_tensors_.append(x)
+    input_tensors = input_tensors_
+
+  for x, y in zip(model.inputs, input_tensors):
+    tensor_map[x] = (y, None)  # tensor, mask
+
+  # Iterated over every node in the reference model, in depth order.
+  depth_keys = list(model._nodes_by_depth.keys())
+  depth_keys.sort(reverse=True)
+  for depth in depth_keys:
+    nodes = model._nodes_by_depth[depth]
+    for node in nodes:
+      # Recover the corresponding layer.
+      layer = node.outbound_layer
+
+      # Get or create layer.
+      if layer not in layer_map:
+        # Clone layer.
+        new_layer = layer.__class__.from_config(layer.get_config())
+        layer_map[layer] = new_layer
+        layer = new_layer
+      else:
+        # Reuse previously cloned layer.
+        layer = layer_map[layer]
+        # Don't call InputLayer multiple times.
+        if isinstance(layer, topology.InputLayer):
+          continue
+
+      # Gather inputs to call the new layer.
+      referenceinput_tensors_ = node.input_tensors
+      reference_output_tensors = node.output_tensors
+
+      # If all previous input tensors are available in tensor_map,
+      # then call node.inbound_layer on them.
+      computed_data = []  # List of tuples (input, mask).
+      for x in referenceinput_tensors_:
+        if x in tensor_map:
+          computed_data.append(tensor_map[x])
+
+      if len(computed_data) == len(referenceinput_tensors_):
+        # Call layer.
+        if node.arguments:
+          kwargs = node.arguments
+        else:
+          kwargs = {}
+        if len(computed_data) == 1:
+          computed_tensor, computed_mask = computed_data[0]
+          if has_arg(layer.call, 'mask'):
+            if 'mask' not in kwargs:
+              kwargs['mask'] = computed_mask
+          output_tensors = topology._to_list(layer(computed_tensor, **kwargs))
+          output_masks = topology._to_list(
+              layer.compute_mask(computed_tensor, computed_mask))
+          computed_tensors = [computed_tensor]
+          computed_masks = [computed_mask]
+        else:
+          computed_tensors = [x[0] for x in computed_data]
+          computed_masks = [x[1] for x in computed_data]
+          if has_arg(layer.call, 'mask'):
+            if 'mask' not in kwargs:
+              kwargs['mask'] = computed_masks
+          output_tensors = topology._to_list(layer(computed_tensors, **kwargs))
+          output_masks = topology._to_list(
+              layer.compute_mask(computed_tensors, computed_masks))
+        # Update tensor_map.
+        for x, y, mask in zip(reference_output_tensors, output_tensors,
+                              output_masks):
+          tensor_map[x] = (y, mask)
+
+  # Check that we did compute the model outputs,
+  # then instantiate a new model from inputs and outputs.
+  output_tensors = []
+  for x in model.outputs:
+    assert x in tensor_map, 'Could not compute output ' + str(x)
+    tensor, _ = tensor_map[x]
+    output_tensors.append(tensor)
+  return Model(input_tensors, output_tensors, name=model.name)
+
+
+def _clone_sequential_model(model, input_tensors=None):
+  """Clone a `Sequential` model instance.
+
+  Model cloning is similar to calling a model on new inputs,
+  except that it creates new layers (and thus new weights) instead
+  of sharing the weights of the existing layers.
+
+  Arguments:
+      model: Instance of `Sequential`.
+      input_tensors: optional list of input tensors
+          to build the model upon. If not provided,
+          placeholders will be created.
+
+  Returns:
+      An instance of `Sequential` reproducing the behavior
+      of the original model, on top of new inputs tensors,
+      using newly instantiated weights.
+
+  Raises:
+      ValueError: in case of invalid `model` argument value.
+  """
+  if not isinstance(model, Sequential):
+    raise ValueError('Expected `model` argument '
+                     'to be a `Sequential` model instance, '
+                     'but got:', model)
+
+  def clone(layer):
+    return layer.__class__.from_config(layer.get_config())
+
+  layers = [clone(layer) for layer in model.layers]
+  if input_tensors is None:
+    return Sequential(layers=layers, name=model.name)
+  else:
+    if len(topology._to_list(input_tensors)) != 1:
+      raise ValueError('To clone a `Sequential` model, we expect '
+                       ' at most one tensor '
+                       'as part of `input_tensors`.')
+    x = topology._to_list(input_tensors)[0]
+    if K.is_keras_tensor(x):
+      origin_layer = x._keras_history[0]
+      if isinstance(origin_layer, topology.InputLayer):
+        return Sequential(layers=[origin_layer] + layers, name=model.name)
+      else:
+        raise ValueError('Cannot clone a `Sequential` model on top '
+                         'of a tensor that comes from a Keras layer '
+                         'other than an `InputLayer`. '
+                         'Use the functional API instead.')
+    input_tensor = Input(tensor=x, name='input_wrapper_for_' + str(x.name))
+    input_layer = input_tensor._keras_history[0]
+    return Sequential(layers=[input_layer] + layers, name=model.name)
+
+
+def clone_model(model, input_tensors=None):
+  """Clone any `Model` instance.
+
+  Model cloning is similar to calling a model on new inputs,
+  except that it creates new layers (and thus new weights) instead
+  of sharing the weights of the existing layers.
+
+  Arguments:
+      model: Instance of `Model`
+          (could be a functional model or a Sequential model).
+      input_tensors: optional list of input tensors
+          to build the model upon. If not provided,
+          placeholders will be created.
+
+  Returns:
+      An instance of `Model` reproducing the behavior
+      of the original model, on top of new inputs tensors,
+      using newly instantiated weights.
+
+  Raises:
+      ValueError: in case of invalid `model` argument value.
+  """
+  if isinstance(model, Sequential):
+    return _clone_sequential_model(model, input_tensors=input_tensors)
+  else:
+    return _clone_functional_model(model, input_tensors=input_tensors)
diff --git a/tensorflow/contrib/keras/python/keras/models_test.py b/tensorflow/python/keras/_impl/keras/models_test.py
similarity index 74%
rename from tensorflow/contrib/keras/python/keras/models_test.py
rename to tensorflow/python/keras/_impl/keras/models_test.py
index 44088a1b32eb2c0bb260ad92c9d4b83a9d7fd41b..fd6b20e0edc024a4e90f16bc23bdb26b4ffbb019 100644
--- a/tensorflow/contrib/keras/python/keras/models_test.py
+++ b/tensorflow/python/keras/_impl/keras/models_test.py
@@ -24,7 +24,7 @@ import tempfile
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 from tensorflow.python.training import training as training_module
 
@@ -316,5 +316,104 @@ class TestSequential(test.TestCase):
         model.build()
 
 
+class TestModelCloning(test.TestCase):
+
+  def test_clone_sequential_model(self):
+    with self.test_session():
+      val_a = np.random.random((10, 4))
+      val_out = np.random.random((10, 4))
+
+      model = keras.models.Sequential()
+      model.add(keras.layers.Dense(4, input_shape=(4,)))
+      model.add(keras.layers.Dropout(0.5))
+      model.add(keras.layers.Dense(4))
+
+    # Everything should work in a new session.
+    keras.backend.clear_session()
+
+    with self.test_session():
+      # With placeholder creation
+      new_model = keras.models.clone_model(model)
+      new_model.compile('rmsprop', 'mse')
+      new_model.train_on_batch(val_a, val_out)
+
+      # On top of new tensor
+      input_a = keras.Input(shape=(4,))
+      new_model = keras.models.clone_model(
+          model, input_tensors=input_a)
+      new_model.compile('rmsprop', 'mse')
+      new_model.train_on_batch(val_a, val_out)
+
+      # On top of new, non-Keras tensor
+      input_a = keras.backend.variable(val_a)
+      new_model = keras.models.clone_model(
+          model, input_tensors=input_a)
+      new_model.compile('rmsprop', 'mse')
+      new_model.train_on_batch(None, val_out)
+
+  def test_clone_functional_model(self):
+    with self.test_session():
+      val_a = np.random.random((10, 4))
+      val_b = np.random.random((10, 4))
+      val_out = np.random.random((10, 4))
+
+      input_a = keras.Input(shape=(4,))
+      input_b = keras.Input(shape=(4,))
+      dense_1 = keras.layers.Dense(4,)
+      dense_2 = keras.layers.Dense(4,)
+
+      x_a = dense_1(input_a)
+      x_a = keras.layers.Dropout(0.5)(x_a)
+      x_b = dense_1(input_b)
+      x_a = dense_2(x_a)
+      outputs = keras.layers.add([x_a, x_b])
+      model = keras.models.Model([input_a, input_b], outputs)
+
+    # Everything should work in a new session.
+    keras.backend.clear_session()
+
+    with self.test_session():
+      # With placeholder creation
+      new_model = keras.models.clone_model(model)
+      new_model.compile('rmsprop', 'mse')
+      new_model.train_on_batch([val_a, val_b], val_out)
+
+      # On top of new tensors
+      input_a = keras.Input(shape=(4,), name='a')
+      input_b = keras.Input(shape=(4,), name='b')
+      new_model = keras.models.clone_model(
+          model, input_tensors=[input_a, input_b])
+      new_model.compile('rmsprop', 'mse')
+      new_model.train_on_batch([val_a, val_b], val_out)
+
+      # On top of new, non-Keras tensors
+      input_a = keras.backend.variable(val_a)
+      input_b = keras.backend.variable(val_b)
+      new_model = keras.models.clone_model(
+          model, input_tensors=[input_a, input_b])
+      new_model.compile('rmsprop', 'mse')
+      new_model.train_on_batch(None, val_out)
+
+  def test_model_cloning_invalid_use_cases(self):
+    seq_model = keras.models.Sequential()
+    seq_model.add(keras.layers.Dense(4, input_shape=(4,)))
+
+    x = keras.Input((4,))
+    y = keras.layers.Dense(4)(x)
+    fn_model = keras.models.Model(x, y)
+
+    with self.assertRaises(ValueError):
+      keras.models._clone_functional_model(seq_model)
+    with self.assertRaises(ValueError):
+      keras.models._clone_functional_model(None)
+    with self.assertRaises(ValueError):
+      keras.models._clone_sequential_model(fn_model)
+
+    with self.assertRaises(ValueError):
+      keras.models._clone_sequential_model(seq_model, input_tensors=[x, x])
+    with self.assertRaises(ValueError):
+      keras.models._clone_sequential_model(seq_model, input_tensors=y)
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/contrib/keras/python/keras/optimizers.py b/tensorflow/python/keras/_impl/keras/optimizers.py
similarity index 79%
rename from tensorflow/contrib/keras/python/keras/optimizers.py
rename to tensorflow/python/keras/_impl/keras/optimizers.py
index a1bd3be026c5d8244af1320aa0ce7acd7f4be26f..a08073fa86442e0564aa63052bb87b92dc64cdf6 100644
--- a/tensorflow/contrib/keras/python/keras/optimizers.py
+++ b/tensorflow/python/keras/_impl/keras/optimizers.py
@@ -23,11 +23,11 @@ import copy
 import six
 from six.moves import zip  # pylint: disable=redefined-builtin
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import serialize_keras_object
 from tensorflow.python.framework import dtypes as dtypes_module
 from tensorflow.python.framework import ops
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
+from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.training import optimizer as tf_optimizer_module
@@ -88,7 +88,7 @@ class Optimizer(object):
     self.updates = []
     self.weights = []
 
-  def get_updates(self, params, constraints, loss):
+  def get_updates(self, loss, params):
     raise NotImplementedError
 
   def get_gradients(self, loss, params):
@@ -163,21 +163,22 @@ class SGD(Optimizer):
 
   def __init__(self, lr=0.01, momentum=0., decay=0., nesterov=False, **kwargs):
     super(SGD, self).__init__(**kwargs)
-    self.iterations = K.variable(0., name='iterations')
-    self.lr = K.variable(lr, name='lr')
-    self.momentum = K.variable(momentum, name='momentum')
-    self.decay = K.variable(decay, name='decay')
+    with K.name_scope(self.__class__.__name__):
+      self.iterations = K.variable(0, dtype='int64', name='iterations')
+      self.lr = K.variable(lr, name='lr')
+      self.momentum = K.variable(momentum, name='momentum')
+      self.decay = K.variable(decay, name='decay')
     self.initial_decay = decay
     self.nesterov = nesterov
 
-  def get_updates(self, params, constraints, loss):
+  def get_updates(self, loss, params):
     grads = self.get_gradients(loss, params)
-    self.updates = []
+    self.updates = [K.update_add(self.iterations, 1)]
 
     lr = self.lr
     if self.initial_decay > 0:
-      lr *= (1. / (1. + self.decay * self.iterations))
-      self.updates.append(K.update_add(self.iterations, 1))
+      lr *= (1. / (1. + self.decay * K.cast(self.iterations,
+                                            K.dtype(self.decay))))
 
     # momentum
     shapes = [K.int_shape(p) for p in params]
@@ -192,10 +193,9 @@ class SGD(Optimizer):
       else:
         new_p = p + v
 
-      # apply constraints
-      if p in constraints:
-        c = constraints[p]
-        new_p = c(new_p)
+      # Apply constraints.
+      if getattr(p, 'constraint', None) is not None:
+        new_p = p.constraint(new_p)
 
       self.updates.append(K.update(p, new_p))
     return self.updates
@@ -212,7 +212,6 @@ class SGD(Optimizer):
 
 
 class RMSprop(Optimizer):
-  # pylint: disable=line-too-long
   """RMSProp optimizer.
 
   It is recommended to leave the parameters of this optimizer
@@ -227,34 +226,30 @@ class RMSprop(Optimizer):
       rho: float >= 0.
       epsilon: float >= 0. Fuzz factor.
       decay: float >= 0. Learning rate decay over each update.
-
-  References:
-      - [rmsprop: Divide the gradient by a running average of its recent
-        magnitude](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
   """
 
-  # pylint: enable=line-too-long
-
   def __init__(self, lr=0.001, rho=0.9, epsilon=1e-8, decay=0., **kwargs):
     super(RMSprop, self).__init__(**kwargs)
-    self.lr = K.variable(lr, name='lr')
-    self.rho = K.variable(rho, name='rho')
+    with K.name_scope(self.__class__.__name__):
+      self.lr = K.variable(lr, name='lr')
+      self.rho = K.variable(rho, name='rho')
+      self.decay = K.variable(decay, name='decay')
+      self.iterations = K.variable(0, dtype='int64', name='iterations')
     self.epsilon = epsilon
-    self.decay = K.variable(decay, name='decay')
     self.initial_decay = decay
-    self.iterations = K.variable(0., name='iterations')
 
-  def get_updates(self, params, constraints, loss):
+  def get_updates(self, loss, params):
     grads = self.get_gradients(loss, params)
-    shapes = [K.int_shape(p) for p in params]
-    accumulators = [K.zeros(shape) for shape in shapes]
+    accumulators = [
+        K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params
+    ]
     self.weights = accumulators
-    self.updates = []
+    self.updates = [K.update_add(self.iterations, 1)]
 
     lr = self.lr
     if self.initial_decay > 0:
-      lr *= (1. / (1. + self.decay * self.iterations))
-      self.updates.append(K.update_add(self.iterations, 1))
+      lr *= (1. / (1. + self.decay * K.cast(self.iterations,
+                                            K.dtype(self.decay))))
 
     for p, g, a in zip(params, grads, accumulators):
       # update accumulator
@@ -262,10 +257,10 @@ class RMSprop(Optimizer):
       self.updates.append(K.update(a, new_a))
       new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
 
-      # apply constraints
-      if p in constraints:
-        c = constraints[p]
-        new_p = c(new_p)
+      # Apply constraints.
+      if getattr(p, 'constraint', None) is not None:
+        new_p = p.constraint(new_p)
+
       self.updates.append(K.update(p, new_p))
     return self.updates
 
@@ -281,7 +276,6 @@ class RMSprop(Optimizer):
 
 
 class Adagrad(Optimizer):
-  # pylint: disable=line-too-long
   """Adagrad optimizer.
 
   It is recommended to leave the parameters of this optimizer
@@ -297,36 +291,36 @@ class Adagrad(Optimizer):
         Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
   """
 
-  # pylint: enable=line-too-long
-
   def __init__(self, lr=0.01, epsilon=1e-8, decay=0., **kwargs):
     super(Adagrad, self).__init__(**kwargs)
-    self.lr = K.variable(lr, name='lr')
+    with K.name_scope(self.__class__.__name__):
+      self.lr = K.variable(lr, name='lr')
+      self.decay = K.variable(decay, name='decay')
+      self.iterations = K.variable(0, dtype='int64', name='iterations')
     self.epsilon = epsilon
-    self.decay = K.variable(decay, name='decay')
     self.initial_decay = decay
-    self.iterations = K.variable(0., name='iterations')
 
-  def get_updates(self, params, constraints, loss):
+  def get_updates(self, loss, params):
     grads = self.get_gradients(loss, params)
     shapes = [K.int_shape(p) for p in params]
     accumulators = [K.zeros(shape) for shape in shapes]
     self.weights = accumulators
-    self.updates = []
+    self.updates = [K.update_add(self.iterations, 1)]
 
     lr = self.lr
     if self.initial_decay > 0:
-      lr *= (1. / (1. + self.decay * self.iterations))
-      self.updates.append(K.update_add(self.iterations, 1))
+      lr *= (1. / (1. + self.decay * K.cast(self.iterations,
+                                            K.dtype(self.decay))))
 
     for p, g, a in zip(params, grads, accumulators):
       new_a = a + K.square(g)  # update accumulator
       self.updates.append(K.update(a, new_a))
       new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon)
-      # apply constraints
-      if p in constraints:
-        c = constraints[p]
-        new_p = c(new_p)
+
+      # Apply constraints.
+      if getattr(p, 'constraint', None) is not None:
+        new_p = p.constraint(new_p)
+
       self.updates.append(K.update(p, new_p))
     return self.updates
 
@@ -341,7 +335,6 @@ class Adagrad(Optimizer):
 
 
 class Adadelta(Optimizer):
-  # pylint: disable=line-too-long
   """Adadelta optimizer.
 
   It is recommended to leave the parameters of this optimizer
@@ -359,29 +352,28 @@ class Adadelta(Optimizer):
         method](http://arxiv.org/abs/1212.5701)
   """
 
-  # pylint: enable=line-too-long
-
   def __init__(self, lr=1.0, rho=0.95, epsilon=1e-8, decay=0., **kwargs):
     super(Adadelta, self).__init__(**kwargs)
-    self.lr = K.variable(lr, name='lr')
+    with K.name_scope(self.__class__.__name__):
+      self.lr = K.variable(lr, name='lr')
+      self.decay = K.variable(decay, name='decay')
+      self.iterations = K.variable(0, dtype='int64', name='iterations')
     self.rho = rho
     self.epsilon = epsilon
-    self.decay = K.variable(decay, name='decay')
     self.initial_decay = decay
-    self.iterations = K.variable(0., name='iterations')
 
-  def get_updates(self, params, constraints, loss):
+  def get_updates(self, loss, params):
     grads = self.get_gradients(loss, params)
     shapes = [K.int_shape(p) for p in params]
     accumulators = [K.zeros(shape) for shape in shapes]
     delta_accumulators = [K.zeros(shape) for shape in shapes]
     self.weights = accumulators + delta_accumulators
-    self.updates = []
+    self.updates = [K.update_add(self.iterations, 1)]
 
     lr = self.lr
     if self.initial_decay > 0:
-      lr *= (1. / (1. + self.decay * self.iterations))
-      self.updates.append(K.update_add(self.iterations, 1))
+      lr *= (1. / (1. + self.decay * K.cast(self.iterations,
+                                            K.dtype(self.decay))))
 
     for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators):
       # update accumulator
@@ -390,12 +382,12 @@ class Adadelta(Optimizer):
 
       # use the new accumulator and the *old* delta_accumulator
       update = g * K.sqrt(d_a + self.epsilon) / K.sqrt(new_a + self.epsilon)
-
       new_p = p - lr * update
-      # apply constraints
-      if p in constraints:
-        c = constraints[p]
-        new_p = c(new_p)
+
+      # Apply constraints.
+      if getattr(p, 'constraint', None) is not None:
+        new_p = p.constraint(new_p)
+
       self.updates.append(K.update(p, new_p))
 
       # update delta_accumulator
@@ -415,7 +407,6 @@ class Adadelta(Optimizer):
 
 
 class Adam(Optimizer):
-  # pylint: disable=line-too-long
   """Adam optimizer.
 
   Default parameters follow those provided in the original paper.
@@ -432,8 +423,6 @@ class Adam(Optimizer):
         Optimization](http://arxiv.org/abs/1412.6980v8)
   """
 
-  # pylint: enable=line-too-long
-
   def __init__(self,
                lr=0.001,
                beta_1=0.9,
@@ -442,29 +431,30 @@ class Adam(Optimizer):
                decay=0.,
                **kwargs):
     super(Adam, self).__init__(**kwargs)
-    self.iterations = K.variable(0, name='iterations')
-    self.lr = K.variable(lr, name='lr')
-    self.beta_1 = K.variable(beta_1, name='beta_1')
-    self.beta_2 = K.variable(beta_2, name='beta_2')
+    with K.name_scope(self.__class__.__name__):
+      self.iterations = K.variable(0, dtype='int64', name='iterations')
+      self.lr = K.variable(lr, name='lr')
+      self.beta_1 = K.variable(beta_1, name='beta_1')
+      self.beta_2 = K.variable(beta_2, name='beta_2')
+      self.decay = K.variable(decay, name='decay')
     self.epsilon = epsilon
-    self.decay = K.variable(decay, name='decay')
     self.initial_decay = decay
 
-  def get_updates(self, params, constraints, loss):
+  def get_updates(self, loss, params):
     grads = self.get_gradients(loss, params)
     self.updates = [K.update_add(self.iterations, 1)]
 
     lr = self.lr
     if self.initial_decay > 0:
-      lr *= (1. / (1. + self.decay * self.iterations))
+      lr *= (1. / (1. + self.decay * K.cast(self.iterations,
+                                            K.dtype(self.decay))))
 
-    t = self.iterations + 1
+    t = K.cast(self.iterations, K.floatx()) + 1
     lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
                  (1. - K.pow(self.beta_1, t)))
 
-    shapes = [K.int_shape(p) for p in params]
-    ms = [K.zeros(shape) for shape in shapes]
-    vs = [K.zeros(shape) for shape in shapes]
+    ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
+    vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
     self.weights = [self.iterations] + ms + vs
 
     for p, g, m, v in zip(params, grads, ms, vs):
@@ -474,12 +464,12 @@ class Adam(Optimizer):
 
       self.updates.append(K.update(m, m_t))
       self.updates.append(K.update(v, v_t))
-
       new_p = p_t
-      # apply constraints
-      if p in constraints:
-        c = constraints[p]
-        new_p = c(new_p)
+
+      # Apply constraints.
+      if getattr(p, 'constraint', None) is not None:
+        new_p = p.constraint(new_p)
+
       self.updates.append(K.update(p, new_p))
     return self.updates
 
@@ -496,7 +486,6 @@ class Adam(Optimizer):
 
 
 class Adamax(Optimizer):
-  # pylint: disable=line-too-long
   """Adamax optimizer from Adam paper's Section 7.
 
   It is a variant of Adam based on the infinity norm.
@@ -513,8 +502,6 @@ class Adamax(Optimizer):
         Optimization](http://arxiv.org/abs/1412.6980v8)
   """
 
-  # pylint: enable=line-too-long
-
   def __init__(self,
                lr=0.002,
                beta_1=0.9,
@@ -523,23 +510,25 @@ class Adamax(Optimizer):
                decay=0.,
                **kwargs):
     super(Adamax, self).__init__(**kwargs)
-    self.iterations = K.variable(0., name='iterations')
-    self.lr = K.variable(lr, name='lr')
-    self.beta_1 = K.variable(beta_1, name='beta_1')
-    self.beta_2 = K.variable(beta_2, name='beta_2')
+    with K.name_scope(self.__class__.__name__):
+      self.iterations = K.variable(0, dtype='int64', name='iterations')
+      self.lr = K.variable(lr, name='lr')
+      self.beta_1 = K.variable(beta_1, name='beta_1')
+      self.beta_2 = K.variable(beta_2, name='beta_2')
+      self.decay = K.variable(decay, name='decay')
     self.epsilon = epsilon
-    self.decay = K.variable(decay, name='decay')
     self.initial_decay = decay
 
-  def get_updates(self, params, constraints, loss):
+  def get_updates(self, loss, params):
     grads = self.get_gradients(loss, params)
     self.updates = [K.update_add(self.iterations, 1)]
 
     lr = self.lr
     if self.initial_decay > 0:
-      lr *= (1. / (1. + self.decay * self.iterations))
+      lr *= (1. / (1. + self.decay * K.cast(self.iterations,
+                                            K.dtype(self.decay))))
 
-    t = self.iterations + 1
+    t = K.cast(self.iterations, K.floatx()) + 1
     lr_t = lr / (1. - K.pow(self.beta_1, t))
 
     shapes = [K.int_shape(p) for p in params]
@@ -557,12 +546,12 @@ class Adamax(Optimizer):
 
       self.updates.append(K.update(m, m_t))
       self.updates.append(K.update(u, u_t))
-
       new_p = p_t
-      # apply constraints
-      if p in constraints:
-        c = constraints[p]
-        new_p = c(new_p)
+
+      # Apply constraints.
+      if getattr(p, 'constraint', None) is not None:
+        new_p = p.constraint(new_p)
+
       self.updates.append(K.update(p, new_p))
     return self.updates
 
@@ -579,7 +568,6 @@ class Adamax(Optimizer):
 
 
 class Nadam(Optimizer):
-  # pylint: disable=line-too-long
   """Nesterov Adam optimizer.
 
   Much like Adam is essentially RMSprop with momentum,
@@ -600,8 +588,6 @@ class Nadam(Optimizer):
         learning](http://www.cs.toronto.edu/~fritz/absps/momentum.pdf)
   """
 
-  # pylint: enable=line-too-long
-
   def __init__(self,
                lr=0.002,
                beta_1=0.9,
@@ -610,26 +596,26 @@ class Nadam(Optimizer):
                schedule_decay=0.004,
                **kwargs):
     super(Nadam, self).__init__(**kwargs)
-    self.iterations = K.variable(0., name='iterations')
-    self.m_schedule = K.variable(1., name='m_schedule')
-    self.lr = K.variable(lr, name='lr')
-    self.beta_1 = K.variable(beta_1, name='beta_1')
-    self.beta_2 = K.variable(beta_2, name='beta_2')
+    with K.name_scope(self.__class__.__name__):
+      self.iterations = K.variable(0, dtype='int64', name='iterations')
+      self.m_schedule = K.variable(1., name='m_schedule')
+      self.lr = K.variable(lr, name='lr')
+      self.beta_1 = K.variable(beta_1, name='beta_1')
+      self.beta_2 = K.variable(beta_2, name='beta_2')
     self.epsilon = epsilon
     self.schedule_decay = schedule_decay
 
-  def get_updates(self, params, constraints, loss):
+  def get_updates(self, loss, params):
     grads = self.get_gradients(loss, params)
     self.updates = [K.update_add(self.iterations, 1)]
-
-    t = self.iterations + 1
+    t = K.cast(self.iterations, K.floatx()) + 1
 
     # Due to the recommendations in [2], i.e. warming momentum schedule
-    momentum_cache_t = self.beta_1 * (1. - 0.5 *
-                                      (K.pow(0.96, t * self.schedule_decay)))
-    momentum_cache_t_1 = self.beta_1 * (1. - 0.5 *
-                                        (K.pow(0.96,
-                                               (t + 1) * self.schedule_decay)))
+    momentum_cache_t = self.beta_1 * (
+        1. - 0.5 * (K.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
+    momentum_cache_t_1 = self.beta_1 * (
+        1. - 0.5 *
+        (K.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
     m_schedule_new = self.m_schedule * momentum_cache_t
     m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
     self.updates.append((self.m_schedule, m_schedule_new))
@@ -656,10 +642,10 @@ class Nadam(Optimizer):
       p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
       new_p = p_t
 
-      # apply constraints
-      if p in constraints:
-        c = constraints[p]
-        new_p = c(new_p)
+      # Apply constraints.
+      if getattr(p, 'constraint', None) is not None:
+        new_p = p.constraint(new_p)
+
       self.updates.append(K.update(p, new_p))
     return self.updates
 
@@ -681,16 +667,12 @@ class TFOptimizer(Optimizer):
 
   def __init__(self, optimizer):  # pylint: disable=super-init-not-called
     self.optimizer = optimizer
-    self.iterations = K.variable(0., name='iterations')
-    self.updates = []
+    with K.name_scope(self.__class__.__name__):
+      self.iterations = K.variable(0, dtype='int64', name='iterations')
 
-  def get_updates(self, params, constraints, loss):
-    if constraints:
-      raise ValueError('TF optimizers do not support '
-                       'weights constraints. Either remove '
-                       'all weights constraints in your model, '
-                       'or use a Keras optimizer.')
+  def get_updates(self, loss, params):
     grads = self.optimizer.compute_gradients(loss, params)
+    self.updates = [K.update_add(self.iterations, 1)]
     opt_update = self.optimizer.apply_gradients(
         grads, global_step=self.iterations)
     self.updates.append(opt_update)
diff --git a/tensorflow/contrib/keras/python/keras/optimizers_test.py b/tensorflow/python/keras/_impl/keras/optimizers_test.py
similarity index 80%
rename from tensorflow/contrib/keras/python/keras/optimizers_test.py
rename to tensorflow/python/keras/_impl/keras/optimizers_test.py
index bb598f30373e797a7850d232d11ac2ace3150b05..b63d82f6a0ff9af3cb3761ed11fd4367e542ad06 100644
--- a/tensorflow/contrib/keras/python/keras/optimizers_test.py
+++ b/tensorflow/python/keras/_impl/keras/optimizers_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 from tensorflow.python.training.adam import AdamOptimizer
 
@@ -54,6 +54,23 @@ def _test_optimizer(optimizer, target=0.75):
   new_config['class_name'] = new_config['class_name'].lower()
   assert config == new_config
 
+  # Test constraints.
+  model = keras.models.Sequential()
+  dense = keras.layers.Dense(10,
+                             input_shape=(x_train.shape[1],),
+                             kernel_constraint=lambda x: 0. * x + 1.,
+                             bias_constraint=lambda x: 0. * x + 2.,
+                             activation='relu')
+  model.add(dense)
+  model.add(keras.layers.Dense(y_train.shape[1], activation='softmax'))
+  model.compile(loss='categorical_crossentropy',
+                optimizer=optimizer,
+                metrics=['accuracy'])
+  model.train_on_batch(x_train[:10], y_train[:10])
+  kernel, bias = dense.get_weights()
+  np.testing.assert_allclose(kernel, 1., atol=1e-3)
+  np.testing.assert_allclose(bias, 2., atol=1e-3)
+
 
 class KerasOptimizersTest(test.TestCase):
 
@@ -105,19 +122,17 @@ class KerasOptimizersTest(test.TestCase):
                                            clipvalue=0.5))
 
   def test_tfoptimizer(self):
-    optimizer = keras.optimizers.TFOptimizer(AdamOptimizer)
+    optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
     model = keras.models.Sequential()
     model.add(keras.layers.Dense(
         2, input_shape=(3,), kernel_constraint=keras.constraints.MaxNorm(1)))
     # This is possible
     model.compile(loss='mean_squared_error', optimizer=optimizer)
-    # TF optimizers do not support weights constraints
-    with self.assertRaises(ValueError):
-      model.fit(np.random.random((5, 3)),
-                np.random.random((5, 2)),
-                epochs=1,
-                batch_size=5,
-                verbose=0)
+    model.fit(np.random.random((5, 3)),
+              np.random.random((5, 2)),
+              epochs=1,
+              batch_size=5,
+              verbose=0)
     # not supported
     with self.assertRaises(NotImplementedError):
       _ = optimizer.weights
diff --git a/tensorflow/contrib/keras/python/keras/preprocessing/__init__.py b/tensorflow/python/keras/_impl/keras/preprocessing/__init__.py
similarity index 79%
rename from tensorflow/contrib/keras/python/keras/preprocessing/__init__.py
rename to tensorflow/python/keras/_impl/keras/preprocessing/__init__.py
index 9ae14c9674e07c281d57993b0024edee4ead7231..2ca48cdbf9c066194f4f4ed448fd621167db7ba9 100644
--- a/tensorflow/contrib/keras/python/keras/preprocessing/__init__.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/__init__.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras.preprocessing import image
-from tensorflow.contrib.keras.python.keras.preprocessing import sequence
-from tensorflow.contrib.keras.python.keras.preprocessing import text
+from tensorflow.python.keras._impl.keras.preprocessing import image
+from tensorflow.python.keras._impl.keras.preprocessing import sequence
+from tensorflow.python.keras._impl.keras.preprocessing import text
 
diff --git a/tensorflow/contrib/keras/python/keras/preprocessing/image.py b/tensorflow/python/keras/_impl/keras/preprocessing/image.py
similarity index 99%
rename from tensorflow/contrib/keras/python/keras/preprocessing/image.py
rename to tensorflow/python/keras/_impl/keras/preprocessing/image.py
index 4f2cff804e56aa7716e3368a2fcd4b0ecb45a49d..052a8addc4c37f6df01a9103dc8a07e4726ec735 100644
--- a/tensorflow/contrib/keras/python/keras/preprocessing/image.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/image.py
@@ -30,7 +30,7 @@ import threading
 import numpy as np
 from six.moves import range  # pylint: disable=redefined-builtin
 
-from tensorflow.contrib.keras.python.keras import backend as K
+from tensorflow.python.keras._impl.keras import backend as K
 from tensorflow.python.platform import tf_logging as logging
 
 
@@ -583,9 +583,9 @@ class ImageDataGenerator(object):
                         'first by calling `.fit(numpy_data)`.')
     if self.zca_whitening:
       if self.principal_components is not None:
-        flatx = np.reshape(x, (x.size))
+        flatx = np.reshape(x, (-1, np.prod(x.shape[-3:])))
         whitex = np.dot(flatx, self.principal_components)
-        x = np.reshape(whitex, (x.shape[0], x.shape[1], x.shape[2]))
+        x = np.reshape(whitex, x.shape)
       else:
         logging.warning('This ImageDataGenerator specifies '
                         '`zca_whitening`, but it hasn\'t'
@@ -864,7 +864,7 @@ class NumpyArrayIterator(Iterator):
                        'with shape', self.x.shape)
     channels_axis = 3 if data_format == 'channels_last' else 1
     if self.x.shape[channels_axis] not in {1, 3, 4}:
-      raise ValueError(
+      logging.warning(
           'NumpyArrayIterator is set to use the '
           'data format convention "' + data_format + '" '
           '(channels on axis ' + str(channels_axis) + '), i.e. expected '
@@ -1076,7 +1076,7 @@ class DirectoryIterator(Iterator):
     self.save_prefix = save_prefix
     self.save_format = save_format
 
-    white_list_formats = {'png', 'jpg', 'jpeg', 'bmp'}
+    white_list_formats = {'png', 'jpg', 'jpeg', 'bmp', 'ppm'}
 
     # first, count the number of samples and classes
     self.samples = 0
diff --git a/tensorflow/contrib/keras/python/keras/preprocessing/image_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py
similarity index 96%
rename from tensorflow/contrib/keras/python/keras/preprocessing/image_test.py
rename to tensorflow/python/keras/_impl/keras/preprocessing/image_test.py
index bb09ed1ae831bdfc57725efb2757546994993970..19693410e761a2d800e8c8e151264f91ef30897c 100644
--- a/tensorflow/contrib/keras/python/keras/preprocessing/image_test.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py
@@ -23,7 +23,7 @@ import shutil
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 try:
@@ -78,6 +78,11 @@ class TestImage(test.TestCase):
           cval=0.5,
           horizontal_flip=True,
           vertical_flip=True)
+      # Basic test before fit
+      x = np.random.random((32, 10, 10, 3))
+      generator.flow(x)
+
+      # Fit
       generator.fit(images, augment=True)
 
       for x, _ in generator.flow(
@@ -95,14 +100,17 @@ class TestImage(test.TestCase):
         samplewise_std_normalization=True,
         zca_whitening=True,
         data_format='channels_last')
+
     # Test fit with invalid data
     with self.assertRaises(ValueError):
       x = np.random.random((3, 10, 10))
       generator.fit(x)
     # Test flow with invalid data
     with self.assertRaises(ValueError):
-      x = np.random.random((32, 10, 10))
-      generator.flow(np.arange(x.shape[0]))
+      generator.flow(np.arange(5))
+    # Invalid number of channels: will work but raise a warning
+    x = np.random.random((32, 10, 10, 5))
+    generator.flow(x)
 
     with self.assertRaises(ValueError):
       generator = keras.preprocessing.image.ImageDataGenerator(
diff --git a/tensorflow/contrib/keras/python/keras/preprocessing/sequence.py b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/preprocessing/sequence.py
rename to tensorflow/python/keras/_impl/keras/preprocessing/sequence.py
index 382aa386d4e05c6bafd8a76fa8beebac44497aba..a5deec87af7729c20face3517689b7da4b48c8df 100644
--- a/tensorflow/contrib/keras/python/keras/preprocessing/sequence.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py
@@ -143,7 +143,8 @@ def skipgrams(sequence,
               negative_samples=1.,
               shuffle=True,
               categorical=False,
-              sampling_table=None):
+              sampling_table=None,
+              seed=None):
   """Generates skipgram word pairs.
 
   Takes a sequence (list of indexes of words),
@@ -169,6 +170,7 @@ def skipgrams(sequence,
           if True labels will be categorical eg. [[1,0],[0,1],[0,1] .. ]
       sampling_table: 1D array of size `vocabulary_size` where the entry i
           encodes the probabibily to sample a word of rank i.
+      seed: Random seed.
 
   Returns:
       couples, labels: where `couples` are int pairs and
@@ -214,7 +216,8 @@ def skipgrams(sequence,
       labels += [0] * num_negative_samples
 
   if shuffle:
-    seed = random.randint(0, 10e6)
+    if seed is None:
+      seed = random.randint(0, 10e6)
     random.seed(seed)
     random.shuffle(couples)
     random.seed(seed)
diff --git a/tensorflow/contrib/keras/python/keras/preprocessing/sequence_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/preprocessing/sequence_test.py
rename to tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py
index 4e54b95c8bf260dbf93e1352c6403c39f090d602..4529e6e94fc42661fb0474c1a827863ddb654776 100644
--- a/tensorflow/contrib/keras/python/keras/preprocessing/sequence_test.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/preprocessing/text.py b/tensorflow/python/keras/_impl/keras/preprocessing/text.py
similarity index 97%
rename from tensorflow/contrib/keras/python/keras/preprocessing/text.py
rename to tensorflow/python/keras/_impl/keras/preprocessing/text.py
index ed00eef6ad8b4d36ef4aeac49dbe4d80f9e90cbc..47e5aa064fd806196fc9457fc90bc1a26e55ebf3 100644
--- a/tensorflow/contrib/keras/python/keras/preprocessing/text.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/text.py
@@ -52,7 +52,13 @@ def text_to_word_sequence(text,
   """
   if lower:
     text = text.lower()
-  text = text.translate(maketrans(filters, split * len(filters)))
+
+  if sys.version_info < (3,) and isinstance(text, unicode):
+    translate_map = dict((ord(c), unicode(split)) for c in filters)
+  else:
+    translate_map = maketrans(filters, split * len(filters))
+
+  text = text.translate(translate_map)
   seq = text.split(split)
   return [i for i in seq if i]
 
diff --git a/tensorflow/contrib/keras/python/keras/preprocessing/text_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py
similarity index 89%
rename from tensorflow/contrib/keras/python/keras/preprocessing/text_test.py
rename to tensorflow/python/keras/_impl/keras/preprocessing/text_test.py
index 7b26219e61bba0c0503e8a886e78810d8fad23fa..17ab48ba3fc9dfd553f8f425579c0a37ff42eb84 100644
--- a/tensorflow/contrib/keras/python/keras/preprocessing/text_test.py
+++ b/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
@@ -33,6 +33,13 @@ class TestText(test.TestCase):
     self.assertLessEqual(np.max(encoded), 4)
     self.assertGreaterEqual(np.min(encoded), 0)
 
+    # Test on unicode.
+    text = u'The cat sat on the mat.'
+    encoded = keras.preprocessing.text.one_hot(text, 5)
+    self.assertEqual(len(encoded), 6)
+    self.assertLessEqual(np.max(encoded), 4)
+    self.assertGreaterEqual(np.min(encoded), 0)
+
   def test_tokenizer(self):
     texts = [
         'The cat sat on the mat.',
diff --git a/tensorflow/contrib/keras/python/keras/regularizers.py b/tensorflow/python/keras/_impl/keras/regularizers.py
similarity index 90%
rename from tensorflow/contrib/keras/python/keras/regularizers.py
rename to tensorflow/python/keras/_impl/keras/regularizers.py
index 36cc5c47e41e8fa2a9019cad458b80cfa21a9d3f..161ff9bf5bf12b3521fe444f1d68bd62b6e8c71d 100644
--- a/tensorflow/contrib/keras/python/keras/regularizers.py
+++ b/tensorflow/python/keras/_impl/keras/regularizers.py
@@ -20,9 +20,9 @@ from __future__ import print_function
 
 import six
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
+from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object
 
 
 class Regularizer(object):
diff --git a/tensorflow/contrib/keras/python/keras/regularizers_test.py b/tensorflow/python/keras/_impl/keras/regularizers_test.py
similarity index 95%
rename from tensorflow/contrib/keras/python/keras/regularizers_test.py
rename to tensorflow/python/keras/_impl/keras/regularizers_test.py
index 528024994f31fcaf4c42016be985fbb4d39b5643..9a1612b7779d1ede008b5bcd88173fe53762cf10 100644
--- a/tensorflow/contrib/keras/python/keras/regularizers_test.py
+++ b/tensorflow/python/keras/_impl/keras/regularizers_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/testing_utils.py b/tensorflow/python/keras/_impl/keras/testing_utils.py
similarity index 99%
rename from tensorflow/contrib/keras/python/keras/testing_utils.py
rename to tensorflow/python/keras/_impl/keras/testing_utils.py
index 2f51ace945fbce636f947659162af186da453ae0..f204a5df3e654eebd5c0165f383f2c418961f5ba 100644
--- a/tensorflow/contrib/keras/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/_impl/keras/testing_utils.py
@@ -20,7 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.util import tf_inspect
 
 
diff --git a/tensorflow/python/keras/_impl/keras/utils/__init__.py b/tensorflow/python/keras/_impl/keras/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa50b123b79cc599e3e1bd2328823dc3eefc1f95
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/utils/__init__.py
@@ -0,0 +1,43 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras utilities.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.utils import conv_utils
+from tensorflow.python.keras._impl.keras.utils import data_utils
+from tensorflow.python.keras._impl.keras.utils import generic_utils
+from tensorflow.python.keras._impl.keras.utils import io_utils
+from tensorflow.python.keras._impl.keras.utils import np_utils
+from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer
+from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence
+from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope
+from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope
+from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
+from tensorflow.python.keras._impl.keras.utils.generic_utils import get_custom_objects
+from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
+from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix
+from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model
+from tensorflow.python.keras._impl.keras.utils.np_utils import normalize
+from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical
+from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model
+
+
+# Globally-importable utils.
diff --git a/tensorflow/contrib/keras/python/keras/utils/conv_utils.py b/tensorflow/python/keras/_impl/keras/utils/conv_utils.py
similarity index 97%
rename from tensorflow/contrib/keras/python/keras/utils/conv_utils.py
rename to tensorflow/python/keras/_impl/keras/utils/conv_utils.py
index ea3a70edab80d6ca9fed0b30e76fbdfefef04972..583079d9626361eb594f16a57af86f103e5ee74d 100644
--- a/tensorflow/contrib/keras/python/keras/utils/conv_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/conv_utils.py
@@ -22,7 +22,7 @@ import numpy as np
 from six.moves import range  # pylint: disable=redefined-builtin
 
 # pylint: disable=unused-import
-from tensorflow.contrib.keras.python.keras import backend as K
+from tensorflow.python.keras._impl.keras import backend as K
 from tensorflow.python.layers.utils import conv_input_length
 from tensorflow.python.layers.utils import conv_output_length
 from tensorflow.python.layers.utils import deconv_output_length as deconv_length
diff --git a/tensorflow/contrib/keras/python/keras/utils/data_utils.py b/tensorflow/python/keras/_impl/keras/utils/data_utils.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/utils/data_utils.py
rename to tensorflow/python/keras/_impl/keras/utils/data_utils.py
index 853625e7c475af7f5927bf60c9900c29708c4b48..0ede7f12f2cd31ee86baefc870748f206332342c 100644
--- a/tensorflow/contrib/keras/python/keras/utils/data_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/data_utils.py
@@ -36,7 +36,7 @@ from six.moves.urllib.error import HTTPError
 from six.moves.urllib.error import URLError
 from six.moves.urllib.request import urlopen
 
-from tensorflow.contrib.keras.python.keras.utils.generic_utils import Progbar
+from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
 
 try:
   import queue  # pylint:disable=g-import-not-at-top
@@ -368,6 +368,12 @@ class Sequence(object):
     """
     raise NotImplementedError
 
+  @abstractmethod
+  def on_epoch_end(self):
+    """Method called at the end of every epoch.
+    """
+    raise NotImplementedError
+
 
 def get_index(ds, i):
   """Quick fix for Python2, otherwise, it cannot be pickled.
@@ -453,15 +459,16 @@ class OrderedEnqueuer(SequenceEnqueuer):
       use_multiprocessing: use multiprocessing if True, otherwise threading
       scheduling: Sequential querying of datas if 'sequential', random
         otherwise.
+      shuffle: Whether to shuffle the data at the beginning of each epoch.
   """
 
   def __init__(self,
                sequence,
                use_multiprocessing=False,
-               scheduling='sequential'):
+               shuffle=False):
     self.sequence = sequence
     self.use_multiprocessing = use_multiprocessing
-    self.scheduling = scheduling
+    self.shuffle = shuffle
     self.workers = 0
     self.executor = None
     self.queue = None
@@ -493,7 +500,7 @@ class OrderedEnqueuer(SequenceEnqueuer):
     """Submits requests to the executor and queues the `Future` objects."""
     sequence = list(range(len(self.sequence)))
     while True:
-      if self.scheduling is not 'sequential':
+      if self.shuffle:
         random.shuffle(sequence)
       for i in sequence:
         if self.stop_signal.is_set():
@@ -501,6 +508,7 @@ class OrderedEnqueuer(SequenceEnqueuer):
         self.queue.put(
             self.executor.apply_async(get_index, (self.sequence, i)),
             block=True)
+      self.sequence.on_epoch_end()
 
   def get(self):
     """Creates a generator to extract data from the queue.
diff --git a/tensorflow/contrib/keras/python/keras/utils/data_utils_test.py b/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py
similarity index 99%
rename from tensorflow/contrib/keras/python/keras/utils/data_utils_test.py
rename to tensorflow/python/keras/_impl/keras/utils/data_utils_test.py
index 55d08a34d0af3cbd5e7b9a81a2349d463a1a5706..45322f1f29cb1351c409957d060c21abffdf1d6f 100644
--- a/tensorflow/contrib/keras/python/keras/utils/data_utils_test.py
+++ b/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py
@@ -28,7 +28,7 @@ import numpy as np
 from six.moves.urllib.parse import urljoin
 from six.moves.urllib.request import pathname2url
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
similarity index 97%
rename from tensorflow/contrib/keras/python/keras/utils/generic_utils.py
rename to tensorflow/python/keras/_impl/keras/utils/generic_utils.py
index 3428476b173f4ca7987a3561ca4502c44364c6f1..39a10c8650f67216ae6a238bb6f3b7e4088ad163 100644
--- a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
@@ -197,7 +197,8 @@ def func_dump(func):
       A tuple `(code, defaults, closure)`.
   """
   if os.name == 'nt':
-    code = marshal.dumps(func.__code__).replace(b'\\',b'/').decode('raw_unicode_escape')
+    code = marshal.dumps(
+        func.__code__).replace(b'\\', b'/').decode('raw_unicode_escape')
   else:
     code = marshal.dumps(func.__code__).decode('raw_unicode_escape')
   defaults = func.__defaults__
@@ -331,7 +332,7 @@ class Progbar(object):
       for k in self.unique_values:
         info += ' - %s:' % k
         if isinstance(self.sum_values[k], list):
-          avg = self.sum_values[k][0] / max(1, self.sum_values[k][1])
+          avg = np.mean(self.sum_values[k][0] / max(1, self.sum_values[k][1]))
           if abs(avg) > 1e-3:
             info += ' %.4f' % avg
           else:
@@ -354,7 +355,7 @@ class Progbar(object):
         info = '%ds' % (now - self.start)
         for k in self.unique_values:
           info += ' - %s:' % k
-          avg = self.sum_values[k][0] / max(1, self.sum_values[k][1])
+          avg = np.mean(self.sum_values[k][0] / max(1, self.sum_values[k][1]))
           if avg > 1e-3:
             info += ' %.4f' % avg
           else:
diff --git a/tensorflow/contrib/keras/python/keras/utils/generic_utils_test.py b/tensorflow/python/keras/_impl/keras/utils/generic_utils_test.py
similarity index 97%
rename from tensorflow/contrib/keras/python/keras/utils/generic_utils_test.py
rename to tensorflow/python/keras/_impl/keras/utils/generic_utils_test.py
index 8a6519f4cc7c0313dbd95331b912d0a7d4c84bf2..d57692f4f41753fc38ead2ace7e989b499bc23ff 100644
--- a/tensorflow/contrib/keras/python/keras/utils/generic_utils_test.py
+++ b/tensorflow/python/keras/_impl/keras/utils/generic_utils_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 
diff --git a/tensorflow/contrib/keras/python/keras/utils/io_utils.py b/tensorflow/python/keras/_impl/keras/utils/io_utils.py
similarity index 99%
rename from tensorflow/contrib/keras/python/keras/utils/io_utils.py
rename to tensorflow/python/keras/_impl/keras/utils/io_utils.py
index 70b2d96907d883a028cbf83ac94404d3846d4d19..5f2ba99be783f8d24e4aef0eaa450a94f9da6e8b 100644
--- a/tensorflow/contrib/keras/python/keras/utils/io_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/io_utils.py
@@ -89,7 +89,7 @@ class HDF5Matrix(object):
         idx = slice(start + self.start, stop + self.start)
       else:
         raise IndexError
-    elif isinstance(key, int):
+    elif isinstance(key, (int, np.integer)):
       if key + self.start < self.end:
         idx = key + self.start
       else:
diff --git a/tensorflow/contrib/keras/python/keras/utils/io_utils_test.py b/tensorflow/python/keras/_impl/keras/utils/io_utils_test.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/utils/io_utils_test.py
rename to tensorflow/python/keras/_impl/keras/utils/io_utils_test.py
index baa9781e71f30021739fa10daf20185bc8397599..cfeba188d3cadfa08efbd07fcbd46776b691e06f 100644
--- a/tensorflow/contrib/keras/python/keras/utils/io_utils_test.py
+++ b/tensorflow/python/keras/_impl/keras/utils/io_utils_test.py
@@ -23,7 +23,7 @@ import shutil
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
+from tensorflow.python.keras._impl import keras
 from tensorflow.python.platform import test
 
 try:
diff --git a/tensorflow/contrib/keras/python/keras/utils/layer_utils.py b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/utils/layer_utils.py
rename to tensorflow/python/keras/_impl/keras/utils/layer_utils.py
index 12d5368b088a788f1bd97b93a526bf0425df1ce4..399bbf3475097a895f03e6b606711c36de9dcaaf 100644
--- a/tensorflow/contrib/keras/python/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras.utils.conv_utils import convert_kernel
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras.utils.conv_utils import convert_kernel
 
 
 def print_summary(model, line_length=None, positions=None, print_fn=None):
diff --git a/tensorflow/contrib/keras/python/keras/utils/np_utils.py b/tensorflow/python/keras/_impl/keras/utils/np_utils.py
similarity index 100%
rename from tensorflow/contrib/keras/python/keras/utils/np_utils.py
rename to tensorflow/python/keras/_impl/keras/utils/np_utils.py
diff --git a/tensorflow/contrib/keras/python/keras/utils/vis_utils.py b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py
similarity index 95%
rename from tensorflow/contrib/keras/python/keras/utils/vis_utils.py
rename to tensorflow/python/keras/_impl/keras/utils/vis_utils.py
index 949767299b9671daaadd23db3457ccab7d48fbbc..f227f3c3f7ba9bdccb25f67cd603c00a87d866be 100644
--- a/tensorflow/contrib/keras/python/keras/utils/vis_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py
@@ -65,8 +65,8 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'):
   Returns:
       A `pydot.Dot` instance representing the Keras model.
   """
-  from tensorflow.contrib.keras.python.keras.layers.wrappers import Wrapper  # pylint: disable=g-import-not-at-top
-  from tensorflow.contrib.keras.python.keras.models import Sequential  # pylint: disable=g-import-not-at-top
+  from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper  # pylint: disable=g-import-not-at-top
+  from tensorflow.python.keras._impl.keras.models import Sequential  # pylint: disable=g-import-not-at-top
 
   _check_pydot()
   dot = pydot.Dot()
diff --git a/tensorflow/contrib/keras/python/keras/wrappers/__init__.py b/tensorflow/python/keras/_impl/keras/wrappers/__init__.py
similarity index 91%
rename from tensorflow/contrib/keras/python/keras/wrappers/__init__.py
rename to tensorflow/python/keras/_impl/keras/wrappers/__init__.py
index 51244ff681050d79433eb3a20557fc28da74dca0..20c95929e3d2e1f66e66efe43b9685c5d6ed1c10 100644
--- a/tensorflow/contrib/keras/python/keras/wrappers/__init__.py
+++ b/tensorflow/python/keras/_impl/keras/wrappers/__init__.py
@@ -18,5 +18,5 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.keras.python.keras.wrappers import scikit_learn
+from tensorflow.python.keras._impl.keras.wrappers import scikit_learn
 
diff --git a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py b/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py
rename to tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py
index 0d04fc120f112d91cfbdcc59d1555d5fcb57e0ed..ac7bd4940628fa206b08899908c1cdd72a368f07 100644
--- a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py
+++ b/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn.py
@@ -23,8 +23,8 @@ import types
 
 import numpy as np
 
-from tensorflow.contrib.keras.python.keras.models import Sequential
-from tensorflow.contrib.keras.python.keras.utils.np_utils import to_categorical
+from tensorflow.python.keras._impl.keras.models import Sequential
+from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical
 from tensorflow.python.util import tf_inspect
 
 
diff --git a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn_test.py b/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn_test.py
similarity index 98%
rename from tensorflow/contrib/keras/python/keras/wrappers/scikit_learn_test.py
rename to tensorflow/python/keras/_impl/keras/wrappers/scikit_learn_test.py
index 95e0b951ebf82b248dc6be7134d2d5c27b316e5d..b20a84ee88b5b2b70ca2f718fbe86ffd6e949461 100644
--- a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn_test.py
+++ b/tensorflow/python/keras/_impl/keras/wrappers/scikit_learn_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
 
 import numpy as np
 
-from tensorflow.contrib.keras.python import keras
-from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras import testing_utils
 from tensorflow.python.platform import test
 
 INPUT_DIM = 5
diff --git a/tensorflow/python/keras/activations/__init__.py b/tensorflow/python/keras/activations/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d04838c218d6643a703723a1d163c88547c14da7
--- /dev/null
+++ b/tensorflow/python/keras/activations/__init__.py
@@ -0,0 +1,41 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras built-in activation functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Activation functions.
+from tensorflow.python.keras._impl.keras.activations import elu
+from tensorflow.python.keras._impl.keras.activations import hard_sigmoid
+from tensorflow.python.keras._impl.keras.activations import linear
+from tensorflow.python.keras._impl.keras.activations import relu
+from tensorflow.python.keras._impl.keras.activations import selu
+from tensorflow.python.keras._impl.keras.activations import sigmoid
+from tensorflow.python.keras._impl.keras.activations import softmax
+from tensorflow.python.keras._impl.keras.activations import softplus
+from tensorflow.python.keras._impl.keras.activations import softsign
+from tensorflow.python.keras._impl.keras.activations import tanh
+
+# Auxiliary utils.
+# pylint: disable=g-bad-import-order
+from tensorflow.python.keras._impl.keras.activations import deserialize
+from tensorflow.python.keras._impl.keras.activations import serialize
+from tensorflow.python.keras._impl.keras.activations import get
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/applications/__init__.py b/tensorflow/python/keras/applications/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e34d9a8e0b9178a234ab6a6fc1090063363fa9b4
--- /dev/null
+++ b/tensorflow/python/keras/applications/__init__.py
@@ -0,0 +1,36 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras Applications are canned architectures with pre-trained weights."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.applications import inception_v3
+from tensorflow.python.keras.applications import mobilenet
+from tensorflow.python.keras.applications import resnet50
+from tensorflow.python.keras.applications import vgg16
+from tensorflow.python.keras.applications import vgg19
+from tensorflow.python.keras.applications import xception
+from tensorflow.python.keras.applications.inception_v3 import InceptionV3
+from tensorflow.python.keras.applications.mobilenet import MobileNet
+from tensorflow.python.keras.applications.resnet50 import ResNet50
+from tensorflow.python.keras.applications.vgg16 import VGG16
+from tensorflow.python.keras.applications.vgg19 import VGG19
+from tensorflow.python.keras.applications.xception import Xception
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/applications/inception_v3/__init__.py b/tensorflow/python/keras/applications/inception_v3/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..abf8393ae45d71dc0cb746706abb72f77b82d199
--- /dev/null
+++ b/tensorflow/python/keras/applications/inception_v3/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Inception V3 Keras application."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.applications.inception_v3 import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.inception_v3 import InceptionV3
+from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/applications/mobilenet/__init__.py b/tensorflow/python/keras/applications/mobilenet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b809e91193b459a46906443796344c092e1d2a6b
--- /dev/null
+++ b/tensorflow/python/keras/applications/mobilenet/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""MobileNet Keras application."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.applications.mobilenet import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.mobilenet import MobileNet
+from tensorflow.python.keras._impl.keras.applications.mobilenet import preprocess_input
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/applications/resnet50/__init__.py b/tensorflow/python/keras/applications/resnet50/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..530805d150bfe32c5b81d7d7d3f92e203b83b602
--- /dev/null
+++ b/tensorflow/python/keras/applications/resnet50/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ResNet50 Keras application."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.applications.resnet50 import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.resnet50 import preprocess_input
+from tensorflow.python.keras._impl.keras.applications.resnet50 import ResNet50
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/applications/vgg16/__init__.py b/tensorflow/python/keras/applications/vgg16/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..118361604bbc7e0a88ed34243c0d5ea98856a301
--- /dev/null
+++ b/tensorflow/python/keras/applications/vgg16/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""VGG16 Keras application."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.applications.vgg16 import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.vgg16 import preprocess_input
+from tensorflow.python.keras._impl.keras.applications.vgg16 import VGG16
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/applications/vgg19/__init__.py b/tensorflow/python/keras/applications/vgg19/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cda52628f3c10d65fdbe70b2f86cc12c771870a9
--- /dev/null
+++ b/tensorflow/python/keras/applications/vgg19/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""VGG19 Keras application."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.applications.vgg19 import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.vgg19 import preprocess_input
+from tensorflow.python.keras._impl.keras.applications.vgg19 import VGG19
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/applications/xception/__init__.py b/tensorflow/python/keras/applications/xception/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae9cd9cd18c5ccc5ec37c8cd1bf36f8aabd9929c
--- /dev/null
+++ b/tensorflow/python/keras/applications/xception/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Xception Keras application."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.applications.xception import decode_predictions
+from tensorflow.python.keras._impl.keras.applications.xception import preprocess_input
+from tensorflow.python.keras._impl.keras.applications.xception import Xception
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/backend/__init__.py b/tensorflow/python/keras/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..10ef5a75852deb6595bced2703d7c5f29b0efac3
--- /dev/null
+++ b/tensorflow/python/keras/backend/__init__.py
@@ -0,0 +1,163 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras backend API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=redefined-builtin
+from tensorflow.python.keras._impl.keras.backend import abs
+from tensorflow.python.keras._impl.keras.backend import all
+from tensorflow.python.keras._impl.keras.backend import any
+from tensorflow.python.keras._impl.keras.backend import arange
+from tensorflow.python.keras._impl.keras.backend import argmax
+from tensorflow.python.keras._impl.keras.backend import argmin
+from tensorflow.python.keras._impl.keras.backend import backend
+from tensorflow.python.keras._impl.keras.backend import batch_dot
+from tensorflow.python.keras._impl.keras.backend import batch_flatten
+from tensorflow.python.keras._impl.keras.backend import batch_get_value
+from tensorflow.python.keras._impl.keras.backend import batch_normalization
+from tensorflow.python.keras._impl.keras.backend import batch_set_value
+from tensorflow.python.keras._impl.keras.backend import bias_add
+from tensorflow.python.keras._impl.keras.backend import binary_crossentropy
+from tensorflow.python.keras._impl.keras.backend import cast
+from tensorflow.python.keras._impl.keras.backend import cast_to_floatx
+from tensorflow.python.keras._impl.keras.backend import categorical_crossentropy
+from tensorflow.python.keras._impl.keras.backend import clear_session
+from tensorflow.python.keras._impl.keras.backend import clip
+from tensorflow.python.keras._impl.keras.backend import concatenate
+from tensorflow.python.keras._impl.keras.backend import constant
+from tensorflow.python.keras._impl.keras.backend import conv1d
+from tensorflow.python.keras._impl.keras.backend import conv2d
+from tensorflow.python.keras._impl.keras.backend import conv2d_transpose
+from tensorflow.python.keras._impl.keras.backend import conv3d
+from tensorflow.python.keras._impl.keras.backend import cos
+from tensorflow.python.keras._impl.keras.backend import count_params
+from tensorflow.python.keras._impl.keras.backend import ctc_batch_cost
+from tensorflow.python.keras._impl.keras.backend import ctc_decode
+from tensorflow.python.keras._impl.keras.backend import ctc_label_dense_to_sparse
+from tensorflow.python.keras._impl.keras.backend import dot
+from tensorflow.python.keras._impl.keras.backend import dropout
+from tensorflow.python.keras._impl.keras.backend import dtype
+from tensorflow.python.keras._impl.keras.backend import elu
+from tensorflow.python.keras._impl.keras.backend import epsilon
+from tensorflow.python.keras._impl.keras.backend import equal
+from tensorflow.python.keras._impl.keras.backend import eval
+from tensorflow.python.keras._impl.keras.backend import exp
+from tensorflow.python.keras._impl.keras.backend import expand_dims
+from tensorflow.python.keras._impl.keras.backend import eye
+from tensorflow.python.keras._impl.keras.backend import flatten
+from tensorflow.python.keras._impl.keras.backend import floatx
+from tensorflow.python.keras._impl.keras.backend import foldl
+from tensorflow.python.keras._impl.keras.backend import foldr
+from tensorflow.python.keras._impl.keras.backend import function
+from tensorflow.python.keras._impl.keras.backend import gather
+from tensorflow.python.keras._impl.keras.backend import get_session
+from tensorflow.python.keras._impl.keras.backend import get_uid
+from tensorflow.python.keras._impl.keras.backend import get_value
+from tensorflow.python.keras._impl.keras.backend import gradients
+from tensorflow.python.keras._impl.keras.backend import greater
+from tensorflow.python.keras._impl.keras.backend import greater_equal
+from tensorflow.python.keras._impl.keras.backend import hard_sigmoid
+from tensorflow.python.keras._impl.keras.backend import image_data_format
+from tensorflow.python.keras._impl.keras.backend import in_test_phase
+from tensorflow.python.keras._impl.keras.backend import in_top_k
+from tensorflow.python.keras._impl.keras.backend import in_train_phase
+from tensorflow.python.keras._impl.keras.backend import int_shape
+from tensorflow.python.keras._impl.keras.backend import is_sparse
+from tensorflow.python.keras._impl.keras.backend import l2_normalize
+from tensorflow.python.keras._impl.keras.backend import learning_phase
+from tensorflow.python.keras._impl.keras.backend import less
+from tensorflow.python.keras._impl.keras.backend import less_equal
+from tensorflow.python.keras._impl.keras.backend import log
+from tensorflow.python.keras._impl.keras.backend import manual_variable_initialization
+from tensorflow.python.keras._impl.keras.backend import map_fn
+from tensorflow.python.keras._impl.keras.backend import max
+from tensorflow.python.keras._impl.keras.backend import maximum
+from tensorflow.python.keras._impl.keras.backend import mean
+from tensorflow.python.keras._impl.keras.backend import min
+from tensorflow.python.keras._impl.keras.backend import minimum
+from tensorflow.python.keras._impl.keras.backend import moving_average_update
+from tensorflow.python.keras._impl.keras.backend import name_scope
+from tensorflow.python.keras._impl.keras.backend import ndim
+from tensorflow.python.keras._impl.keras.backend import normalize_batch_in_training
+from tensorflow.python.keras._impl.keras.backend import not_equal
+from tensorflow.python.keras._impl.keras.backend import one_hot
+from tensorflow.python.keras._impl.keras.backend import ones
+from tensorflow.python.keras._impl.keras.backend import ones_like
+from tensorflow.python.keras._impl.keras.backend import permute_dimensions
+from tensorflow.python.keras._impl.keras.backend import placeholder
+from tensorflow.python.keras._impl.keras.backend import pool2d
+from tensorflow.python.keras._impl.keras.backend import pool3d
+from tensorflow.python.keras._impl.keras.backend import pow
+from tensorflow.python.keras._impl.keras.backend import print_tensor
+from tensorflow.python.keras._impl.keras.backend import prod
+from tensorflow.python.keras._impl.keras.backend import random_binomial
+from tensorflow.python.keras._impl.keras.backend import random_normal
+from tensorflow.python.keras._impl.keras.backend import random_normal_variable
+from tensorflow.python.keras._impl.keras.backend import random_uniform
+from tensorflow.python.keras._impl.keras.backend import random_uniform_variable
+from tensorflow.python.keras._impl.keras.backend import relu
+from tensorflow.python.keras._impl.keras.backend import repeat
+from tensorflow.python.keras._impl.keras.backend import repeat_elements
+from tensorflow.python.keras._impl.keras.backend import reset_uids
+from tensorflow.python.keras._impl.keras.backend import reshape
+from tensorflow.python.keras._impl.keras.backend import resize_images
+from tensorflow.python.keras._impl.keras.backend import resize_volumes
+from tensorflow.python.keras._impl.keras.backend import reverse
+from tensorflow.python.keras._impl.keras.backend import rnn
+from tensorflow.python.keras._impl.keras.backend import round
+from tensorflow.python.keras._impl.keras.backend import separable_conv2d
+from tensorflow.python.keras._impl.keras.backend import set_epsilon
+from tensorflow.python.keras._impl.keras.backend import set_floatx
+from tensorflow.python.keras._impl.keras.backend import set_image_data_format
+from tensorflow.python.keras._impl.keras.backend import set_learning_phase
+from tensorflow.python.keras._impl.keras.backend import set_session
+from tensorflow.python.keras._impl.keras.backend import set_value
+from tensorflow.python.keras._impl.keras.backend import shape
+from tensorflow.python.keras._impl.keras.backend import sigmoid
+from tensorflow.python.keras._impl.keras.backend import sign
+from tensorflow.python.keras._impl.keras.backend import sin
+from tensorflow.python.keras._impl.keras.backend import softmax
+from tensorflow.python.keras._impl.keras.backend import softplus
+from tensorflow.python.keras._impl.keras.backend import softsign
+from tensorflow.python.keras._impl.keras.backend import sparse_categorical_crossentropy
+from tensorflow.python.keras._impl.keras.backend import spatial_2d_padding
+from tensorflow.python.keras._impl.keras.backend import spatial_3d_padding
+from tensorflow.python.keras._impl.keras.backend import sqrt
+from tensorflow.python.keras._impl.keras.backend import square
+from tensorflow.python.keras._impl.keras.backend import squeeze
+from tensorflow.python.keras._impl.keras.backend import stack
+from tensorflow.python.keras._impl.keras.backend import std
+from tensorflow.python.keras._impl.keras.backend import stop_gradient
+from tensorflow.python.keras._impl.keras.backend import sum
+from tensorflow.python.keras._impl.keras.backend import switch
+from tensorflow.python.keras._impl.keras.backend import tanh
+from tensorflow.python.keras._impl.keras.backend import temporal_padding
+from tensorflow.python.keras._impl.keras.backend import to_dense
+from tensorflow.python.keras._impl.keras.backend import transpose
+from tensorflow.python.keras._impl.keras.backend import truncated_normal
+from tensorflow.python.keras._impl.keras.backend import update
+from tensorflow.python.keras._impl.keras.backend import update_add
+from tensorflow.python.keras._impl.keras.backend import update_sub
+from tensorflow.python.keras._impl.keras.backend import var
+from tensorflow.python.keras._impl.keras.backend import variable
+from tensorflow.python.keras._impl.keras.backend import zeros
+from tensorflow.python.keras._impl.keras.backend import zeros_like
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/callbacks/__init__.py b/tensorflow/python/keras/callbacks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d884790ddb9ccf49649c6af4cfd40cddbc38cb3
--- /dev/null
+++ b/tensorflow/python/keras/callbacks/__init__.py
@@ -0,0 +1,37 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras callback classes."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.callbacks import BaseLogger
+from tensorflow.python.keras._impl.keras.callbacks import Callback
+from tensorflow.python.keras._impl.keras.callbacks import CSVLogger
+from tensorflow.python.keras._impl.keras.callbacks import EarlyStopping
+from tensorflow.python.keras._impl.keras.callbacks import History
+from tensorflow.python.keras._impl.keras.callbacks import LambdaCallback
+from tensorflow.python.keras._impl.keras.callbacks import LearningRateScheduler
+from tensorflow.python.keras._impl.keras.callbacks import ModelCheckpoint
+from tensorflow.python.keras._impl.keras.callbacks import ProgbarLogger
+from tensorflow.python.keras._impl.keras.callbacks import ReduceLROnPlateau
+from tensorflow.python.keras._impl.keras.callbacks import RemoteMonitor
+from tensorflow.python.keras._impl.keras.callbacks import TensorBoard
+from tensorflow.python.keras._impl.keras.callbacks import TerminateOnNaN
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/constraints/__init__.py b/tensorflow/python/keras/constraints/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..152606d8ebbcadf57d971d508e15283da65e4aa3
--- /dev/null
+++ b/tensorflow/python/keras/constraints/__init__.py
@@ -0,0 +1,40 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras built-in constraints functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Constraints functions / callable classes.
+from tensorflow.python.keras._impl.keras.constraints import Constraint
+from tensorflow.python.keras._impl.keras.constraints import max_norm
+from tensorflow.python.keras._impl.keras.constraints import MaxNorm
+from tensorflow.python.keras._impl.keras.constraints import min_max_norm
+from tensorflow.python.keras._impl.keras.constraints import MinMaxNorm
+from tensorflow.python.keras._impl.keras.constraints import non_neg
+from tensorflow.python.keras._impl.keras.constraints import NonNeg
+from tensorflow.python.keras._impl.keras.constraints import unit_norm
+from tensorflow.python.keras._impl.keras.constraints import UnitNorm
+
+# Auxiliary utils.
+# pylint: disable=g-bad-import-order
+from tensorflow.python.keras._impl.keras.constraints import deserialize
+from tensorflow.python.keras._impl.keras.constraints import serialize
+from tensorflow.python.keras._impl.keras.constraints import get
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/datasets/__init__.py b/tensorflow/python/keras/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b76f278964b5f5ac7ea666fc12225f5bbd90ec58
--- /dev/null
+++ b/tensorflow/python/keras/datasets/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras built-in datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.datasets import boston_housing
+from tensorflow.python.keras.datasets import cifar10
+from tensorflow.python.keras.datasets import cifar100
+from tensorflow.python.keras.datasets import imdb
+from tensorflow.python.keras.datasets import mnist
+from tensorflow.python.keras.datasets import reuters
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/datasets/boston_housing/__init__.py b/tensorflow/python/keras/datasets/boston_housing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5371a03fd5f5755ba8844415276113c565f52db
--- /dev/null
+++ b/tensorflow/python/keras/datasets/boston_housing/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Boston housing price regression dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.datasets.boston_housing import load_data
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/datasets/cifar10/__init__.py b/tensorflow/python/keras/datasets/cifar10/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..68d3eb789ea2c410095c0c75e0b79a9b07d209a3
--- /dev/null
+++ b/tensorflow/python/keras/datasets/cifar10/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""CIFAR10 small image classification dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.datasets.cifar10 import load_data
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/datasets/cifar100/__init__.py b/tensorflow/python/keras/datasets/cifar100/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca93742673341660ba69712feb59c5dd32ea3252
--- /dev/null
+++ b/tensorflow/python/keras/datasets/cifar100/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""CIFAR100 small image classification dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.datasets.cifar100 import load_data
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/datasets/imdb/__init__.py b/tensorflow/python/keras/datasets/imdb/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c6396d2d32b88eaa900a5af4e62c7484fceab63
--- /dev/null
+++ b/tensorflow/python/keras/datasets/imdb/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""IMDB movie review sentiment classification dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.datasets.imdb import get_word_index
+from tensorflow.python.keras._impl.keras.datasets.imdb import load_data
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/datasets/mnist/__init__.py b/tensorflow/python/keras/datasets/mnist/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..364255f3387b59a419c010db9b93cdfbcba36186
--- /dev/null
+++ b/tensorflow/python/keras/datasets/mnist/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""MNIST handwritten digits classification dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.datasets.mnist import load_data
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/datasets/reuters/__init__.py b/tensorflow/python/keras/datasets/reuters/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb6791a344ad0c372ac60cd4a332f5632841dd46
--- /dev/null
+++ b/tensorflow/python/keras/datasets/reuters/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Reuters newswire topic classification dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.datasets.reuters import get_word_index
+from tensorflow.python.keras._impl.keras.datasets.reuters import load_data
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/initializers/__init__.py b/tensorflow/python/keras/initializers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b1fcfd2d9585d19ae3fd9705e128b19b1ec40e7
--- /dev/null
+++ b/tensorflow/python/keras/initializers/__init__.py
@@ -0,0 +1,49 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras built-in initializers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Initializer functions / callable classes.
+from tensorflow.python.keras._impl.keras.initializers import Constant
+from tensorflow.python.keras._impl.keras.initializers import Identity
+from tensorflow.python.keras._impl.keras.initializers import Initializer
+from tensorflow.python.keras._impl.keras.initializers import Ones
+from tensorflow.python.keras._impl.keras.initializers import Orthogonal
+from tensorflow.python.keras._impl.keras.initializers import RandomNormal
+from tensorflow.python.keras._impl.keras.initializers import RandomUniform
+from tensorflow.python.keras._impl.keras.initializers import TruncatedNormal
+from tensorflow.python.keras._impl.keras.initializers import VarianceScaling
+from tensorflow.python.keras._impl.keras.initializers import Zeros
+
+# Functional interface.
+# pylint: disable=g-bad-import-order
+from tensorflow.python.keras._impl.keras.initializers import glorot_normal
+from tensorflow.python.keras._impl.keras.initializers import glorot_uniform
+from tensorflow.python.keras._impl.keras.initializers import he_normal
+from tensorflow.python.keras._impl.keras.initializers import he_uniform
+from tensorflow.python.keras._impl.keras.initializers import lecun_normal
+from tensorflow.python.keras._impl.keras.initializers import lecun_uniform
+
+# Auxiliary utils.
+from tensorflow.python.keras._impl.keras.initializers import deserialize
+from tensorflow.python.keras._impl.keras.initializers import serialize
+from tensorflow.python.keras._impl.keras.initializers import get
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..acf0a5e1799b7c57dfd82861c9ccc1f132c34375
--- /dev/null
+++ b/tensorflow/python/keras/layers/__init__.py
@@ -0,0 +1,148 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras layers API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Generic layers.
+# pylint: disable=g-bad-import-order
+from tensorflow.python.keras._impl.keras.engine import Input
+from tensorflow.python.keras._impl.keras.engine import InputLayer
+from tensorflow.python.keras._impl.keras.engine import InputSpec
+from tensorflow.python.keras._impl.keras.engine import Layer
+
+# Advanced activations.
+from tensorflow.python.keras._impl.keras.layers.advanced_activations import LeakyReLU
+from tensorflow.python.keras._impl.keras.layers.advanced_activations import PReLU
+from tensorflow.python.keras._impl.keras.layers.advanced_activations import ELU
+from tensorflow.python.keras._impl.keras.layers.advanced_activations import ThresholdedReLU
+
+# Convolution layers.
+from tensorflow.python.keras._impl.keras.layers.convolutional import Conv1D
+from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2D
+from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3D
+from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2DTranspose
+from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3DTranspose
+from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv2D
+
+# Convolution layer aliases.
+from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution1D
+from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2D
+from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3D
+from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2DTranspose
+from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose
+from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D
+
+# Image processing layers.
+from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling1D
+from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling2D
+from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling3D
+from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding1D
+from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding2D
+from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding3D
+from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping1D
+from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping2D
+from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping3D
+
+# Convolutional-recurrent layers.
+from tensorflow.python.keras._impl.keras.layers.convolutional_recurrent import ConvLSTM2D
+
+# Core layers.
+from tensorflow.python.keras._impl.keras.layers.core import Masking
+from tensorflow.python.keras._impl.keras.layers.core import Dropout
+from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout1D
+from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout2D
+from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout3D
+from tensorflow.python.keras._impl.keras.layers.core import Activation
+from tensorflow.python.keras._impl.keras.layers.core import Reshape
+from tensorflow.python.keras._impl.keras.layers.core import Permute
+from tensorflow.python.keras._impl.keras.layers.core import Flatten
+from tensorflow.python.keras._impl.keras.layers.core import RepeatVector
+from tensorflow.python.keras._impl.keras.layers.core import Lambda
+from tensorflow.python.keras._impl.keras.layers.core import Dense
+from tensorflow.python.keras._impl.keras.layers.core import ActivityRegularization
+
+# Embedding layers.
+from tensorflow.python.keras._impl.keras.layers.embeddings import Embedding
+
+# Locally-connected layers.
+from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected1D
+from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected2D
+
+# Merge layers.
+from tensorflow.python.keras._impl.keras.layers.merge import Add
+from tensorflow.python.keras._impl.keras.layers.merge import Multiply
+from tensorflow.python.keras._impl.keras.layers.merge import Average
+from tensorflow.python.keras._impl.keras.layers.merge import Maximum
+from tensorflow.python.keras._impl.keras.layers.merge import Concatenate
+from tensorflow.python.keras._impl.keras.layers.merge import Dot
+from tensorflow.python.keras._impl.keras.layers.merge import add
+from tensorflow.python.keras._impl.keras.layers.merge import multiply
+from tensorflow.python.keras._impl.keras.layers.merge import average
+from tensorflow.python.keras._impl.keras.layers.merge import maximum
+from tensorflow.python.keras._impl.keras.layers.merge import concatenate
+from tensorflow.python.keras._impl.keras.layers.merge import dot
+
+# Noise layers.
+from tensorflow.python.keras._impl.keras.layers.noise import AlphaDropout
+from tensorflow.python.keras._impl.keras.layers.noise import GaussianNoise
+from tensorflow.python.keras._impl.keras.layers.noise import GaussianDropout
+
+# Normalization layers.
+from tensorflow.python.keras._impl.keras.layers.normalization import BatchNormalization
+
+# Pooling layers.
+from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling1D
+from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling2D
+from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling3D
+from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D
+from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling2D
+from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling3D
+from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling1D
+from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling2D
+from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling3D
+from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling1D
+from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling2D
+from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling3D
+
+# Pooling layer aliases.
+from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool1D
+from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool2D
+from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool3D
+from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool1D
+from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool2D
+from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool3D
+from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool1D
+from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool2D
+from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool3D
+from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool1D
+from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool2D
+from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool3D
+
+# Recurrent layers.
+from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNN
+from tensorflow.python.keras._impl.keras.layers.recurrent import GRU
+from tensorflow.python.keras._impl.keras.layers.recurrent import LSTM
+
+# Wrapper functions
+from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper
+from tensorflow.python.keras._impl.keras.layers.wrappers import Bidirectional
+from tensorflow.python.keras._impl.keras.layers.wrappers import TimeDistributed
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/losses/__init__.py b/tensorflow/python/keras/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..66721b694f5fd5fae7ca521ff56d4c6c6bce79b5
--- /dev/null
+++ b/tensorflow/python/keras/losses/__init__.py
@@ -0,0 +1,45 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras built-in loss functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Loss functions.
+from tensorflow.python.keras._impl.keras.losses import binary_crossentropy
+from tensorflow.python.keras._impl.keras.losses import categorical_crossentropy
+from tensorflow.python.keras._impl.keras.losses import categorical_hinge
+from tensorflow.python.keras._impl.keras.losses import cosine_proximity
+from tensorflow.python.keras._impl.keras.losses import hinge
+from tensorflow.python.keras._impl.keras.losses import kullback_leibler_divergence
+from tensorflow.python.keras._impl.keras.losses import logcosh
+from tensorflow.python.keras._impl.keras.losses import mean_absolute_error
+from tensorflow.python.keras._impl.keras.losses import mean_absolute_percentage_error
+from tensorflow.python.keras._impl.keras.losses import mean_squared_error
+from tensorflow.python.keras._impl.keras.losses import mean_squared_logarithmic_error
+from tensorflow.python.keras._impl.keras.losses import poisson
+from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy
+from tensorflow.python.keras._impl.keras.losses import squared_hinge
+
+# Auxiliary utils.
+# pylint: disable=g-bad-import-order
+from tensorflow.python.keras._impl.keras.losses import deserialize
+from tensorflow.python.keras._impl.keras.losses import serialize
+from tensorflow.python.keras._impl.keras.losses import get
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/metrics/__init__.py b/tensorflow/python/keras/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..59faf037bce0f087d244a2faaeb52713bdc3b772
--- /dev/null
+++ b/tensorflow/python/keras/metrics/__init__.py
@@ -0,0 +1,47 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras built-in metrics functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Metrics functions.
+from tensorflow.python.keras._impl.keras.metrics import binary_accuracy
+from tensorflow.python.keras._impl.keras.metrics import binary_crossentropy
+from tensorflow.python.keras._impl.keras.metrics import categorical_accuracy
+from tensorflow.python.keras._impl.keras.metrics import categorical_crossentropy
+from tensorflow.python.keras._impl.keras.metrics import cosine_proximity
+from tensorflow.python.keras._impl.keras.metrics import hinge
+from tensorflow.python.keras._impl.keras.metrics import kullback_leibler_divergence
+from tensorflow.python.keras._impl.keras.metrics import mean_absolute_error
+from tensorflow.python.keras._impl.keras.metrics import mean_absolute_percentage_error
+from tensorflow.python.keras._impl.keras.metrics import mean_squared_error
+from tensorflow.python.keras._impl.keras.metrics import mean_squared_logarithmic_error
+from tensorflow.python.keras._impl.keras.metrics import poisson
+from tensorflow.python.keras._impl.keras.metrics import sparse_categorical_crossentropy
+from tensorflow.python.keras._impl.keras.metrics import sparse_top_k_categorical_accuracy
+from tensorflow.python.keras._impl.keras.metrics import squared_hinge
+from tensorflow.python.keras._impl.keras.metrics import top_k_categorical_accuracy
+
+# Auxiliary utils.
+# pylint: disable=g-bad-import-order
+from tensorflow.python.keras._impl.keras.metrics import deserialize
+from tensorflow.python.keras._impl.keras.metrics import serialize
+from tensorflow.python.keras._impl.keras.metrics import get
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/models/__init__.py b/tensorflow/python/keras/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fb4ac0960d38f28a1c9c897a0f1aedf57e048ac
--- /dev/null
+++ b/tensorflow/python/keras/models/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras models API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.models import load_model
+from tensorflow.python.keras._impl.keras.models import Model
+from tensorflow.python.keras._impl.keras.models import model_from_config
+from tensorflow.python.keras._impl.keras.models import model_from_json
+from tensorflow.python.keras._impl.keras.models import model_from_yaml
+from tensorflow.python.keras._impl.keras.models import save_model
+from tensorflow.python.keras._impl.keras.models import Sequential
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/optimizers/__init__.py b/tensorflow/python/keras/optimizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..44f47bc47f4a0e31aaf2ac8f67cfdbef410d8c44
--- /dev/null
+++ b/tensorflow/python/keras/optimizers/__init__.py
@@ -0,0 +1,39 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras built-in optimizers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Optimizer classes.
+from tensorflow.python.keras._impl.keras.optimizers import Adadelta
+from tensorflow.python.keras._impl.keras.optimizers import Adagrad
+from tensorflow.python.keras._impl.keras.optimizers import Adam
+from tensorflow.python.keras._impl.keras.optimizers import Adamax
+from tensorflow.python.keras._impl.keras.optimizers import Nadam
+from tensorflow.python.keras._impl.keras.optimizers import Optimizer
+from tensorflow.python.keras._impl.keras.optimizers import RMSprop
+from tensorflow.python.keras._impl.keras.optimizers import SGD
+
+# Auxiliary utils.
+# pylint: disable=g-bad-import-order
+from tensorflow.python.keras._impl.keras.optimizers import deserialize
+from tensorflow.python.keras._impl.keras.optimizers import serialize
+from tensorflow.python.keras._impl.keras.optimizers import get
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/preprocessing/__init__.py b/tensorflow/python/keras/preprocessing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fa3911a7a8833f4b296519c84662cf39ea2dc88
--- /dev/null
+++ b/tensorflow/python/keras/preprocessing/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras data preprocessing utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.preprocessing import image
+from tensorflow.python.keras.preprocessing import sequence
+from tensorflow.python.keras.preprocessing import text
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/preprocessing/image/__init__.py b/tensorflow/python/keras/preprocessing/image/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b96e7675527041d3952b049f5f431d3df36eea4c
--- /dev/null
+++ b/tensorflow/python/keras/preprocessing/image/__init__.py
@@ -0,0 +1,38 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras data preprocessing utils for image data."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.preprocessing.image import apply_transform
+from tensorflow.python.keras._impl.keras.preprocessing.image import array_to_img
+from tensorflow.python.keras._impl.keras.preprocessing.image import DirectoryIterator
+from tensorflow.python.keras._impl.keras.preprocessing.image import flip_axis
+from tensorflow.python.keras._impl.keras.preprocessing.image import ImageDataGenerator
+from tensorflow.python.keras._impl.keras.preprocessing.image import img_to_array
+from tensorflow.python.keras._impl.keras.preprocessing.image import Iterator
+from tensorflow.python.keras._impl.keras.preprocessing.image import load_img
+from tensorflow.python.keras._impl.keras.preprocessing.image import NumpyArrayIterator
+from tensorflow.python.keras._impl.keras.preprocessing.image import random_channel_shift
+from tensorflow.python.keras._impl.keras.preprocessing.image import random_rotation
+from tensorflow.python.keras._impl.keras.preprocessing.image import random_shear
+from tensorflow.python.keras._impl.keras.preprocessing.image import random_shift
+from tensorflow.python.keras._impl.keras.preprocessing.image import random_zoom
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/preprocessing/sequence/__init__.py b/tensorflow/python/keras/preprocessing/sequence/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..112f6af5e588bcb2e85fdbecea86f402742d44e7
--- /dev/null
+++ b/tensorflow/python/keras/preprocessing/sequence/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras data preprocessing utils for sequence data."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.preprocessing.sequence import make_sampling_table
+from tensorflow.python.keras._impl.keras.preprocessing.sequence import pad_sequences
+from tensorflow.python.keras._impl.keras.preprocessing.sequence import skipgrams
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/preprocessing/text/__init__.py b/tensorflow/python/keras/preprocessing/text/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bf1a2fb21dc27f7aa10cd08b1496e3991c61d2f
--- /dev/null
+++ b/tensorflow/python/keras/preprocessing/text/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras data preprocessing utils for text data."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.preprocessing.text import one_hot
+from tensorflow.python.keras._impl.keras.preprocessing.text import text_to_word_sequence
+from tensorflow.python.keras._impl.keras.preprocessing.text import Tokenizer
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/regularizers/__init__.py b/tensorflow/python/keras/regularizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e707ccab577b5e28febd83d91f84d7b1c0d5d82
--- /dev/null
+++ b/tensorflow/python/keras/regularizers/__init__.py
@@ -0,0 +1,38 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras built-in regularizers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Regularizer functions / callable classes.
+from tensorflow.python.keras._impl.keras.regularizers import L1L2
+from tensorflow.python.keras._impl.keras.regularizers import Regularizer
+
+# Functional interface.
+# pylint: disable=g-bad-import-order
+from tensorflow.python.keras._impl.keras.regularizers import l1
+from tensorflow.python.keras._impl.keras.regularizers import l2
+from tensorflow.python.keras._impl.keras.regularizers import l1_l2
+
+# Auxiliary utils.
+from tensorflow.python.keras._impl.keras.regularizers import deserialize
+from tensorflow.python.keras._impl.keras.regularizers import serialize
+from tensorflow.python.keras._impl.keras.regularizers import get
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/utils/__init__.py b/tensorflow/python/keras/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7c2179fe7ad434356921a5fb8709aa5b1f33498
--- /dev/null
+++ b/tensorflow/python/keras/utils/__init__.py
@@ -0,0 +1,39 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer
+from tensorflow.python.keras._impl.keras.utils.data_utils import get_file
+from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence
+from tensorflow.python.keras._impl.keras.utils.data_utils import SequenceEnqueuer
+from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope
+from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope
+from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object
+from tensorflow.python.keras._impl.keras.utils.generic_utils import get_custom_objects
+from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar
+from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix
+from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model
+from tensorflow.python.keras._impl.keras.utils.np_utils import normalize
+from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical
+from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/wrappers/__init__.py b/tensorflow/python/keras/wrappers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..da579a7ab58aecd551e25a2730d611e4895c0e42
--- /dev/null
+++ b/tensorflow/python/keras/wrappers/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Wrappers for Keras models, providing compatibility with other frameworks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.wrappers import scikit_learn
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/keras/wrappers/scikit_learn/__init__.py b/tensorflow/python/keras/wrappers/scikit_learn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a46f859273ea0117e29a403057f9f81bc758dd52
--- /dev/null
+++ b/tensorflow/python/keras/wrappers/scikit_learn/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras scikit-learn API wrapper."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasClassifier
+from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasRegressor
+
+del absolute_import
+del division
+del print_function
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 3e2cb82d023a0587f2d6313fb3c6952648d33569..d9c5f3bce9972b8bc7ce8bcd118ab4e0749f341d 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -518,14 +518,17 @@ tf_py_test(
 
 tf_py_test(
     name = "matrix_solve_ls_op_test",
-    size = "small",
+    size = "medium",
     srcs = ["matrix_solve_ls_op_test.py"],
     additional_deps = [
         "//third_party/py/numpy",
+        "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:linalg_ops",
+        "//tensorflow/python:math_ops",
     ],
+    tags = ["nomsan"],  # fails in msan from numpy calls
 )
 
 tf_py_test(
@@ -534,6 +537,7 @@ tf_py_test(
     srcs = ["matrix_solve_op_test.py"],
     additional_deps = [
         "//third_party/py/numpy",
+        "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:linalg_ops",
@@ -680,13 +684,15 @@ cuda_py_test(
 
 tf_py_test(
     name = "segment_reduction_ops_test",
-    size = "small",
+    size = "medium",
     srcs = ["segment_reduction_ops_test.py"],
     additional_deps = [
         "//third_party/py/numpy",
+        "//tensorflow/python:client",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
+        "//tensorflow/python:variables",
         "//tensorflow/python:nn_grad",
     ],
 )
@@ -932,12 +938,14 @@ tf_py_test(
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:errors",
         "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:init_ops",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:state_ops",
         "//tensorflow/python:variables",
+        "//tensorflow/python/eager:context",
     ],
     tags = ["no_windows"],
 )
@@ -1479,6 +1487,7 @@ cuda_py_test(
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:linalg_ops",
+        "//tensorflow/python:linalg_ns",
         "//tensorflow/python:math_ops",
     ],
     tags = ["no_windows_gpu"],
@@ -1702,6 +1711,26 @@ cuda_py_test(
     tags = ["no_windows_gpu"],
 )
 
+cuda_py_test(
+    name = "reduction_ops_test_big",
+    size = "medium",
+    srcs = ["reduction_ops_test_big.py"],
+    additional_deps = [
+        "//third_party/py/numpy",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:math_ops",
+    ],
+    tags = [
+        "manual",
+        "no_gpu",
+        "nogpu",
+        "noguitar",
+        "notap",
+    ],
+)
+
 cuda_py_test(
     name = "relu_op_test",
     size = "small",
@@ -2148,7 +2177,7 @@ cuda_py_test(
         "//tensorflow/python:nn_grad",
         "//tensorflow/python:nn_ops",
     ],
-    tags = ["noasan"],  # times out b/63680444
+    shard_count = 2,
 )
 
 cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 42e63d8b81445ecc667ee45c79be1f6ca752ac49..77c5bb6d400011e38208bf50b9dd321b6c1d71c0 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -981,15 +981,15 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
 
 class ConcatSliceResourceTest(test_util.TensorFlowTestCase):
 
+  @test_util.run_in_graph_and_eager_modes()
   def testConcatSlice(self):
-    with self.test_session():
-      r1 = test_ops.stub_resource_handle_op(container="a", shared_name="b")
-      r2 = test_ops.stub_resource_handle_op(container="a", shared_name="c")
-      c = array_ops.stack([r1, r2])
-      s = array_ops.strided_slice(c, [1], [2])
-      test_ops.resource_create_op(s).run()
-      with self.assertRaises(errors.AlreadyExistsError):
-        test_ops.resource_create_op(r2).run()
+    r1 = test_ops.stub_resource_handle_op(container="a", shared_name="b")
+    r2 = test_ops.stub_resource_handle_op(container="a", shared_name="c")
+    c = array_ops.stack([r1, r2])
+    s = array_ops.strided_slice(c, [1], [2])
+    self.evaluate(test_ops.resource_create_op(s))
+    with self.assertRaises(errors.AlreadyExistsError):
+      self.evaluate(test_ops.resource_create_op(r2))
 
 
 class IdentityTest(test_util.TensorFlowTestCase):
@@ -1020,5 +1020,19 @@ class IdentityTest(test_util.TensorFlowTestCase):
         _test(d, e, "gpu")
 
 
+class PadTest(test_util.TensorFlowTestCase):
+
+  def testEager(self):
+    with context.eager_mode():
+      t = constant_op.constant([[1, 2, 3], [4, 5, 6]])
+      paddings = constant_op.constant([[1, 1,], [2, 2]])
+      padded = array_ops.pad(t, paddings, "CONSTANT")
+      self.assertAllEqual(padded.numpy(),
+                          [[0, 0, 0, 0, 0, 0, 0],
+                           [0, 0, 1, 2, 3, 0, 0],
+                           [0, 0, 4, 5, 6, 0, 0],
+                           [0, 0, 0, 0, 0, 0, 0]])
+
+
 if __name__ == "__main__":
   test_lib.main()
diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py
index eb06e067a7f7a539efe240f2420563318238e76c..de80fb30554a8aec4da0d2a162e429ca914cda83 100644
--- a/tensorflow/python/kernel_tests/cholesky_op_test.py
+++ b/tensorflow/python/kernel_tests/cholesky_op_test.py
@@ -183,14 +183,11 @@ class CholeskyGradTest(test.TestCase):
     self.runFiniteDifferences(
         shapes, dtypes=(dtypes_lib.float32, dtypes_lib.float64))
 
-  # TODO(eriche): investigate why this test fails only in opensource
-  # ubuntu gpu python3
-
-  # def testSmallMatricesComplex(self):
-    # np.random.seed(0)
-    # shapes = self.getShapes([1, 2, 10])
-    # self.runFiniteDifferences(
-        # shapes, dtypes=(dtypes_lib.complex64, dtypes_lib.complex128))
+  def testSmallMatricesComplex(self):
+    np.random.seed(0)
+    shapes = self.getShapes([1, 2, 10])
+    self.runFiniteDifferences(
+        shapes, dtypes=(dtypes_lib.complex64, dtypes_lib.complex128))
 
   def testOneBlockMatrices(self):
     np.random.seed(0)
diff --git a/tensorflow/python/kernel_tests/constant_op_eager_test.py b/tensorflow/python/kernel_tests/constant_op_eager_test.py
index 0e98afbe6e415b305b579d848de4c1fbfc9f02fd..0b4fa60d81b10497e1b609ff81381d09bea3090e 100644
--- a/tensorflow/python/kernel_tests/constant_op_eager_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_eager_test.py
@@ -26,27 +26,33 @@ from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes as dtypes_lib
 from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.util import compat
 
 
-# TODO(josh11b): add tests with string types, lists/tuples, Shape.
+# TODO(josh11b): add tests with lists/tuples, Shape.
 class ConstantTest(test.TestCase):
 
   def _testCpu(self, x):
     np_ans = np.array(x)
-    tf_ans = ops.convert_to_tensor(x).numpy()
+    with context.device("/device:CPU:0"):
+      tf_ans = ops.convert_to_tensor(x).numpy()
     if np_ans.dtype in [np.float32, np.float64, np.complex64, np.complex128]:
       self.assertAllClose(np_ans, tf_ans)
     else:
       self.assertAllEqual(np_ans, tf_ans)
 
   def _testGpu(self, x):
-    np_ans = np.array(x)
-    tf_ans = ops.convert_to_tensor(x).numpy()
-    if np_ans.dtype in [np.float32, np.float64, np.complex64, np.complex128]:
-      self.assertAllClose(np_ans, tf_ans)
-    else:
-      self.assertAllEqual(np_ans, tf_ans)
+    device = test_util.gpu_device_name()
+    if device:
+      np_ans = np.array(x)
+      with context.device(device):
+        tf_ans = ops.convert_to_tensor(x).numpy()
+      if np_ans.dtype in [np.float32, np.float64, np.complex64, np.complex128]:
+        self.assertAllClose(np_ans, tf_ans)
+      else:
+        self.assertAllEqual(np_ans, tf_ans)
 
   def _testAll(self, x):
     self._testCpu(x)
@@ -78,11 +84,11 @@ class ConstantTest(test.TestCase):
 
   def testComplex64(self):
     self._testAll(
-        np.complex(1, 2) * np.arange(-15, 15).reshape([2, 3, 5
-                                                      ]).astype(np.complex64))
+        np.complex(1, 2) *
+        np.arange(-15, 15).reshape([2, 3, 5]).astype(np.complex64))
     self._testAll(
-        np.complex(1, 2) * np.random.normal(size=30).reshape(
-            [2, 3, 5]).astype(np.complex64))
+        np.complex(1, 2) *
+        np.random.normal(size=30).reshape([2, 3, 5]).astype(np.complex64))
     self._testAll(np.empty((2, 0, 5)).astype(np.complex64))
 
   def testComplex128(self):
@@ -94,6 +100,26 @@ class ConstantTest(test.TestCase):
             [2, 3, 5]).astype(np.complex128))
     self._testAll(np.empty((2, 0, 5)).astype(np.complex128))
 
+  def testString(self):
+    val = [compat.as_bytes(str(x)) for x in np.arange(-15, 15)]
+    self._testCpu(np.array(val).reshape([2, 3, 5]))
+    self._testCpu(np.empty((2, 0, 5)).astype(np.str_))
+
+  def testStringWithNulls(self):
+    val = ops.convert_to_tensor(b"\0\0\0\0").numpy()
+    self.assertEqual(len(val), 4)
+    self.assertEqual(val, b"\0\0\0\0")
+
+    val = ops.convert_to_tensor(b"xx\0xx").numpy()
+    self.assertEqual(len(val), 5)
+    self.assertAllEqual(val, b"xx\0xx")
+
+    nested = [[b"\0\0\0\0", b"xx\0xx"], [b"\0_\0_\0_\0", b"\0"]]
+    val = ops.convert_to_tensor(nested).numpy()
+    # NOTE(mrry): Do not use assertAllEqual, because it converts nested to a
+    #   numpy array, which loses the null terminators.
+    self.assertEqual(val.tolist(), nested)
+
   def testExplicitShapeNumPy(self):
     c = constant_op.constant(
         np.arange(-15, 15).reshape([2, 3, 5]).astype(np.float32),
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index df413939c766ac46fec6370a2ba3dfa270d336c2..6167cb9999b1be2b1e8b530ebacfe9c4a5a2d8d1 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import logging_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 from tensorflow.python.util import compat
@@ -119,11 +120,11 @@ class ConstantTest(test.TestCase):
           variant_val=[
               tensor_pb2.VariantTensorDataProto(
                   # Match registration in variant_op_registry.cc
-                  type_name=b"int32",
+                  type_name=b"int",
                   metadata=np.array(1, dtype=np.int32).tobytes())
           ])
-      const_op = constant_op.constant(variant_tensor).op
-      const_value = const_op.get_attr("value")
+      const = constant_op.constant(variant_tensor)
+      const_value = const.op.get_attr("value")
 
       # Ensure we stored the tensor proto properly.
       self.assertProtoEquals(variant_tensor, const_value)
@@ -134,7 +135,10 @@ class ConstantTest(test.TestCase):
       # native numpy types cannot be passed to ops.convert_to_tensor.
       # TODO(ebrevdo): Add registration mechanism for
       # ops.convert_to_tensor and for session.run output.
-      const_op.run()
+      logging_const_op = logging_ops.Print(
+          const, [const],
+          message="Variant storing an int, decoded const value:").op
+      logging_const_op.run()
 
   def testStringWithNulls(self):
     with self.test_session():
@@ -469,6 +473,35 @@ class ZerosLikeTest(test.TestCase):
           self.assertEqual(y.shape, shape)
           self.assertAllEqual(y, np.zeros(shape, dtype=out_type))
 
+  def testZerosLikeVariant(self):
+    # TODO(ebrevdo): Re-enable use_gpu=True once non-DMA Variant
+    # copying between CPU and GPU is supported AND we register a
+    # ZerosLike callback for GPU for Variant storing primitive types
+    # in variant_op_registry.cc.
+    with self.test_session(use_gpu=False):
+      variant_tensor = tensor_pb2.TensorProto(
+          dtype=dtypes_lib.variant.as_datatype_enum,
+          tensor_shape=tensor_shape.TensorShape([]).as_proto(),
+          variant_val=[
+              tensor_pb2.VariantTensorDataProto(
+                  # Match registration in variant_op_registry.cc
+                  type_name=b"int",
+                  metadata=np.array(1, dtype=np.int32).tobytes())
+          ])
+      const_variant = constant_op.constant(variant_tensor)
+      zeros_like = array_ops.zeros_like(const_variant)
+      zeros_like_op = logging_ops.Print(
+          zeros_like, [const_variant, zeros_like],
+          message="Variant storing an int, input and output of zeros_like:").op
+
+      # Smoke test -- ensure this executes without trouble.
+      # Right now, non-numpy-compatible objects cannot be returned from a
+      # session.run call; similarly, objects that can't be converted to
+      # native numpy types cannot be passed to ops.convert_to_tensor.
+      # TODO(ebrevdo): Add registration mechanism for
+      # ops.convert_to_tensor and for session.run output.
+      zeros_like_op.run()
+
 
 class OnesTest(test.TestCase):
 
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index a43fe71b9f330485433ab07d1fc1baf9f23afcaa..6e81e1fdbd82e3fc678a187b8bbd810d4e2cf042 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -30,6 +30,7 @@ from six.moves import xrange  # pylint: disable=redefined-builtin
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.client import device_lib
 from tensorflow.python.client import session
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors_impl
@@ -2849,5 +2850,50 @@ class WhileOpBenchmark(test.Benchmark):
         name="unroll_same_device", iters=iters, wall_time=duration)
 
 
+class EagerTest(test.TestCase):
+
+  def testCond(self):
+    with context.eager_mode():
+      pred = math_ops.less(1, 2)
+      fn1 = lambda: constant_op.constant(10)
+      fn2 = lambda: constant_op.constant(20)
+      r = control_flow_ops.cond(pred, fn1, fn2)
+
+      self.assertAllEqual(r.numpy(), 10)
+
+  def testWhileLoop(self):
+    with context.eager_mode():
+      tensor = constant_op.constant([1, 2, 3, 4, 5])
+      self.assertAllEqual(isum(tensor).numpy(),
+                          [46, 47, 48, 49, 50])
+
+  def testWithDependencies(self):
+    with context.eager_mode():
+      t1 = constant_op.constant(1)
+      t2 = constant_op.constant(2)
+      t3 = control_flow_ops.with_dependencies(t1, t2)
+      self.assertAllEqual(t2.numpy(), t3.numpy())
+
+  def testTuple(self):
+    with context.eager_mode():
+      t1 = constant_op.constant(1)
+      t2 = constant_op.constant(2)
+      tup1, tup2 = control_flow_ops.tuple([t1, t2])
+      self.assertAllEqual(t1.numpy(), tup1.numpy())
+      self.assertAllEqual(t2.numpy(), tup2.numpy())
+
+  def testCase(self):
+    with context.eager_mode():
+      x = constant_op.constant(1)
+      y = constant_op.constant(2)
+      z = constant_op.constant(3)
+      f1 = lambda: constant_op.constant(17)
+      f2 = lambda: constant_op.constant(23)
+      f3 = lambda: constant_op.constant(-1)
+
+      r1 = control_flow_ops.case([(x < y, f1), (x > z, f2)],
+                                 default=f3, exclusive=True)
+      self.assertAllEqual(r1.numpy(), 17)
+
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index f95d9833d844585cc487d380df33177e1c0fe31e..18801f6158b9bdbde39a6694989129f2ffbb250a 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -1956,7 +1956,7 @@ class ComplexMakeRealImagTest(test.TestCase):
     with self.test_session(use_gpu=use_gpu) as sess:
       inx = ops.convert_to_tensor(cplx)
       tf_angle = math_ops.angle(inx)
-      tf_angle_val = sess.run([tf_angle])
+      tf_angle_val = sess.run(tf_angle)
     self.assertAllEqual(np_angle, tf_angle_val)
     self.assertShapeEqual(np_angle, tf_angle)
 
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index 85e7b635d800b1aec1d61e27129cf5a5d14f25a3..748135440ec5e8ad387f910e1433f638abf2260a 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -1078,6 +1078,9 @@ class FIFOQueueTest(test.TestCase):
       self.assertEqual([50.0], dequeued_t.eval())
       self.assertEqual([60.0], dequeued_t.eval())
 
+      # Make sure the thread finishes before exiting.
+      thread.join()
+
   def testBlockingEnqueueBeforeClose(self):
     with self.test_session() as sess:
       q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32)
diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py
index b40bcd7e131c2192d230388c87d2555c5ee5fbe0..7d367a92750ae3562c93d2381eb895c94a866eaa 100644
--- a/tensorflow/python/kernel_tests/linalg_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg_grad_test.py
@@ -59,7 +59,7 @@ class MatrixUnaryFunctorGradientTest(test_lib.TestCase):
 def _GetMatrixUnaryFunctorGradientTest(functor_, dtype_, shape_, **kwargs_):
 
   def Test(self):
-    with self.test_session():
+    with self.test_session(use_gpu=True):
       np.random.seed(1)
       a_np = np.random.uniform(
           low=-1.0, high=1.0,
@@ -97,7 +97,11 @@ def _GetMatrixBinaryFunctorGradientTest(functor_,
                                         **kwargs_):
 
   def Test(self):
-    with self.test_session():
+    # TODO(rmlarsen): Debug illegal address bug on CUDA and re-enable
+    # GPU test for matrix_solve.
+    use_gpu = False if functor_ == linalg_ops.matrix_solve else True
+
+    with self.test_session(use_gpu=use_gpu):
       np.random.seed(1)
       a_np = np.random.uniform(
           low=-1.0, high=1.0,
@@ -142,26 +146,21 @@ if __name__ == '__main__':
           shape = extra + (size, size)
           name = '%s_%s_adj_%s' % (dtype.__name__, '_'.join(map(str, shape)),
                                    str(adjoint))
-          _AddTest(
-              MatrixBinaryFunctorGradientTest,
-              'MatrixSolveGradient',
-              name,
-              _GetMatrixBinaryFunctorGradientTest(
-                  linalg_ops.matrix_solve, dtype, shape, adjoint=adjoint))
+          _AddTest(MatrixBinaryFunctorGradientTest, 'MatrixSolveGradient', name,
+                   _GetMatrixBinaryFunctorGradientTest(
+                       linalg_ops.matrix_solve, dtype, shape, adjoint=adjoint))
 
           for lower in True, False:
             name = '%s_low_%s' % (name, lower)
-            _AddTest(
-                MatrixBinaryFunctorGradientTest,
-                'MatrixTriangularSolveGradient',
-                name,
-                _GetMatrixBinaryFunctorGradientTest(
-                    linalg_ops.matrix_triangular_solve,
-                    dtype,
-                    shape,
-                    float32_tol_fudge=4.0,
-                    adjoint=adjoint,
-                    lower=lower))
+            _AddTest(MatrixBinaryFunctorGradientTest,
+                     'MatrixTriangularSolveGradient', name,
+                     _GetMatrixBinaryFunctorGradientTest(
+                         linalg_ops.matrix_triangular_solve,
+                         dtype,
+                         shape,
+                         float32_tol_fudge=4.0,
+                         adjoint=adjoint,
+                         lower=lower))
 
   # Tests for gradients of unary matrix operations.
   for dtype in np.float32, np.float64:
@@ -191,8 +190,10 @@ if __name__ == '__main__':
               MatrixBinaryFunctorGradientTest,
               'MatrixSolveLsGradient',
               name,
+              # pylint: disable=long-lambda,g-long-lambda
               _GetMatrixBinaryFunctorGradientTest(
-                  lambda a, b, l=l2_regularization: linalg_ops.matrix_solve_ls(a, b, l),
+                  (lambda a, b, l=l2_regularization:
+                   linalg_ops.matrix_solve_ls(a, b, l)),
                   dtype,
                   shape,
                   float32_tol_fudge=4.0))
diff --git a/tensorflow/python/kernel_tests/linalg_ops_test.py b/tensorflow/python/kernel_tests/linalg_ops_test.py
index 0e096bbc09cf413a94c8a0c32ec24ab09676206a..c198e13f848e89d8d7d79001fefe3875f509bb0c 100644
--- a/tensorflow/python/kernel_tests/linalg_ops_test.py
+++ b/tensorflow/python/kernel_tests/linalg_ops_test.py
@@ -22,6 +22,7 @@ import numpy as np
 
 from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ns as linalg
 from tensorflow.python.ops import linalg_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
@@ -34,10 +35,12 @@ def _AddTest(test_class, op_name, testcase_name, fn):
   setattr(test_class, test_name, fn)
 
 
-def _RandomPDMatrix(n, rng):
+def _RandomPDMatrix(n, rng, dtype=np.float64):
   """Random positive definite matrix."""
-  temp = rng.randn(n, n)
-  return temp.dot(temp.T)
+  temp = rng.randn(n, n).astype(dtype)
+  if dtype in [np.complex64, np.complex128]:
+    temp.imag = rng.randn(n, n)
+  return np.conj(temp).dot(temp.T)
 
 
 class CholeskySolveTest(test.TestCase):
@@ -46,9 +49,9 @@ class CholeskySolveTest(test.TestCase):
     self.rng = np.random.RandomState(0)
 
   def test_works_with_five_different_random_pos_def_matrices(self):
-    with self.test_session(use_gpu=True):
-      for n in range(1, 6):
-        for np_type, atol in [(np.float32, 0.05), (np.float64, 1e-5)]:
+    for n in range(1, 6):
+      for np_type, atol in [(np.float32, 0.05), (np.float64, 1e-5)]:
+        with self.test_session(use_gpu=True):
           # Create 2 x n x n matrix
           array = np.array(
               [_RandomPDMatrix(n, self.rng),
@@ -61,6 +64,35 @@ class CholeskySolveTest(test.TestCase):
                 rhs, math_ops.matmul(array, x).eval(), atol=atol)
 
 
+class LogdetTest(test.TestCase):
+
+  def setUp(self):
+    self.rng = np.random.RandomState(42)
+
+  def test_works_with_five_different_random_pos_def_matrices(self):
+    for n in range(1, 6):
+      for np_dtype, atol in [(np.float32, 0.05), (np.float64, 1e-5),
+                             (np.complex64, 0.05), (np.complex128, 1e-5)]:
+        matrix = _RandomPDMatrix(n, self.rng, np_dtype)
+        _, logdet_np = np.linalg.slogdet(matrix)
+        with self.test_session(use_gpu=True):
+          # Create 2 x n x n matrix
+          # matrix = np.array(
+          #     [_RandomPDMatrix(n, self.rng, np_dtype),
+          #      _RandomPDMatrix(n, self.rng, np_dtype)]).astype(np_dtype)
+          logdet_tf = linalg.logdet(matrix)
+          self.assertAllClose(logdet_np, logdet_tf.eval(), atol=atol)
+
+  def test_works_with_underflow_case(self):
+    for np_dtype, atol in [(np.float32, 0.05), (np.float64, 1e-5),
+                           (np.complex64, 0.05), (np.complex128, 1e-5)]:
+      matrix = (np.eye(20) * 1e-6).astype(np_dtype)
+      _, logdet_np = np.linalg.slogdet(matrix)
+      with self.test_session(use_gpu=True):
+        logdet_tf = linalg.logdet(matrix)
+        self.assertAllClose(logdet_np, logdet_tf.eval(), atol=atol)
+
+
 class EyeTest(test.TestCase):
   pass  # Will be filled in below
 
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index d081d266793c3016611390055ef00d82985f8b1f..da57f918ac286bb59e0525a02482b672dc40dc89 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -40,6 +40,7 @@ from tensorflow.python.training import momentum as momentum_lib
 class AbsoluteDifferenceLossTest(test.TestCase):
 
   def setUp(self):
+    super(AbsoluteDifferenceLossTest, self).setUp()
     self._predictions = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
     self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
 
@@ -608,6 +609,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
 class LogLossTest(test.TestCase):
 
   def setUp(self):
+    super(LogLossTest, self).setUp()
     predictions = np.asarray([.9, .2, .2, .8, .4, .6]).reshape((2, 3))
     labels = np.asarray([1.0, 0.0, 1.0, 1.0, 0.0, 0.0]).reshape((2, 3))
 
@@ -868,6 +870,7 @@ class HuberLossTest(test.TestCase):
 class MeanSquaredErrorTest(test.TestCase):
 
   def setUp(self):
+    super(MeanSquaredErrorTest, self).setUp()
     self._predictions = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
     self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
 
@@ -941,6 +944,7 @@ class MeanSquaredErrorTest(test.TestCase):
 class MeanPairwiseSquaredErrorTest(test.TestCase):
 
   def setUp(self):
+    super(MeanPairwiseSquaredErrorTest, self).setUp()
     self._predictions = np.array([[4, 8, 12], [8, 1, 3]])
     self._labels = np.array([[1, 9, 2], [-5, -5, 7]])
 
@@ -1167,6 +1171,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
 class CosineDistanceLossTest(test.TestCase):
 
   def setUp(self):
+    super(CosineDistanceLossTest, self).setUp()
     self._predictions = np.asarray([
         [1, 0, 0],  # Batch 1
         [0, 0, -1],
@@ -1290,6 +1295,7 @@ class AddLossTest(test.TestCase):
 class ComputeWeightedLossTest(test.TestCase):
 
   def setUp(self):
+    super(ComputeWeightedLossTest, self).setUp()
     self._shape = (3, 2, 4)
     raw_losses = np.zeros(self._shape)
     next_loss = 0.0
diff --git a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
index c2530b3597cfef0d81fdd71cf0e7a7860f4263fb..e641d5511f5d256efd7a729f71a4ee0885c99831 100644
--- a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
@@ -33,7 +33,7 @@ def _GetMatrixBandPartTest(dtype_, batch_shape_, shape_):
 
   def Test(self):
     mat = np.ones(shape_).astype(dtype_)
-    batch_mat = np.tile(mat, batch_shape + (1, 1))
+    batch_mat = np.tile(mat, batch_shape_ + (1, 1))
     with self.test_session(use_gpu=True):
       for lower in -1, 0, 1, shape_[-2] - 1:
         for upper in -1, 0, 1, shape_[-1] - 1:
@@ -42,7 +42,7 @@ def _GetMatrixBandPartTest(dtype_, batch_shape_, shape_):
             band_np = np.triu(band_np, -lower)
           if upper >= 0:
             band_np = np.tril(band_np, upper)
-          if batch_shape is not ():
+          if batch_shape_ is not ():
             band_np = np.tile(band_np, batch_shape + (1, 1))
           band = array_ops.matrix_band_part(batch_mat, lower, upper)
           self.assertAllEqual(band_np, band.eval())
diff --git a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
index ec4effec1b77f386571f4a59d7dd679f4956f941..de495968a710276caef5214eb12fa965edbfd64c 100644
--- a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
@@ -20,158 +20,121 @@ from __future__ import print_function
 
 import numpy as np
 
+from tensorflow.python.client import session
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import linalg_ops
-from tensorflow.python.platform import test
-
-
-def BatchMatMul(a, b):
-  # A numpy implementation of tf.matmul().
-  if a.ndim < 3:
-    return np.dot(a, b)
-  # Get the number of matrices.
-  n = np.prod(a.shape[:-2])
-  assert n == np.prod(b.shape[:-2])
-  a_flat = np.reshape(a, tuple([n]) + a.shape[-2:])
-  b_flat = np.reshape(b, tuple([n]) + b.shape[-2:])
-  c_flat_shape = [n, a.shape[-2], b.shape[-1]]
-  c_flat = np.empty(c_flat_shape)
-  for i in range(n):
-    c_flat[i, :, :] = np.dot(a_flat[i, :, :], b_flat[i, :, :])
-  return np.reshape(c_flat, a.shape[:-1] + b_flat.shape[-1:])
-
-
-def BatchRegularizedLeastSquares(matrices, rhss, l2_regularization=0.0):
-  # A numpy implementation of regularized least squares solver using
-  # the normal equations.
-  matrix_dims = matrices.shape
-  matrices_transposed = np.swapaxes(matrices, -2, -1)
-  rows = matrix_dims[-2]
-  cols = matrix_dims[-1]
-  if rows >= cols:
-    preconditioner = l2_regularization * np.identity(cols)
-    gramian = BatchMatMul(matrices_transposed, matrices) + preconditioner
-    inverse = np.linalg.inv(gramian)
-    left_pseudo_inverse = BatchMatMul(inverse, matrices_transposed)
-    return BatchMatMul(left_pseudo_inverse, rhss)
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test as test_lib
+
+
+def _AddTest(test, op_name, testcase_name, fn):
+  test_name = "_".join(["test", op_name, testcase_name])
+  if hasattr(test, test_name):
+    raise RuntimeError("Test %s defined more than once" % test_name)
+  setattr(test, test_name, fn)
+
+
+def _GenerateTestData(matrix_shape, num_rhs):
+  batch_shape = matrix_shape[:-2]
+  matrix_shape = matrix_shape[-2:]
+  m = matrix_shape[-2]
+  np.random.seed(1)
+  matrix = np.random.uniform(
+      low=-1.0, high=1.0,
+      size=np.prod(matrix_shape)).reshape(matrix_shape).astype(np.float32)
+  rhs = np.ones([m, num_rhs]).astype(np.float32)
+  matrix = variables.Variable(
+      np.tile(matrix, batch_shape + (1, 1)), trainable=False)
+  rhs = variables.Variable(np.tile(rhs, batch_shape + (1, 1)), trainable=False)
+  return matrix, rhs
+
+
+def _SolveWithNumpy(matrix, rhs, l2_regularizer=0):
+  if l2_regularizer == 0:
+    np_ans, _, _, _ = np.linalg.lstsq(matrix, rhs)
+    return np_ans
   else:
-    preconditioner = l2_regularization * np.identity(rows)
-    gramian = BatchMatMul(matrices, matrices_transposed) + preconditioner
-    inverse = np.linalg.inv(gramian)
-    right_pseudo_inverse = BatchMatMul(matrices_transposed, inverse)
-    return BatchMatMul(right_pseudo_inverse, rhss)
+    rows = matrix.shape[-2]
+    cols = matrix.shape[-1]
+    if rows >= cols:
+      preconditioner = l2_regularizer * np.identity(cols)
+      gramian = np.dot(np.conj(matrix.T), matrix) + preconditioner
+      rhs = np.dot(np.conj(matrix.T), rhs)
+      return np.linalg.solve(gramian, rhs)
+    else:
+      preconditioner = l2_regularizer * np.identity(rows)
+      gramian = np.dot(matrix, np.conj(matrix.T)) + preconditioner
+      z = np.linalg.solve(gramian, rhs)
+      return np.dot(np.conj(matrix.T), z)
 
 
-class MatrixSolveLsOpTest(test.TestCase):
+class MatrixSolveLsOpTest(test_lib.TestCase):
 
-  def _verifySolve(self, x, y):
-    for np_type in [np.float32, np.float64, np.complex64, np.complex128]:
-      a = x.astype(np_type)
-      b = y.astype(np_type)
-      if np_type in [np.complex64, np.complex128]:
+  def _verifySolve(self,
+                   x,
+                   y,
+                   dtype,
+                   use_placeholder,
+                   fast,
+                   l2_regularizer,
+                   batch_shape=()):
+    if not fast and l2_regularizer != 0:
+      # The slow path does not support regularization.
+      return
+    maxdim = np.max(x.shape)
+    if dtype == np.float32 or dtype == np.complex64:
+      tol = maxdim * 5e-4
+    else:
+      tol = maxdim * 5e-7
+      a = x.astype(dtype)
+      b = y.astype(dtype)
+      if dtype in [np.complex64, np.complex128]:
         a.imag = a.real
         b.imag = b.real
-      np_ans, _, _, _ = np.linalg.lstsq(a, b)
-      for fast in [True, False]:
-        with self.test_session():
-          tf_ans = linalg_ops.matrix_solve_ls(a, b, fast=fast)
-          ans = tf_ans.eval()
-        self.assertEqual(np_ans.shape, tf_ans.get_shape())
-        self.assertEqual(np_ans.shape, ans.shape)
-
-        # Check residual norm.
-        tf_r = b - BatchMatMul(a, ans)
-        tf_r_norm = np.sum(tf_r * tf_r)
-        np_r = b - BatchMatMul(a, np_ans)
-        np_r_norm = np.sum(np_r * np_r)
-        self.assertAllClose(np_r_norm, tf_r_norm)
-
-        # Check solution.
-        if np_type == np.float32 or np_type == np.complex64:
-          tol = 5e-5
+      # numpy.linalg.lstqr does not batching, so we just solve a single system
+      # and replicate the solution. and residual norm.
+      np_ans = _SolveWithNumpy(x, y, l2_regularizer=l2_regularizer)
+      np_r = np.dot(np.conj(a.T), b - np.dot(a, np_ans))
+      np_r_norm = np.sqrt(np.sum(np.conj(np_r) * np_r))
+      if batch_shape is not ():
+        a = np.tile(a, batch_shape + (1, 1))
+        b = np.tile(b, batch_shape + (1, 1))
+        np_ans = np.tile(np_ans, batch_shape + (1, 1))
+        np_r_norm = np.tile(np_r_norm, batch_shape)
+      with self.test_session(use_gpu=fast) as sess:
+        if use_placeholder:
+          a_ph = array_ops.placeholder(dtypes.as_dtype(dtype))
+          b_ph = array_ops.placeholder(dtypes.as_dtype(dtype))
+          feed_dict = {a_ph: a, b_ph: b}
+          tf_ans = linalg_ops.matrix_solve_ls(
+              a_ph, b_ph, fast=fast, l2_regularizer=l2_regularizer)
         else:
-          tol = 1e-12
-        self.assertAllClose(np_ans, ans, atol=tol, rtol=tol)
-
-  def _verifySolveBatch(self, x, y):
-    # Since numpy.linalg.lsqr does not support batch solves, as opposed
-    # to numpy.linalg.solve, we just perform this test for a fixed batch size
-    # of 2x3.
-    for np_type in [np.float32, np.float64]:
-      a = np.tile(x.astype(np_type), [2, 3, 1, 1])
-      b = np.tile(y.astype(np_type), [2, 3, 1, 1])
-      np_ans = np.empty([2, 3, a.shape[-1], b.shape[-1]])
-      for dim1 in range(2):
-        for dim2 in range(3):
-          np_ans[dim1, dim2, :, :], _, _, _ = np.linalg.lstsq(
-              a[dim1, dim2, :, :], b[dim1, dim2, :, :])
-      for fast in [True, False]:
-        with self.test_session():
-          tf_ans = linalg_ops.matrix_solve_ls(a, b, fast=fast).eval()
-        self.assertEqual(np_ans.shape, tf_ans.shape)
-        # Check residual norm.
-        tf_r = b - BatchMatMul(a, tf_ans)
-        tf_r_norm = np.sum(tf_r * tf_r)
-        np_r = b - BatchMatMul(a, np_ans)
-        np_r_norm = np.sum(np_r * np_r)
-        self.assertAllClose(np_r_norm, tf_r_norm)
-        # Check solution.
-        if fast or a.shape[-2] >= a.shape[-1]:
-          # We skip this test for the underdetermined case when using the
-          # slow path, because Eigen does not return a minimum norm solution.
-          # TODO(rmlarsen): Enable this check for all paths if/when we fix
-          # Eigen's solver.
-          self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
-
-  def _verifyRegularized(self, x, y, l2_regularizer):
-    for np_type in [np.float32, np.float64]:
-      # Test with a single matrix.
-      a = x.astype(np_type)
-      b = y.astype(np_type)
-      np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer)
-      with self.test_session():
-        # Test matrix_solve_ls on regular matrices
-        tf_ans = linalg_ops.matrix_solve_ls(
-            a, b, l2_regularizer=l2_regularizer, fast=True).eval()
-        self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
-
-      # Test with a 2x3 batch of matrices.
-      a = np.tile(x.astype(np_type), [2, 3, 1, 1])
-      b = np.tile(y.astype(np_type), [2, 3, 1, 1])
-      np_ans = BatchRegularizedLeastSquares(a, b, l2_regularizer)
-      with self.test_session():
-        tf_ans = linalg_ops.matrix_solve_ls(
-            a, b, l2_regularizer=l2_regularizer, fast=True).eval()
-      self.assertAllClose(np_ans, tf_ans, atol=1e-5, rtol=1e-5)
-
-  def testSquare(self):
-    # 2x2 matrices, 2x3 right-hand sides.
-
-    matrix = np.array([[1., 2.], [3., 4.]])
-    rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
-    self._verifySolve(matrix, rhs)
-    self._verifySolveBatch(matrix, rhs)
-    self._verifyRegularized(matrix, rhs, l2_regularizer=0.1)
-
-  def testOverdetermined(self):
-    # 2x2 matrices, 2x3 right-hand sides.
-    matrix = np.array([[1., 2.], [3., 4.], [5., 6.]])
-    rhs = np.array([[1., 0., 1.], [0., 1., 1.], [1., 1., 0.]])
-    self._verifySolve(matrix, rhs)
-    self._verifySolveBatch(matrix, rhs)
-    self._verifyRegularized(matrix, rhs, l2_regularizer=0.1)
+          tf_ans = linalg_ops.matrix_solve_ls(
+              a, b, fast=fast, l2_regularizer=l2_regularizer)
+          feed_dict = {}
+          self.assertEqual(np_ans.shape, tf_ans.get_shape())
+        if l2_regularizer == 0:
+          # The least squares solution should satisfy A^H * (b - A*x) = 0.
+          tf_r = b - math_ops.matmul(a, tf_ans)
+          tf_r = math_ops.matmul(a, tf_r, adjoint_a=True)
+          tf_r_norm = linalg_ops.norm(tf_r, ord="fro", axis=[-2, -1])
+          tf_ans_val, tf_r_norm_val = sess.run(
+              [tf_ans, tf_r_norm], feed_dict=feed_dict)
+          self.assertAllClose(np_r_norm, tf_r_norm_val, atol=tol, rtol=tol)
+        else:
+          tf_ans_val = sess.run(tf_ans, feed_dict=feed_dict)
 
-  def testUnderdetermined(self):
-    # 2x2 matrices, 2x3 right-hand sides.
-    matrix = np.array([[1., 2., 3], [4., 5., 6.]])
-    rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
-    self._verifySolve(matrix, rhs)
-    self._verifySolveBatch(matrix, rhs)
-    self._verifyRegularized(matrix, rhs, l2_regularizer=0.1)
+      self.assertEqual(np_ans.shape, tf_ans_val.shape)
+      self.assertAllClose(np_ans, tf_ans_val, atol=2 * tol, rtol=2 * tol)
 
   def testWrongDimensions(self):
     # The matrix and right-hand sides should have the same number of rows.
-    with self.test_session():
+    with self.test_session(use_gpu=True):
       matrix = constant_op.constant([[1., 0.], [0., 1.]])
       rhs = constant_op.constant([[1., 0.]])
       with self.assertRaises(ValueError):
@@ -182,7 +145,7 @@ class MatrixSolveLsOpTest(test.TestCase):
     empty0 = np.empty([3, 0])
     empty1 = np.empty([0, 2])
     for fast in [True, False]:
-      with self.test_session():
+      with self.test_session(use_gpu=True):
         tf_ans = linalg_ops.matrix_solve_ls(empty0, empty0, fast=fast).eval()
         self.assertEqual(tf_ans.shape, (0, 0))
         tf_ans = linalg_ops.matrix_solve_ls(empty0, full, fast=fast).eval()
@@ -202,5 +165,202 @@ class MatrixSolveLsOpTest(test.TestCase):
     self.assertEqual(answer.get_shape(), [3, 3, 1])
 
 
+def _GetSmallMatrixSolveLsOpTests(dtype, use_placeholder, fast, l2_regularizer):
+
+  def Square(self):
+    # 2x2 matrices, 2x3 right-hand sides.
+    matrix = np.array([[1., 2.], [3., 4.]])
+    rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
+    for batch_shape in (), (2, 3):
+      self._verifySolve(
+          matrix,
+          rhs,
+          dtype,
+          use_placeholder,
+          fast,
+          l2_regularizer,
+          batch_shape=batch_shape)
+
+  def Overdetermined(self):
+    # 2x2 matrices, 2x3 right-hand sides.
+    matrix = np.array([[1., 2.], [3., 4.], [5., 6.]])
+    rhs = np.array([[1., 0., 1.], [0., 1., 1.], [1., 1., 0.]])
+    for batch_shape in (), (2, 3):
+      self._verifySolve(
+          matrix,
+          rhs,
+          dtype,
+          use_placeholder,
+          fast,
+          l2_regularizer,
+          batch_shape=batch_shape)
+
+  def Underdetermined(self):
+    # 2x2 matrices, 2x3 right-hand sides.
+    matrix = np.array([[1., 2., 3], [4., 5., 6.]])
+    rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
+    for batch_shape in (), (2, 3):
+      self._verifySolve(
+          matrix,
+          rhs,
+          dtype,
+          use_placeholder,
+          fast,
+          l2_regularizer,
+          batch_shape=batch_shape)
+
+  return (Square, Overdetermined, Underdetermined)
+
+
+def _GetLargeMatrixSolveLsOpTests(dtype, use_placeholder, fast, l2_regularizer):
+
+  def LargeBatchSquare(self):
+    np.random.seed(1)
+    num_rhs = 1
+    matrix_shape = (127, 127)
+    matrix = np.random.uniform(
+        low=-1.0, high=1.0,
+        size=np.prod(matrix_shape)).reshape(matrix_shape).astype(np.float32)
+    rhs = np.ones([matrix_shape[0], num_rhs]).astype(np.float32)
+    self._verifySolve(
+        matrix,
+        rhs,
+        dtype,
+        use_placeholder,
+        fast,
+        l2_regularizer,
+        batch_shape=(16, 8))
+
+  def LargeBatchOverdetermined(self):
+    np.random.seed(1)
+    num_rhs = 1
+    matrix_shape = (127, 64)
+    matrix = np.random.uniform(
+        low=-1.0, high=1.0,
+        size=np.prod(matrix_shape)).reshape(matrix_shape).astype(np.float32)
+    rhs = np.ones([matrix_shape[0], num_rhs]).astype(np.float32)
+    self._verifySolve(
+        matrix,
+        rhs,
+        dtype,
+        use_placeholder,
+        fast,
+        l2_regularizer,
+        batch_shape=(16, 8))
+
+  def LargeBatchUnderdetermined(self):
+    np.random.seed(1)
+    num_rhs = 1
+    matrix_shape = (64, 127)
+    matrix = np.random.uniform(
+        low=-1.0, high=1.0,
+        size=np.prod(matrix_shape)).reshape(matrix_shape).astype(np.float32)
+    rhs = np.ones([matrix_shape[0], num_rhs]).astype(np.float32)
+    self._verifySolve(
+        matrix,
+        rhs,
+        dtype,
+        use_placeholder,
+        fast,
+        l2_regularizer,
+        batch_shape=(16, 8))
+
+  return (LargeBatchSquare, LargeBatchOverdetermined, LargeBatchUnderdetermined)
+
+
+class MatrixSolveLsBenchmark(test_lib.Benchmark):
+
+  matrix_shapes = [
+      (4, 4),
+      (8, 4),
+      (4, 8),
+      (10, 10),
+      (10, 8),
+      (8, 10),
+      (16, 16),
+      (16, 10),
+      (10, 16),
+      (101, 101),
+      (101, 31),
+      (31, 101),
+      (256, 256),
+      (256, 200),
+      (200, 256),
+      (1001, 1001),
+      (1001, 501),
+      (501, 1001),
+      (1024, 1024),
+      (1024, 128),
+      (128, 1024),
+      (2048, 2048),
+      (2048, 64),
+      (64, 2048),
+      (513, 4, 4),
+      (513, 4, 2),
+      (513, 2, 4),
+      (513, 16, 16),
+      (513, 16, 10),
+      (513, 10, 16),
+      (513, 256, 256),
+      (513, 256, 128),
+      (513, 128, 256),
+  ]
+
+  def benchmarkMatrixSolveLsOp(self):
+    run_gpu_test = test_lib.is_gpu_available(True)
+    regularizer = 1.0
+    for matrix_shape in self.matrix_shapes:
+      for num_rhs in 1, 2, matrix_shape[-1]:
+
+        with ops.Graph().as_default(), \
+            session.Session() as sess, \
+            ops.device("/cpu:0"):
+          matrix, rhs = _GenerateTestData(matrix_shape, num_rhs)
+          x = linalg_ops.matrix_solve_ls(matrix, rhs, regularizer)
+          variables.global_variables_initializer().run()
+          self.run_op_benchmark(
+              sess,
+              control_flow_ops.group(x),
+              min_iters=25,
+              store_memory_usage=False,
+              name=("matrix_solve_ls_cpu_shape_{matrix_shape}_num_rhs_{num_rhs}"
+                   ).format(matrix_shape=matrix_shape, num_rhs=num_rhs))
+
+        if run_gpu_test and (len(matrix_shape) < 3 or matrix_shape[0] < 513):
+          with ops.Graph().as_default(), \
+                session.Session() as sess, \
+                ops.device("/gpu:0"):
+            matrix, rhs = _GenerateTestData(matrix_shape, num_rhs)
+            x = linalg_ops.matrix_solve_ls(matrix, rhs, regularizer)
+            variables.global_variables_initializer().run()
+            self.run_op_benchmark(
+                sess,
+                control_flow_ops.group(x),
+                min_iters=25,
+                store_memory_usage=False,
+                name=("matrix_solve_ls_gpu_shape_{matrix_shape}_num_rhs_"
+                      "{num_rhs}").format(
+                          matrix_shape=matrix_shape, num_rhs=num_rhs))
+
+
 if __name__ == "__main__":
-  test.main()
+  for dtype_ in [np.float32, np.float64, np.complex64, np.complex128]:
+    for use_placeholder_ in [True, False]:
+      for fast_ in [True, False]:
+        l2_regularizers = [0] if dtype_ == np.complex128 else [0, 0.1]
+        for l2_regularizer_ in l2_regularizers:
+          for test_case in _GetSmallMatrixSolveLsOpTests(
+              dtype_, use_placeholder_, fast_, l2_regularizer_):
+            name = "%s_%s_placeholder_%s_fast_%s_regu_%s" % (test_case.__name__,
+                                                             dtype_.__name__,
+                                                             use_placeholder_,
+                                                             fast_,
+                                                             l2_regularizer_)
+            _AddTest(MatrixSolveLsOpTest, "MatrixSolveLsOpTest", name,
+                     test_case)
+  for dtype_ in [np.float32, np.float64, np.complex64, np.complex128]:
+    for test_case in _GetLargeMatrixSolveLsOpTests(dtype_, False, True, 0.0):
+      name = "%s_%s" % (test_case.__name__, dtype_.__name__)
+      _AddTest(MatrixSolveLsOpTest, "MatrixSolveLsOpTest", name, test_case)
+
+  test_lib.main()
diff --git a/tensorflow/python/kernel_tests/matrix_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_op_test.py
index 4a1f3591f14c8e99af13e6d9f0feeae8128390f0..96993595387ec5a003ea84326ffac960ed76e28f 100644
--- a/tensorflow/python/kernel_tests/matrix_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_solve_op_test.py
@@ -22,7 +22,9 @@ import numpy as np
 
 from tensorflow.python.client import session
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import linalg_ops
 from tensorflow.python.ops import variables
@@ -44,18 +46,25 @@ class MatrixSolveOpTest(test.TestCase):
         else:
           a = x.astype(np_type)
           b = y.astype(np_type)
-        a_np = np.conj(np.transpose(a)) if adjoint else a
+          a_np = np.conj(np.transpose(a)) if adjoint else a
         if batch_dims is not None:
           a = np.tile(a, batch_dims + [1, 1])
           a_np = np.tile(a_np, batch_dims + [1, 1])
           b = np.tile(b, batch_dims + [1, 1])
         np_ans = np.linalg.solve(a_np, b)
-        with self.test_session(use_gpu=True):
-          tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint)
-          out = tf_ans.eval()
-          self.assertEqual(tf_ans.get_shape(), out.shape)
-          self.assertEqual(np_ans.shape, out.shape)
-          self.assertAllClose(np_ans, out, atol=tol, rtol=tol)
+        for use_placeholder in False, True:
+          with self.test_session(use_gpu=True) as sess:
+            if use_placeholder:
+              a_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
+              b_ph = array_ops.placeholder(dtypes.as_dtype(np_type))
+              tf_ans = linalg_ops.matrix_solve(a_ph, b_ph, adjoint=adjoint)
+              out = sess.run(tf_ans, {a_ph: a, b_ph: b})
+            else:
+              tf_ans = linalg_ops.matrix_solve(a, b, adjoint=adjoint)
+              out = tf_ans.eval()
+              self.assertEqual(tf_ans.get_shape(), out.shape)
+            self.assertEqual(np_ans.shape, out.shape)
+            self.assertAllClose(np_ans, out, atol=tol, rtol=tol)
 
   def _generateMatrix(self, m, n):
     matrix = (np.random.normal(-5, 5,
diff --git a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
index 53b1897f488636683a5a03f0cb3b95340fa4b25c..d8c3f9823c3d5ab5832305d988890b213cbab9b7 100644
--- a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
@@ -1191,6 +1191,9 @@ class PaddingFIFOQueueTest(test.TestCase):
       self.assertEqual([50.0], dequeued_t.eval())
       self.assertEqual([60.0], dequeued_t.eval())
 
+      # Make sure the thread finishes before exiting.
+      thread.join()
+
   def testBlockingEnqueueBeforeClose(self):
     with self.test_session() as sess:
       q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
diff --git a/tensorflow/python/kernel_tests/pooling_ops_3d_test.py b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py
index fa1553a3f6b421e551b19ad763cb5434bb528eb6..b01fc129538b8f54adcdf4b38ac8cc095e3901f4 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_3d_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py
@@ -321,6 +321,15 @@ class PoolingTest(test.TestCase):
         strides=(1, 1, 1),
         padding="VALID")
 
+  def testMaxPoolGradValidPadding1_2_3d(self):
+    self._ConstructAndTestGradient(
+        nn_ops.max_pool3d,
+        input_sizes=[1, 3, 3, 3, 1],
+        output_sizes=[1, 2, 2, 2, 1],
+        window=(1, 1, 1),
+        strides=(2, 2, 2),
+        padding="VALID")
+
   def testMaxPoolGradValidPadding2_2_3d(self):
     self._ConstructAndTestGradient(
         nn_ops.max_pool3d,
@@ -339,6 +348,15 @@ class PoolingTest(test.TestCase):
         strides=(1, 1, 1),
         padding="SAME")
 
+  def testMaxPoolGradSamePadding1_2_3d(self):
+    self._ConstructAndTestGradient(
+        nn_ops.max_pool3d,
+        input_sizes=[1, 3, 2, 4, 1],
+        output_sizes=[1, 2, 1, 2, 1],
+        window=(1, 1, 1),
+        strides=(2, 2, 2),
+        padding="SAME")
+
   def testMaxPoolGradSamePadding2_1_3d(self):
     self._ConstructAndTestGradient(
         nn_ops.max_pool3d,
@@ -375,6 +393,15 @@ class PoolingTest(test.TestCase):
         strides=(1, 1, 1),
         padding="VALID")
 
+  def testAvgPoolGradValidPadding1_2_3d(self):
+    self._ConstructAndTestGradient(
+        nn_ops.avg_pool3d,
+        input_sizes=[1, 3, 3, 3, 1],
+        output_sizes=[1, 2, 2, 2, 1],
+        window=(1, 1, 1),
+        strides=(2, 2, 2),
+        padding="VALID")
+
   def testAvgPoolGradValidPadding2_1_3d(self):
     self._ConstructAndTestGradient(
         nn_ops.avg_pool3d,
@@ -402,6 +429,15 @@ class PoolingTest(test.TestCase):
         strides=(1, 1, 1),
         padding="SAME")
 
+  def testAvgPoolGradSamePadding1_2_3d(self):
+    self._ConstructAndTestGradient(
+        nn_ops.avg_pool3d,
+        input_sizes=[1, 3, 2, 4, 2],
+        output_sizes=[1, 2, 1, 2, 2],
+        window=(1, 1, 1),
+        strides=(2, 2, 2),
+        padding="SAME")
+
   def testAvgPoolGradSamePadding2_1_3d(self):
     self._ConstructAndTestGradient(
         nn_ops.avg_pool3d,
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index da14871c872a3c585952d374910bd24e325ccf37..9eb1fea80375d107b5ac9c2a2d5a38a314bfa51c 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -998,6 +998,20 @@ class PoolingTest(test.TestCase):
           data_format=data_format,
           use_gpu=use_gpu)
 
+  def _testMaxPoolGradValidPadding1_2(self, data_format, use_gpu):
+    for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+      self._ConstructAndTestGradient(
+          pool_func,
+          input_sizes=[1, 3, 3, 1],
+          output_sizes=[1, 2, 2, 1],
+          window_rows=1,
+          window_cols=1,
+          row_stride=2,
+          col_stride=2,
+          padding="VALID",
+          data_format=data_format,
+          use_gpu=use_gpu)
+
   def _testMaxPoolGradValidPadding2_2(self, data_format, use_gpu):
     for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
       self._ConstructAndTestGradient(
@@ -1026,6 +1040,20 @@ class PoolingTest(test.TestCase):
           data_format=data_format,
           use_gpu=use_gpu)
 
+  def _testMaxPoolGradSamePadding1_2(self, data_format, use_gpu):
+    for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
+      self._ConstructAndTestGradient(
+          pool_func,
+          input_sizes=[2, 2, 4, 3],
+          output_sizes=[2, 1, 2, 3],
+          window_rows=1,
+          window_cols=1,
+          row_stride=2,
+          col_stride=2,
+          padding="SAME",
+          data_format=data_format,
+          use_gpu=use_gpu)
+
   def _testMaxPoolGradSamePadding2_1(self, data_format, use_gpu):
     for pool_func in [gen_nn_ops._max_pool_v2, nn_ops.max_pool]:
       self._ConstructAndTestGradient(
@@ -1071,10 +1099,12 @@ class PoolingTest(test.TestCase):
   def testMaxPoolGrad(self):
     for (data_format, use_gpu) in GetTestConfigs():
       self._testMaxPoolGradValidPadding1_1(data_format, use_gpu)
+      self._testMaxPoolGradValidPadding1_2(data_format, use_gpu)
       self._testMaxPoolGradValidPadding2_1_6(data_format, use_gpu)
       self._testMaxPoolGradValidPadding2_1_7(data_format, use_gpu)
       self._testMaxPoolGradValidPadding2_2(data_format, use_gpu)
       self._testMaxPoolGradSamePadding1_1(data_format, use_gpu)
+      self._testMaxPoolGradSamePadding1_2(data_format, use_gpu)
       self._testMaxPoolGradSamePadding2_1(data_format, use_gpu)
       self._testMaxPoolGradSamePadding2_2(data_format, use_gpu)
       self._testMaxPoolGradSamePadding3_1(data_format, use_gpu)
@@ -1497,9 +1527,11 @@ class PoolingTest(test.TestCase):
   def testAvgPoolGrad(self):
     for (data_format, use_gpu) in GetTestConfigs():
       self._testAvgPoolGradValidPadding1_1(data_format, use_gpu)
+      self._testAvgPoolGradValidPadding1_2(data_format, use_gpu)
       self._testAvgPoolGradValidPadding2_1(data_format, use_gpu)
       self._testAvgPoolGradValidPadding2_2(data_format, use_gpu)
       self._testAvgPoolGradSamePadding1_1(data_format, use_gpu)
+      self._testAvgPoolGradSamePadding1_2(data_format, use_gpu)
       self._testAvgPoolGradSamePadding2_1(data_format, use_gpu)
       self._testAvgPoolGradSamePadding2_2(data_format, use_gpu)
       self._testAvgPoolGradSamePadding3_1(data_format, use_gpu)
@@ -1517,6 +1549,19 @@ class PoolingTest(test.TestCase):
         data_format=data_format,
         use_gpu=use_gpu)
 
+  def _testAvgPoolGradValidPadding1_2(self, data_format, use_gpu):
+    self._ConstructAndTestGradient(
+        nn_ops.avg_pool,
+        input_sizes=[2, 3, 3, 3],
+        output_sizes=[2, 2, 2, 3],
+        window_rows=1,
+        window_cols=1,
+        row_stride=2,
+        col_stride=2,
+        padding="VALID",
+        data_format=data_format,
+        use_gpu=use_gpu)
+
   def _testAvgPoolGradValidPadding2_1(self, data_format, use_gpu):
     self._ConstructAndTestGradient(
         nn_ops.avg_pool,
@@ -1556,6 +1601,19 @@ class PoolingTest(test.TestCase):
         data_format=data_format,
         use_gpu=use_gpu)
 
+  def _testAvgPoolGradSamePadding1_2(self, data_format, use_gpu):
+    self._ConstructAndTestGradient(
+        nn_ops.avg_pool,
+        input_sizes=[2, 2, 4, 3],
+        output_sizes=[2, 1, 2, 3],
+        window_rows=1,
+        window_cols=1,
+        row_stride=2,
+        col_stride=2,
+        padding="SAME",
+        data_format=data_format,
+        use_gpu=use_gpu)
+
   def _testAvgPoolGradSamePadding2_1(self, data_format, use_gpu):
     self._ConstructAndTestGradient(
         nn_ops.avg_pool,
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index 04ce99a4a63e1827c37c2c572fafe5801b1b4bd3..8d6b7925e4551a78d32de960dedbda093b128162 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -175,6 +175,24 @@ class SumReductionTest(BaseReductionTest):
       np_arr = self._makeIncremental((2,) * rank, dtypes.int32)
       self._compareAllAxes(np_arr)
 
+  def testFloat16(self):
+    for rank in range(1, _MAX_RANK + 1):
+      np_arr = self._makeIncremental((2,) * rank, dtypes.float16)
+      self._compareAllAxes(np_arr)
+
+    # test that mean doesn't overflow
+    # only on GPU, since it has the more accurate implementation
+    if not test.is_gpu_available():
+      return
+
+    arr = np.ones([68000], dtype=np.float16)
+
+    with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+      tf_arr = array_ops.constant(arr)
+      tf_mean = math_ops.reduce_mean(tf_arr, 0, False)
+      tf_out_mean = sess.run(tf_mean)
+    self.assertAllClose(tf_out_mean, 1.)
+
   def testFloat32(self):
     for rank in range(1, _MAX_RANK + 1):
       np_arr = self._makeIncremental((2,) * rank, dtypes.float32)
@@ -523,7 +541,7 @@ class MinReductionTest(test.TestCase):
   def testFloatReduce3D(self):
     # Create a 3D array of floats and reduce across all possible
     # dimensions
-    np_arr = np.arange(0, 30).reshape([2, 3, 5]).astype(np.float32)
+    np_arr = np.arange(1, 31).reshape([2, 3, 5]).astype(np.float32)
     self._compareAll(np_arr, None)
     self._compareAll(np_arr, [])
     self._compareAll(np_arr, [0])
@@ -537,7 +555,7 @@ class MinReductionTest(test.TestCase):
   def testDoubleReduce3D(self):
     # Create a 3D array of doubles and reduce across all possible
     # dimensions
-    np_arr = np.arange(0, 30).reshape([2, 3, 5]).astype(np.float64)
+    np_arr = np.arange(1, 31).reshape([2, 3, 5]).astype(np.float64)
     self._compareAll(np_arr, None)
     self._compareAll(np_arr, [])
     self._compareAll(np_arr, [0])
@@ -629,7 +647,7 @@ class MaxReductionTest(test.TestCase):
   def testFloatReduce3D(self):
     # Create a 3D array of floats and reduce across all possible
     # dimensions
-    np_arr = np.arange(0, 30).reshape([2, 3, 5]).astype(np.float32)
+    np_arr = np.arange(-31, -1).reshape([2, 3, 5]).astype(np.float32)
     self._compareAll(np_arr, None)
     self._compareAll(np_arr, [])
     self._compareAll(np_arr, [0])
@@ -643,7 +661,7 @@ class MaxReductionTest(test.TestCase):
   def testDoubleReduce3D(self):
     # Create a 3D array of doubles and reduce across all possible
     # dimensions
-    np_arr = np.arange(0, 30).reshape([2, 3, 5]).astype(np.float64)
+    np_arr = np.arange(-31, -1).reshape([2, 3, 5]).astype(np.float64)
     self._compareAll(np_arr, None)
     self._compareAll(np_arr, [])
     self._compareAll(np_arr, [0])
@@ -656,7 +674,7 @@ class MaxReductionTest(test.TestCase):
 
   def testGradient(self):
     s = [2, 3, 4, 2]
-    x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
+    x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
     with self.test_session():
       t = ops.convert_to_tensor(x)
       su = math_ops.reduce_max(t, [1, 2])
@@ -666,7 +684,7 @@ class MaxReductionTest(test.TestCase):
 
   def testGradient2(self):
     s = [2, 3, 4, 2]
-    x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
+    x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
     with self.test_session():
       t = ops.convert_to_tensor(x)
       su = math_ops.reduce_max(t, [1])
@@ -676,7 +694,7 @@ class MaxReductionTest(test.TestCase):
 
   def testGradient3(self):
     s = [2, 3, 4, 2]
-    x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
+    x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
     with self.test_session():
       t = ops.convert_to_tensor(x)
       su = math_ops.reduce_max(t, [2])
@@ -686,7 +704,7 @@ class MaxReductionTest(test.TestCase):
 
   def testGradient4(self):
     s = [2, 3, 4, 2]
-    x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
+    x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
     with self.test_session():
       t = ops.convert_to_tensor(x)
       su = math_ops.reduce_max(t)
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test_big.py b/tensorflow/python/kernel_tests/reduction_ops_test_big.py
new file mode 100644
index 0000000000000000000000000000000000000000..0959adb026e3934713442e6f3487b30a0b252943
--- /dev/null
+++ b/tensorflow/python/kernel_tests/reduction_ops_test_big.py
@@ -0,0 +1,179 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for reduction ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class BaseReductionTest(test.TestCase):
+
+  def _tf_reduce(self, x, reduction_axes, keep_dims):
+    raise NotImplementedError()
+
+
+class BigReductionTest(BaseReductionTest):
+  """Test reductions for sum and boolean all over a wide range of shapes."""
+
+  def _tf_reduce_max(self, x, reduction_axes, keep_dims):
+    return math_ops.reduce_max(x, reduction_axes, keep_dims)
+
+  def _tf_reduce_all(self, x, reduction_axes, keep_dims):
+    return math_ops.reduce_all(x, reduction_axes, keep_dims)
+
+  def _tf_reduce_mean(self, x, reduction_axes, keep_dims):
+    return math_ops.reduce_mean(x, reduction_axes, keep_dims)
+
+  def _tf_reduce_sum(self, x, reduction_axes, keep_dims):
+    return math_ops.reduce_sum(x, reduction_axes, keep_dims)
+
+  def testFloat32Sum(self):
+    # make sure we test all possible kernel invocations
+    # logic is the same for all ops, test just float32 for brevity
+    arr_ = np.ones([4097, 4097], dtype=np.float32)
+    for size_x in [
+        1, 2, 3, 4, 16, 17, 32, 33, 64, 65, 128, 131, 256, 263, 1024, 1025,
+        4096, 4097
+    ]:
+      for size_y in [
+          1, 2, 3, 4, 16, 17, 32, 33, 64, 65, 128, 131, 256, 263, 1024, 1025,
+          4096, 4097
+      ]:
+        arr = arr_[0:size_x, 0:size_y]
+        col_sum = np.ones([size_y], dtype=np.float32) * size_x
+        row_sum = np.ones([size_x], dtype=np.float32) * size_y
+        full_sum = np.ones([], dtype=np.float32) * size_x * size_y
+
+        with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+          tf_row_sum = self._tf_reduce_sum(arr, 1, False)
+          tf_col_sum = self._tf_reduce_sum(arr, 0, False)
+          tf_full_sum = self._tf_reduce_sum(arr, [0, 1], False)
+          tf_out_row, tf_out_col, tf_out_full = sess.run(
+              [tf_row_sum, tf_col_sum, tf_full_sum])
+        self.assertAllClose(col_sum, tf_out_col)
+        self.assertAllClose(row_sum, tf_out_row)
+        self.assertAllClose(full_sum, tf_out_full)
+
+    arr_ = np.ones([130, 130, 130], dtype=np.float32)
+    for size_x in range(1, 130, 13):
+      for size_y in range(1, 130, 13):
+        for size_z in range(1, 130, 13):
+          arr = arr_[0:size_x, 0:size_y, 0:size_z]
+          sum_y = np.ones([size_x, size_z], dtype=np.float32)
+          sum_xz = np.ones([size_y], dtype=np.float32)
+
+          with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+            tf_sum_xz = self._tf_reduce_mean(arr, [0, 2], False)
+            tf_sum_y = self._tf_reduce_mean(arr, 1, False)
+            tf_out_sum_xz, tf_out_sum_y = sess.run([tf_sum_xz, tf_sum_y])
+          self.assertAllClose(sum_y, tf_out_sum_y)
+          self.assertAllClose(sum_xz, tf_out_sum_xz)
+
+  def testFloat32Max(self):
+    # make sure we test all possible kernel invocations
+    # logic is the same for all ops, test just float32 for brevity
+    arr_ = np.random.uniform(
+        low=-3, high=-1, size=[4105, 4105]).astype(np.float32)
+    for size_x in [
+        1, 2, 3, 4, 16, 17, 32, 33, 64, 65, 128, 131, 256, 263, 1024, 1025,
+        4096, 4097
+    ]:
+      for size_y in [
+          1, 2, 3, 4, 16, 17, 32, 33, 64, 65, 128, 131, 256, 263, 1024, 1025,
+          4096, 4097
+      ]:
+        arr = arr_[0:size_x, 0:size_y]
+        col_max = np.max(arr, axis=0)
+        row_max = np.max(arr, axis=1)
+        full_max = np.max(col_max)
+
+        with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+          tf_row_max = self._tf_reduce_max(arr, 1, False)
+          tf_col_max = self._tf_reduce_max(arr, 0, False)
+          tf_full_max = self._tf_reduce_max(arr, [0, 1], False)
+          tf_out_row, tf_out_col, tf_out_full = sess.run(
+              [tf_row_max, tf_col_max, tf_full_max])
+        self.assertAllClose(col_max, tf_out_col)
+        self.assertAllClose(row_max, tf_out_row)
+        self.assertAllClose(full_max, tf_out_full)
+
+    arr_ = np.random.uniform(
+        low=-3, high=-1, size=[130, 130, 130]).astype(np.float32)
+    for size_x in range(1, 130, 13):
+      for size_y in range(1, 130, 13):
+        for size_z in range(1, 130, 13):
+          arr = arr_[0:size_x, 0:size_y, 0:size_z]
+          sum_y = np.max(arr, axis=1)
+          sum_xz = np.max(arr, axis=(0, 2))
+
+          with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+            tf_sum_xz = self._tf_reduce_max(arr, [0, 2], False)
+            tf_sum_y = self._tf_reduce_max(arr, 1, False)
+            tf_out_sum_xz, tf_out_sum_y = sess.run([tf_sum_xz, tf_sum_y])
+          self.assertAllClose(sum_y, tf_out_sum_y)
+          self.assertAllClose(sum_xz, tf_out_sum_xz)
+
+  def testBooleanAll(self):
+    # make sure we test all possible kernel invocations
+    # test operation where T(0) is not the identity
+    arr_ = np.ones([4097, 4097], dtype=np.bool)
+    for size_x in [
+        1, 2, 3, 4, 16, 17, 32, 33, 64, 65, 128, 131, 256, 263, 1024, 1025,
+        4096, 4097
+    ]:
+      for size_y in [
+          1, 2, 3, 4, 16, 17, 32, 33, 64, 65, 128, 131, 256, 263, 1024, 1025,
+          4096, 4097
+      ]:
+        arr = arr_[0:size_x, 0:size_y]
+        col_sum = np.ones([size_y], dtype=np.bool)
+        row_sum = np.ones([size_x], dtype=np.bool)
+        full_sum = np.ones([1], dtype=np.bool).reshape([])
+
+        with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+          tf_row_sum = self._tf_reduce_all(arr, 1, False)
+          tf_col_sum = self._tf_reduce_all(arr, 0, False)
+          tf_full_sum = self._tf_reduce_all(arr, [0, 1], False)
+          tf_out_row, tf_out_col, tf_out_full = sess.run(
+              [tf_row_sum, tf_col_sum, tf_full_sum])
+        self.assertAllClose(col_sum, tf_out_col)
+        self.assertAllClose(row_sum, tf_out_row)
+        self.assertAllClose(full_sum, tf_out_full)
+
+    arr_ = np.ones([130, 130, 130], dtype=np.bool)
+    for size_x in range(1, 130, 13):
+      for size_y in range(1, 130, 13):
+        for size_z in range(1, 130, 13):
+          arr = arr_[0:size_x, 0:size_y, 0:size_z]
+          sum_y = np.ones([size_x, size_z], dtype=np.bool)
+          sum_xz = np.ones([size_y], dtype=np.bool)
+
+          with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+            tf_sum_xz = self._tf_reduce_all(arr, [0, 2], False)
+            tf_sum_y = self._tf_reduce_all(arr, 1, False)
+            tf_out_sum_xz, tf_out_sum_y = sess.run([tf_sum_xz, tf_sum_y])
+          self.assertAllClose(sum_y, tf_out_sum_y)
+          self.assertAllClose(sum_xz, tf_out_sum_xz)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index cd895040517db2c52ad091edd948f0df4710598a..c31732d80773f50d1e04cda40ff7087426094722 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -53,7 +53,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
                                                    0,
                                                    dtype=dtypes.int32)).run()
 
-  def testReadVariableDtypeMismatch(self):
+  def testReadVariableDtypeMismatchEager(self):
     with context.eager_mode():
       handle = resource_variable_ops.var_handle_op(
           dtype=dtypes.int32, shape=[1], name="foo")
@@ -62,7 +62,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
                                    "Expected float got int32."):
         _ = resource_variable_ops.read_variable_op(handle, dtype=dtypes.float32)
 
-  def testAssignVariableDtypeMismatch(self):
+  def testAssignVariableDtypeMismatchEager(self):
     with context.eager_mode():
       handle = resource_variable_ops.var_handle_op(
           dtype=dtypes.int32, shape=[1], name="foo")
@@ -145,17 +145,17 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
               resource_variable_ops.var_is_initialized_op(abc.handle)),
           True)
 
-  # TODO(alive): fix bug in convert_to_tensor; get this to work in Eager.
+  @test_util.run_in_graph_and_eager_modes()
   def testConstraintArg(self):
     constraint = lambda x: x
     v = resource_variable_ops.ResourceVariable(
-        initial_value=lambda: 1, constraint=constraint)
+        initial_value=lambda: 1, constraint=constraint, name="var0")
     self.assertEqual(v.constraint, constraint)
 
     constraint = 0
     with self.assertRaises(ValueError):
       v = resource_variable_ops.ResourceVariable(
-          initial_value=lambda: 1, constraint=constraint)
+          initial_value=lambda: 1, constraint=constraint, name="var1")
 
   # TODO(alive): how should this work in Eager mode?
   def testInitFn(self):
@@ -165,18 +165,17 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
       self.assertEqual(v.handle.op.colocation_groups(),
                        v.initializer.inputs[1].op.colocation_groups())
 
-  # TODO(alive): fix bug in convert_to_tensor; get this to work in Eager.
+  @test_util.run_in_graph_and_eager_modes()
   def testInitFnDtype(self):
-    with self.test_session():
-      v = resource_variable_ops.ResourceVariable(
-          initial_value=lambda: 1, dtype=dtypes.float32)
-      self.assertEqual(dtypes.float32, v.value().dtype)
+    v = resource_variable_ops.ResourceVariable(
+        initial_value=lambda: 1, dtype=dtypes.float32)
+    self.assertEqual(dtypes.float32, v.value().dtype)
 
-  # TODO(alive): fix bug in convert_to_tensor; get this to work in Eager.
+  @test_util.run_in_graph_and_eager_modes()
   def testInitFnNoDtype(self):
-    with self.test_session():
-      v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1)
-      self.assertEqual(dtypes.int32, v.value().dtype)
+    v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1,
+                                               name="var2")
+    self.assertEqual(dtypes.int32, v.value().dtype)
 
   @test_util.run_in_graph_and_eager_modes()
   def testInitializeAllVariables(self):
@@ -209,7 +208,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
     with self.test_session():
       init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4))
       v = resource_variable_ops.ResourceVariable(
-          constant_op.constant(init_value, dtype=dtypes.int32), name="var0")
+          constant_op.constant(init_value, dtype=dtypes.int32), name="var3")
       self.evaluate(variables.global_variables_initializer())
 
       value = self.evaluate(v.sparse_read([0, 3, 1, 2]))
@@ -293,33 +292,30 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testSharedName(self):
-    v = resource_variable_ops.ResourceVariable(300.0, name="var1")
+    v = resource_variable_ops.ResourceVariable(300.0, name="var4")
     self.evaluate(variables.global_variables_initializer())
 
     w = resource_variable_ops.var_handle_op(
-        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var1")
+        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var4")
     w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
     self.assertEqual(300.0, self.evaluate(w_read))
 
     x = resource_variable_ops.var_handle_op(
-        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var2")
-    if context.in_graph_mode():
-      with self.assertRaisesOpError("Resource .*/var2/.* does not exist"):
-        x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
-        self.evaluate(x_read)
-    else:
-      with self.assertRaisesRegexp(errors.NotFoundError,
-                                   "Attempted to read a nonexistent variable."):
-        _ = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
+        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="var5")
+    with self.assertRaisesOpError("Resource .*/var5/.* does not exist"):
+      x_read = resource_variable_ops.read_variable_op(x, v.dtype.base_dtype)
+      self.evaluate(x_read)
 
   @test_util.run_in_graph_and_eager_modes()
   def testSharedNameWithNamescope(self):
     with ops.name_scope("foo"):
-      v = resource_variable_ops.ResourceVariable(300.0, name="var3")
+      v = resource_variable_ops.ResourceVariable(300.0, name="var6")
+      self.assertEqual("foo/var6", v._shared_name)  # pylint: disable=protected-access
+      self.assertEqual("foo/var6:0", v.name)
       self.evaluate(variables.global_variables_initializer())
 
     w = resource_variable_ops.var_handle_op(
-        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var3")
+        dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var6")
     w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
     self.assertEqual(300.0, self.evaluate(w_read))
 
@@ -364,15 +360,15 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
       constraint = lambda x: x
       with ops.name_scope("foo"):
         v = resource_variable_ops.ResourceVariable(
-            name="var5",
+            name="var7",
             initial_value=init,
             caching_device="cpu:0",
             constraint=constraint)
       # Test properties
       self.assertEqual(dtypes.int32, v.dtype)
-      self.assertEqual("foo/var5:0", v.name)
+      self.assertEqual("foo/var7:0", v.name)
       self.assertAllEqual([10, 20, 35], v.shape.as_list())
-      self.assertAllEqual(init.device, v.device)
+      self.assertEqual(context.get_default_context().device_name, v.device)
       self.assertTrue(isinstance(v.handle, ops.EagerTensor))
       self.assertEqual(constraint, v.constraint)
       self.assertAllEqual(init.numpy(), v.read_value().numpy())
@@ -381,8 +377,8 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
       # Callable init.
       callable_init = lambda: init * 2
       v2 = resource_variable_ops.ResourceVariable(
-          initial_value=callable_init, name="var6")
-      self.assertEqual("var6:0", v2.name)
+          initial_value=callable_init, name="var7")
+      self.assertEqual("var7:0", v2.name)
       self.assertAllEqual(2 * init.numpy(), v2.read_value().numpy())
 
       # Test assign_add.
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index 33269c912343311cba22aace50bdb8b0ba87b127..516a9d000e91f55d595ee9dc9cf633fd578942b1 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -18,13 +18,17 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import itertools
+
 import numpy as np
 
+from tensorflow.python.client import session
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import ops
 from tensorflow.python.ops import gradient_checker
 from tensorflow.python.ops import math_ops
-import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
+from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 
 
@@ -107,19 +111,19 @@ class SegmentReductionOpTest(SegmentReductionHelper):
         curr_ops_list = complex_ops_list
       else:
         curr_ops_list = ops_list
-
-      with self.test_session(use_gpu=False):
-        tf_x, np_x = self._input(shape, dtype=dtype)
-        for np_op1, np_op2, tf_op in curr_ops_list:
-          np_ans = self._segmentReduce(indices, np_x, np_op1, np_op2)
-          s = tf_op(data=tf_x, segment_ids=indices)
-          tf_ans = s.eval()
-          self.assertAllClose(np_ans, tf_ans)
-          # NOTE(mrry): The static shape inference that computes
-          # `tf_ans.shape` can only infer that sizes from dimension 1
-          # onwards, because the size of dimension 0 is data-dependent
-          # and may therefore vary dynamically.
-          self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
+      for use_gpu in [True, False]:
+        with self.test_session(use_gpu=use_gpu):
+          tf_x, np_x = self._input(shape, dtype=dtype)
+          for np_op1, np_op2, tf_op in curr_ops_list:
+            np_ans = self._segmentReduce(indices, np_x, np_op1, np_op2)
+            s = tf_op(data=tf_x, segment_ids=indices)
+            tf_ans = s.eval()
+            self.assertAllClose(np_ans, tf_ans)
+            # NOTE(mrry): The static shape inference that computes
+            # `tf_ans.shape` can only infer that sizes from dimension 1
+            # onwards, because the size of dimension 0 is data-dependent
+            # and may therefore vary dynamically.
+            self.assertAllEqual(np_ans.shape[1:], tf_ans.shape[1:])
 
   def testSegmentIdsShape(self):
     shape = [4, 4]
@@ -130,41 +134,45 @@ class SegmentReductionOpTest(SegmentReductionHelper):
 
   def testSegmentIdsSize(self):
     shape = [4, 4]
-    with self.test_session():
-      tf_x, _ = self._input(shape)
-      indices = [0, 1]
-      s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
-      with self.assertRaisesOpError("segment_ids should be the same size"):
-        s.eval()
+    for use_gpu in [True, False]:
+      with self.test_session(use_gpu=use_gpu):
+        tf_x, _ = self._input(shape)
+        indices = [0, 1]
+        s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+        with self.assertRaisesOpError("segment_ids should be the same size"):
+          s.eval()
 
   def testSegmentIdsValid(self):
     # This is a baseline for the following SegmentIdsInvalid* tests.
     shape = [4, 4]
-    with self.test_session():
-      tf_x, _ = self._input(shape)
-      indices = [0, 0, 0, 1]
-      result = math_ops.segment_sum(data=tf_x, segment_ids=indices).eval()
-      self.assertAllEqual([[15, 18, 21, 24], [13, 14, 15, 16]], result)
+    for use_gpu in [True, False]:
+      with self.test_session(use_gpu=use_gpu):
+        tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
+        indices = [0, 0, 0, 1]
+        result = math_ops.segment_sum(data=tf_x, segment_ids=indices).eval()
+        self.assertAllEqual([[15, 18, 21, 24], [13, 14, 15, 16]], result)
 
   def testSegmentIdsGreaterThanZero(self):
     shape = [4, 4]
-    with self.test_session():
-      tf_x, np_x = self._input(shape)
-      indices = [1, 1, 2, 2]
-      np_ans = self._segmentReduce(indices, np_x, np.add)
-      s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
-      tf_ans = s.eval()
-      self.assertAllClose(np_ans, tf_ans)
+    for use_gpu in [True, False]:
+      with self.test_session(use_gpu=use_gpu):
+        tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32)
+        indices = [1, 1, 2, 2]
+        np_ans = self._segmentReduce(indices, np_x, np.add)
+        s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+        tf_ans = s.eval()
+        self.assertAllClose(np_ans, tf_ans)
 
   def testSegmentIdsHole(self):
     shape = [4, 4]
-    with self.test_session():
-      tf_x, np_x = self._input(shape)
-      indices = [0, 0, 3, 3]
-      np_ans = self._segmentReduce(indices, np_x, np.add)
-      s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
-      tf_ans = s.eval()
-      self.assertAllClose(np_ans, tf_ans)
+    for use_gpu in [True, False]:
+      with self.test_session(use_gpu=use_gpu):
+        tf_x, np_x = self._input(shape, dtype=dtypes_lib.float32)
+        indices = [0, 0, 3, 3]
+        np_ans = self._segmentReduce(indices, np_x, np.add)
+        s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+        tf_ans = s.eval()
+        self.assertAllClose(np_ans, tf_ans)
 
   def testSegmentIdsInvalid1(self):
     shape = [4, 4]
@@ -199,21 +207,23 @@ class SegmentReductionOpTest(SegmentReductionHelper):
 
   def testSegmentIdsInvalid4(self):
     shape = [4, 4]
-    with self.test_session():
-      tf_x, _ = self._input(shape)
-      indices = [0, 0, 0, -1]
-      s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
-      with self.assertRaisesOpError("segment ids must be >= 0"):
-        s.eval()
+    for use_gpu in [True, False]:
+      with self.test_session(use_gpu=use_gpu):
+        tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
+        indices = [0, 0, 0, -1]
+        s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+        with self.assertRaisesOpError("segment ids must be >= 0"):
+          s.eval()
 
   def testSegmentIdsInvalid5(self):
     shape = [4, 4]
-    with self.test_session():
-      tf_x, _ = self._input(shape)
-      indices = [0, 0, 0, -2]
-      s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
-      with self.assertRaisesOpError("segment ids must be >= 0"):
-        s.eval()
+    for use_gpu in [True, False]:
+      with self.test_session(use_gpu=use_gpu):
+        tf_x, _ = self._input(shape, dtype=dtypes_lib.float32)
+        indices = [0, 0, 0, -2]
+        s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
+        with self.assertRaisesOpError("segment ids must be >= 0"):
+          s.eval()
 
   def testGradient(self):
     shape = [4, 4]
@@ -340,8 +350,8 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
       shape = indices.shape + (num_cols,)
       with self.test_session(use_gpu=True):
         tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64)
-        s = math_ops.unsorted_segment_max(data=tf_x, segment_ids=indices,
-                                    num_segments=num_segments)
+        s = math_ops.unsorted_segment_max(
+            data=tf_x, segment_ids=indices, num_segments=num_segments)
         jacob_t, jacob_n = gradient_checker.compute_gradient(
             tf_x,
             shape,
@@ -635,6 +645,67 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
         with self.assertRaisesOpError(r"Segment id 0 out of range \[0, 0\)"):
           s.eval()
 
+class SegmentReductionOpBenchmark(test.Benchmark):
+  outer_dim_options = [2**x for x in range(9, 14, 2)]
+  ratio_options = [2**x for x in range(1, 6, 2)]
+  inner_dim_options = [2**x for x in range(9, 14, 2)]
+  # randomly generated sizes with less alignments
+  inner_dim_options += [
+      1120, 1215, 1856, 1302, 1329, 1531, 1313, 1672, 1851, 1584
+  ]
+  dtype_options = [np.float32, np.float64]
+  options = (outer_dim_options, ratio_options, inner_dim_options, dtype_options)
+  # pylint: disable=g-long-lambda
+  op_functors = [lambda vc, vs, seg_ids:
+                 ("sorted", math_ops.segment_sum(vc, vs)),
+                 lambda vc, vs, seg_ids:
+                 ("unsorted",
+                  math_ops.unsorted_segment_sum(vc, vs, seg_ids[-1]+1))]
+  # pylint: enable=g-long-lambda
+  repeat = 10
+
+  def _npTypeToStr(self, t):
+    if t == np.float32:
+      return "fp32"
+    if t == np.float64:
+      return "fp64"
+
+  def _runGraph(self, op_functor, outer_dim, ratio, inner_dim, dtype):
+    output_outer_dim = int(outer_dim / ratio)
+    const = np.random.randint(5, size=(outer_dim, inner_dim))
+    seg_ids = np.sort(np.random.randint(output_outer_dim, size=outer_dim))
+    vs = variables.Variable(seg_ids.astype(np.int32))
+    with ops.device("/gpu:0"):
+      vc = variables.Variable(const.astype(dtype))
+    name, op = op_functor(vc, vs, seg_ids)
+    with session.Session() as sess:
+      variables.global_variables_initializer().run()
+      r = self.run_op_benchmark(
+          sess,
+          op,
+          min_iters=self.repeat,
+          name="_".join(
+              map(str,
+                  [name, outer_dim, ratio, inner_dim,
+                   self._npTypeToStr(dtype)])))
+    return name, r["wall_time"]
+
+  def benchmarkSegmentSumGPU(self):
+    if not test.is_gpu_available(cuda_only=True):
+      return
+    for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options):
+      op_functor = self.op_functors[0]
+      with ops.Graph().as_default():
+        self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype)
+
+  def benchmarkUnsortedSegmentSumGPU(self):
+    if not test.is_gpu_available(cuda_only=True):
+      return
+    for outer_dim, ratio, inner_dim, dtype in itertools.product(*self.options):
+      op_functor = self.op_functors[1]
+      with ops.Graph().as_default():
+        self._runGraph(op_functor, outer_dim, ratio, inner_dim, dtype)
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
index 51bfceee01f5d3656f1769a3077ce45462b26cf9..9161b8c5d1c1e05de30520dcbb5b9c607b6774b8 100644
--- a/tensorflow/python/kernel_tests/sparse_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gradient_checker
 from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import variables
 import tensorflow.python.ops.sparse_grad  # pylint: disable=unused-import
 from tensorflow.python.platform import googletest
 from tensorflow.python.platform import test
@@ -544,6 +545,22 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase):
       self.assertAllEqual(empty_row_indicator_out, np.zeros(2).astype(np.bool))
 
 
+class SparseAddTest(test_util.TensorFlowTestCase):
+
+  def testValuesInVariable(self):
+    indices = constant_op.constant([[1]], dtype=dtypes.int64)
+    values = variables.Variable([1], trainable=False, dtype=dtypes.float32)
+    shape = constant_op.constant([1], dtype=dtypes.int64)
+
+    sp_input = sparse_tensor.SparseTensor(indices, values, shape)
+    sp_output = sparse_ops.sparse_add(sp_input, sp_input)
+
+    with self.test_session(use_gpu=False) as sess:
+      sess.run(variables.global_variables_initializer())
+      output = sess.run(sp_output)
+      self.assertAllEqual(output.values, [2])
+
+
 class SparseReduceTest(test_util.TensorFlowTestCase):
 
   # [[1, ?, 2]
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 67932d08236b3401499d593e746ba3a7202618bb..cdac12f05a747574a86575b043fdc67faa3e16b3 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -20,10 +20,12 @@ from __future__ import print_function
 
 import numpy
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import init_ops
@@ -37,32 +39,38 @@ from tensorflow.python.platform import test
 
 class VariableScopeTest(test.TestCase):
 
+  @test_util.run_in_graph_and_eager_modes()
   def testGetVar(self):
     vs = variable_scope._get_default_variable_store()
     v = vs.get_variable("v", [1])
     v1 = vs.get_variable("v", [1])
     self.assertEqual(v, v1)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testResource(self):
     vs = variable_scope._get_default_variable_store()
     v1 = vs.get_variable("v", [1], use_resource=True)
     self.assertTrue(isinstance(v1, resource_variable_ops.ResourceVariable))
 
+  @test_util.run_in_graph_and_eager_modes()
   def testNameExists(self):
     vs = variable_scope._get_default_variable_store()
     # No check by default, so we can both create and get existing names.
     v = vs.get_variable("v", [1])
     v1 = vs.get_variable("v", [1])
     self.assertEqual(v, v1)
-    # When reuse is False, we fail when variables are already there.
-    vs.get_variable("w", [1], reuse=False)  # That's ok.
-    with self.assertRaises(ValueError):
-      vs.get_variable("v", [1], reuse=False)  # That fails.
-    # When reuse is True, we fail when variables are new.
-    vs.get_variable("v", [1], reuse=True)  # That's ok.
-    with self.assertRaises(ValueError):
-      vs.get_variable("u", [1], reuse=True)  # That fails.
 
+    if context.in_graph_mode():
+      # When reuse is False, we fail when variables are already there.
+      vs.get_variable("w", [1], reuse=False)  # That's ok.
+      with self.assertRaises(ValueError):
+        vs.get_variable("v", [1], reuse=False)  # That fails.
+      # When reuse is True, we fail when variables are new.
+      vs.get_variable("v", [1], reuse=True)  # That's ok.
+      with self.assertRaises(ValueError):
+        vs.get_variable("u", [1], reuse=True)  # That fails.
+
+  @test_util.run_in_graph_and_eager_modes()
   def testNamelessStore(self):
     vs = variable_scope._get_default_variable_store()
     vs.get_variable("v1", [2])
@@ -71,22 +79,23 @@ class VariableScopeTest(test.TestCase):
     self.assertEqual(
         set(expected_names), set([v.name for v in vs._vars.values()]))
 
+  @test_util.run_in_graph_and_eager_modes()
   def testVarScopeInitializer(self):
-    with self.test_session() as sess:
-      init = init_ops.constant_initializer(0.3)
-      with variable_scope.variable_scope("tower") as tower:
-        with variable_scope.variable_scope("foo", initializer=init):
-          v = variable_scope.get_variable("v", [])
-          sess.run(variables_lib.initialize_variables([v]))
-          self.assertAllClose(v.eval(), 0.3)
-        with variable_scope.variable_scope(tower, initializer=init):
-          w = variable_scope.get_variable("w", [])
-          sess.run(variables_lib.initialize_variables([w]))
-          self.assertAllClose(w.eval(), 0.3)
+    init = init_ops.constant_initializer(0.3)
+    with variable_scope.variable_scope("tower0") as tower:
+      with variable_scope.variable_scope("foo", initializer=init):
+        v = variable_scope.get_variable("v", [])
+        self.evaluate(variables_lib.variables_initializer([v]))
+        self.assertAllClose(self.evaluate(v.value()), 0.3)
+      with variable_scope.variable_scope(tower, initializer=init):
+        w = variable_scope.get_variable("w", [])
+        self.evaluate(variables_lib.variables_initializer([w]))
+        self.assertAllClose(self.evaluate(w.value()), 0.3)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testVarScopeConstraint(self):
     constraint = lambda x: 0. * x
-    with variable_scope.variable_scope("tower") as tower:
+    with variable_scope.variable_scope("tower1") as tower:
       with variable_scope.variable_scope("foo", constraint=constraint):
         v = variable_scope.get_variable("v", [])
         self.assertEqual(v.constraint, constraint)
@@ -94,51 +103,56 @@ class VariableScopeTest(test.TestCase):
         w = variable_scope.get_variable("w", [])
         self.assertEqual(w.constraint, constraint)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testVarScopeDType(self):
-    with self.test_session():
-      with variable_scope.variable_scope("tower") as tower:
-        with variable_scope.variable_scope("foo", dtype=dtypes.float16):
-          v = variable_scope.get_variable("v", [])
-          self.assertEqual(v.dtype.base_dtype, dtypes.float16)
-        with variable_scope.variable_scope(tower, dtype=dtypes.float16):
-          w = variable_scope.get_variable("w", [])
-          self.assertEqual(w.dtype.base_dtype, dtypes.float16)
+    with variable_scope.variable_scope("tower2") as tower:
+      with variable_scope.variable_scope("foo", dtype=dtypes.float16):
+        v = variable_scope.get_variable("v", [])
+        self.assertEqual(v.dtype.base_dtype, dtypes.float16)
+      with variable_scope.variable_scope(tower, dtype=dtypes.float16):
+        w = variable_scope.get_variable("w", [])
+        self.assertEqual(w.dtype.base_dtype, dtypes.float16)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testInitFromNonTensorValue(self):
-    with self.test_session() as sess:
-      v = variable_scope.get_variable("v", initializer=4, dtype=dtypes.int32)
-      sess.run(variables_lib.initialize_variables([v]))
-      self.assertAllClose(v.eval(), 4)
+    v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32)
+    self.evaluate(variables_lib.variables_initializer([v]))
+    self.assertAllClose(self.evaluate(v.value()), 4)
 
-      w = variable_scope.get_variable(
-          "w", initializer=numpy.array([1, 2, 3]), dtype=dtypes.int64)
-      sess.run(variables_lib.initialize_variables([w]))
-      self.assertAllClose(w.eval(), [1, 2, 3])
+    w = variable_scope.get_variable(
+        "w4", initializer=numpy.array([1, 2, 3]), dtype=dtypes.int64)
+    self.evaluate(variables_lib.variables_initializer([w]))
+    self.assertAllClose(self.evaluate(w.value()), [1, 2, 3])
 
+    if context.in_graph_mode():
       with self.assertRaises(TypeError):
-        variable_scope.get_variable("x", initializer={})
+        variable_scope.get_variable("x4", initializer={})
+    else:
+      with self.assertRaises(errors.InvalidArgumentError):
+        variable_scope.get_variable("x4", initializer={})
 
+  @test_util.run_in_graph_and_eager_modes()
   def testInitFromNonInitializer(self):
-    with self.test_session():
-      # Test various dtypes with zeros initializer as following:
-      types = [
-          dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.uint16, dtypes.int32,
-          dtypes.int64, dtypes.bool
-      ]
-
-      # Use different variable_name to distinguish various dtypes
-      for (i, dtype) in enumerate(types):
-        x = variable_scope.get_variable(
-            name="x%d" % i, shape=(3, 4), dtype=dtype)
-        y = variable_scope.get_variable(
-            name="y%d" % i,
-            shape=(3, 4),
-            dtype=dtype,
-            initializer=init_ops.zeros_initializer(dtype=dtype))
-
-        variables_lib.global_variables_initializer().run()
-        self.assertAllEqual(x.eval(), y.eval())
-
+    # Test various dtypes with zeros initializer as following:
+    types = [
+        dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.uint16, dtypes.int32,
+        dtypes.int64, dtypes.bool
+    ]
+
+    # Use different variable_name to distinguish various dtypes
+    for (i, dtype) in enumerate(types):
+      x = variable_scope.get_variable(
+          name="xx%d" % i, shape=(3, 4), dtype=dtype)
+      y = variable_scope.get_variable(
+          name="yy%d" % i,
+          shape=(3, 4),
+          dtype=dtype,
+          initializer=init_ops.zeros_initializer(dtype=dtype))
+
+      self.evaluate(variables_lib.global_variables_initializer())
+      self.assertAllEqual(self.evaluate(x.value()), self.evaluate(y.value()))
+
+  # TODO(alive): support variable partitioning/caching in eager mode.
   def testVarScopeCachingDevice(self):
     with self.test_session():
       caching_device = "/job:moo"
@@ -172,74 +186,74 @@ class VariableScopeTest(test.TestCase):
         v_tower = variable_scope.get_variable("v", [])
         self.assertFalse(v_tower.value().device.startswith(caching_device))
 
+  @test_util.run_in_graph_and_eager_modes()
   def testVarScopeRegularizer(self):
-    with self.test_session() as sess:
-      init = init_ops.constant_initializer(0.3)
-
-      def regularizer1(v):
-        return math_ops.reduce_mean(v) + 0.1
-
-      def regularizer2(v):
-        return math_ops.reduce_mean(v) + 0.2
-
-      with variable_scope.variable_scope(
-          "tower", regularizer=regularizer1) as tower:
-        with variable_scope.variable_scope("foo", initializer=init):
-          v = variable_scope.get_variable("v", [])
-          sess.run(variables_lib.initialize_variables([v]))
-          losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
-          self.assertEqual(1, len(losses))
-          self.assertAllClose(losses[0].eval(), 0.4)
-        with variable_scope.variable_scope(tower, initializer=init) as vs:
-          u = variable_scope.get_variable("u", [])
-          vs.set_regularizer(regularizer2)
-          w = variable_scope.get_variable("w", [])
-          # Next 3 variable not regularized to test disabling regularization.
-          x = variable_scope.get_variable(
-              "x", [], regularizer=variable_scope.no_regularizer)
-          with variable_scope.variable_scope(
-              "baz", regularizer=variable_scope.no_regularizer):
-            y = variable_scope.get_variable("y", [])
-          vs.set_regularizer(variable_scope.no_regularizer)
-          z = variable_scope.get_variable("z", [])
-          # Check results.
-          losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
-          self.assertEqual(3, len(losses))
-          sess.run(variables_lib.initialize_variables([u, w, x, y, z]))
-          self.assertAllClose(losses[0].eval(), 0.4)
-          self.assertAllClose(losses[1].eval(), 0.4)
-          self.assertAllClose(losses[2].eval(), 0.5)
-        with variable_scope.variable_scope("foo", reuse=True):
-          v = variable_scope.get_variable("v",
-                                          [])  # "v" is alredy there, reused
-          losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
-          self.assertEqual(3, len(losses))  # No new loss added.
+    init = init_ops.constant_initializer(0.3)
 
-  def testInitializeFromValue(self):
-    with self.test_session() as sess:
-      init = constant_op.constant(0.1)
-      w = variable_scope.get_variable("v", initializer=init)
-      sess.run(variables_lib.initialize_variables([w]))
-      self.assertAllClose(w.eval(), 0.1)
+    def regularizer1(v):
+      return math_ops.reduce_mean(v) + 0.1
 
-      with self.assertRaisesRegexp(ValueError, "shape"):
-        # We disallow explicit shape specification when initializer is constant.
-        variable_scope.get_variable("u", [1], initializer=init)
+    def regularizer2(v):
+      return math_ops.reduce_mean(v) + 0.2
 
+    with variable_scope.variable_scope(
+        "tower3", regularizer=regularizer1) as tower:
       with variable_scope.variable_scope("foo", initializer=init):
-        # Constant initializer can be passed through scopes if needed.
-        v = variable_scope.get_variable("v")
-        sess.run(variables_lib.initialize_variables([v]))
-        self.assertAllClose(v.eval(), 0.1)
-
-      # Check that non-float32 initializer creates a non-float32 variable.
-      init = constant_op.constant(1, dtype=dtypes.int32)
-      t = variable_scope.get_variable("t", initializer=init)
-      self.assertEqual(t.dtype.base_dtype, dtypes.int32)
-
-      # Raise error if `initializer` dtype and `dtype` are not identical.
-      with self.assertRaisesRegexp(ValueError, "don't match"):
-        variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64)
+        v = variable_scope.get_variable("v", [])
+        self.evaluate(variables_lib.variables_initializer([v]))
+        losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+        self.assertEqual(1, len(losses))
+        self.assertAllClose(self.evaluate(losses[0]), 0.4)
+      with variable_scope.variable_scope(tower, initializer=init) as vs:
+        u = variable_scope.get_variable("u", [])
+        vs.set_regularizer(regularizer2)
+        w = variable_scope.get_variable("w", [])
+        # Next 3 variable not regularized to test disabling regularization.
+        x = variable_scope.get_variable(
+            "x", [], regularizer=variable_scope.no_regularizer)
+        with variable_scope.variable_scope(
+            "baz", regularizer=variable_scope.no_regularizer):
+          y = variable_scope.get_variable("y", [])
+        vs.set_regularizer(variable_scope.no_regularizer)
+        z = variable_scope.get_variable("z", [])
+        # Check results.
+        losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+        self.assertEqual(3, len(losses))
+        self.evaluate(variables_lib.variables_initializer([u, w, x, y, z]))
+        self.assertAllClose(self.evaluate(losses[0]), 0.4)
+        self.assertAllClose(self.evaluate(losses[1]), 0.4)
+        self.assertAllClose(self.evaluate(losses[2]), 0.5)
+      with variable_scope.variable_scope("foo", reuse=True):
+        v = variable_scope.get_variable("v",
+                                        [])  # "v" is alredy there, reused
+        losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
+        self.assertEqual(3, len(losses))  # No new loss added.
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testInitializeFromValue(self):
+    init = constant_op.constant(0.1)
+    w = variable_scope.get_variable("v", initializer=init)
+    self.evaluate(variables_lib.variables_initializer([w]))
+    self.assertAllClose(self.evaluate(w.value()), 0.1)
+
+    with self.assertRaisesRegexp(ValueError, "shape"):
+      # We disallow explicit shape specification when initializer is constant.
+      variable_scope.get_variable("u", [1], initializer=init)
+
+    with variable_scope.variable_scope("foo", initializer=init):
+      # Constant initializer can be passed through scopes if needed.
+      v = variable_scope.get_variable("v")
+      self.evaluate(variables_lib.variables_initializer([v]))
+      self.assertAllClose(self.evaluate(v.value()), 0.1)
+
+    # Check that non-float32 initializer creates a non-float32 variable.
+    init = constant_op.constant(1, dtype=dtypes.int32)
+    t = variable_scope.get_variable("t", initializer=init)
+    self.assertEqual(t.dtype.base_dtype, dtypes.int32)
+
+    # Raise error if `initializer` dtype and `dtype` are not identical.
+    with self.assertRaisesRegexp(ValueError, "don't match"):
+      variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64)
 
   def testControlDeps(self):
     with self.test_session() as sess:
@@ -250,16 +264,16 @@ class VariableScopeTest(test.TestCase):
             "v1", [1], initializer=init_ops.constant_initializer(1))
         add = v1 + v0
       # v0 should be uninitialized.
-      with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
+      with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
         sess.run(v0)
       # We should be able to initialize and run v1 without initializing
       # v0, even if the variable was created with a control dep on v0.
       sess.run(v1.initializer)
       self.assertEqual(1, sess.run(v1))
       # v0 should still be uninitialized.
-      with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
+      with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
         sess.run(v0)
-      with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
+      with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
         sess.run(add)
       # If we initialize v0 we should be able to run 'add'.
       sess.run(v0.initializer)
@@ -295,82 +309,85 @@ class VariableScopeTest(test.TestCase):
       sess.run(v2.initializer)
       self.assertEqual([2], sess.run(v2))
       # v0 should still be uninitialized.
-      with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
+      with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
         sess.run(v0)
       # We should not be able to run 'add' yet.
-      with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
+      with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
         sess.run(add)
       # If we initialize v0 we should be able to run 'add'.
       sess.run(v0.initializer)
       sess.run(add)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testGetVariableScope(self):
     # Test the get_variable_scope() function and setting properties of result.
-    with self.test_session() as sess:
-      init = init_ops.constant_initializer(0.3)
-      with variable_scope.variable_scope("foo"):
-        new_init1 = variable_scope.get_variable_scope().initializer
-        self.assertEqual(new_init1, None)
-        # Check that we can set initializer like this.
-        variable_scope.get_variable_scope().set_initializer(init)
-        v = variable_scope.get_variable("v", [])
-        sess.run(variables_lib.initialize_variables([v]))
-        self.assertAllClose(v.eval(), 0.3)
+    init = init_ops.constant_initializer(0.3)
+    with variable_scope.variable_scope("bar"):
+      new_init1 = variable_scope.get_variable_scope().initializer
+      self.assertEqual(new_init1, None)
+      # Check that we can set initializer like this.
+      variable_scope.get_variable_scope().set_initializer(init)
+      v = variable_scope.get_variable("v", [])
+      self.evaluate(variables_lib.variables_initializer([v]))
+      self.assertAllClose(self.evaluate(v.value()), 0.3)
+      if context.in_graph_mode():
         # Check that we can set reuse.
         variable_scope.get_variable_scope().reuse_variables()
         with self.assertRaises(ValueError):  # Fail, w does not exist yet.
           variable_scope.get_variable("w", [1])
-      # Check that the set initializer goes away.
-      new_init = variable_scope.get_variable_scope().initializer
-      self.assertEqual(new_init, None)
+    # Check that the set initializer goes away.
+    new_init = variable_scope.get_variable_scope().initializer
+    self.assertEqual(new_init, None)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testVarScope(self):
-    with self.test_session():
-      with variable_scope.variable_scope("tower") as tower:
-        self.assertEqual(tower.name, "tower")
+    with variable_scope.variable_scope("tower4") as tower:
+      self.assertEqual(tower.name, "tower4")
+      with ops.name_scope("scope") as sc:
+        self.assertEqual(sc, "tower4/scope/")
+
+    with variable_scope.variable_scope("tower5"):
+      with variable_scope.variable_scope("bar") as bar:
+        self.assertEqual(bar.name, "tower5/bar")
         with ops.name_scope("scope") as sc:
-          self.assertEqual(sc, "tower/scope/")
+          self.assertEqual(sc, "tower5/bar/scope/")
 
-      with variable_scope.variable_scope("foo"):
-        with variable_scope.variable_scope("bar") as bar:
-          self.assertEqual(bar.name, "foo/bar")
-          with ops.name_scope("scope") as sc:
-            self.assertEqual(sc, "foo/bar/scope/")
-
-      with variable_scope.variable_scope("foo"):
-        with variable_scope.variable_scope(tower, reuse=True) as tower_shared:
-          self.assertEqual(tower_shared.name, "tower")
-          with ops.name_scope("scope") as sc:
-            self.assertEqual(sc, "foo_1/tower/scope/")
+    with variable_scope.variable_scope("tower6"):
+      with variable_scope.variable_scope(tower, reuse=True) as tower_shared:
+        self.assertEqual(tower_shared.name, "tower4")
+        with ops.name_scope("scope") as sc:
+          self.assertEqual(sc, "tower6/tower4/scope/")
 
+  @test_util.run_in_graph_and_eager_modes()
   def testVarScopeNameScope(self):
-    with self.test_session():
-      with ops.name_scope("scope1"):
-        with variable_scope.variable_scope("tower") as tower:
-          with ops.name_scope("scope2") as sc2:
-            self.assertEqual(sc2, "scope1/tower/scope2/")
+    with ops.name_scope("testVarScopeNameScope1"):
+      with variable_scope.variable_scope("tower") as tower:
+        with ops.name_scope("scope2") as sc2:
+          self.assertEqual(sc2, "testVarScopeNameScope1/tower/scope2/")
+      if context.in_graph_mode():
         with variable_scope.variable_scope(
             tower):  # Re-entering acts like another "tower".
           with ops.name_scope("scope2") as sc2:
-            self.assertEqual(sc2, "scope1/tower_1/scope2/")
+            self.assertEqual(sc2, "testVarScopeNameScope1/tower_1/scope2/")
         with variable_scope.variable_scope(
             "tower"):  # Re-entering by string acts the same.
           with ops.name_scope("scope2") as sc2:
-            self.assertEqual(sc2, "scope1/tower_2/scope2/")
+            self.assertEqual(sc2, "testVarScopeNameScope1/tower_2/scope2/")
 
-      with ops.name_scope("scope3"):
-        with variable_scope.variable_scope("tower"):
-          with ops.name_scope("scope2") as sc2:
-            self.assertEqual(sc2, "scope3/tower/scope2/")
+    with ops.name_scope("testVarScopeNameScope2"):
+      with variable_scope.variable_scope("tower"):
+        with ops.name_scope("scope2") as sc2:
+          self.assertEqual(sc2, "testVarScopeNameScope2/tower/scope2/")
+      if context.in_graph_mode():
         with variable_scope.variable_scope(tower):
           with ops.name_scope("scope2") as sc2:
-            self.assertEqual(sc2, "scope3/tower_1/scope2/")
+            self.assertEqual(sc2, "testVarScopeNameScope2/tower_1/scope2/")
 
-      root_var_scope = variable_scope.get_variable_scope()
-      with ops.name_scope("scope4"):
-        with variable_scope.variable_scope(root_var_scope):
-          with ops.name_scope("scope2") as sc2:
-            self.assertEqual(sc2, "scope4/scope2/")
+    root_var_scope = variable_scope.get_variable_scope()
+    with ops.name_scope("testVarScopeNameScope3"):
+      with variable_scope.variable_scope(root_var_scope):
+        with ops.name_scope("scope2") as sc2:
+          self.assertEqual(sc2, "testVarScopeNameScope3/scope2/")
 
   def testVarScopeOriginalNameScope(self):
     with self.test_session():
@@ -422,51 +439,46 @@ class VariableScopeTest(test.TestCase):
       with variable_scope.variable_scope(vs, reuse=False) as jump_no_reuse:
         self.assertFalse(jump_no_reuse.reuse)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testVarScopeGetOrCreateReuse(self):
-    x = array_ops.placeholder(dtypes.float32)
-
-    with variable_scope.variable_scope("bar",
-                                       reuse=variable_scope.AUTO_REUSE):
-      v_assign = state_ops.assign(variable_scope.get_variable("var", []), x)
-
-    with variable_scope.variable_scope("bar",
-                                       reuse=variable_scope.AUTO_REUSE):
-      v = variable_scope.get_variable("var", [])
-
-    with self.test_session() as sess:
-      def test_value(value):
-        sess.run(v_assign, feed_dict={x: value})
-        self.assertEqual(value, v.eval())
-
-      test_value(42)  # Variable is created.
-      test_value(13)  # Variable is reused hereafter.
-      test_value(17)
+    def test_value(value):
+      x = constant_op.constant(value)
+      with variable_scope.variable_scope("testVarScopeGetOrCreateReuse_bar",
+                                         reuse=variable_scope.AUTO_REUSE):
+        _ = state_ops.assign(variable_scope.get_variable("var", []), x)
+      with variable_scope.variable_scope("testVarScopeGetOrCreateReuse_bar",
+                                         reuse=variable_scope.AUTO_REUSE):
+        _ = variable_scope.get_variable("var", [])
+      self.assertEqual(value, self.evaluate(x))
+    test_value(42.)  # Variable is created.
+    test_value(13.)  # Variable is reused hereafter.
+    test_value(17.)
 
   def testVarOpScope(self):
     with self.test_session():
-      with ops.name_scope("scope1"):
+      with ops.name_scope("testVarOpScope1"):
         with variable_scope.variable_scope("tower", "default", []):
           self.assertEqual(
               variable_scope.get_variable("w", []).name, "tower/w:0")
-          with ops.name_scope("scope2") as sc2:
-            self.assertEqual(sc2, "scope1/tower/scope2/")
+          with ops.name_scope("testVarOpScope2") as sc2:
+            self.assertEqual(sc2, "testVarOpScope1/tower/testVarOpScope2/")
         with variable_scope.variable_scope("tower", "default", []):
           with self.assertRaises(ValueError):
             variable_scope.get_variable("w", [])
-          with ops.name_scope("scope2") as sc2:
-            self.assertEqual(sc2, "scope1/tower_1/scope2/")
+          with ops.name_scope("testVarOpScope2") as sc2:
+            self.assertEqual(sc2, "testVarOpScope1/tower_1/testVarOpScope2/")
 
-      with ops.name_scope("scope2"):
+      with ops.name_scope("testVarOpScope2"):
         with variable_scope.variable_scope(None, "default", []):
           self.assertEqual(
               variable_scope.get_variable("w", []).name, "default/w:0")
-          with ops.name_scope("scope2") as sc2:
-            self.assertEqual(sc2, "scope2/default/scope2/")
+          with ops.name_scope("testVarOpScope2") as sc2:
+            self.assertEqual(sc2, "testVarOpScope2/default/testVarOpScope2/")
         with variable_scope.variable_scope(None, "default", []):
           self.assertEqual(
               variable_scope.get_variable("w", []).name, "default_1/w:0")
-          with ops.name_scope("scope2") as sc2:
-            self.assertEqual(sc2, "scope2/default_1/scope2/")
+          with ops.name_scope("testVarOpScope2") as sc2:
+            self.assertEqual(sc2, "testVarOpScope2/default_1/testVarOpScope2/")
 
   def testVarOpScopeUniqueNamesInterleavedSubstringScopes(self):
     with self.test_session():
@@ -714,27 +726,27 @@ class VariableScopeTest(test.TestCase):
           with ops.name_scope("scope2") as sc2:
             self.assertEqual(sc2, "outer_1/default/scope2/")
 
+  @test_util.run_in_graph_and_eager_modes()
   def testGetLocalVar(self):
-    with self.test_session():
-      # Check that local variable respects naming.
-      with variable_scope.variable_scope("outer") as outer:
-        with variable_scope.variable_scope(outer, "default", []):
-          local_var = variable_scope.get_local_variable(
-              "w", [], collections=["foo"])
-          self.assertEqual(local_var.name, "outer/w:0")
-
-      # Since variable is local, it should be in the local variable collection
-      # but not the trainable collection.
-      self.assertIn(local_var,
-                    ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
-      self.assertIn(local_var, ops.get_collection("foo"))
-      self.assertNotIn(local_var,
-                       ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
-
-      # Check that local variable respects `reuse`.
-      with variable_scope.variable_scope(outer, "default", reuse=True):
-        self.assertEqual(
-            variable_scope.get_local_variable("w", []).name, "outer/w:0")
+    # Check that local variable respects naming.
+    with variable_scope.variable_scope("outer") as outer:
+      with variable_scope.variable_scope(outer, "default", []):
+        local_var = variable_scope.get_local_variable(
+            "w", [], collections=["foo"])
+        self.assertEqual(local_var.name, "outer/w:0")
+
+    # Since variable is local, it should be in the local variable collection
+    # but not the trainable collection.
+    self.assertIn(local_var,
+                  ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
+    self.assertIn(local_var, ops.get_collection("foo"))
+    self.assertNotIn(local_var,
+                     ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
+
+    # Check that local variable respects `reuse`.
+    with variable_scope.variable_scope(outer, "default", reuse=True):
+      self.assertEqual(
+          variable_scope.get_local_variable("w", []).name, "outer/w:0")
 
   def testGetVarWithDevice(self):
     g = ops.Graph()
@@ -753,69 +765,93 @@ class VariableScopeTest(test.TestCase):
     self.assertEqual(varname_type[0], ("x", dtypes.float32))
     self.assertEqual(varname_type[1], ("y", dtypes.int64))
 
+  @test_util.run_in_graph_and_eager_modes()
   def testGetCollection(self):
-    with self.test_session():
-      _ = variable_scope.get_variable("a", [])
-      _ = variable_scope.get_variable("b", [], trainable=False)
-      with variable_scope.variable_scope("foo_") as scope1:
-        _ = variable_scope.get_variable("a", [])
-        _ = variable_scope.get_variable("b", [], trainable=False)
-        self.assertEqual([
-            v.name
-            for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
-        ], ["foo_/a:0"])
-        self.assertEqual([
-            v.name
-            for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
-        ], ["foo_/a:0", "foo_/b:0"])
-      with variable_scope.variable_scope("foo") as scope2:
-        _ = variable_scope.get_variable("a", [])
-        _ = variable_scope.get_variable("b", [], trainable=False)
-        self.assertEqual([
-            v.name
-            for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
-        ], ["foo/a:0"])
-        self.assertEqual([
-            v.name
-            for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
-        ], ["foo/a:0", "foo/b:0"])
-      scope = variable_scope.get_variable_scope()
+    _ = variable_scope.get_variable("testGetCollection_a", [])
+    _ = variable_scope.get_variable("testGetCollection_b", [], trainable=False)
+    with variable_scope.variable_scope("testGetCollection_foo_") as scope1:
+      _ = variable_scope.get_variable("testGetCollection_a", [])
+      _ = variable_scope.get_variable("testGetCollection_b", [],
+                                      trainable=False)
       self.assertEqual([
-          v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
-      ], ["a:0", "b:0", "foo_/a:0", "foo_/b:0", "foo/a:0", "foo/b:0"])
+          v.name
+          for v in scope1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+      ], ["testGetCollection_foo_/testGetCollection_a:0"])
       self.assertEqual([
           v.name
-          for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
-      ], ["a:0", "foo_/a:0", "foo/a:0"])
-
+          for v in scope1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+      ], [
+          "testGetCollection_foo_/testGetCollection_a:0",
+          "testGetCollection_foo_/testGetCollection_b:0"
+      ])
+    with variable_scope.variable_scope("testGetCollection_foo") as scope2:
+      _ = variable_scope.get_variable("testGetCollection_a", [])
+      _ = variable_scope.get_variable("testGetCollection_b", [],
+                                      trainable=False)
+      self.assertEqual([
+          v.name
+          for v in scope2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+      ], ["testGetCollection_foo/testGetCollection_a:0"])
+      self.assertEqual([
+          v.name
+          for v in scope2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+      ], [
+          "testGetCollection_foo/testGetCollection_a:0",
+          "testGetCollection_foo/testGetCollection_b:0"
+      ])
+    scope = variable_scope.get_variable_scope()
+    self.assertEqual([
+        v.name for v in scope.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+    ], [
+        "testGetCollection_a:0", "testGetCollection_b:0",
+        "testGetCollection_foo_/testGetCollection_a:0",
+        "testGetCollection_foo_/testGetCollection_b:0",
+        "testGetCollection_foo/testGetCollection_a:0",
+        "testGetCollection_foo/testGetCollection_b:0"
+    ])
+    self.assertEqual([
+        v.name
+        for v in scope.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+    ], [
+        "testGetCollection_a:0",
+        "testGetCollection_foo_/testGetCollection_a:0",
+        "testGetCollection_foo/testGetCollection_a:0"
+    ])
+
+  @test_util.run_in_graph_and_eager_modes()
   def testGetTrainableVariables(self):
-    with self.test_session():
-      _ = variable_scope.get_variable("a", [])
-      with variable_scope.variable_scope("foo") as scope:
-        _ = variable_scope.get_variable("b", [])
-        _ = variable_scope.get_variable("c", [], trainable=False)
-        self.assertEqual([v.name
-                          for v in scope.trainable_variables()], ["foo/b:0"])
-
+    _ = variable_scope.get_variable("testGetTrainableVariables_a", [])
+    with variable_scope.variable_scope(
+        "testGetTrainableVariables_foo") as scope:
+      _ = variable_scope.get_variable("testGetTrainableVariables_b", [])
+      _ = variable_scope.get_variable("testGetTrainableVariables_c", [],
+                                      trainable=False)
+      self.assertEqual([v.name
+                        for v in scope.trainable_variables()],
+                       ["testGetTrainableVariables_foo/"
+                        "testGetTrainableVariables_b:0"])
+
+  @test_util.run_in_graph_and_eager_modes()
   def testGetGlobalVariables(self):
-    with self.test_session():
-      _ = variable_scope.get_variable("a", [])
-      with variable_scope.variable_scope("foo") as scope:
-        _ = variable_scope.get_variable("b", [])
-        self.assertEqual([v.name
-                          for v in scope.global_variables()], ["foo/b:0"])
-
+    _ = variable_scope.get_variable("testGetGlobalVariables_a", [])
+    with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope:
+      _ = variable_scope.get_variable("testGetGlobalVariables_b", [])
+      self.assertEqual([v.name
+                        for v in scope.global_variables()],
+                       ["testGetGlobalVariables_foo/"
+                        "testGetGlobalVariables_b:0"])
+
+  @test_util.run_in_graph_and_eager_modes()
   def testGetLocalVariables(self):
-    with self.test_session():
+    _ = variable_scope.get_variable(
+        "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
+    with variable_scope.variable_scope("foo") as scope:
       _ = variable_scope.get_variable(
-          "a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
-      with variable_scope.variable_scope("foo") as scope:
-        _ = variable_scope.get_variable(
-            "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
-        _ = variable_scope.get_variable(
-            "c", [])
-        self.assertEqual([v.name
-                          for v in scope.local_variables()], ["foo/b:0"])
+          "b", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
+      _ = variable_scope.get_variable(
+          "c", [])
+      self.assertEqual([v.name
+                        for v in scope.local_variables()], ["foo/b:0"])
 
   def testGetVariableWithRefDtype(self):
     v = variable_scope.get_variable("v", shape=[3, 4], dtype=dtypes.float32)
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 414fb0c699edde4373deaeda66fe65e796590a86..7718710c690555e183150347b8cc125af1365f6b 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -44,12 +44,14 @@ class VariablesTestCase(test.TestCase):
     with self.test_session():
       var0 = variables.Variable(0.0)
       self.assertEqual("Variable:0", var0.name)
+      self.assertEqual("Variable", var0._shared_name)
       self.assertEqual([], var0.get_shape())
       self.assertEqual([], var0.get_shape())
       self.assertEqual([], var0.shape)
 
       var1 = variables.Variable(1.1)
       self.assertEqual("Variable_1:0", var1.name)
+      self.assertEqual("Variable_1", var1._shared_name)
       self.assertEqual([], var1.get_shape())
       self.assertEqual([], var1.get_shape())
       self.assertEqual([], var1.shape)
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 5f21a2bdfa2346232222bd92898ea6cf63456e69..ab5557672fc9e6285e46ad7f9a154badfacd7e63 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -33,6 +33,7 @@ from six.moves import xrange  # pylint: disable=redefined-builtin
 import numpy as np
 import six
 
+from tensorflow.python.eager import context
 from tensorflow.python.estimator import util as estimator_util
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import dtypes
@@ -102,6 +103,8 @@ class Layer(object):
     self._per_input_updates = {}
     self.dtype = dtypes.as_dtype(dtype).name
     self.input_spec = None
+    self._compute_previous_mask = ('mask' in estimator_util.fn_args(self.call)
+                                   or hasattr(self, 'compute_mask'))
 
     # These lists will be filled via successive calls
     # to self._add_inbound_node().
@@ -181,6 +184,8 @@ class Layer(object):
 
   @property
   def updates(self):
+    if context.in_eager_mode():
+      raise RuntimeError('Layer.updates not supported in Eager mode.')
     return self._updates
 
   def add_update(self, updates, inputs=None):
@@ -202,7 +207,12 @@ class Layer(object):
         match the `inputs` argument passed to the `__call__` method at the time
         the updates are created. If `None` is passed, the updates are assumed
         to be unconditional, and will apply across all dataflows of the layer.
+
+    Raises:
+      RuntimeError: If called in Eager mode.
     """
+    if context.in_eager_mode():
+      raise RuntimeError('Layer.add_update not supported in Eager mode.')
     updates = _to_list(updates)
     if not updates:
       return
@@ -232,7 +242,12 @@ class Layer(object):
 
     Returns:
       List of update ops of the layer that depend on `inputs`.
+
+    Raises:
+      RuntimeError: If called in Eager mode.
     """
+    if context.in_eager_mode():
+      raise RuntimeError('Layer.get_updates_for not supported in Eager mode.')
     if inputs is not None:
       inputs = _to_list(inputs)
     if not inputs:
@@ -245,6 +260,8 @@ class Layer(object):
 
   @property
   def losses(self):
+    if context.in_eager_mode():
+      raise RuntimeError('Layer.losses not supported in Eager mode.')
     return self._losses
 
   def add_loss(self, losses, inputs=None):
@@ -266,7 +283,12 @@ class Layer(object):
         the losses are created. If `None` is passed, the losses are assumed
         to be unconditional, and will apply across all dataflows of the layer
         (e.g. weight regularization losses).
+
+    Raises:
+      RuntimeError: If called in Eager mode.
     """
+    if context.in_eager_mode():
+      raise RuntimeError('Layer.add_loss not supported in Eager mode.')
     losses = _to_list(losses)
     if not losses:
       return
@@ -297,7 +319,12 @@ class Layer(object):
 
     Returns:
       List of loss tensors of the layer that depend on `inputs`.
+
+    Raises:
+      RuntimeError: If called in Eager mode.
     """
+    if context.in_eager_mode():
+      raise RuntimeError('Layer.get_losses_for not supported in Eager mode.')
     if inputs is not None:
       inputs = _to_list(inputs)
     if not inputs:
@@ -376,15 +403,24 @@ class Layer(object):
 
     Returns:
       The created variable.
+
+    Raises:
+      RuntimeError: If called in Eager mode with regularizers.
     """
+    # Note that we currently don't support variable regularization in Eager
+    # mode. An alternative is for users to directly compute these losses before
+    # performing a backward pass.
+    if regularizer is not None and context.in_eager_mode():
+      raise RuntimeError('Variable regularization not supported in Eager mode.')
     if dtype is None:
       dtype = self.dtype
     existing_variables = set(tf_variables.global_variables())
 
     self._set_scope(None)
 
-    with vs.variable_scope(self._scope,
-                           reuse=self.built or self._reuse) as scope:
+    vs_reuse = ((self.built or self._reuse)
+                if context.in_graph_mode() else vs.AUTO_REUSE)
+    with vs.variable_scope(self._scope, reuse=vs_reuse) as scope:
       with ops.name_scope(scope.original_name_scope):
         variable = vs.get_variable(name,
                                    shape=shape,
@@ -446,29 +482,48 @@ class Layer(object):
     """
     self._set_scope(kwargs.pop('scope', None))
 
+    in_graph_mode = context.in_graph_mode()
     # Ensure the Layer, if being reused, is working with inputs from
     # the same graph as where it was created.
-    try:
-      ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph)  # pylint: disable=protected-access
-    except ValueError as e:
-      raise ValueError('Input graph and Layer graph are not the same: %s' % e)
+    if in_graph_mode:
+      try:
+        ops._get_graph_from_inputs(nest.flatten(inputs), graph=self.graph)  # pylint: disable=protected-access
+      except ValueError as e:
+        raise ValueError('Input graph and Layer graph are not the same: %s' % e)
+      user_kwargs = copy.copy(kwargs)
 
     # Handle Keras mask propagation from previous layer to current layer.
-    previous_mask = _collect_previous_mask(inputs)
-    user_kwargs = copy.copy(kwargs)
-    if not _is_all_none(previous_mask):
-      # The previous layer generated a mask, which we should use as default
-      # for any "mask" argument in `call`.
-      if 'mask' in estimator_util.fn_args(self.call):
-        if 'mask' not in kwargs:
-          # If mask is explicitly passed to __call__,
-          # we should override the default mask.
-          kwargs['mask'] = previous_mask
-
-    with vs.variable_scope(self._scope,
-                           reuse=self.built or self._reuse) as scope:
+    previous_mask = None
+    if (not hasattr(self, '_compute_previous_mask') or
+        self._compute_previous_mask):
+      previous_mask = _collect_previous_mask(inputs)
+      if ('mask' in estimator_util.fn_args(self.call) and
+          'mask' not in kwargs and
+          not _is_all_none(previous_mask)):
+        # The previous layer generated a mask, and mask was not explicitly pass
+        # to __call__, hence we set previous_mask as the default value.
+        kwargs['mask'] = previous_mask
+
+    vs_reuse = ((self.built or self._reuse)
+                if context.in_graph_mode else vs.AUTO_REUSE)
+    with vs.variable_scope(self._scope, reuse=vs_reuse) as scope:
       with ops.name_scope(scope.original_name_scope):
         if not self.built:
+          if not in_graph_mode:
+            # Activity regularization is unsupported in Eager mode.
+            if hasattr(self,
+                       'activity_regularizer') and self.activity_regularizer:
+              raise ValueError('activity_regularizer currently unsupported in '
+                               'Eager mode. Found an activity_regularizer in '
+                               '%s(%s).' % (self.__class__.__name__, self))
+            # TODO(agarwal): support _keras_history in Eager mode.
+            for x in _to_list(inputs):
+              if hasattr(x, '_keras_history'):
+                raise ValueError('_keras_history currently unsupported in '
+                                 'Eager mode. Found _keras_history in %s while '
+                                 'executing __call__ for %s(%s)' %
+                                 (x, self.__class_.__name__, self))
+
           # Check input assumptions set before layer building, e.g. input rank.
           self._assert_input_compatibility(inputs)
           input_list = nest.flatten(inputs)
@@ -480,24 +535,27 @@ class Layer(object):
         if 'scope' in estimator_util.fn_args(self.call):
           kwargs['scope'] = scope
         # Check input assumptions set after layer building, e.g. input shape.
-        self._assert_input_compatibility(inputs)
+        if in_graph_mode:
+          self._assert_input_compatibility(inputs)
         outputs = self.call(inputs, *args, **kwargs)
 
         if outputs is None:
           raise ValueError('A layer\'s `call` method should return a Tensor '
                            'or a list of Tensors, not None.')
 
-        # Apply activity regularization.
-        # Note that it should be applied every time the layer creates a new
-        # output, since it is output-specific.
-        if hasattr(self, 'activity_regularizer') and self.activity_regularizer:
-          output_list = _to_list(outputs)
-          for output in output_list:
-            with ops.name_scope('ActivityRegularizer'):
-              activity_regularization = self.activity_regularizer(output)
-            self.add_loss(activity_regularization)
-            _add_elements_to_collection(
-                activity_regularization, ops.GraphKeys.REGULARIZATION_LOSSES)
+        if in_graph_mode:
+          # Apply activity regularization.
+          # Note that it should be applied every time the layer creates a new
+          # output, since it is output-specific.
+          if hasattr(self,
+                     'activity_regularizer') and self.activity_regularizer:
+            output_list = _to_list(outputs)
+            for output in output_list:
+              with ops.name_scope('ActivityRegularizer'):
+                activity_regularization = self.activity_regularizer(output)
+              self.add_loss(activity_regularization)
+              _add_elements_to_collection(activity_regularization,
+                                          ops.GraphKeys.REGULARIZATION_LOSSES)
 
         # Handle mask computation and propagation to the next layer.
         if hasattr(self, 'compute_mask'):
@@ -510,40 +568,42 @@ class Layer(object):
           else:
             outputs._keras_mask = output_mask  # pylint: disable=protected-access
 
-    # If all input tensors have history metadata,
-    # we update the output tensors
-    # with corresponding history metadata, thus eventually allowing to use
-    # these tensors to instantiate a Network.
-    if _have_all_keras_metadata(inputs):
-      # If the layer returns tensors from its inputs, unmodified,
-      # we copy them to avoid loss of tensor metadata.
-      output_ls = _to_list(outputs)
-      inputs_ls = _to_list(inputs)
-      output_ls_copy = []
-      for x in output_ls:
-        if x in inputs_ls:
-          with ops.name_scope(scope.original_name_scope):
-            x = array_ops.identity(x)
-        output_ls_copy.append(x)
-      if len(output_ls_copy) == 1:
-        outputs = output_ls_copy[0]
-      else:
-        outputs = output_ls_copy
+    if in_graph_mode:
+      # If all input tensors have history metadata,
+      # we update the output tensors
+      # with corresponding history metadata, thus eventually allowing to use
+      # these tensors to instantiate a Network.
+      if _have_all_keras_metadata(inputs):
+        # If the layer returns tensors from its inputs, unmodified,
+        # we copy them to avoid loss of tensor metadata.
+        output_ls = _to_list(outputs)
+        inputs_ls = _to_list(inputs)
+        output_ls_copy = []
+        for x in output_ls:
+          if x in inputs_ls:
+            with ops.name_scope(scope.original_name_scope):
+              x = array_ops.identity(x)
+          output_ls_copy.append(x)
+        if len(output_ls_copy) == 1:
+          outputs = output_ls_copy[0]
+        else:
+          outputs = output_ls_copy
 
-      # Add an inbound node to the layer, so it can keep track of this call.
-      # This updates the layer history of the output tensor(s).
-      self._add_inbound_node(
-          input_tensors=inputs,
-          output_tensors=outputs,
-          arguments=user_kwargs)
+        # Add an inbound node to the layer, so it can keep track of this call.
+        # This updates the layer history of the output tensor(s).
+        self._add_inbound_node(
+            input_tensors=inputs, output_tensors=outputs, arguments=user_kwargs)
+
+      # Update global default collections.
+      _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
 
-    # Update global default collections.
-    _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
     self.built = True
     return outputs
 
   @property
   def graph(self):
+    if context.in_eager_mode():
+      raise RuntimeError('Layer.graph not supported in Eager mode.')
     return self._graph
 
   def __deepcopy__(self, memo):
@@ -590,6 +650,7 @@ class Layer(object):
         arguments: dictionary of keyword arguments that were passed to the
             `call` method of the layer at the call that created the node.
     """
+    assert context.in_graph_mode()
     input_tensors = _to_list(input_tensors)
     output_tensors = _to_list(output_tensors)
 
@@ -642,9 +703,11 @@ class Layer(object):
         The layer's attribute `attr` at the node of index `node_index`.
 
     Raises:
-        RuntimeError: If the layer has no inbound nodes.
+        RuntimeError: If the layer has no inbound nodes, or if called in Eager
+        mode.
         ValueError: If the index provided does not match any node.
     """
+    assert context.in_graph_mode()
     if not self.inbound_nodes:
       raise RuntimeError('The layer has never been called '
                          'and thus has no defined ' + attr_name + '.')
@@ -670,7 +733,13 @@ class Layer(object):
     Returns:
         A shape tuple
         (or list of shape tuples if the layer has multiple inputs).
+
+    Raises:
+      RuntimeError: If called in Eager mode.
     """
+    if context.in_eager_mode():
+      raise RuntimeError(
+          'Layer.get_input_shape_at not supported in Eager mode.')
     return self._get_node_attribute_at_index(node_index, 'input_shapes',
                                              'input shape')
 
@@ -686,7 +755,13 @@ class Layer(object):
     Returns:
         A shape tuple
         (or list of shape tuples if the layer has multiple outputs).
+
+    Raises:
+      RuntimeError: If called in Eager mode.
     """
+    if context.in_eager_mode():
+      raise RuntimeError(
+          'Layer.get_output_shape_at not supported in Eager mode.')
     return self._get_node_attribute_at_index(node_index, 'output_shapes',
                                              'output shape')
 
@@ -701,7 +776,12 @@ class Layer(object):
 
     Returns:
         A tensor (or list of tensors if the layer has multiple inputs).
+
+    Raises:
+      RuntimeError: If called in Eager mode.
     """
+    if context.in_eager_mode():
+      raise RuntimeError('Layer.get_input_at not supported in Eager mode.')
     return self._get_node_attribute_at_index(node_index, 'input_tensors',
                                              'input')
 
@@ -716,7 +796,12 @@ class Layer(object):
 
     Returns:
         A tensor (or list of tensors if the layer has multiple outputs).
+
+    Raises:
+      RuntimeError: If called in Eager mode.
     """
+    if context.in_eager_mode():
+      raise RuntimeError('Layer.get_output_at not supported in Eager mode.')
     return self._get_node_attribute_at_index(node_index, 'output_tensors',
                                              'output')
 
@@ -733,7 +818,13 @@ class Layer(object):
     Raises:
         AttributeError: if the layer is connected to
         more than one incoming layers.
+
+    Raises:
+      RuntimeError: If called in Eager mode.
+      AttributeError: If no inbound nodes are found.
     """
+    if context.in_eager_mode():
+      raise RuntimeError('Layer.input not supported in Eager mode.')
     if not self.inbound_nodes:
       raise AttributeError('Layer ' + self.name +
                            ' is not connected, no input to return.')
@@ -747,12 +838,15 @@ class Layer(object):
     i.e. if it is connected to one incoming layer.
 
     Returns:
-        Output tensor or list of output tensors.
+      Output tensor or list of output tensors.
 
     Raises:
-        AttributeError: if the layer is connected to
-        more than one incoming layers.
+      AttributeError: if the layer is connected to more than one incoming
+        layers.
+      RuntimeError: if called in Eager mode.
     """
+    if context.in_eager_mode():
+      raise RuntimeError('Layer.output not supported in Eager mode.')
     if not self.inbound_nodes:
       raise AttributeError('Layer ' + self.name + ' has no inbound nodes.')
     return self._get_node_attribute_at_index(0, 'output_tensors', 'output')
@@ -771,7 +865,10 @@ class Layer(object):
 
     Raises:
         AttributeError: if the layer has no defined input_shape.
+        RuntimeError: if called in Eager mode.
     """
+    if context.in_eager_mode():
+      raise RuntimeError('Layer.input_shape not supported in Eager mode.')
     if not self.inbound_nodes:
       raise AttributeError('The layer has never been called '
                            'and thus has no defined input shape.')
@@ -829,7 +926,10 @@ class Layer(object):
 
     Raises:
         AttributeError: if the layer has no defined output shape.
+        RuntimeError: if called in Eager mode.
     """
+    if context.in_eager_mode():
+      raise RuntimeError('Layer.output_shape not supported in Eager mode.')
     if not self.inbound_nodes:
       raise AttributeError('The layer has never been called '
                            'and thus has no defined output shape.')
@@ -897,8 +997,8 @@ class Layer(object):
         if ndim != spec.ndim:
           raise ValueError('Input ' + str(input_index) + ' of layer ' +
                            self.name + ' is incompatible with the layer: '
-                           'expected ndim=' + str(spec.ndim) + ', found ndim='
-                           + str(ndim) + '. Full shape received: ' +
+                           'expected ndim=' + str(spec.ndim) + ', found ndim=' +
+                           str(ndim) + '. Full shape received: ' +
                            str(x.get_shape().as_list()))
       if spec.max_ndim is not None:
         ndim = x.get_shape().ndims
@@ -1115,6 +1215,9 @@ class InputLayer(Layer):
       sparse: Boolean, whether the placeholder created
           is meant to be sparse.
       name: Name of the layer (string).
+
+    Raises:
+      RuntimeError: If created in Eager mode.
   """
 
   def __init__(self,
@@ -1124,6 +1227,8 @@ class InputLayer(Layer):
                input_tensor=None,
                sparse=False,
                name=None):
+    if context.in_eager_mode():
+      raise RuntimeError('InputLayer not supported in Eager mode.')
     super(InputLayer, self).__init__(dtype=dtype, name=name)
     self.built = True
     self.sparse = sparse
@@ -1210,7 +1315,12 @@ def Input(  # pylint: disable=invalid-name
   Returns:
       A tensor: either a new placeholder (with history metadata) or
       `tensor` (if passed), with added history metadata.
+
+  Raises:
+    RuntimeError: If called in Eager mode.
   """
+  if context.in_eager_mode():
+    raise RuntimeError('Input not supported in Eager mode.')
   input_layer = InputLayer(
       input_shape=shape,
       batch_size=batch_size,
@@ -1268,9 +1378,15 @@ class Network(Layer):
   Methods:
     Network has the same methods as Layer. On top of it, it also has:
       - get_layer: retrieves a child layer by name or index in the graph.
+
+  Raises:
+    RuntimeError: If created in Eager mode.
   """
 
   def __init__(self, inputs, outputs, name=None):  # pylint: disable=super-init-not-called
+    # TODO(agarwal): Make Network work in Eager mode.
+    if context.in_eager_mode():
+      raise RuntimeError('Network not supported in Eager mode.')
     # Set layer name and scope
     if isinstance(name, vs.VariableScope):
       base_name = name.name
@@ -1282,6 +1398,8 @@ class Network(Layer):
       self.name = _unique_layer_name(base_name)
     self._scope = next(vs.variable_scope(None, default_name=base_name).gen)
     self._base_name = base_name
+    self._compute_previous_mask = ('mask' in estimator_util.fn_args(self.call)
+                                   or hasattr(self, 'compute_mask'))
 
     # This acts just like the `trainable` attribute of any layer instance.
     # It does not affect users of the underlying layers, only users of the
@@ -2033,6 +2151,10 @@ def _to_list(x):
 
 
 def _add_elements_to_collection(elements, collection_list):
+  if context.in_eager_mode():
+    raise RuntimeError('Using collections from Layers not supported in Eager '
+                       'mode. Tried to add %s to %s' % (elements,
+                                                        collection_list))
   elements = _to_list(elements)
   collection_list = _to_list(collection_list)
   for name in collection_list:
@@ -2102,6 +2224,7 @@ def _collect_previous_mask(input_tensors):
     return masks[0]
   return masks
 
+
 # A global dictionary mapping graph objects to an index of counters used
 # for various layer names in each graph.
 # Allows to give unique autogenerated names to layers, in a graph-specific way.
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index a1af153e5514c0f2251032042dcb7208e9566ecc..ced5466bebf03c452fdbfb8df16bcd22ec3f805b 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -20,7 +20,11 @@ from __future__ import print_function
 
 import copy
 
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
 from tensorflow.python.layers import base as base_layers
 from tensorflow.python.layers import core as core_layers
 from tensorflow.python.ops import array_ops
@@ -34,45 +38,48 @@ from tensorflow.python.platform import test
 
 class BaseLayerTest(test.TestCase):
 
+  @test_util.run_in_graph_and_eager_modes()
   def testLayerProperties(self):
     layer = base_layers.Layer(name='my_layer')
     self.assertListEqual(layer.variables, [])
     self.assertListEqual(layer.trainable_variables, [])
     self.assertListEqual(layer.non_trainable_variables, [])
-    self.assertListEqual(layer.updates, [])
-    self.assertListEqual(layer.losses, [])
+    if context.in_graph_mode():
+      # updates, losses only suppported in GRAPH mode
+      self.assertListEqual(layer.updates, [])
+      self.assertListEqual(layer.losses, [])
     self.assertEqual(layer.built, False)
     layer = base_layers.Layer(name='my_layer', trainable=False)
     self.assertEqual(layer.trainable, False)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testAddWeight(self):
-    with self.test_session():
-      layer = base_layers.Layer(name='my_layer')
+    layer = base_layers.Layer(name='my_layer')
 
-      # Test basic variable creation.
-      variable = layer.add_variable(
-          'my_var', [2, 2], initializer=init_ops.zeros_initializer())
-      self.assertEqual(variable.name, 'my_layer/my_var:0')
-      self.assertListEqual(layer.variables, [variable])
-      self.assertListEqual(layer.trainable_variables, [variable])
-      self.assertListEqual(layer.non_trainable_variables, [])
-      self.assertListEqual(
-          layer.variables,
-          ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
-
-      # Test non-trainable variable creation.
-      # layer.add_variable should work even outside `build` and `call`.
-      variable_2 = layer.add_variable(
-          'non_trainable_var', [2, 2],
-          initializer=init_ops.zeros_initializer(),
-          trainable=False)
-      self.assertListEqual(layer.variables, [variable, variable_2])
-      self.assertListEqual(layer.trainable_variables, [variable])
-      self.assertListEqual(layer.non_trainable_variables, [variable_2])
-      self.assertEqual(
-          len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
-
-      # Test with regularizer.
+    # Test basic variable creation.
+    variable = layer.add_variable(
+        'my_var', [2, 2], initializer=init_ops.zeros_initializer())
+    self.assertEqual(variable.name, 'my_layer/my_var:0')
+    self.assertListEqual(layer.variables, [variable])
+    self.assertListEqual(layer.trainable_variables, [variable])
+    self.assertListEqual(layer.non_trainable_variables, [])
+    self.assertListEqual(layer.variables,
+                         ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
+
+    # Test non-trainable variable creation.
+    # layer.add_variable should work even outside `build` and `call`.
+    variable_2 = layer.add_variable(
+        'non_trainable_var', [2, 2],
+        initializer=init_ops.zeros_initializer(),
+        trainable=False)
+    self.assertListEqual(layer.variables, [variable, variable_2])
+    self.assertListEqual(layer.trainable_variables, [variable])
+    self.assertListEqual(layer.non_trainable_variables, [variable_2])
+    self.assertEqual(
+        len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
+
+    if context.in_graph_mode():
+      # regularizers only supported in GRAPH mode.
       regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
       variable = layer.add_variable(
           'reg_var', [2, 2],
@@ -80,69 +87,63 @@ class BaseLayerTest(test.TestCase):
           regularizer=regularizer)
       self.assertEqual(len(layer.losses), 1)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testGetVariable(self):
-    with self.test_session():
 
-      class MyLayer(base_layers.Layer):
+    class MyLayer(base_layers.Layer):
 
-        def build(self, input_shape):
-          self.my_var = self.add_variable(
-              'my_var', [2, 2], initializer=init_ops.zeros_initializer())
+      def build(self, input_shape):
+        self.my_var = self.add_variable(
+            'my_var', [2, 2], initializer=init_ops.zeros_initializer())
 
-        def call(self, inputs):
-          return inputs * 2
+      def call(self, inputs):
+        return inputs * 2
 
-      layer = MyLayer(name='my_layer')
-      inputs = random_ops.random_uniform((5,), seed=1)
-      layer.apply(inputs)
-      layer.apply(inputs)
-      self.assertListEqual([v.name for v in layer.variables],
-                           ['my_layer/my_var:0'])
-
-      # Creating a layer with no scope leads to lazy construction of
-      # the scope at apply() time.  It uses scope "/base_name"
-      lazy_layer = MyLayer(_reuse=True)
-      with variable_scope.variable_scope('new_scope'):
-        # This should attempt to reuse 'my_var' in 'new_scope'
-        with self.assertRaisesRegexp(
-            ValueError, r'new_scope/my_layer/my_var does not exist'):
-          lazy_layer.apply(inputs)
-        with variable_scope.variable_scope('my_layer'):
-          variable_scope.get_variable('my_var', [2, 2])
-
-        # Smoke test: it runs.
-        lazy_layer.apply(inputs)
-        # The variables were created outside of the Layer, and
-        # reuse=True, so the Layer does not own them and they are not
-        # stored in its collection.
-        self.assertListEqual(lazy_layer.variables, [])
-        self.assertEqual(lazy_layer._scope.name, 'new_scope/my_layer')
-
-      # Creating a layer with no scope leads to lazy construction of
-      # the scope at apply() time.  If 'scope' argument is passed to
-      # apply(), it uses that scope when accessing variables.
-      lazy_layer = MyLayer(_reuse=True)
-      with variable_scope.variable_scope('new_scope') as new_scope:
-        # This should attempt to reuse 'my_var' in 'new_scope'
-        with self.assertRaisesRegexp(
-            ValueError, r'new_scope/my_var does not exist'):
-          lazy_layer.apply(inputs, scope=new_scope)
+    layer = MyLayer(name='my_layer')
+    inputs = random_ops.random_uniform((5,), seed=1)
+    layer.apply(inputs)
+    layer.apply(inputs)
+    self.assertListEqual([v.name for v in layer.variables],
+                         ['my_layer/my_var:0'])
+
+    # Creating a layer with no scope leads to lazy construction of
+    # the scope at apply() time.  It uses scope "/base_name"
+    lazy_layer = MyLayer(_reuse=True)
+    with variable_scope.variable_scope('new_scope'):
+      with variable_scope.variable_scope('my_layer'):
         variable_scope.get_variable('my_var', [2, 2])
 
-        # Smoke test: it runs.
-        lazy_layer.apply(inputs, scope=new_scope)
-        # The variables were created outside of the Layer, and
-        # reuse=True, so the Layer does not own them and they are not
-        # stored in its collection.
-        self.assertListEqual(lazy_layer.variables, [])
-        self.assertEqual(lazy_layer._scope.name, 'new_scope')
-
+      # Smoke test: it runs.
+      lazy_layer.apply(inputs)
+      # The variables were created outside of the Layer, and
+      # reuse=True, so the Layer does not own them and they are not
+      # stored in its collection.
+      self.assertListEqual(lazy_layer.variables, [])
+      self.assertEqual(lazy_layer._scope.name, 'new_scope/my_layer')
+
+    # Creating a layer with no scope leads to lazy construction of
+    # the scope at apply() time. If 'scope' argument is passed to
+    # apply(), it uses that scope when accessing variables.
+    lazy_layer = MyLayer(_reuse=True)
+    with variable_scope.variable_scope('new_scope') as new_scope:
+      variable_scope.get_variable('my_var', [2, 2])
+
+      # Smoke test: it runs.
+      lazy_layer.apply(inputs, scope=new_scope)
+      # The variables were created outside of the Layer, and
+      # reuse=True, so the Layer does not own them and they are not
+      # stored in its collection.
+      self.assertListEqual(lazy_layer.variables, [])
+      self.assertEqual(lazy_layer._scope.name, 'new_scope')
+
+    if context.in_graph_mode():
+      # Checking for graph equality is only done in GRAPH mode.
       with ops.Graph().as_default():
         inputs_ng = random_ops.random_uniform((5,), seed=1)
-        with self.assertRaisesRegexp(ValueError,
-                                     r'graph are not the same'):
+        with self.assertRaisesRegexp(ValueError, r'graph are not the same'):
           layer.apply(inputs_ng)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testCall(self):
 
     class MyLayer(base_layers.Layer):
@@ -154,9 +155,13 @@ class BaseLayerTest(test.TestCase):
     inputs = random_ops.random_uniform((5,), seed=1)
     outputs = layer.apply(inputs)
     self.assertEqual(layer.built, True)
-    self.assertEqual(outputs.op.name, 'my_layer/Square')
+    if context.in_graph_mode():
+      # op is only supported in GRAPH mode
+      self.assertEqual(outputs.op.name, 'my_layer/Square')
 
   def testFirstCallCanCreateVariablesButSecondCanNotWhenBuildEmpty(self):
+    # Note that this test is only run in Graph mode since with EAGER mode we can
+    # still create a new variable on second call.
 
     class MyLayer(base_layers.Layer):
 
@@ -177,15 +182,16 @@ class BaseLayerTest(test.TestCase):
     outputs = layer.apply(inputs)
     self.assertEqual(layer.built, True)
     self.assertEqual(outputs.op.name, 'my_layer/add')
-    self.assertListEqual(
-        [v.name for v in layer.variables], ['my_layer/my_var:0'])
+    self.assertListEqual([v.name
+                          for v in layer.variables], ['my_layer/my_var:0'])
     with self.assertRaisesRegexp(ValueError,
                                  'my_layer/this_will_break_on_second_call'):
       layer.apply(inputs)
     # The list of variables hasn't changed.
-    self.assertListEqual(
-        [v.name for v in layer.variables], ['my_layer/my_var:0'])
+    self.assertListEqual([v.name
+                          for v in layer.variables], ['my_layer/my_var:0'])
 
+  @test_util.run_in_graph_and_eager_modes()
   def testDeepCopy(self):
 
     class MyLayer(base_layers.Layer):
@@ -198,7 +204,9 @@ class BaseLayerTest(test.TestCase):
     inputs = random_ops.random_uniform((5,), seed=1)
     outputs = layer.apply(inputs)
     self.assertEqual(layer.built, True)
-    self.assertEqual(outputs.op.name, 'my_layer/Square')
+    if context.in_graph_mode():
+      # op only supported in GRAPH mode.
+      self.assertEqual(outputs.op.name, 'my_layer/Square')
 
     layer_copy = copy.deepcopy(layer)
     self.assertEqual(layer_copy.name, layer.name)
@@ -206,6 +214,7 @@ class BaseLayerTest(test.TestCase):
     self.assertEqual(layer_copy._graph, layer._graph)
     self.assertEqual(layer_copy._private_tensor, layer._private_tensor)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testScopeNaming(self):
 
     class PrivateLayer(base_layers.Layer):
@@ -253,6 +262,7 @@ class BaseLayerTest(test.TestCase):
       my_layer_scoped1.apply(inputs)
       self.assertEqual(my_layer_scoped1._scope.name, 'var_scope/my_layer_1')
 
+  @test_util.run_in_graph_and_eager_modes()
   def testInputSpecNdimCheck(self):
 
     class CustomerLayer(base_layers.Layer):
@@ -264,18 +274,22 @@ class BaseLayerTest(test.TestCase):
       def call(self, inputs):
         return inputs
 
-    layer = CustomerLayer()
-    with self.assertRaisesRegexp(ValueError,
-                                 r'requires a defined rank'):
-      layer.apply(array_ops.placeholder('int32'))
+    if context.in_graph_mode():
+      layer = CustomerLayer()
+      with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
+        layer.apply(array_ops.placeholder('int32'))
 
-    with self.assertRaisesRegexp(ValueError,
-                                 r'expected ndim=2'):
-      layer.apply(array_ops.placeholder('int32', shape=(None,)))
+    layer = CustomerLayer()
+    with self.assertRaisesRegexp(ValueError, r'expected ndim=2'):
+      layer.apply(constant_op.constant([1]))
 
+    # Note that we re-create the layer since in Eager mode, input spec checks
+    # only happen on first call.
     # Works
-    layer.apply(array_ops.placeholder('int32', shape=(None, None)))
+    layer = CustomerLayer()
+    layer.apply(constant_op.constant([[1], [2]]))
 
+  @test_util.run_in_graph_and_eager_modes()
   def testInputSpecMinNdimCheck(self):
 
     class CustomerLayer(base_layers.Layer):
@@ -287,19 +301,23 @@ class BaseLayerTest(test.TestCase):
       def call(self, inputs):
         return inputs
 
-    layer = CustomerLayer()
-    with self.assertRaisesRegexp(ValueError,
-                                 r'requires a defined rank'):
-      layer.apply(array_ops.placeholder('int32'))
+    if context.in_graph_mode():
+      layer = CustomerLayer()
+      with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
+        layer.apply(array_ops.placeholder('int32'))
 
-    with self.assertRaisesRegexp(ValueError,
-                                 r'expected min_ndim=2'):
-      layer.apply(array_ops.placeholder('int32', shape=(None,)))
+    layer = CustomerLayer()
+    with self.assertRaisesRegexp(ValueError, r'expected min_ndim=2'):
+      layer.apply(constant_op.constant([1]))
 
     # Works
-    layer.apply(array_ops.placeholder('int32', shape=(None, None)))
-    layer.apply(array_ops.placeholder('int32', shape=(None, None, None)))
+    layer = CustomerLayer()
+    layer.apply(constant_op.constant([[1], [2]]))
+
+    layer = CustomerLayer()
+    layer.apply(constant_op.constant([[[1], [2]]]))
 
+  @test_util.run_in_graph_and_eager_modes()
   def testInputSpecMaxNdimCheck(self):
 
     class CustomerLayer(base_layers.Layer):
@@ -311,19 +329,23 @@ class BaseLayerTest(test.TestCase):
       def call(self, inputs):
         return inputs
 
-    layer = CustomerLayer()
-    with self.assertRaisesRegexp(ValueError,
-                                 r'requires a defined rank'):
-      layer.apply(array_ops.placeholder('int32'))
+    if context.in_graph_mode():
+      layer = CustomerLayer()
+      with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
+        layer.apply(array_ops.placeholder('int32'))
 
-    with self.assertRaisesRegexp(ValueError,
-                                 r'expected max_ndim=2'):
-      layer.apply(array_ops.placeholder('int32', shape=(None, None, None)))
+    layer = CustomerLayer()
+    with self.assertRaisesRegexp(ValueError, r'expected max_ndim=2'):
+      layer.apply(constant_op.constant([[[1], [2]]]))
 
     # Works
-    layer.apply(array_ops.placeholder('int32', shape=(None, None)))
-    layer.apply(array_ops.placeholder('int32', shape=(None,)))
+    layer = CustomerLayer()
+    layer.apply(constant_op.constant([1]))
 
+    layer = CustomerLayer()
+    layer.apply(constant_op.constant([[1], [2]]))
+
+  @test_util.run_in_graph_and_eager_modes()
   def testInputSpecDtypeCheck(self):
 
     class CustomerLayer(base_layers.Layer):
@@ -336,13 +358,14 @@ class BaseLayerTest(test.TestCase):
         return inputs
 
     layer = CustomerLayer()
-    with self.assertRaisesRegexp(ValueError,
-                                 r'expected dtype=float32'):
-      layer.apply(array_ops.placeholder('int32'))
+    with self.assertRaisesRegexp(ValueError, r'expected dtype=float32'):
+      layer.apply(constant_op.constant(1, dtype=dtypes.int32))
 
     # Works
-    layer.apply(array_ops.placeholder('float32', shape=(None, None)))
+    layer = CustomerLayer()
+    layer.apply(constant_op.constant(1.0, dtype=dtypes.float32))
 
+  @test_util.run_in_graph_and_eager_modes()
   def testInputSpecAxesCheck(self):
 
     class CustomerLayer(base_layers.Layer):
@@ -355,14 +378,16 @@ class BaseLayerTest(test.TestCase):
         return inputs
 
     layer = CustomerLayer()
-    with self.assertRaisesRegexp(ValueError,
-                                 r'expected axis'):
-      layer.apply(array_ops.placeholder('int32', shape=(None, 3)))
+    with self.assertRaisesRegexp(ValueError, r'expected axis'):
+      layer.apply(constant_op.constant([1, 2, 3]))
 
     # Works
-    layer.apply(array_ops.placeholder('int32', shape=(None, None, 2)))
-    layer.apply(array_ops.placeholder('int32', shape=(None, 2)))
+    layer = CustomerLayer()
+    layer.apply(constant_op.constant([1, 2]))
+    layer = CustomerLayer()
+    layer.apply(constant_op.constant([[1, 2], [3, 4], [5, 6]]))
 
+  @test_util.run_in_graph_and_eager_modes()
   def testInputSpecShapeCheck(self):
 
     class CustomerLayer(base_layers.Layer):
@@ -375,14 +400,14 @@ class BaseLayerTest(test.TestCase):
         return inputs
 
     layer = CustomerLayer()
-    with self.assertRaisesRegexp(ValueError,
-                                 r'expected shape'):
-      layer.apply(array_ops.placeholder('int32', shape=(None, 2)))
+    with self.assertRaisesRegexp(ValueError, r'expected shape'):
+      layer.apply(constant_op.constant([[1, 2]]))
 
     # Works
-    layer.apply(array_ops.placeholder('int32', shape=(None, 3)))
-    layer.apply(array_ops.placeholder('int32', shape=(2, 3)))
+    layer = CustomerLayer()
+    layer.apply(constant_op.constant([[1, 2, 3], [4, 5, 6]]))
 
+  @test_util.run_in_graph_and_eager_modes()
   def testNoInputSpec(self):
 
     class CustomerLayer(base_layers.Layer):
@@ -396,9 +421,12 @@ class BaseLayerTest(test.TestCase):
 
     layer = CustomerLayer()
 
+    layer.apply(constant_op.constant(1))
+
     # Works
-    layer.apply(array_ops.placeholder('int32'))
-    layer.apply(array_ops.placeholder('int32', shape=(2, 3)))
+    if context.in_graph_mode():
+      layer.apply(array_ops.placeholder('int32'))
+      layer.apply(array_ops.placeholder('int32', shape=(2, 3)))
 
   def test_get_updates_for(self):
     a = base_layers.Input(shape=(2,))
@@ -476,10 +504,11 @@ class BaseLayerTest(test.TestCase):
       _ = new_dense.output_shape
 
   def testTopologicalAttributesMultiOutputLayer(self):
+
     class PowersLayer(base_layers.Layer):
 
       def call(self, inputs):
-        return [inputs ** 2, inputs ** 3]
+        return [inputs**2, inputs**3]
 
     x = base_layers.Input(shape=(32,))
     test_layer = PowersLayer()
@@ -491,6 +520,7 @@ class BaseLayerTest(test.TestCase):
     self.assertEqual(test_layer.output_shape, [(None, 32), (None, 32)])
 
   def testTopologicalAttributesMultiInputLayer(self):
+
     class AddLayer(base_layers.Layer):
 
       def call(self, inputs):
@@ -507,6 +537,7 @@ class BaseLayerTest(test.TestCase):
     self.assertEqual(test_layer.input_shape, [(None, 32), (None, 32)])
     self.assertEqual(test_layer.output_shape, (None, 32))
 
+  @test_util.run_in_graph_and_eager_modes()
   def test_count_params(self):
     dense = core_layers.Dense(16)
     dense.build((None, 4))
@@ -652,7 +683,7 @@ class NetworkTest(test.TestCase):
     class PowersLayer(base_layers.Layer):
 
       def call(self, inputs):
-        return [inputs ** 2, inputs ** 3]
+        return [inputs**2, inputs**3]
 
     x = base_layers.Input(shape=(32,))
     p1, p2 = PowersLayer()(x)  # pylint: disable=not-callable
@@ -673,7 +704,7 @@ class NetworkTest(test.TestCase):
 
   def testNetworkAttributes(self):
     x = base_layers.Input(shape=(32,))
-    z = core_layers.Dense(2, kernel_regularizer=lambda x: 0.01 * (x ** 2))(x)
+    z = core_layers.Dense(2, kernel_regularizer=lambda x: 0.01 * (x**2))(x)
     dense = core_layers.Dense(2, name='dense')
     dense.add_update(1)
     y = dense(z)
@@ -795,6 +826,7 @@ class NetworkTest(test.TestCase):
     self.assertEqual(len(network.layers), 2)
     self.assertEqual(network.layers[0].sparse, True)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testMaskingSingleInput(self):
 
     class MaskedLayer(base_layers.Layer):
@@ -807,19 +839,29 @@ class NetworkTest(test.TestCase):
       def compute_mask(self, inputs, mask=None):
         return array_ops.ones_like(inputs)
 
-    x = base_layers.Input(shape=(32,))
-    y = MaskedLayer()(x)  # pylint: disable=not-callable
-    network = base_layers.Network(x, y)
-
-    # test callability on Input
-    x_2 = base_layers.Input(shape=(32,))
-    y_2 = network(x_2)
-    self.assertEqual(y_2.get_shape().as_list(), [None, 32])
-
-    # test callability on regular tensor
-    x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32))
-    y_2 = network(x_2)
-    self.assertEqual(y_2.get_shape().as_list(), [None, 32])
+    if context.in_graph_mode():
+      x = base_layers.Input(shape=(32,))
+      y = MaskedLayer()(x)  # pylint: disable=not-callable
+      network = base_layers.Network(x, y)
+
+      # test callability on Input
+      x_2 = base_layers.Input(shape=(32,))
+      y_2 = network(x_2)
+      self.assertEqual(y_2.get_shape().as_list(), [None, 32])
+
+      # test callability on regular tensor
+      x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32))
+      y_2 = network(x_2)
+      self.assertEqual(y_2.get_shape().as_list(), [None, 32])
+    else:
+      a = constant_op.constant([2] * 32)
+      mask = constant_op.constant([0, 1] * 16)
+      a._keras_mask = mask
+      b = MaskedLayer().apply(a)
+      self.assertTrue(hasattr(b, '_keras_mask'))
+      self.assertAllEqual(self.evaluate(array_ops.ones_like(mask)),
+                          self.evaluate(getattr(b, '_keras_mask')))
+      self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py
index 68293aa5fe58e2d2cfcd8cbf9730a9e48f6ff913..41c67743b6df4bd7c3a1a105bda199fe42d2c577 100644
--- a/tensorflow/python/layers/convolutional.py
+++ b/tensorflow/python/layers/convolutional.py
@@ -24,6 +24,7 @@ import six
 from six.moves import xrange  # pylint: disable=redefined-builtin
 import numpy as np
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
@@ -171,7 +172,7 @@ class _Conv(base.Layer):
         padding=self.padding.upper(),
         data_format=utils.convert_data_format(self.data_format, self.rank + 2))
 
-    if self.bias is not None:
+    if self.use_bias:
       if self.data_format == 'channels_first':
         if self.rank == 1:
           # nn.bias_add does not accept a 1D input tensor.
@@ -988,7 +989,7 @@ class SeparableConv2D(Conv2D):
         rate=self.dilation_rate,
         data_format=utils.convert_data_format(self.data_format, ndim=4))
 
-    if self.bias is not None:
+    if self.use_bias:
       outputs = nn.bias_add(
           outputs,
           self.bias,
@@ -1293,20 +1294,21 @@ class Conv2DTranspose(Conv2D):
         padding=self.padding.upper(),
         data_format=utils.convert_data_format(self.data_format, ndim=4))
 
-    # Infer the static output shape:
-    out_shape = inputs.get_shape().as_list()
-    out_shape[c_axis] = self.filters
-    out_shape[h_axis] = utils.deconv_output_length(out_shape[h_axis],
-                                                   kernel_h,
-                                                   self.padding,
-                                                   stride_h)
-    out_shape[w_axis] = utils.deconv_output_length(out_shape[w_axis],
-                                                   kernel_w,
-                                                   self.padding,
-                                                   stride_w)
-    outputs.set_shape(out_shape)
-
-    if self.bias:
+    if context.in_graph_mode():
+      # Infer the static output shape:
+      out_shape = inputs.get_shape().as_list()
+      out_shape[c_axis] = self.filters
+      out_shape[h_axis] = utils.deconv_output_length(out_shape[h_axis],
+                                                     kernel_h,
+                                                     self.padding,
+                                                     stride_h)
+      out_shape[w_axis] = utils.deconv_output_length(out_shape[w_axis],
+                                                     kernel_w,
+                                                     self.padding,
+                                                     stride_w)
+      outputs.set_shape(out_shape)
+
+    if self.use_bias:
       outputs = nn.bias_add(
           outputs,
           self.bias,
@@ -1591,24 +1593,25 @@ class Conv3DTranspose(Conv3D):
         data_format=utils.convert_data_format(self.data_format, ndim=5),
         padding=self.padding.upper())
 
-    # Infer the static output shape:
-    out_shape = inputs.get_shape().as_list()
-    out_shape[c_axis] = self.filters
-    out_shape[d_axis] = utils.deconv_output_length(out_shape[d_axis],
-                                                   kernel_d,
-                                                   self.padding,
-                                                   stride_d)
-    out_shape[h_axis] = utils.deconv_output_length(out_shape[h_axis],
-                                                   kernel_h,
-                                                   self.padding,
-                                                   stride_h)
-    out_shape[w_axis] = utils.deconv_output_length(out_shape[w_axis],
-                                                   kernel_w,
-                                                   self.padding,
-                                                   stride_w)
-    outputs.set_shape(out_shape)
-
-    if self.bias:
+    if context.in_graph_mode():
+      # Infer the static output shape:
+      out_shape = inputs.get_shape().as_list()
+      out_shape[c_axis] = self.filters
+      out_shape[d_axis] = utils.deconv_output_length(out_shape[d_axis],
+                                                     kernel_d,
+                                                     self.padding,
+                                                     stride_d)
+      out_shape[h_axis] = utils.deconv_output_length(out_shape[h_axis],
+                                                     kernel_h,
+                                                     self.padding,
+                                                     stride_h)
+      out_shape[w_axis] = utils.deconv_output_length(out_shape[w_axis],
+                                                     kernel_w,
+                                                     self.padding,
+                                                     stride_w)
+      outputs.set_shape(out_shape)
+
+    if self.use_bias:
       outputs_shape = outputs.shape.as_list()
       if self.data_format == 'channels_first':
         outputs_4d = array_ops.reshape(outputs, [
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index a2906802dfd3ab1d43daf719b6ae6ac35e05b733..f63f004726d8a639dd97b9e1a640de42db5a45ef 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -26,6 +26,7 @@ import six
 from six.moves import xrange  # pylint: disable=redefined-builtin
 import numpy as np
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
@@ -148,13 +149,14 @@ class Dense(base.Layer):
   def call(self, inputs):
     inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
     shape = inputs.get_shape().as_list()
-    output_shape = shape[:-1] + [self.units]
-    if len(output_shape) > 2:
+    if len(shape) > 2:
       # Broadcasting is required for the inputs.
       outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1],
                                                              [0]])
       # Reshape the output back to the original ndim of the input.
-      outputs.set_shape(output_shape)
+      if context.in_graph_mode():
+        output_shape = shape[:-1] + [self.units]
+        outputs.set_shape(output_shape)
     else:
       outputs = standard_ops.matmul(inputs, self.kernel)
     if self.use_bias:
@@ -287,6 +289,7 @@ class Dropout(base.Layer):
     return self.noise_shape
 
   def call(self, inputs, training=False):
+
     def dropped_inputs():
       return nn.dropout(inputs, 1  - self.rate,
                         noise_shape=self._get_noise_shape(inputs),
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index 6f315f8c4df2014036744c32077bff3cd1dd9783..188bdeb85c679f220f0ab71515f0593e26914ad8 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -20,9 +20,11 @@ from __future__ import print_function
 
 import numpy as np
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
 from tensorflow.python.layers import core as core_layers
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
@@ -36,6 +38,7 @@ from tensorflow.python.platform import test
 
 class DenseTest(test.TestCase):
 
+  @test_util.run_in_graph_and_eager_modes()
   def testDenseProperties(self):
     dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')
     self.assertEqual(dense.units, 2)
@@ -53,10 +56,12 @@ class DenseTest(test.TestCase):
     dense.apply(random_ops.random_uniform((5, 2)))
     self.assertEqual(dense.name, 'dense_2')
 
+  @test_util.run_in_graph_and_eager_modes()
   def testCall(self):
     dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')
-    inputs = random_ops.random_uniform((5, 2), seed=1)
-    _ = dense(inputs)
+    inputs = random_ops.random_uniform((5, 4), seed=1)
+    outputs = dense(inputs)
+    self.assertListEqual([5, 2], outputs.get_shape().as_list())
     self.assertListEqual(dense.variables, [dense.kernel, dense.bias])
     self.assertListEqual(dense.trainable_variables, [dense.kernel, dense.bias])
     self.assertListEqual(dense.non_trainable_variables, [])
@@ -65,6 +70,14 @@ class DenseTest(test.TestCase):
     self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
     self.assertEqual(dense.bias.name, 'my_dense/bias:0')
 
+  @test_util.run_in_graph_and_eager_modes()
+  def testCallTensorDot(self):
+    dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')
+    inputs = random_ops.random_uniform((5, 4, 3), seed=1)
+    outputs = dense(inputs)
+    self.assertListEqual([5, 4, 2], outputs.get_shape().as_list())
+
+  @test_util.run_in_graph_and_eager_modes()
   def testNoBias(self):
     dense = core_layers.Dense(2, use_bias=False, name='my_dense')
     inputs = random_ops.random_uniform((5, 2), seed=1)
@@ -77,6 +90,7 @@ class DenseTest(test.TestCase):
     self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
     self.assertEqual(dense.bias, None)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testNonTrainable(self):
     dense = core_layers.Dense(2, trainable=False, name='my_dense')
     inputs = random_ops.random_uniform((5, 2), seed=1)
@@ -88,6 +102,7 @@ class DenseTest(test.TestCase):
     self.assertEqual(
         len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testOutputShape(self):
     dense = core_layers.Dense(7, activation=nn_ops.relu, name='my_dense')
     inputs = random_ops.random_uniform((5, 3), seed=1)
@@ -127,16 +142,19 @@ class DenseTest(test.TestCase):
     dense = core_layers.Dense(4, name='my_dense')
     dense(inputs)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testActivation(self):
     dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1')
     inputs = random_ops.random_uniform((5, 3), seed=1)
     outputs = dense(inputs)
-    self.assertEqual(outputs.op.name, 'dense1/Relu')
+    if context.in_graph_mode():
+      self.assertEqual(outputs.op.name, 'dense1/Relu')
 
     dense = core_layers.Dense(2, name='dense2')
     inputs = random_ops.random_uniform((5, 3), seed=1)
     outputs = dense(inputs)
-    self.assertEqual(outputs.op.name, 'dense2/BiasAdd')
+    if context.in_graph_mode():
+      self.assertEqual(outputs.op.name, 'dense2/BiasAdd')
 
   def testActivityRegularizer(self):
     regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
@@ -179,15 +197,18 @@ class DenseTest(test.TestCase):
     self.assertEqual(len(loss_keys), 1)
     self.assertListEqual(dense.losses, loss_keys)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testFunctionalDense(self):
     inputs = random_ops.random_uniform((5, 3), seed=1)
     outputs = core_layers.dense(
         inputs, 2, activation=nn_ops.relu, name='my_dense')
     self.assertEqual(
         len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 2)
-    self.assertEqual(outputs.op.name, 'my_dense/Relu')
+    if context.in_graph_mode():
+      self.assertEqual(outputs.op.name, 'my_dense/Relu')
     self.assertEqual(outputs.get_shape().as_list(), [5, 2])
 
+  @test_util.run_in_graph_and_eager_modes()
   def testFunctionalDenseTwice(self):
     inputs = random_ops.random_uniform((5, 3), seed=1)
     core_layers.dense(inputs, 2)
@@ -197,6 +218,7 @@ class DenseTest(test.TestCase):
     self.assertEqual(len(vars1), 2)
     self.assertEqual(len(vars2), 4)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testFunctionalDenseTwiceReuse(self):
     inputs = random_ops.random_uniform((5, 3), seed=1)
     core_layers.dense(inputs, 2, name='my_dense')
@@ -205,6 +227,7 @@ class DenseTest(test.TestCase):
     vars2 = variables.trainable_variables()
     self.assertEqual(vars1, vars2)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testFunctionalDenseTwiceReuseFromScope(self):
     with variable_scope.variable_scope('scope'):
       inputs = random_ops.random_uniform((5, 3), seed=1)
@@ -215,20 +238,23 @@ class DenseTest(test.TestCase):
       vars2 = variables.trainable_variables()
     self.assertEqual(vars1, vars2)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testFunctionalDenseInitializerFromScope(self):
-    with self.test_session() as sess:
-      with variable_scope.variable_scope(
-          'scope', initializer=init_ops.ones_initializer()):
-        inputs = random_ops.random_uniform((5, 3), seed=1)
-        core_layers.dense(inputs, 2)
-        sess.run(variables.global_variables_initializer())
-        weights = sess.run(variables.trainable_variables())
-        self.assertEqual(len(weights), 2)
-        # Check that the matrix weights got initialized to ones (from scope).
-        self.assertAllClose(weights[0], np.ones((3, 2)))
-        # Check that the bias still got initialized to zeros.
-        self.assertAllClose(weights[1], np.zeros((2)))
-
+    with variable_scope.variable_scope(
+        'scope', initializer=init_ops.ones_initializer()):
+      inputs = random_ops.random_uniform((5, 3), seed=1)
+      core_layers.dense(inputs, 2)
+      if context.in_graph_mode():
+        self.evaluate(variables.global_variables_initializer())
+      weights = variables.trainable_variables()
+      self.assertEqual(len(weights), 2)
+      # Check that the matrix weights got initialized to ones (from scope).
+      self.assertAllClose(
+          self.evaluate(weights[0].read_value()), np.ones((3, 2)))
+      # Check that the bias still got initialized to zeros.
+      self.assertAllClose(self.evaluate(weights[1].read_value()), np.zeros((2)))
+
+  @test_util.run_in_graph_and_eager_modes()
   def testFunctionalDenseWithCustomGetter(self):
     called = [0]
 
@@ -241,6 +267,7 @@ class DenseTest(test.TestCase):
       core_layers.dense(inputs, 2)
     self.assertEqual(called[0], 2)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testFunctionalDenseInScope(self):
     with variable_scope.variable_scope('test'):
       inputs = random_ops.random_uniform((5, 3), seed=1)
@@ -258,6 +285,7 @@ class DenseTest(test.TestCase):
       var = variables.trainable_variables()[4]
       self.assertEqual(var.name, 'test2/dense/kernel:0')
 
+  @test_util.run_in_graph_and_eager_modes()
   def testComputeOutputShape(self):
     dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1')
     ts = tensor_shape.TensorShape
@@ -279,6 +307,7 @@ class DenseTest(test.TestCase):
         dense._compute_output_shape(ts([None, 4, 3])).as_list())
     # pylint: enable=protected-access
 
+  @test_util.run_in_graph_and_eager_modes()
   def testConstraints(self):
     k_constraint = lambda x: x / math_ops.reduce_sum(x)
     b_constraint = lambda x: x / math_ops.reduce_max(x)
@@ -293,6 +322,7 @@ class DenseTest(test.TestCase):
 
 class DropoutTest(test.TestCase):
 
+  @test_util.run_in_graph_and_eager_modes()
   def testDropoutProperties(self):
     dp = core_layers.Dropout(0.5, name='dropout')
     self.assertEqual(dp.rate, 0.5)
@@ -300,17 +330,18 @@ class DropoutTest(test.TestCase):
     dp.apply(array_ops.ones(()))
     self.assertEqual(dp.name, 'dropout')
 
+  @test_util.run_in_graph_and_eager_modes()
   def testBooleanLearningPhase(self):
-    with self.test_session() as sess:
-      dp = core_layers.Dropout(0.5)
-      inputs = array_ops.ones((5, 3))
-      dropped = dp.apply(inputs, training=True)
-      sess.run(variables.global_variables_initializer())
-      np_output = sess.run(dropped)
-      self.assertAlmostEqual(0., np_output.min())
-      dropped = dp.apply(inputs, training=False)
-      np_output = sess.run(dropped)
-      self.assertAllClose(np.ones((5, 3)), np_output)
+    dp = core_layers.Dropout(0.5)
+    inputs = array_ops.ones((5, 3))
+    dropped = dp.apply(inputs, training=True)
+    if context.in_graph_mode():
+      self.evaluate(variables.global_variables_initializer())
+    np_output = self.evaluate(dropped)
+    self.assertAlmostEqual(0., np_output.min())
+    dropped = dp.apply(inputs, training=False)
+    np_output = self.evaluate(dropped)
+    self.assertAllClose(np.ones((5, 3)), np_output)
 
   def testDynamicLearningPhase(self):
     with self.test_session() as sess:
@@ -318,35 +349,34 @@ class DropoutTest(test.TestCase):
       inputs = array_ops.ones((5, 5))
       training = array_ops.placeholder(dtype='bool')
       dropped = dp.apply(inputs, training=training)
-      sess.run(variables.global_variables_initializer())
+      self.evaluate(variables.global_variables_initializer())
       np_output = sess.run(dropped, feed_dict={training: True})
       self.assertAlmostEqual(0., np_output.min())
       np_output = sess.run(dropped, feed_dict={training: False})
       self.assertAllClose(np.ones((5, 5)), np_output)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testCustomNoiseShape(self):
-    with self.test_session() as sess:
-      inputs = array_ops.ones((5, 3, 2))
-      noise_shape = [5, 1, 2]
-      dp = core_layers.Dropout(0.5, noise_shape=noise_shape, seed=1)
-      dropped = dp.apply(inputs, training=True)
-      sess.run(variables.global_variables_initializer())
-      np_output = sess.run(dropped)
-      self.assertAlmostEqual(0., np_output.min())
-      self.assertAllClose(np_output[:, 0, :], np_output[:, 1, :])
-
+    inputs = array_ops.ones((5, 3, 2))
+    noise_shape = [5, 1, 2]
+    dp = core_layers.Dropout(0.5, noise_shape=noise_shape, seed=1)
+    dropped = dp.apply(inputs, training=True)
+    self.evaluate(variables.global_variables_initializer())
+    np_output = self.evaluate(dropped)
+    self.assertAlmostEqual(0., np_output.min())
+    self.assertAllClose(np_output[:, 0, :], np_output[:, 1, :])
+
+  @test_util.run_in_graph_and_eager_modes()
   def testFunctionalDropout(self):
-    with self.test_session() as sess:
-      inputs = array_ops.ones((5, 5))
-      training = array_ops.placeholder(dtype='bool')
-      dropped = core_layers.dropout(inputs, 0.5, training=training, seed=1)
-      self.assertEqual(dropped.op.name, 'dropout/cond/Merge')
-
-      sess.run(variables.global_variables_initializer())
-      np_output = sess.run(dropped, feed_dict={training: True})
-      self.assertAlmostEqual(0., np_output.min())
-      np_output = sess.run(dropped, feed_dict={training: False})
-      self.assertAllClose(np.ones((5, 5)), np_output)
+    inputs = array_ops.ones((5, 5))
+    dropped = core_layers.dropout(inputs, 0.5, training=True, seed=1)
+    if context.in_graph_mode():
+      self.evaluate(variables.global_variables_initializer())
+    np_output = self.evaluate(dropped)
+    self.assertAlmostEqual(0., np_output.min())
+    dropped = core_layers.dropout(inputs, 0.5, training=False, seed=1)
+    np_output = self.evaluate(dropped)
+    self.assertAllClose(np.ones((5, 5)), np_output)
 
   def testDynamicRate(self):
     with self.test_session() as sess:
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index dffbfc72902d6607bce03a924e3de4d608c0d7e5..62265dce3c512cf12a9ad1ab147a42b737ef3b00 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -25,14 +25,19 @@ import six
 from six.moves import xrange  # pylint: disable=redefined-builtin
 import numpy as np
 
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import nn
+from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import standard_ops
+from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope as vs
 from tensorflow.python.training import moving_averages
 from tensorflow.python.framework import tensor_util
@@ -187,17 +192,6 @@ class BatchNormalization(base.Layer):
     self.input_spec = base.InputSpec(ndim=ndim,
                                      axes={self.axis: param_dim.value})
 
-    if self.center:
-      self.beta = self.add_variable(name='beta',
-                                    shape=(param_dim,),
-                                    initializer=self.beta_initializer,
-                                    regularizer=self.beta_regularizer,
-                                    constraint=self.beta_constraint,
-                                    trainable=True)
-    else:
-      self.beta = None
-      if self.fused:
-        self._beta_const = array_ops.constant(0.0, shape=(param_dim,))
     if self.scale:
       self.gamma = self.add_variable(name='gamma',
                                      shape=(param_dim,),
@@ -210,6 +204,18 @@ class BatchNormalization(base.Layer):
       if self.fused:
         self._gamma_const = array_ops.constant(1.0, shape=(param_dim,))
 
+    if self.center:
+      self.beta = self.add_variable(name='beta',
+                                    shape=(param_dim,),
+                                    initializer=self.beta_initializer,
+                                    regularizer=self.beta_regularizer,
+                                    constraint=self.beta_constraint,
+                                    trainable=True)
+    else:
+      self.beta = None
+      if self.fused:
+        self._beta_const = array_ops.constant(0.0, shape=(param_dim,))
+
     # Disable variable partitioning when creating the moving mean and variance
     try:
       if self._scope:
@@ -227,6 +233,7 @@ class BatchNormalization(base.Layer):
           shape=(param_dim,),
           initializer=self.moving_variance_initializer,
           trainable=False)
+      self._one_minus_decay = 1.0 - self.momentum
       if self.renorm:
         # Create variables to maintain the moving mean and standard deviation.
         # These are used in training and thus are different from the moving
@@ -241,15 +248,20 @@ class BatchNormalization(base.Layer):
                                   initializer=init_ops.zeros_initializer(),
                                   trainable=False)
           return var
+
         with ops.device(None):
-          with ops.device(lambda _: self.moving_mean.device):
+          device = ((lambda _: self.moving_mean.device)
+                    if context.in_graph_mode() else self.moving_mean.device)
+          with ops.device(device):
             self.renorm_mean = _renorm_variable('renorm_mean', (param_dim,))
             self.renorm_mean_weight = _renorm_variable('renorm_mean_weight', ())
           # We initialize renorm_stddev to 0, and maintain the (0-initialized)
           # renorm_stddev_weight. This allows us to (1) mix the average
           # stddev with the minibatch stddev early in training, and (2) compute
           # the unbiased average stddev by dividing renorm_stddev by the weight.
-          with ops.device(lambda _: self.moving_variance.device):
+          device = ((lambda _: self.moving_variance.device)
+                    if context.in_graph_mode() else self.moving_variance.device)
+          with ops.device(device):
             self.renorm_stddev = _renorm_variable('renorm_stddev', (param_dim,))
             self.renorm_stddev_weight = _renorm_variable(
                 'renorm_stddev_weight', ())
@@ -258,6 +270,19 @@ class BatchNormalization(base.Layer):
         self._scope.set_partitioner(partitioner)
     self.built = True
 
+  def _assign_moving_average(self, variable, value, one_minus_decay):
+    with ops.name_scope(None, 'AssignMovingAvg',
+                        [variable, value, one_minus_decay]) as scope:
+      with ops.colocate_with(variable):
+        update_delta = (variable.read_value() - value) * one_minus_decay
+        if isinstance(variable, resource_variable_ops.ResourceVariable):
+          # state_ops.assign_sub does an extra read_variable_op after the
+          # assign. We avoid that here.
+          return gen_resource_variable_ops.assign_sub_variable_op(
+              variable.handle, update_delta, name=scope)
+        else:
+          return state_ops.assign_sub(variable, update_delta, name=scope)
+
   def _fused_batch_norm(self, inputs, training):
     """Returns the output of fused batch norm."""
     beta = self.beta if self.center else self._beta_const
@@ -294,14 +319,23 @@ class BatchNormalization(base.Layer):
       variance *= factor
 
     training_value = utils.constant_value(training)
-    if training_value is not False:
-      decay = _smart_select(training, lambda: self.momentum, lambda: 1.)
-      mean_update = moving_averages.assign_moving_average(
-          self.moving_mean, mean, decay, zero_debias=False)
-      variance_update = moving_averages.assign_moving_average(
-          self.moving_variance, variance, decay, zero_debias=False)
-      self.add_update(mean_update, inputs=inputs)
-      self.add_update(variance_update, inputs=inputs)
+    if training_value is None:
+      one_minus_decay = _smart_select(training,
+                                      lambda: self._one_minus_decay,
+                                      lambda: 0.)
+    else:
+      one_minus_decay = self._one_minus_decay
+    if training_value or training_value is None:
+      mean_update = self._assign_moving_average(self.moving_mean, mean,
+                                                one_minus_decay)
+      variance_update = self._assign_moving_average(self.moving_variance,
+                                                    variance, one_minus_decay)
+      if context.in_graph_mode():
+        # Note that in Eager mode, the updates are already executed when running
+        # assign_moving_averages. So we do not need to put them into
+        # collections.
+        self.add_update(mean_update, inputs=inputs)
+        self.add_update(variance_update, inputs=inputs)
 
     return output
 
@@ -334,6 +368,7 @@ class BatchNormalization(base.Layer):
     r = _smart_select(training, lambda: r, lambda: array_ops.ones_like(r))
     d = _smart_select(training, lambda: d, lambda: array_ops.zeros_like(d))
     decay = _smart_select(training, lambda: self.renorm_momentum, lambda: 1.)
+
     def _update_renorm_variable(var, weight, value):
       """Updates a moving average and weight, returns the unbiased value."""
       # Update the variables without zero debiasing. The debiasing will be
@@ -417,9 +452,9 @@ class BatchNormalization(base.Layer):
           self.moving_mean, new_mean, decay, zero_debias=False)
       variance_update = moving_averages.assign_moving_average(
           self.moving_variance, new_variance, decay, zero_debias=False)
-
-      self.add_update(mean_update, inputs=inputs)
-      self.add_update(variance_update, inputs=inputs)
+      if context.in_graph_mode():
+        self.add_update(mean_update, inputs=inputs)
+        self.add_update(variance_update, inputs=inputs)
 
     else:
       mean, variance = self.moving_mean, self.moving_variance
@@ -565,7 +600,6 @@ def batch_normalization(inputs,
 BatchNorm = BatchNormalization
 batch_norm = batch_normalization
 
-
 # Helper function
 
 
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 55b1a7296dd74217c2eee9bc2c34f7ee79e054a5..bdc1f406153572bd4234da8631ab69d9e8718c92 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -534,7 +534,8 @@ def _TileGrad(op, grad):
   axes = math_ops.range(0, array_ops.size(split_shape), 2)
   input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
   # Fix shape inference
-  input_grad.set_shape(op.inputs[0].get_shape())
+  if context.in_graph_mode():
+    input_grad.set_shape(op.inputs[0].get_shape())
   return [input_grad, None]
 
 
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 39609255b146d3ea4b18620a4ec150132e34f997..33ba5df7a6e8791f0b416e1bab57eeecb951d92a 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1466,12 +1466,15 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
   with ops.name_scope(name, "zeros_like", [tensor]) as name:
     tensor = ops.convert_to_tensor(tensor, name="tensor")
 
-    if tensor.shape.is_fully_defined():
+    # For now, variant types must be created via zeros_like; as we need to
+    # pass the input variant object to the proper zeros callback.
+
+    if tensor.shape.is_fully_defined() and tensor.dtype != dtypes.variant:
       # We can produce a zeros tensor independent of the value of 'tensor',
       # since the shape is known statically.
       return zeros(tensor.shape, dtype=dtype or tensor.dtype, name=name)
 
-    if dtype is not None and dtype != tensor.dtype:
+    if dtype is not None and dtype != tensor.dtype and dtype != dtypes.variant:
       return zeros(
           shape_internal(tensor, optimize=optimize), dtype=dtype, name=name)
     else:
@@ -1726,18 +1729,19 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0):  # pyl
     raise ValueError("Unknown padding mode: %s" % mode)
 
   # Restore shape information where possible.
-  paddings_constant = tensor_util.constant_value(
-      result.op.inputs[1], partial=True)
-  input_shape = result.op.inputs[0].shape
-  if (input_shape.ndims is not None and not result.shape.is_fully_defined() and
-      paddings_constant is not None):
-    new_shape = []
-    for padding, dim in zip(paddings_constant, input_shape.as_list()):
-      if padding is None or dim is None or not all(padding):
-        new_shape.append(None)
-      else:
-        new_shape.append(sum(padding) + dim)
-    result.set_shape(new_shape)
+  if context.in_graph_mode():
+    paddings_constant = tensor_util.constant_value(
+        result.op.inputs[1], partial=True)
+    input_shape = result.op.inputs[0].shape
+    if (input_shape.ndims is not None and not result.shape.is_fully_defined()
+        and paddings_constant is not None):
+      new_shape = []
+      for padding, dim in zip(paddings_constant, input_shape.as_list()):
+        if padding is None or dim is None or not all(padding):
+          new_shape.append(None)
+        else:
+          new_shape.append(sum(padding) + dim)
+      result.set_shape(new_shape)
 
   return result
 
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index b3e724be9c1fea15ff35e2c9e59dcd1bb4650cbf..fb4817528551f6a1b0c8be33c8508beec8a83693 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -46,6 +46,7 @@ from __future__ import print_function
 
 import numpy as np
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
@@ -861,8 +862,12 @@ def assert_type(tensor, tf_type, message=None, name=None):
   with ops.name_scope(name, 'assert_type', [tensor]):
     tensor = ops.convert_to_tensor(tensor, name='tensor')
     if tensor.dtype != tf_type:
-      raise TypeError(
-          '%s  %s must be of type %s' % (message, tensor.op.name, tf_type))
+      if context.in_graph_mode():
+        raise TypeError(
+            '%s  %s must be of type %s' % (message, tensor.name, tf_type))
+      else:
+        raise TypeError(
+            '%s tensor must be of type %s' % (message, tf_type))
 
     return control_flow_ops.no_op('statically_determined_correct_type')
 
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 89de88a5307b8a8ee9154ab31bb652cc10c649f9..97b37ea0272438a113808d627fd262d342f53cf2 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1813,6 +1813,11 @@ def cond(pred, true_fn=None, false_fn=None, strict=False, name=None,
   if not callable(false_fn):
     raise TypeError("false_fn must be callable.")
 
+  if context.in_eager_mode():
+    if pred:
+      return true_fn()
+    return false_fn()
+
   with ops.name_scope(name, "cond", [pred]):
     # Add the Switch to the graph.
     if isinstance(pred, bool):
@@ -2671,7 +2676,7 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
   Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
   call to `while_loop`, and not at all during `Session.run()`). `while_loop`
   stitches together the graph fragments created during the `cond` and `body`
-  calls with some additional graph nodes to create the graph flow that 
+  calls with some additional graph nodes to create the graph flow that
   repeats `body` until `cond` returns false.
 
   For correctness, `tf.while_loop()` strictly enforces shape invariants for
@@ -2779,12 +2784,17 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
     if parallel_iterations < 1:
       raise TypeError("parallel_iterations must be a positive integer.")
 
+    if context.in_eager_mode():
+      while cond(*loop_vars):
+        loop_vars = body(*loop_vars)
+      return loop_vars
+
     if shape_invariants is not None:
       nest.assert_same_structure(loop_vars, shape_invariants)
 
-    context = WhileContext(parallel_iterations, back_prop, swap_memory)  # pylint: disable=redefined-outer-name
-    ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, context)
-    result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
+    loop_context = WhileContext(parallel_iterations, back_prop, swap_memory)  # pylint: disable=redefined-outer-name
+    ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
+    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
     return result
 
 
@@ -2850,6 +2860,8 @@ def with_dependencies(dependencies, output_tensor, name=None):
   Raises:
     TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
   """
+  if context.in_eager_mode():
+    return output_tensor
   with ops.name_scope(name, "control_dependency",
                       list(dependencies) + [output_tensor]) as name:
     with ops.colocate_with(output_tensor):
@@ -2962,6 +2974,8 @@ def tuple(tensors, name=None, control_inputs=None):
       objects.
 
   """
+  if context.in_eager_mode():
+    return tensors
   with ops.name_scope(name, "tuple", tensors) as name:
     gating_ops = [t.op for t in tensors if t is not None]
     if control_inputs:
@@ -3087,12 +3101,18 @@ def case(pred_fn_pairs, default=None, exclusive=False, strict=False,
     TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
     TypeError: If `fns[i]` is not callable for any i, or `default` is not
                callable.
+    ValueError: If in eager mode and all predicates are false and no
+               default is provided.
+    ValueError: If in eager mode and is passed a dictionary.
   """
   pfp = pred_fn_pairs  # For readability
   if not (isinstance(pfp, list) or isinstance(pfp, _basetuple)
           or isinstance(pfp, dict)):
     raise TypeError("fns must be a list, tuple, or dict")
   if isinstance(pfp, dict):
+    if context.in_eager_mode():
+      raise ValueError(
+          "In eager mode the predicates must be a list, not a dictionary.")
     if isinstance(pfp, collections.OrderedDict):
       pfp = pfp.items()
     else:
@@ -3113,6 +3133,14 @@ def case(pred_fn_pairs, default=None, exclusive=False, strict=False,
   if default is not None and not callable(default):
     raise TypeError("default must be callable.")
 
+  if context.in_eager_mode():
+    for pred, fn in pfp:
+      if pred:
+        return fn()
+    if default is None:
+      raise ValueError("tf.case received all false predicates and no default.")
+    return default()
+
   preds, fns = map(list, zip(*pfp))
   del pfp  # From now on, preds and fns form the source of truth.
 
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 2a2ef5809d0eba86b0b30dd2858a89af78342a8b..41dd7f1467657ff755e44fc7bf27b34cdea61fdb 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -25,6 +25,7 @@ import threading
 
 import six
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes as _dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
@@ -160,7 +161,10 @@ class QueueBase(object):
     else:
       self._names = None
     self._queue_ref = queue_ref
-    self._name = self._queue_ref.op.name.split("/")[-1]
+    if context.in_graph_mode():
+      self._name = self._queue_ref.op.name.split("/")[-1]
+    else:
+      self._name = context.context().scope_name
 
   @staticmethod
   def from_list(index, queues):
@@ -208,7 +212,9 @@ class QueueBase(object):
   @property
   def name(self):
     """The name of the underlying queue."""
-    return self._queue_ref.op.name
+    if context.in_graph_mode():
+      return self._queue_ref.op.name
+    return self._name
 
   @property
   def dtypes(self):
@@ -419,9 +425,10 @@ class QueueBase(object):
 
     # NOTE(mrry): Not using a shape function because we need access to
     # the `QueueBase` object.
-    op = ret[0].op
-    for output, shape in zip(op.values(), self._shapes):
-      output.set_shape(shape)
+    if context.in_graph_mode():
+      op = ret[0].op
+      for output, shape in zip(op.values(), self._shapes):
+        output.set_shape(shape)
 
     return self._dequeue_return_value(ret)
 
@@ -458,10 +465,13 @@ class QueueBase(object):
 
     # NOTE(mrry): Not using a shape function because we need access to
     # the Queue object.
-    op = ret[0].op
-    batch_dim = tensor_shape.Dimension(tensor_util.constant_value(op.inputs[1]))
-    for output, shape in zip(op.values(), self._shapes):
-      output.set_shape(tensor_shape.TensorShape([batch_dim]).concatenate(shape))
+    if context.in_graph_mode():
+      op = ret[0].op
+      batch_dim = tensor_shape.Dimension(
+          tensor_util.constant_value(op.inputs[1]))
+      for output, shape in zip(op.values(), self._shapes):
+        output.set_shape(
+            tensor_shape.TensorShape([batch_dim]).concatenate(shape))
 
     return self._dequeue_return_value(ret)
 
@@ -499,9 +509,10 @@ class QueueBase(object):
 
     # NOTE(mrry): Not using a shape function because we need access to
     # the Queue object.
-    op = ret[0].op
-    for output, shape in zip(op.values(), self._shapes):
-      output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape))
+    if context.in_graph_mode():
+      op = ret[0].op
+      for output, shape in zip(op.values(), self._shapes):
+        output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape))
 
     return self._dequeue_return_value(ret)
 
@@ -897,7 +908,10 @@ class Barrier(object):
     self._barrier_ref = gen_data_flow_ops._barrier(
         component_types=self._types, shapes=self._shapes,
         shared_name=shared_name, name=name)
-    self._name = self._barrier_ref.op.name.split("/")[-1]
+    if context.in_graph_mode():
+      self._name = self._barrier_ref.op.name.split("/")[-1]
+    else:
+      self._name = context.context().scope_name
 
   @property
   def barrier_ref(self):
@@ -907,7 +921,9 @@ class Barrier(object):
   @property
   def name(self):
     """The name of the underlying barrier."""
-    return self._barrier_ref.op.name
+    if context.in_graph_mode():
+      return self._barrier_ref.op.name
+    return self._name
 
   def insert_many(self, component_index, keys, values, name=None):
     """For each key, assigns the respective value to the specified component.
@@ -984,16 +1000,19 @@ class Barrier(object):
 
     # NOTE(mrry): Not using a shape function because we need access to
     # the Barrier object.
-    op = ret[0].op
-    if allow_small_batch:
-      batch_dim = None
-    else:
-      batch_dim = tensor_shape.Dimension(
-          tensor_util.constant_value(op.inputs[1]))
-    op.outputs[0].set_shape(tensor_shape.vector(batch_dim))  # indices
-    op.outputs[1].set_shape(tensor_shape.vector(batch_dim))  # keys
-    for output, shape in zip(op.outputs[2:], self._shapes):  # value_list
-      output.set_shape(tensor_shape.TensorShape([batch_dim]).concatenate(shape))
+    if context.in_graph_mode():
+      op = ret[0].op
+      if allow_small_batch:
+        batch_dim = None
+      else:
+        batch_dim = tensor_shape.Dimension(
+            tensor_util.constant_value(op.inputs[1]))
+      op.outputs[0].set_shape(tensor_shape.vector(batch_dim))  # indices
+      op.outputs[1].set_shape(tensor_shape.vector(batch_dim))  # keys
+      for output, shape in zip(op.outputs[2:], self._shapes):  # value_list
+        output.set_shape(
+            tensor_shape.TensorShape([batch_dim]).concatenate(
+                shape))
 
     return ret
 
@@ -1081,7 +1100,10 @@ class ConditionalAccumulatorBase(object):
     else:
       self._shape = tensor_shape.unknown_shape()
     self._accumulator_ref = accumulator_ref
-    self._name = self._accumulator_ref.op.name.split("/")[-1]
+    if context.in_graph_mode():
+      self._name = self._accumulator_ref.op.name.split("/")[-1]
+    else:
+      self._name = context.context().scope_name
 
   @property
   def accumulator_ref(self):
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 59add19a58171e8febfbb981273e00a89ec22c1f..0c6096a0755e1d4635048281a7b34092d8fcf9bf 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -770,7 +770,7 @@ def fill_lower_triangular(x, validate_args=False, name="fill_lower_triangular"):
     else:
       d = math_ops.cast(array_ops.shape(x)[-1], dtype=dtypes.float32)
       # d = n(n+1)/2 implies n is:
-      n = math_ops.cast(0.5 * (dtypes.sqrt(1. + 8. * d) - 1.),
+      n = math_ops.cast(0.5 * (math_ops.sqrt(1. + 8. * d) - 1.),
                         dtype=dtypes.int32)
       if validate_args:
         is_valid_input_shape = check_ops.assert_equal(
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index e073dbc640cdc2e9dd8e0b1b7840cfa0ce6b3e9d..cb7d409f3bce38571d9a5574e9ea08da92efd6d6 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -27,6 +27,7 @@ import six
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -277,7 +278,7 @@ def _VerifyGeneratedGradients(grads, op):
                      "inputs %d" % (len(grads), op.node_def, len(op.inputs)))
 
 
-def _StopOps(from_ops, pending_count):
+def _StopOps(from_ops, stop_gradient_ops, pending_count):
   """The set of ops that terminate the gradient computation.
 
   This computes the frontier of the forward graph *before* which backprop
@@ -287,8 +288,11 @@ def _StopOps(from_ops, pending_count):
   `_PendingCount(g, xs, from_ops)`. An 'op' has predecessors in `from_ops`
   iff pending_count[op._id] > 0.
 
+  In addition, none of `stop_gradient_ops` will be differentiated.
+
   Args:
     from_ops: list of Operations.
+    stop_gradient_ops: list of Operations never to backprop through.
     pending_count: List of integers, indexed by operation id.
 
   Returns:
@@ -303,6 +307,7 @@ def _StopOps(from_ops, pending_count):
         break
     if is_stop_op:
       stop_ops.add(op._id)
+  stop_ops.update(op._id for op in stop_gradient_ops)  # pylint: disable=protected-access
   return stop_ops
 
 
@@ -373,17 +378,17 @@ def gradients(ys,
               name="gradients",
               colocate_gradients_with_ops=False,
               gate_gradients=False,
-              aggregation_method=None):
-  """Constructs symbolic partial derivatives of sum of `ys` w.r.t. x in `xs`.
+              aggregation_method=None,
+              stop_gradients=None):
+  """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.
 
   `ys` and `xs` are each a `Tensor` or a list of tensors.  `grad_ys`
   is a list of `Tensor`, holding the gradients received by the
   `ys`. The list must be the same length as `ys`.
 
-  `gradients()` adds ops to the graph to output the partial
-  derivatives of `ys` with respect to `xs`.  It returns a list of
-  `Tensor` of length `len(xs)` where each tensor is the `sum(dy/dx)`
-  for y in `ys`.
+  `gradients()` adds ops to the graph to output the derivatives of `ys` with
+  respect to `xs`.  It returns a list of `Tensor` of length `len(xs)` where
+  each tensor is the `sum(dy/dx)` for y in `ys`.
 
   `grad_ys` is a list of tensors of the same length as `ys` that holds
   the initial gradients for each y in `ys`.  When `grad_ys` is None,
@@ -393,6 +398,31 @@ def gradients(ys,
   one wanted to weight the gradient differently for each value in
   each y).
 
+  `stop_gradients` is a `Tensor` or a list of tensors to be considered constant
+  with respect to all `xs`. These tensors will not be backpropagated through,
+  as though they had been explicitly disconnected using `stop_gradient`.  Among
+  other things, this allows computation of partial derivatives as opposed to
+  total derivatives. For example:
+
+    a = tf.constant(0.)
+    b = 2 * a
+    g = tf.gradients(a + b, [a, b], stop_gradients=[a, b])
+
+  Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the
+  total derivatives `tf.gradients(a + b, [a, b])`, which take into account the
+  influence of `a` on `b` and evaluate to `[3.0, 1.0]`.  Note that the above is
+  equivalent to:
+
+    a = tf.stop_gradient(tf.constant(0.))
+    b = tf.stop_gradient(2 * a)
+    g = tf.gradients(a + b, [a, b])
+
+  `stop_gradients` provides a way of stopping gradient after the graph has
+  already been constructed, as compared to `tf.stop_gradient` which is used
+  during graph construction.  When the two approaches are combined,
+  backpropagation stops at both `tf.stop_gradient` nodes and nodes in
+  `stop_gradients`, whichever is encountered first.
+
   Args:
     ys: A `Tensor` or list of tensors to be differentiated.
     xs: A `Tensor` or list of tensors to be used for differentiation.
@@ -406,6 +436,8 @@ def gradients(ys,
       for an operations.  This avoids some race conditions.
     aggregation_method: Specifies the method used to combine gradient terms.
       Accepted values are constants defined in the class `AggregationMethod`.
+    stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate
+      through.
 
   Returns:
     A list of `sum(dy/dx)` for each x in `xs`.
@@ -414,16 +446,23 @@ def gradients(ys,
     LookupError: if one of the operations between `x` and `y` does not
       have a registered gradient function.
     ValueError: if the arguments are invalid.
+    RuntimeError: if called in Eager mode.
 
   """
+  if context.in_eager_mode():
+    raise RuntimeError("tf.gradients not supported in EAGER mode. Use "
+                       "functions in tf.contrib.eager.backprop instead.")
   ys = _AsList(ys)
   xs = _AsList(xs)
+  stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
   if grad_ys is None:
     grad_ys = [None] * len(ys)
   else:
     grad_ys = _AsList(grad_ys)
 
-  with ops.name_scope(name, "gradients", ys + xs + grad_ys) as grad_scope:
+  with ops.name_scope(
+      name, "gradients",
+      list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
     ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
     xs = [x.handle if isinstance(x, resource_variable_ops.ResourceVariable)
           else x
@@ -445,6 +484,7 @@ def gradients(ys,
       ys = [array_ops.identity(y) if y.consumers() else y for y in ys]
     to_ops = [t.op for t in ys]
     from_ops = [t.op for t in xs]
+    stop_gradient_ops = [t.op for t in stop_gradients]
     pending_count, loop_state = _PendingCount(ops.get_default_graph(), to_ops,
                                               from_ops,
                                               colocate_gradients_with_ops)
@@ -483,8 +523,7 @@ def gradients(ys,
           _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
           queue.append(y.op)
 
-    # The set of 'from_ops'.
-    stop_ops = _StopOps(from_ops, pending_count)
+    stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count)
     while queue:
       # generate gradient subgraph for op.
       op = queue.popleft()
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 11c204b5b7f3b720d07a445b852551c4465bd3d1..7a561d046a8d5e6f72e3bdab59c238a88425a706 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -349,6 +349,64 @@ class GradientsTest(test_util.TensorFlowTestCase):
       g = gradients.gradients([z, z2], x)
       self.assertAllClose(17502.0, g[0].eval())
 
+  def testPartialDerivatives(self):
+    with self.test_session():
+      x = constant_op.constant(1.)
+      y = 2 * x
+      z = x + y
+      totalg = gradients.gradients(z, [x, y])
+      self.assertEqual([3.0, 1.0], [g.eval() for g in totalg])
+      partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y])
+      self.assertEqual([1.0, 1.0], [g.eval() for g in partialg])
+
+  def testStopGradients(self):
+    def _MakeGraph(rng, stop_gradients=()):
+      def _FunctionOf(xs, k=3):
+        return ops.convert_to_tensor(
+            sum(math_ops.matmul(rng.rand(k, k), x) for x in xs)
+            + rng.rand(k, k))
+
+      a = _FunctionOf([])
+      if "a" in stop_gradients: a = array_ops.stop_gradient(a)
+      b = _FunctionOf([a])
+      if "b" in stop_gradients: b = array_ops.stop_gradient(b)
+      c = _FunctionOf([a, b])
+      if "c" in stop_gradients: c = array_ops.stop_gradient(c)
+      d = _FunctionOf([b, c])
+      if "d" in stop_gradients: d = array_ops.stop_gradient(d)
+      return dict(a=a, b=b, c=c, d=d)
+
+    def _Gradients(ys, xs, **kwargs):
+      dydxs = gradients.gradients(ys, xs, **kwargs)
+      dydxs = [0. * x if dydx is None else dydx
+               for x, dydx in zip(xs, dydxs)]
+      return dydxs
+
+    seed = np.random.randint(1000)
+    cases = []
+    subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split()
+    graph = _MakeGraph(np.random.RandomState(seed))
+    for constants in subsets:
+      graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants)
+      for variables_ in subsets:
+        # compute the gradient when stopped using tf.stop_gradients
+        grad1 = _Gradients([graph_with_stops["d"]],
+                           [graph_with_stops[v] for v in variables_])
+        # compute the gradient when stopped using the stop_gradients kwarg
+        grad2 = _Gradients([graph["d"]],
+                           [graph[v] for v in variables_],
+                           stop_gradients=[graph[v] for v in constants])
+        cases.append(dict(grad1=grad1, grad2=grad2,
+                          constants=constants, variables=variables_))
+
+    # evaluate all tensors in one call to session.run for speed
+    with self.test_session() as session:
+      results = session.run([(case["grad1"], case["grad2"]) for case in cases])
+
+    for (npgrad1, npgrad2), case in zip(results, cases):
+      for a, b in zip(npgrad1, npgrad2):
+        np.testing.assert_allclose(a, b)
+
 
 class FunctionGradientsTest(test_util.TensorFlowTestCase):
 
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index 75c67dcb3c2ad34b53a86970562fe770f7fd4e69..31485ae9d41ab892a688a3c28b25e55903210ffc 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -22,6 +22,7 @@ See the @{$python/image} guide.
 @@decode_gif
 @@decode_jpeg
 @@encode_jpeg
+@@extract_jpeg_shape
 @@decode_png
 @@encode_png
 @@decode_image
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 19a65037e9e5362100dd93cfc145c28701de0dd5..9e656f0e0831f65ff995d72cac00923d2511d76a 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -386,11 +386,11 @@ class AdjustHueBenchmark(test.Benchmark):
           sess.run(run_op)
     end = time.time()
     step_time = (end - start) / benchmark_rounds
-    tag = "%s" % (cpu_count) if cpu_count is not None else "_all"
-    print("benchmarkAdjustHue_299_299_3_cpu%s step_time: %.2f us" %
+    tag = device + "_%s" % (cpu_count if cpu_count is not None else "_all")
+    print("benchmarkAdjustHue_299_299_3_%s step_time: %.2f us" %
           (tag, step_time * 1e6))
     self.report_benchmark(
-        name="benchmarkAdjustHue_299_299_3_cpu%s" % (tag),
+        name="benchmarkAdjustHue_299_299_3_%s" % (tag),
         iters=benchmark_rounds,
         wall_time=step_time)
 
@@ -432,11 +432,11 @@ class AdjustSaturationBenchmark(test.Benchmark):
           sess.run(run_op)
     end = time.time()
     step_time = (end - start) / benchmark_rounds
-    tag = "%s" % (cpu_count) if cpu_count is not None else "_all"
-    print("benchmarkAdjustSaturation_599_599_3_cpu%s step_time: %.2f us" %
+    tag = device + "_%s" % (cpu_count if cpu_count is not None else "_all")
+    print("benchmarkAdjustSaturation_299_299_3_%s step_time: %.2f us" %
           (tag, step_time * 1e6))
     self.report_benchmark(
-        name="benchmarkAdjustSaturation_599_599_3_cpu%s" % (tag),
+        name="benchmarkAdjustSaturation_299_299_3_%s" % (tag),
         iters=benchmark_rounds,
         wall_time=step_time)
 
@@ -704,7 +704,7 @@ class AdjustSaturationTest(test_util.TensorFlowTestCase):
         "gb_same",
         "rgb_same",
     ]
-    with self.test_session():
+    with self.test_session(use_gpu=True):
       for x_shape in x_shapes:
         for test_style in test_styles:
           x_np = np.random.rand(*x_shape) * 255.
@@ -2455,6 +2455,26 @@ class JpegTest(test_util.TensorFlowTestCase):
         self.assertEqual(image.get_shape().as_list(),
                          [None, None, channels or None])
 
+  def testExtractJpegShape(self):
+    # Read a real jpeg and verify shape.
+    path = ("tensorflow/core/lib/jpeg/testdata/"
+            "jpeg_merge_test1.jpg")
+    with self.test_session(use_gpu=True) as sess:
+      jpeg = io_ops.read_file(path)
+      # Extract shape without decoding.
+      [image_shape] = sess.run([image_ops.extract_jpeg_shape(jpeg)])
+      self.assertEqual(image_shape.tolist(), [256, 128, 3])
+
+  def testExtractJpegShapeforCmyk(self):
+    # Read a cmyk jpeg image, and verify its shape.
+    path = ("tensorflow/core/lib/jpeg/testdata/"
+            "jpeg_merge_test1_cmyk.jpg")
+    with self.test_session(use_gpu=True) as sess:
+      jpeg = io_ops.read_file(path)
+      [image_shape] = sess.run([image_ops.extract_jpeg_shape(jpeg)])
+      # Cmyk jpeg image has 4 channels.
+      self.assertEqual(image_shape.tolist(), [256, 128, 4])
+
 
 class PngTest(test_util.TensorFlowTestCase):
 
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index 5cd5d7ba2f3dc6d88b3f3f37ee70ddd6be62aebd..bd879ac423847c07167672ee5464e146629d5eb7 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -37,6 +37,7 @@ See the @{$python/io_ops} guide.
 @@parse_example
 @@parse_single_example
 @@parse_tensor
+@@serialize_tensor
 @@decode_json_example
 @@QueueBase
 @@FIFOQueue
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
index 9cc05e7a6a65dd3adc7566a8bc5fe07c41d6a006..b5e4e0e0af8b5a94048da52219cedbf469e2e09d 100644
--- a/tensorflow/python/ops/linalg_grad.py
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -39,8 +39,7 @@ def _MatrixInverseGrad(op, grad):
   """Gradient for MatrixInverse."""
   ainv = op.outputs[0]
   return -math_ops.matmul(
-      ainv, math_ops.matmul(
-          grad, ainv, adjoint_b=True), adjoint_a=True)
+      ainv, math_ops.matmul(grad, ainv, adjoint_b=True), adjoint_a=True)
 
 
 @ops.RegisterGradient("MatrixDeterminant")
@@ -49,8 +48,9 @@ def _MatrixDeterminantGrad(op, grad):
   a = op.inputs[0]
   c = op.outputs[0]
   a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
-  multipliers = array_ops.reshape(
-      grad * c, array_ops.concat([array_ops.shape(c), [1, 1]], 0))
+  multipliers = array_ops.reshape(grad * c,
+                                  array_ops.concat([array_ops.shape(c), [1, 1]],
+                                                   0))
   return multipliers * a_adj_inv
 
 
@@ -62,8 +62,11 @@ def _CholeskyGrad(op, grad):
   l = op.outputs[0]
   num_rows = array_ops.shape(l)[-1]
   batch_shape = array_ops.shape(l)[:-2]
-  l_inverse = linalg_ops.matrix_triangular_solve(
-      l, linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=l.dtype))
+  l_inverse = linalg_ops.matrix_triangular_solve(l,
+                                                 linalg_ops.eye(
+                                                     num_rows,
+                                                     batch_shape=batch_shape,
+                                                     dtype=l.dtype))
 
   middle = math_ops.matmul(l, grad, adjoint_a=True)
   middle = array_ops.matrix_set_diag(middle,
@@ -112,15 +115,12 @@ def _MatrixSolveLsGrad(op, grad):
     """
     a = op.inputs[0]
     b = op.inputs[1]
-    l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
     x = op.outputs[0]
-    a_shape = array_ops.shape(a)
-    batch_shape = a_shape[:-2]
-    n = a_shape[-1]
-
-    identity = linalg_ops.eye(n, batch_shape=batch_shape, dtype=a.dtype)
-    gramian = math_ops.matmul(a, a, adjoint_a=True) + l2_regularizer * identity
-    chol = linalg_ops.cholesky(gramian)
+    l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
+    # pylint: disable=protected-access
+    chol = linalg_ops._RegularizedGramianCholesky(
+        a, l2_regularizer=l2_regularizer, first_kind=True)
+    # pylint: enable=protected-access
     # Temporary z = (A^T * A + lambda * I)^{-1} * grad.
     z = linalg_ops.cholesky_solve(chol, grad)
     xzt = math_ops.matmul(x, z, adjoint_b=True)
@@ -141,13 +141,10 @@ def _MatrixSolveLsGrad(op, grad):
     a = op.inputs[0]
     b = op.inputs[1]
     l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
-    a_shape = array_ops.shape(a)
-    batch_shape = a_shape[:-2]
-    m = a_shape[-2]
-
-    identity = linalg_ops.eye(m, batch_shape=batch_shape, dtype=a.dtype)
-    gramian = math_ops.matmul(a, a, adjoint_b=True) + l2_regularizer * identity
-    chol = linalg_ops.cholesky(gramian)
+    # pylint: disable=protected-access
+    chol = linalg_ops._RegularizedGramianCholesky(
+        a, l2_regularizer=l2_regularizer, first_kind=False)
+    # pylint: enable=protected-access
     grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad))
     # Temporary tmp = (A * A^T + lambda * I)^{-1} * B.
     tmp = linalg_ops.cholesky_solve(chol, b)
@@ -203,7 +200,7 @@ def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
   compute_v = op.get_attr("compute_v")
   # a = op.inputs[0], which satisfies
   # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
-  with ops.control_dependencies([grad_e.op, grad_v.op]):
+  with ops.control_dependencies([grad_e, grad_v]):
     if compute_v:
       v = op.outputs[1]
       # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
@@ -219,15 +216,17 @@ def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
       grad_a = math_ops.matmul(
           v,
           math_ops.matmul(
-              array_ops.matrix_diag(grad_e) + f * math_ops.matmul(
-                  v, grad_v, adjoint_a=True),
+              array_ops.matrix_diag(grad_e) +
+              f * math_ops.matmul(v, grad_v, adjoint_a=True),
               v,
               adjoint_b=True))
     else:
       _, v = linalg_ops.self_adjoint_eig(op.inputs[0])
-      grad_a = math_ops.matmul(
-          v, math_ops.matmul(
-              array_ops.matrix_diag(grad_e), v, adjoint_b=True))
+      grad_a = math_ops.matmul(v,
+                               math_ops.matmul(
+                                   array_ops.matrix_diag(grad_e),
+                                   v,
+                                   adjoint_b=True))
     # The forward op only depends on the lower triangular part of a, so here we
     # symmetrize and take the lower triangle
     grad_a = array_ops.matrix_band_part(
diff --git a/tensorflow/python/ops/linalg_impl.py b/tensorflow/python/ops/linalg_impl.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca57653d14f8e4e3350500708f6ce4019838058f
--- /dev/null
+++ b/tensorflow/python/ops/linalg_impl.py
@@ -0,0 +1,56 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operations for linear algebra."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_linalg_ops
+from tensorflow.python.ops import math_ops
+
+
+def logdet(matrix, name=None):
+  """Computes log of the determinant of a hermitian positive definite matrix.
+
+  ```python
+  # Compute the determinant of a matrix while reducing the chance of over- or
+  underflow:
+  A = ... # shape 10 x 10
+  det = tf.exp(tf.logdet(A))  # scalar
+  ```
+
+  Args:
+    matrix:  A `Tensor`. Must be `float32`, `float64`, `complex64`, or
+      `complex128` with shape `[..., M, M]`.
+    name:  A name to give this `Op`.  Defaults to `logdet`.
+
+  Returns:
+    The natural log of the determinant of `matrix`.
+
+  @compatibility(numpy)
+  Equivalent to numpy.linalg.slogdet, although no sign is returned since only
+  hermitian positive definite matrices are supported.
+  @end_compatibility
+  """
+  # This uses the property that the log det(A) = 2*sum(log(real(diag(C))))
+  # where C is the cholesky decomposition of A.
+  with ops.name_scope(name, 'logdet', [matrix]):
+    chol = gen_linalg_ops.cholesky(matrix)
+    return 2.0 * math_ops.reduce_sum(
+        math_ops.log(math_ops.real(array_ops.matrix_diag_part(chol))),
+        reduction_indices=[-1])
diff --git a/tensorflow/python/ops/linalg_ns.py b/tensorflow/python/ops/linalg_ns.py
index 6451f91c72450eb33c67b978b2e5fb7771315c32..ccd7e452a02fa1bbc6ad28e465131f0f7dfe73c9 100644
--- a/tensorflow/python/ops/linalg_ns.py
+++ b/tensorflow/python/ops/linalg_ns.py
@@ -12,7 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Public API for tf.linalg namespace."""
+"""Public API for tf.linalg namespace.
+
+@@logdet
+"""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -23,6 +26,12 @@ from tensorflow.python.ops import linalg_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import special_math_ops
 
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import,unused-import
+from tensorflow.python.ops.linalg_impl import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.all_util import remove_undocumented
+
 # Linear algebra ops.
 band_part = array_ops.matrix_band_part
 cholesky = linalg_ops.cholesky
@@ -60,3 +69,7 @@ del linalg_ops
 del math_ops
 del print_function
 del special_math_ops
+
+_allowed_symbols = []
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index 34fe3a124130e2b45ce0279a240979d0916ef82d..2b11835d85fdeb0352a0007290a99bb43c1393c1 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -23,6 +23,7 @@ import numpy as np
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gen_linalg_ops
 from tensorflow.python.ops import math_ops
 # pylint: disable=wildcard-import
@@ -34,6 +35,47 @@ from tensorflow.python.util import compat
 # pylint: disable=invalid-name
 
 
+def _RegularizedGramianCholesky(matrix, l2_regularizer, first_kind):
+  r"""Computes Cholesky factorization of regularized gramian matrix.
+
+  Below we will use the following notation for each pair of matrix and
+  right-hand sides in the batch:
+
+  `matrix`=\\(A \in \Re^{m \times n}\\),
+  `output`=\\(C  \in \Re^{\min(m, n) \times \min(m,n)}\\),
+  `l2_regularizer`=\\(\lambda\\).
+
+  If `first_kind` is True, returns the Cholesky factorization \\(L\\) such that
+  \\(L L^H =  A^H A + \lambda I\\).
+  If `first_kind` is False, returns the Cholesky factorization \\(L\\) such that
+  \\(L L^H =  A A^H + \lambda I\\).
+
+  Args:
+    matrix: `Tensor` of shape `[..., M, N]`.
+    l2_regularizer: 0-D `double` `Tensor`. Ignored if `fast=False`.
+    first_kind: bool. Controls what gramian matrix to factor.
+  Returns:
+    output: `Tensor` of shape `[..., min(M,N), min(M,N)]` whose inner-most 2
+      dimensions contain the Cholesky factors \\(L\\) described above.
+  """
+
+  gramian = math_ops.matmul(
+      matrix, matrix, adjoint_a=first_kind, adjoint_b=not first_kind)
+  if isinstance(l2_regularizer, ops.Tensor) or l2_regularizer != 0:
+    matrix_shape = array_ops.shape(matrix)
+    batch_shape = matrix_shape[:-2]
+    if first_kind:
+      small_dim = matrix_shape[-1]
+    else:
+      small_dim = matrix_shape[-2]
+    identity = eye(small_dim, batch_shape=batch_shape, dtype=matrix.dtype)
+    small_dim_static = matrix.shape[-1 if first_kind else -2]
+    identity.set_shape(
+        matrix.shape[:-2].concatenate([small_dim_static, small_dim_static]))
+    gramian += l2_regularizer * identity
+  return gen_linalg_ops.cholesky(gramian)
+
+
 def cholesky_solve(chol, rhs, name=None):
   """Solves systems of linear eqns `A X = RHS`, given Cholesky factorizations.
 
@@ -195,10 +237,90 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
       `M`-by-`K` matrices that solve the equations
       `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]` in the least
       squares sense.
+
+  Raises:
+    NotImplementedError: matrix_solve_ls is currently disabled for complex128
+    and l2_regularizer != 0 due to poor accuracy.
   """
-  # pylint: disable=protected-access
-  return gen_linalg_ops._matrix_solve_ls(
-      matrix, rhs, l2_regularizer, fast=fast, name=name)
+
+  # pylint: disable=protected-access,long-lambda
+  def _use_composite_impl(fast, tensor_shape):
+    """Determines whether to use the composite or specialized CPU kernel.
+
+    When the total size of the tensor is larger than the cache size and the
+    batch size is large compared to the smallest matrix dimension, then the
+    composite implementation is inefficient since it has to read the entire
+    tensor from memory multiple times. In this case we fall back to the
+    original CPU kernel, which does all the computational steps on each
+    matrix separately.
+
+    Only fast mode is supported by the composite impl, so `False` is returned
+    if `fast` is `False`.
+
+    Args:
+      fast: bool indicating if fast mode in the solver was requested.
+      tensor_shape: The shape of the tensor.
+
+    Returns:
+      True if the composite impl should be used. False otherwise.
+    """
+    if fast is False:
+      return False
+    batch_shape = tensor_shape[:-2]
+    matrix_shape = tensor_shape[-2:]
+    if not tensor_shape.is_fully_defined():
+      return True
+    tensor_size = tensor_shape.num_elements() * matrix.dtype.size
+    is_io_bound = batch_shape.num_elements() > np.min(matrix_shape)
+    L2_CACHE_SIZE_GUESSTIMATE = 256000
+    if tensor_size > L2_CACHE_SIZE_GUESSTIMATE and is_io_bound:
+      return False
+    else:
+      return True
+
+  def _overdetermined(matrix, rhs, l2_regularizer):
+    """Computes (A^H*A + l2_regularizer)^{-1} * A^H * rhs."""
+    chol = _RegularizedGramianCholesky(
+        matrix, l2_regularizer=l2_regularizer, first_kind=True)
+    return cholesky_solve(chol, math_ops.matmul(matrix, rhs, adjoint_a=True))
+
+  def _underdetermined(matrix, rhs, l2_regularizer):
+    """Computes A^H * (A*A^H + l2_regularizer)^{-1} * rhs."""
+    chol = _RegularizedGramianCholesky(
+        matrix, l2_regularizer=l2_regularizer, first_kind=False)
+    return math_ops.matmul(matrix, cholesky_solve(chol, rhs), adjoint_a=True)
+
+  def _composite_impl(matrix, rhs, l2_regularizer):
+    """Composite implementation of matrix_solve_ls that supports GPU."""
+    with ops.name_scope(name, 'matrix_solve_ls', [matrix, rhs, l2_regularizer]):
+      matrix_shape = matrix.get_shape()[-2:]
+      if matrix_shape.is_fully_defined():
+        if matrix_shape[-2] >= matrix_shape[-1]:
+          return _overdetermined(matrix, rhs, l2_regularizer)
+        else:
+          return _underdetermined(matrix, rhs, l2_regularizer)
+      else:
+        # We have to defer determining the shape to runtime and use
+        # conditional execution of the appropriate graph.
+        matrix_shape = array_ops.shape(matrix)[-2:]
+        return control_flow_ops.cond(
+            matrix_shape[-2] >= matrix_shape[-1],
+            lambda: _overdetermined(matrix, rhs, l2_regularizer),
+            lambda: _underdetermined(matrix, rhs, l2_regularizer))
+
+  matrix = ops.convert_to_tensor(matrix, name='matrix')
+  if matrix.dtype == dtypes.complex128 and l2_regularizer != 0:
+    # TODO(rmlarsen): Investigate and fix accuracy bug.
+    raise NotImplementedError('matrix_solve_ls is currently disabled for '
+                              'complex128 and l2_regularizer != 0 due to '
+                              'poor accuracy.')
+  tensor_shape = matrix.get_shape()
+  if _use_composite_impl(fast, tensor_shape):
+    return _composite_impl(matrix, rhs, l2_regularizer)
+  else:
+    return gen_linalg_ops._matrix_solve_ls(
+        matrix, rhs, l2_regularizer, fast=fast, name=name)
+  # pylint: enable=protected-access
 
 
 def self_adjoint_eig(tensor, name=None):
@@ -291,7 +413,7 @@ def svd(tensor, full_matrices=False, compute_uv=True, name=None):
   """
   # pylint: disable=protected-access
   s, u, v = gen_linalg_ops._svd(
-      tensor, compute_uv=compute_uv, full_matrices=full_matrices)
+      tensor, compute_uv=compute_uv, full_matrices=full_matrices, name=name)
   # pylint: enable=protected-access
   if compute_uv:
     return math_ops.real(s), u, v
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index ab9e3b0af9bed968577ec269086c4dd2c3983cd2..f9b1733dda621df38e89c07ea9f46fd6f745261d 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -21,6 +21,7 @@ from __future__ import print_function
 import collections
 import functools
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -152,9 +153,13 @@ class InitializableLookupTableBase(LookupInterface):
       default_value: The value to use if a key is missing in the table.
       initializer: The table initializer to use.
     """
+    if context.in_graph_mode():
+      name = table_ref.op.name.split("/")[-1]
+    else:
+      name = context.context().scope_name
     super(InitializableLookupTableBase,
           self).__init__(initializer.key_dtype, initializer.value_dtype,
-                         table_ref.op.name.split("/")[-1])
+                         name)
     self._table_ref = table_ref
     self._default_value = ops.convert_to_tensor(
         default_value, dtype=self._value_dtype)
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 3d70465e68a5eeb9fa3587241cf7c49b31b646c6..4175f6ec44693c95d34f91ef3ea9bfbf330515ec 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -750,6 +750,12 @@ def _FloorDivGrad(_, unused_grad):
   return None, None
 
 
+@ops.RegisterGradient("FloorMod")
+def _FloorModGrad(_, unused_grad):
+  """The gradient for the FloorMod operator."""
+  return None, None
+
+
 @ops.RegisterGradient("TruncateDiv")
 def _TruncateDivGrad(_, unused_grad):
   return None, None
@@ -903,7 +909,7 @@ def _SparseMatMulGrad(op, grad):
       op.inputs[0]: op.get_attr("a_is_sparse"),
       op.inputs[1]: op.get_attr("b_is_sparse"),
       # Use heuristic to figure out if grad might be sparse
-      grad: (grad.op.type == "ReluGrad")
+      grad: context.in_graph_mode() and (grad.op.type == "ReluGrad")
   }
 
   def _SparseMatMul(t1, t2, out_dtype, transpose_a=False, transpose_b=False):
@@ -1027,7 +1033,7 @@ def _ImagGrad(_, grad):
 def _AngleGrad(op, grad):
   """Returns -grad / (Im(x) + iRe(x))"""
   x = op.inputs[0]
-  with ops.control_dependencies([grad.op]):
+  with ops.control_dependencies([grad]):
     re = math_ops.real(x)
     im = math_ops.imag(x)
     z = math_ops.reciprocal(math_ops.complex(im, re))
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index fd8a5daa334e456b6c25e34021136aeaf343dcc4..6559929560c3b3a820b6edd8c3798fa1815bffd1 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -144,9 +144,11 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from autograd import core as ag_core
 import numpy as np
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import common_shapes
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -1112,11 +1114,12 @@ floormod = gen_math_ops._floor_mod
 
 def _mul_dispatch(x, y, name=None):
   """Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse"."""
-  is_tensor_y = isinstance(y, ops.Tensor)
+  is_tensor_y = isinstance(ag_core.getval(y), ops.Tensor)
   if is_tensor_y:
     return gen_math_ops._mul(x, y, name=name)
   else:
-    assert isinstance(y, sparse_tensor.SparseTensor)  # Case: Dense * Sparse.
+    assert isinstance(ag_core.getval(y),
+                      sparse_tensor.SparseTensor)  # Case: Dense * Sparse.
     new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values,
                                                      y.dense_shape, x, name)
     return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape)
@@ -2039,6 +2042,10 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
     ValueError: If `inputs` don't all have same shape and dtype or the shape
     cannot be inferred.
   """
+  if context.in_eager_mode():
+    # TODO(apassos) remove this once the lifetime of eager variables gets
+    # addressed.
+    raise ValueError("accumulate_n not supported in eager mode")
   if not inputs or not isinstance(inputs, (list, tuple)):
     raise ValueError("inputs must be a list of at least one Tensor with the "
                      "same dtype and shape")
@@ -2060,7 +2067,7 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
     tensor_dtype = inputs[0].dtype
   if tensor_dtype != inputs[0].dtype:
     raise TypeError("tensor_dtype is {}, but input is of type {}"
-                     .format(tensor_dtype, inputs[0].dtype))
+                    .format(tensor_dtype, inputs[0].dtype))
   if len(inputs) == 1:
     return inputs[0]
   with ops.name_scope(name, "AccumulateN", inputs) as name:
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index f1103f209c659d2323af91bcdf3f523e8347031f..4642f4c580fbf5401af4c6a5ec43851e67a0af8b 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
 
 import numpy as np
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
@@ -26,6 +27,7 @@ from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gradients
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import googletest
 
@@ -37,23 +39,32 @@ log = np.log
 
 class ReduceTest(test_util.TensorFlowTestCase):
 
+  @test_util.run_in_graph_and_eager_modes()
   def testReduceAllDims(self):
     x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
-    with self.test_session(use_gpu=True):
-      y_tf = math_ops.reduce_sum(x).eval()
+    with test_util.device(use_gpu=True):
+      y_tf = self.evaluate(math_ops.reduce_sum(x))
       self.assertEqual(y_tf, 21)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testReduceExplicitAxes(self):
     x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
-    with self.test_session(use_gpu=True):
+    with test_util.device(use_gpu=True):
       for axis in (0, -2, (0, 0), (0, -2)):
-        self.assertAllEqual(math_ops.reduce_sum(x, axis=axis).eval(), [5, 7, 9])
+        self.assertAllEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)),
+                            [5, 7, 9])
       for axis in (1, -1, (1, 1), (1, -1)):
-        self.assertAllEqual(math_ops.reduce_sum(x, axis=axis).eval(), [6, 15])
+        self.assertAllEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)),
+                            [6, 15])
       for axis in (None, (0, 1), (-1, -2), (-2, -1, 0, 1)):
-        self.assertEqual(math_ops.reduce_sum(x, axis=axis).eval(), 21)
+        self.assertEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)), 21)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testReduceInvalidAxis(self):
+    if context.in_eager_mode():
+      # The shape check is in run a graph contruction time. In eager mode,
+      # it misses the check, magically return result given wrong shape.
+      return
     x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
     axis = np.array([[0], [1]])
     with self.assertRaisesRegexp(errors.InvalidArgumentError,
@@ -142,16 +153,17 @@ class LogSumExpTest(test_util.TensorFlowTestCase):
 
 class RoundTest(test_util.TensorFlowTestCase):
 
+  @test_util.run_in_graph_and_eager_modes()
   def testRounding(self):
     x = [0.49, 0.7, -0.3, -0.8]
     # TODO(nolivia): Remove this when RoundOp is forwards compatible
     # x = np.arange(-5.0, 5.0, .25)
     for dtype in [np.float32, np.double, np.int32]:
       x_np = np.array(x, dtype=dtype)
-      with self.test_session(use_gpu=True):
+      with test_util.device(use_gpu=True):
         x_tf = constant_op.constant(x_np, shape=x_np.shape)
         y_tf = math_ops.round(x_tf)
-        y_tf_np = y_tf.eval()
+        y_tf_np = self.evaluate(y_tf)
         y_np = np.round(x_np)
         self.assertAllClose(y_tf_np, y_np, atol=1e-2)
 
@@ -187,77 +199,87 @@ class ModTest(test_util.TensorFlowTestCase):
 
 class SquaredDifferenceTest(test_util.TensorFlowTestCase):
 
+  @test_util.run_in_graph_and_eager_modes()
   def testSquaredDifference(self):
     for dtype in [np.int32, np.float16]:
       x = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
       y = np.array([-3, -2, -1], dtype=dtype)
       z = (x - y) * (x - y)
-      with self.test_session(use_gpu=True):
-        z_tf = math_ops.squared_difference(x, y).eval()
+      with test_util.device(use_gpu=True):
+        z_tf = self.evaluate(math_ops.squared_difference(x, y))
         self.assertAllClose(z, z_tf)
 
 
 class ApproximateEqualTest(test_util.TensorFlowTestCase):
 
+  @test_util.run_in_graph_and_eager_modes()
   def testApproximateEqual(self):
     for dtype in [np.float32, np.double]:
       x = dtype(1)
       y = dtype(1.00009)
       z = False
-      with self.test_session(use_gpu=True):
+      with test_util.device(use_gpu=True):
         # Default tolerance is 0.00001
-        z_tf = math_ops.approximate_equal(x, y).eval()
+        z_tf = self.evaluate(math_ops.approximate_equal(x, y))
         self.assertAllEqual(z, z_tf)
 
     for dtype in [np.float32, np.double]:
       x = dtype(1)
       y = dtype(1.000009)
       z = True
-      with self.test_session(use_gpu=True):
+      with test_util.device(use_gpu=True):
         # Default tolerance is 0.00001
-        z_tf = math_ops.approximate_equal(x, y).eval()
+        z_tf = self.evaluate(math_ops.approximate_equal(x, y))
         self.assertAllEqual(z, z_tf)
 
     for dtype in [np.float32, np.double]:
       x = np.array([[[[-1, 2.00009999], [-3, 4.01]]]], dtype=dtype)
       y = np.array([[[[-1.001, 2], [-3.00009, 4]]]], dtype=dtype)
       z = np.array([[[[False, True], [True, False]]]], dtype=np.bool)
-      with self.test_session(use_gpu=True):
-        z_tf = math_ops.approximate_equal(x, y, tolerance=0.0001).eval()
+      with test_util.device(use_gpu=True):
+        z_tf = self.evaluate(math_ops.approximate_equal(x, y, tolerance=0.0001))
         self.assertAllEqual(z, z_tf)
 
 
 class ScalarMulTest(test_util.TensorFlowTestCase):
 
+  @test_util.run_in_graph_and_eager_modes()
   def testAcceptsRefs(self):
-    var = variables.Variable(10)
+    if context.in_eager_mode():
+      var = resource_variable_ops.ResourceVariable(10, name="var")
+    else:
+      var = variables.Variable(10)
     result = math_ops.scalar_mul(3, var)
     init = variables.global_variables_initializer()
-    with self.test_session(use_gpu=True) as sess:
-      sess.run(init)
-      self.assertEqual(30, result.eval())
+    with test_util.device(use_gpu=True):
+      self.evaluate(init)
+      self.assertEqual(30, self.evaluate(result))
 
+  @test_util.run_in_graph_and_eager_modes()
   def testAcceptsConstant(self):
     const = constant_op.constant(10)
     result = math_ops.scalar_mul(3, const)
-    with self.test_session(use_gpu=True):
-      self.assertEqual(30, result.eval())
+    with test_util.device(use_gpu=True):
+      self.assertEqual(30, self.evaluate(result))
 
+  @test_util.run_in_graph_and_eager_modes()
   def testAcceptsTensor(self):
     tensor = array_ops.ones([10, 10])
     result = math_ops.scalar_mul(3, tensor)
     expected = array_ops.ones([10, 10]) * 3
 
-    with self.test_session(use_gpu=True):
-      self.assertAllEqual(expected.eval(), result.eval())
+    with test_util.device(use_gpu=True):
+      self.assertAllEqual(self.evaluate(expected), self.evaluate(result))
 
+  @test_util.run_in_graph_and_eager_modes()
   def testAcceptsIndexedSlices(self):
     values = constant_op.constant([2, 3, 5, 7, 0, -1], shape=[3, 2])
     indices = constant_op.constant([0, 2, 5])
     x = math_ops.scalar_mul(-3, ops.IndexedSlices(values, indices))
-    with self.test_session(use_gpu=True):
-      self.assertAllEqual(x.values.eval(), [[-6, -9], [-15, -21], [0, 3]])
-      self.assertAllEqual(x.indices.eval(), [0, 2, 5])
+    with test_util.device(use_gpu=True):
+      self.assertAllEqual(self.evaluate(x.values),
+                          [[-6, -9], [-15, -21], [0, 3]])
+      self.assertAllEqual(self.evaluate(x.indices), [0, 2, 5])
 
 
 class AccumulateNTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 3b0a357b16452741ef352b9628e2e932fc31714a..eb0b08c5fd3f70863fd795761273040e24b50fe9 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -461,12 +461,22 @@ def _confusion_matrix_at_thresholds(
   else:
     for include in includes:
       if include not in all_includes:
-        raise ValueError('Invaild key: %s.' % include)
-
-  predictions, labels, weights = _remove_squeezable_dimensions(
-      predictions=math_ops.to_float(predictions),
-      labels=math_ops.cast(labels, dtype=dtypes.bool),
-      weights=weights)
+        raise ValueError('Invalid key: %s.' % include)
+
+  with ops.control_dependencies([
+      check_ops.assert_greater_equal(
+          predictions,
+          math_ops.cast(0.0, dtype=predictions.dtype),
+          message='predictions must be in [0, 1]'),
+      check_ops.assert_less_equal(
+          predictions,
+          math_ops.cast(1.0, dtype=predictions.dtype),
+          message='predictions must be in [0, 1]')
+  ]):
+    predictions, labels, weights = _remove_squeezable_dimensions(
+        predictions=math_ops.to_float(predictions),
+        labels=math_ops.cast(labels, dtype=dtypes.bool),
+        weights=weights)
 
   num_thresholds = len(thresholds)
 
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index df5e66664f19a297a157eb1d5d4871a5962b32ec..4ad0603e56c425c5ccac078a51c0652358e5dbf5 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.eager import context
+from tensorflow.python.eager import tensor
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_util
@@ -373,7 +375,7 @@ def _SoftplusGradGrad(op, grad):
   #   dx = gen_nn_ops._softplus_grad(dy, x) = dy / (1 + exp(-x))
   # This op computes (ddy, d2x) from op.inputs == [dy, x] and grad == ddx.
   dy, x = op.inputs
-  with ops.control_dependencies([grad.op]):
+  with ops.control_dependencies([grad]):
     ddy = gen_nn_ops._softplus_grad(grad, x)  # pylint: disable=protected-access
     d2x = grad * dy / (math_ops.exp(-x) + 2.0 + math_ops.exp(x))
     return (ddy, d2x)
@@ -420,6 +422,8 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
 
   def IsZero(g):
     # Some introspection to check if the gradient is feeding zeros
+    if context.in_eager_mode():
+      return isinstance(g, tensor.LazyZero)
     if g.op.type in ("ZerosLike", "Zeros"):
       return True
     const_fill_value = tensor_util.constant_value(g)
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 467567f30e37db004b25107948faef6eb1bcd2d8..d4b16635071b594888727d88664b54d293722a2c 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1368,21 +1368,21 @@ def _flatten_outer_dims(logits):
   output = array_ops.reshape(logits, array_ops.concat([[-1], last_dim_size], 0))
 
   # Set output shape if known.
-  shape = logits.get_shape()
-  if shape is not None and shape.dims is not None:
-    shape = shape.as_list()
-    product = 1
-    product_valid = True
-    for d in shape[:-1]:
-      if d is None:
-        product_valid = False
-        break
-      else:
-        product *= d
-    # Only need to set shape if in graph mode
-    if product_valid and context.in_graph_mode():
-      output_shape = [product, shape[-1]]
-      output.set_shape(output_shape)
+  if context.in_graph_mode():
+    shape = logits.get_shape()
+    if shape is not None and shape.dims is not None:
+      shape = shape.as_list()
+      product = 1
+      product_valid = True
+      for d in shape[:-1]:
+        if d is None:
+          product_valid = False
+          break
+        else:
+          product *= d
+      if product_valid:
+        output_shape = [product, shape[-1]]
+        output.set_shape(output_shape)
 
   return output
 
@@ -1605,7 +1605,7 @@ def softmax_cross_entropy_with_logits(_sentinel=None,  # pylint: disable=invalid
 
   # Make shape inference work since reshape and transpose may erase its static
   # shape.
-  if shape is not None and shape.dims is not None and context.in_graph_mode():
+  if context.in_graph_mode() and shape is not None and shape.dims is not None:
     shape = shape.as_list()
     del shape[dim]
     cost.set_shape(shape)
@@ -1922,7 +1922,8 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):  # pylint: di
     # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
     binary_tensor = math_ops.floor(random_tensor)
     ret = math_ops.div(x, keep_prob) * binary_tensor
-    ret.set_shape(x.get_shape())
+    if context.in_graph_mode():
+      ret.set_shape(x.get_shape())
     return ret
 
 
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index e0e3d08e7ce4c777ae1a8ed0d3c58a2108917528..c5fd15bae4566961af8e7c63147df27c5b9a6575 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -40,6 +40,7 @@ from tensorflow.python.platform import tf_logging
 
 ops.NotDifferentiable("DecodeRaw")
 ops.NotDifferentiable("ParseTensor")
+ops.NotDifferentiable("SerializeTensor")
 ops.NotDifferentiable("StringToNumber")
 
 
@@ -198,7 +199,11 @@ def _features_to_raw_params(features, types):
   sparse_types = []
   dense_keys = []
   dense_types = []
-  dense_defaults = {}
+  # When the graph is built twice, multiple dense_defaults in a normal dict
+  # could come out in different orders. This will fail the _e2e_test which
+  # expects exactly the same graph.
+  # OrderedDict which preserves the order can solve the problem.
+  dense_defaults = collections.OrderedDict()
   dense_shapes = []
   if features:
     # NOTE: We iterate over sorted keys to keep things deterministic.
@@ -624,7 +629,8 @@ def _parse_example_raw(serialized,
   """
   with ops.name_scope(name, "ParseExample", [serialized, names]):
     names = [] if names is None else names
-    dense_defaults = {} if dense_defaults is None else dense_defaults
+    dense_defaults = collections.OrderedDict(
+    ) if dense_defaults is None else dense_defaults
     sparse_keys = [] if sparse_keys is None else sparse_keys
     sparse_types = [] if sparse_types is None else sparse_types
     dense_keys = [] if dense_keys is None else dense_keys
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index ed344c0b148461cfb10649a74f780e84cbe88176..fdc8a5843fee80883956dc6ac5871af420bb1b56 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -19,13 +19,18 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from autograd import core as ag_core
+
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import variable_pb2
 from tensorflow.python.eager import context
 from tensorflow.python.eager import custom_gradient
+from tensorflow.python.eager import tape
+from tensorflow.python.eager import tensor_node
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_array_ops
 from tensorflow.python.ops import gen_resource_variable_ops
@@ -37,6 +42,29 @@ from tensorflow.python.ops.gen_resource_variable_ops import *
 from tensorflow.python.util import compat
 
 
+def _eager_safe_variable_handle(shape, dtype, shared_name, name,
+                                container=None):
+  """Creates a variable handle with information to do shape inference."""
+  handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
+                                                   shared_name=shared_name,
+                                                   name=name,
+                                                   container=container)
+  if context.in_graph_mode():
+    return handle
+  with context.graph_mode(), ops.Graph().as_default():
+    h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
+                                                shared_name=shared_name,
+                                                name=name,
+                                                container=container)
+
+    # Tensor._handle_data contains information for the shape-inference code to
+    # know the shape and dtype of the variable pointed to by a handle. Since
+    # shape inference doesn't run in eager mode we copy this data here for when
+    # the handle is captured by an eager mode function.
+    handle._handle_data = h._handle_data  # pylint: disable=protected-access
+  return handle
+
+
 class ResourceVariable(variables.Variable):
   """Variable based on resource handles.
 
@@ -223,10 +251,11 @@ class ResourceVariable(variables.Variable):
     if constraint is not None and not callable(constraint):
       raise ValueError("The `constraint` argument must be a callable.")
 
+    self._trainable = trainable
     if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
       collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
     self._save_slice_info = None
-    in_graph_mode = context.in_graph_mode()
+    self._in_graph_mode = context.in_graph_mode()
     with ops.control_dependencies(None):
       with ops.name_scope(name, "Variable", []
                           if init_from_fn else [initial_value]) as name:
@@ -236,7 +265,7 @@ class ResourceVariable(variables.Variable):
           # Use attr_scope and device(None) to simulate the behavior of
           # colocate_with when the variable we want to colocate with doesn't
           # yet exist.
-          if in_graph_mode:
+          if self._in_graph_mode:
             attr = attr_value_pb2.AttrValue(
                 list=attr_value_pb2.AttrValue.ListValue(
                     s=[compat.as_bytes("loc:@%s" % handle_name)]))
@@ -244,21 +273,28 @@ class ResourceVariable(variables.Variable):
               with ops.name_scope("Initializer"), ops.device(None):
                 initial_value = ops.convert_to_tensor(
                     initial_value(), name="initial_value", dtype=dtype)
-              self._handle = gen_resource_variable_ops.var_handle_op(
+              self._handle = _eager_safe_variable_handle(
                   shape=initial_value.get_shape(),
                   dtype=initial_value.dtype.base_dtype,
                   shared_name=handle_name,
                   name=name)
+              self._handle_device = (
+                  self._handle.device if self._in_graph_mode else
+                  context.get_default_context().device_name)
           else:
             initial_value = initial_value()
-            initial_value = ops.convert_to_tensor(
-                initial_value, name="initial_value", dtype=dtype)
-            self._handle = gen_resource_variable_ops.var_handle_op(
+            with ops.name_scope("Initializer"):
+              initial_value = ops.convert_to_tensor(
+                  initial_value, name="initial_value", dtype=dtype)
+            self._handle = _eager_safe_variable_handle(
                 shape=initial_value.get_shape(),
                 dtype=initial_value.dtype.base_dtype,
                 shared_name=handle_name,
                 name=name,
                 container="")
+            self._handle_device = (
+                self._handle.device if self._in_graph_mode else
+                context.get_default_context().device_name)
         # pylint: enable=protected-access
 
         # Or get the initial value from a Tensor or Python object.
@@ -267,8 +303,7 @@ class ResourceVariable(variables.Variable):
             initial_value = ops.convert_to_tensor(
                 initial_value, name="initial_value", dtype=dtype)
           # pylint: disable=protected-access
-          if (in_graph_mode and
-              initial_value is not None and
+          if (self._in_graph_mode and initial_value is not None and
               initial_value.op._get_control_flow_context() is not None):
             raise ValueError(
                 "Initializer for variable %s is from inside a control-flow "
@@ -276,19 +311,21 @@ class ResourceVariable(variables.Variable):
                 "variable inside a loop or conditional, use a lambda as the "
                 "initializer." % name)
           # pylint: enable=protected-access
-          self._handle = gen_resource_variable_ops.var_handle_op(
+          self._handle = _eager_safe_variable_handle(
               shape=initial_value.get_shape(),
               dtype=initial_value.dtype.base_dtype,
               shared_name=handle_name,
               name=name,
               container="")
+          self._handle_device = (self._handle.device if self._in_graph_mode else
+                                 context.get_default_context().device_name)
 
-        self._initial_value = initial_value if in_graph_mode else None
+        self._initial_value = initial_value if self._in_graph_mode else None
         self._handle_name = handle_name + ":0"
         self._dtype = initial_value.dtype.base_dtype
         self._constraint = constraint
 
-        if in_graph_mode:
+        if self._in_graph_mode:
           with ops.name_scope("IsInitialized"):
             self._is_initialized_op = (
                 gen_resource_variable_ops.var_is_initialized_op(self._handle))
@@ -302,8 +339,8 @@ class ResourceVariable(variables.Variable):
           with ops.name_scope("Read"), ops.colocate_with(self._handle):
             # Manually assign reads to the handle's device to avoid log
             # messages.
-            with ops.device(self._handle.device):
-              value = read_variable_op(self._handle, dtype=self._dtype)
+            with ops.device(self._handle_device):
+              value = self._read_variable_op()
             self._graph_element = value
             if caching_device is not None:
               # Variables may be created in a tf.device() or ops.colocate_with()
@@ -326,14 +363,14 @@ class ResourceVariable(variables.Variable):
           self._graph_element = None
           if caching_device:
             with ops.device(caching_device):
-              self._cached_value = read_variable_op(self._handle,
-                                                    dtype=self._dtype)
+              self._cached_value = self._read_variable_op()
           else:
             self._cached_value = None
         ops.add_to_collections(collections, self)
 
   def _init_from_proto(self, variable_def, import_scope=None):
     """Initializes from `VariableDef` proto."""
+    # Note that init_from_proto is currently not supported in Eager mode.
     assert context.in_graph_mode()
     assert isinstance(variable_def, variable_pb2.VariableDef)
     if not variable_def.is_resource:
@@ -344,6 +381,7 @@ class ResourceVariable(variables.Variable):
     self._handle = g.as_graph_element(
         ops.prepend_name_scope(
             variable_def.variable_name, import_scope=import_scope))
+    self._handle_device = self._handle.device
     self._handle_name = self._handle.name
     self._initializer_op = g.as_graph_element(
         ops.prepend_name_scope(
@@ -372,7 +410,7 @@ class ResourceVariable(variables.Variable):
   @property
   def device(self):
     """The device this variable is on."""
-    return self._handle.device
+    return self._handle_device
 
   @property
   def graph(self):
@@ -387,10 +425,11 @@ class ResourceVariable(variables.Variable):
   @property
   def shape(self):
     """The shape of this variable."""
-    if context.in_graph_mode():
+    if self._in_graph_mode:
       return tensor_shape.TensorShape(self._handle.op.get_attr("shape"))
     return tensor_shape.TensorShape(
-        gen_resource_variable_ops.variable_shape(self._handle).numpy())
+        tensor_util.constant_value(
+            gen_resource_variable_ops.variable_shape(self._handle)))
 
   @property
   def create(self):
@@ -409,8 +448,8 @@ class ResourceVariable(variables.Variable):
     if self._cached_value is not None:
       return self._cached_value
     with ops.colocate_with(None, ignore_existing=True):
-      with ops.device(self._handle.device):
-        return read_variable_op(self._handle, dtype=self._dtype)
+      with ops.device(self._handle_device):
+        return self._read_variable_op()
 
   def _as_graph_element(self):
     """Conversion function for Graph.as_graph_element()."""
@@ -425,7 +464,7 @@ class ResourceVariable(variables.Variable):
   def initial_value(self):
     """Returns the Tensor used as the initial value for the variable."""
     if context.in_eager_mode():
-      raise RuntimeError("initial_value not supported in EAGER mode.""")
+      raise RuntimeError("initial_value not supported in EAGER mode.")
     return self._initial_value
 
   @property
@@ -460,6 +499,14 @@ class ResourceVariable(variables.Variable):
   def _get_save_slice_info(self):
     return self._save_slice_info
 
+  def _read_variable_op(self):
+    if hasattr(self, "_trainable") and self._trainable:
+      tape.watch(self._handle)
+      return read_variable_op(self._handle, dtype=self._dtype)
+    else:
+      return gen_resource_variable_ops.read_variable_op(self._handle,
+                                                        self._dtype)
+
   def read_value(self):
     """Constructs an op which reads the value of this variable.
 
@@ -477,10 +524,10 @@ class ResourceVariable(variables.Variable):
       # separate notions of device and memory, so handle.device can be GPU while
       # handle.memory_space is always CPU.
       if context.in_graph_mode():
-        with ops.device(self._handle.device):
-          value = read_variable_op(self._handle, dtype=self._dtype)
+        with ops.device(self._handle_device):
+          value = self._read_variable_op()
       else:
-        value = read_variable_op(self._handle, dtype=self._dtype)
+        value = self._read_variable_op()
     # Return an identity so it can get placed on whatever device the context
     # specifies instead of the device where the variable is.
     return array_ops.identity(value)
@@ -488,6 +535,8 @@ class ResourceVariable(variables.Variable):
   def sparse_read(self, indices, name=None):
     """Reads the value of this variable sparsely, using `gather`."""
     with ops.name_scope("Gather" if name is None else name) as name:
+      if self._trainable:
+        tape.watch(self._handle)
       value = resource_gather(
           self._handle, indices, dtype=self._dtype, name=name)
     return array_ops.identity(value)
@@ -560,7 +609,14 @@ class ResourceVariable(variables.Variable):
 
     def _run_op(a, *args):
       # pylint: disable=protected-access
-      return getattr(ops.Tensor, operator)(a._AsTensor(), *args)
+      value = a._AsTensor()
+      if ag_core.isnode(value):
+        # This avoids autograd trying to wrap a ResourceVariable.
+        value = ops.convert_to_tensor(value)
+        args = [ops.convert_to_tensor(x) for x in args]
+        return getattr(tensor_node.TensorNode, operator)(value, *args)
+      else:
+        return getattr(ops.Tensor, operator)(value, *args)
 
     # Propagate __doc__ to wrapper
     try:
@@ -653,8 +709,10 @@ def read_variable_op(handle, dtype):
     A `Tensor` of type `dtype`.
   """
   result = gen_resource_variable_ops.read_variable_op(handle, dtype)
+
   def grad(dresult):
     return dresult
+
   return result, grad
 
 
@@ -720,7 +778,8 @@ def resource_gather(resource, indices, dtype, validate_indices=True, name=None):
 
   def grad(dresult):
     return ops.IndexedSlices(
-        dresult, indices,
+        dresult,
+        indices,
         dense_shape=gen_resource_variable_ops.variable_shape(resource))
 
   return result, grad
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index b1626feb27ac04b375c279a5f4c92aee0ac0d422..25a0ad0a37e33b0732e2ec038615e93d843a7def 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -28,6 +28,7 @@ import collections
 import hashlib
 import numbers
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -124,9 +125,10 @@ def _zero_state_tensors(state_size, batch_size, dtype):
   def get_state_shape(s):
     """Combine s with batch_size to get a proper tensor shape."""
     c = _concat(batch_size, s)
-    c_static = _concat(batch_size, s, static=True)
     size = array_ops.zeros(c, dtype=dtype)
-    size.set_shape(c_static)
+    if context.in_graph_mode():
+      c_static = _concat(batch_size, s, static=True)
+      size.set_shape(c_static)
     return size
   return nest.map_structure(get_state_shape, state_size)
 
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 5a179048b17d28faf8b8d10ff39930a8897bbc31..e3990791c62139b0809a07cfa55c37e8f8f4c70a 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -296,7 +296,7 @@ def sparse_add(a, b, thresh=0):
     a = _convert_to_sparse_tensor(a)
     b = _convert_to_sparse_tensor(b)
     thresh = ops.convert_to_tensor(
-        thresh, dtype=a.values.dtype.real_dtype, name="thresh")
+        thresh, dtype=a.values.dtype.real_dtype.base_dtype, name="thresh")
     output_ind, output_val, output_shape = (gen_sparse_ops._sparse_add(
         a.indices, a.values, a.dense_shape,
         b.indices, b.values, b.dense_shape,
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index 684d34b170dcaa006de98d58bf69ba2b63e7a960..f54bbfe90e9a9913dae90f8d7310cae4a71c4627 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -81,6 +81,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import gen_resource_variable_ops
@@ -183,7 +184,7 @@ def is_variable_initialized(ref, name=None):
   if ref.dtype._is_ref_dtype:
     return gen_state_ops.is_variable_initialized(ref=ref, name=name)
   # Handle resource variables.
-  if ref.op.type == "VarHandleOp":
+  if context.in_eager_mode() or ref.op.type == "VarHandleOp":
     return gen_resource_variable_ops.var_is_initialized_op(ref.handle,
                                                            name=name)
 
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index 20ae082ee12bd9278bcacfa05c676cff7aa605af..08325ba7710d5f7007f5c55934f15ab5a4015536 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -24,6 +24,7 @@ from __future__ import print_function
 
 import contextlib
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
@@ -445,7 +446,7 @@ class TensorArray(object):
       ta._infer_shape = self._infer_shape
       ta._element_shape = self._element_shape
       ta._colocate_with = self._colocate_with
-      if ta._infer_shape:
+      if ta._infer_shape and context.in_graph_mode():
         val_shape = flow_out.op.inputs[2].get_shape()
         element_shape = tensor_shape.unknown_shape()
         if val_shape.dims is not None:
@@ -487,7 +488,7 @@ class TensorArray(object):
       ta._infer_shape = self._infer_shape
       ta._element_shape = self._element_shape
       ta._colocate_with = self._colocate_with
-      if ta._infer_shape:
+      if ta._infer_shape and context.in_graph_mode():
         val_shape = flow_out.op.inputs[1].get_shape()
         clengths = tensor_util.constant_value(flow_out.op.inputs[2])
         element_shape = tensor_shape.unknown_shape()
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 908b6b2111a0e67a568479d06f2e845f4347dfd4..645775239fd4b01a5775874a1fcd829fe8239a59 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -259,7 +259,8 @@ class _VariableStore(object):
         applying it on a newly created variable will be added to the collection
         GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
       reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation
-        of variables.
+        of variables. In Eager mode, this argument is always forced to be
+        tf.AUTO_REUSE.
       trainable: If `True` also add the variable to the graph collection
         `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
       collections: List of graph collections keys to add the `Variable` to.
@@ -278,6 +279,7 @@ class _VariableStore(object):
       use_resource: If False, creates a regular Variable. If True, creates
         instead an experimental ResourceVariable which has well-defined
         semantics. Defaults to False (will later change to True).
+        In Eager mode, this argument is always forced to be true.
       custom_getter: Callable that takes as a first argument the true getter,
         and allows overwriting the internal get_variable method.
         The signature of `custom_getter` should match that of this method,
@@ -311,6 +313,10 @@ class _VariableStore(object):
       raise ValueError(
           "Passed a custom_getter which is not callable: %s" % custom_getter)
 
+    if context.in_eager_mode():
+      reuse = AUTO_REUSE
+      use_resource = True
+
     # If a *_ref type is passed in an error would be triggered further down the
     # stack. We prevent this using base_dtype to get a non-ref version of the
     # type, before doing anything else. When _ref types are removed in favor of
@@ -498,6 +504,9 @@ class _VariableStore(object):
         when violating reuse during variable creation, or if an existing
         sharded variable exists for the given name but with different sharding.
     """
+    if context.in_eager_mode():
+      raise NotImplementedError("Partitioned variables are not yet supported "
+                                "in Eager mode.")
 
     initializing_from_value = initializer is not None and isinstance(
         initializer, ops.Tensor)
@@ -792,14 +801,19 @@ class _VariableStore(object):
 
     # Run the regularizer if requested and save the resulting loss.
     if regularizer:
-      with ops.colocate_with(v.op):
+      with ops.colocate_with(v):
         with ops.name_scope(name + "/Regularizer/"):
           loss = regularizer(v)
         if loss is not None:
+          if context.in_graph_mode():
+            v_name = v.name
+            loss_name = loss.name
+          else:
+            v_name = "v_%s" % type(v)
+            loss_name = "loss_%s" % type(loss)
           logging.vlog(1, "Applied regularizer to %s and added the result %s "
-                       "to REGULARIZATION_LOSSES.", v.name, loss.name)
+                       "to REGULARIZATION_LOSSES.", v_name, loss_name)
           ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, loss)
-
     return v
 
   # Initialize variable when no initializer provided
@@ -853,7 +867,8 @@ class VariableScope(object):
     initializer: default initializer passed to get_variable.
     regularizer: default regularizer passed to get_variable.
     reuse: Boolean, None, or tf.AUTO_REUSE, setting the reuse in
-      get_variable.
+      get_variable. In Eager mode, this argument is always forced to be
+      tf.AUTO_REUSE.
     caching_device: string, callable, or None: the caching device passed to
       get_variable.
     partitioner: callable or `None`: the partitioner passed to `get_variable`.
@@ -862,7 +877,8 @@ class VariableScope(object):
     dtype: default type passed to get_variable (defaults to DT_FLOAT).
     use_resource: if False, create a normal Variable; if True create an
       experimental ResourceVariable with well-defined semantics. Defaults
-      to False (will later change to True).
+      to False (will later change to True). In Eager mode, this argument is
+      always forced to be True.
     constraint: An optional projection function to be applied to the variable
       after being updated by an `Optimizer` (e.g. used to implement norm
       constraints or value constraints for layer weights). The function must
@@ -903,6 +919,7 @@ class VariableScope(object):
       if self._partitioner is not None:
         raise NotImplementedError("Partitioned variables are not yet supported "
                                   "in Eager mode.")
+      self._reuse = AUTO_REUSE
       self._use_resource = True
 
   @property
@@ -963,6 +980,8 @@ class VariableScope(object):
 
   def set_use_resource(self, use_resource):
     """Sets whether to use ResourceVariables for this scope."""
+    if context.in_eager_mode() and not use_resource:
+      raise ValueError("In eager mode, use_resource cannot be set to false.")
     self._use_resource = use_resource
 
   def set_regularizer(self, regularizer):
@@ -978,7 +997,7 @@ class VariableScope(object):
 
   def set_partitioner(self, partitioner):
     """Set partitioner for this scope."""
-    if context.in_eager_mode():
+    if partitioner and context.in_eager_mode():
       raise NotImplementedError("Partitioned variables are not yet supported "
                                 "in Eager mode.")
     self._partitioner = partitioner
@@ -1029,8 +1048,14 @@ class VariableScope(object):
       partitioner = self._partitioner
     if custom_getter is None:
       custom_getter = self._custom_getter
-    if reuse is None:
-      reuse = self._reuse
+    if context.in_graph_mode():
+      if reuse is None:
+        reuse = self._reuse
+      if use_resource is None:
+        use_resource = self._use_resource
+    else:
+      reuse = AUTO_REUSE
+      use_resource = True
 
     full_name = self.name + "/" + name if self.name else name
     # Variable names only depend on variable_scope (full_name here),
@@ -1050,12 +1075,6 @@ class VariableScope(object):
         constraint = self._constraint
       if dtype is None:
         dtype = self._dtype
-      if context.in_graph_mode():
-        if use_resource is None:
-          use_resource = self._use_resource
-      else:
-        use_resource = True
-
       return var_store.get_variable(
           full_name, shape=shape, dtype=dtype, initializer=initializer,
           regularizer=regularizer, reuse=reuse, trainable=trainable,
@@ -1232,7 +1251,8 @@ Args:
       must be known.
   use_resource: If False, creates a regular Variable. If true, creates an
     experimental ResourceVariable instead with well-defined semantics.
-    Defaults to False (will later change to True).
+    Defaults to False (will later change to True). In Eager mode, this argument
+    is always forced to be True.
   custom_getter: Callable that takes as a first argument the true getter, and
     allows overwriting the internal get_variable method.
     The signature of `custom_getter` should match that of this method,
@@ -1661,12 +1681,14 @@ def variable_scope(name_or_scope,
     reuse: `True`, None, or tf.AUTO_REUSE; if `True`, we go into reuse mode
       for this scope as well as all sub-scopes; if tf.AUTO_REUSE, we create
       variables if they do not exist, and return them otherwise; if None, we
-      inherit the parent scope's reuse flag.
+      inherit the parent scope's reuse flag. In Eager mode, this argument is
+      always forced to be tf.AUTO_REUSE.
     dtype: type of variables created in this scope (defaults to the type
       in the passed scope, or inherited from parent scope).
     use_resource: If False, all variables will be regular Variables. If True,
       experimental ResourceVariables with well-defined semantics will be used
-      instead. Defaults to False (will later change to True).
+      instead. Defaults to False (will later change to True). In Eager mode,
+      this argument is always forced to be True.
     constraint: An optional projection function to be applied to the variable
       after being updated by an `Optimizer` (e.g. used to implement norm
       constraints or value constraints for layer weights). The function must
@@ -1676,7 +1698,7 @@ def variable_scope(name_or_scope,
       use when doing asynchronous distributed training.
 
   Returns:
-    A scope that can be to captured and reused.
+    A scope that can be captured and reused.
 
   Raises:
     ValueError: when trying to reuse within a create scope, or create within
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 33b9907dc636bd211ec11fb0bf5c07b2b5d225ae..2e49e452d0dd356ba1947a554f420f037d040655 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -857,6 +857,18 @@ class Variable(object):
     """The name of this variable."""
     return self._variable.name
 
+  @property
+  def _shared_name(self):
+    """The shared name of the variable.
+
+      Unlike name(), shared_name doesn't have ":0" suffix. It is user-specified
+      name with name scope prefix.
+
+    Returns:
+      variable name.
+    """
+    return self.name[:-2]
+
   @property
   def initializer(self):
     """The initializer operation for this variable."""
diff --git a/tensorflow/python/profiler/model_analyzer.py b/tensorflow/python/profiler/model_analyzer.py
index 534594966469824ac0dcd2be7f82ee0c7d82dd0f..a1fe47982f08dfbdff67fc87dbfe9ee5546095ad 100644
--- a/tensorflow/python/profiler/model_analyzer.py
+++ b/tensorflow/python/profiler/model_analyzer.py
@@ -117,7 +117,7 @@ class Profiler(object):
   ```python
   Typical use case:
     # Currently we are only allowed to create 1 profiler per process.
-    profiler = Profile(sess.graph)
+    profiler = Profiler(sess.graph)
 
     for i in xrange(total_steps):
       if i % 10000 == 0:
@@ -174,7 +174,7 @@ class Profiler(object):
     """Add statistics of a step.
 
     Args:
-      step: A step uint64 used to identify the RunMetadata. Must be different
+      step: int, A step used to identify the RunMetadata. Must be different
          across different AddStep() calls.
       run_meta: RunMetadata proto that contains statistics of a session run.
     """
diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py
index 9492cadb8b8951597b073a3f658d3861b77f5b06..3432765b60bf9f72271058b8aafc94c420d7da17 100644
--- a/tensorflow/python/profiler/model_analyzer_test.py
+++ b/tensorflow/python/profiler/model_analyzer_test.py
@@ -215,7 +215,7 @@ class PrintModelAnalysisTest(test.TestCase):
       with gfile.Open(outfile, 'r') as f:
         lines = f.read().split('\n')
         result = '\n'.join([l[:min(len(l), 80)] for l in lines])
-        self.assertEqual('node name | # parameters | # float_ops\n_TFProfRoot (--/2.84k params, --/91.04k flops)\n  model_analyzer_testlib.py:58:BuildFullModel:seq.append(array_... (0/1.80k para\n    model_analyzer_testlib.py:35:BuildSmallModel:image = array_ops... (0/0 param\n    model_analyzer_testlib.py:39:BuildSmallModel:initializer=init_... (0/4 param\n    model_analyzer_testlib.py:43:BuildSmallModel:initializer=init_... (0/648 par\n    model_analyzer_testlib.py:44:BuildSmallModel:x = nn_ops.conv2d... (0/0 param\n    model_analyzer_testlib.py:48:BuildSmallModel:initializer=init_... (0/1.15k p\n    model_analyzer_testlib.py:49:BuildSmallModel:x = nn_ops.conv2d... (0/0 param\n  model_analyzer_testlib.py:58:BuildFullModel:seq.append(array_... (gradient) (0\n    model_analyzer_testlib.py:44:BuildSmallModel:x = nn_ops.conv2d... (gradient)\n    model_analyzer_testlib.py:49:BuildSmallModel:x = nn_ops.conv2d... (gradient)\n  model_analyzer_testlib.py:62:BuildFullModel:cell, array_ops.c... (0/1.04k para\n  model_analyzer_testlib.py:62:BuildFullModel:cell, array_ops.c... (gradient) (0\n  model_analyzer_testlib.py:64:BuildFullModel:target = array_op... (0/0 params, \n  model_analyzer_testlib.py:65:BuildFullModel:loss = nn_ops.l2_... (0/0 params, \n  model_analyzer_testlib.py:65:BuildFullModel:loss = nn_ops.l2_... (gradient) (0\n  model_analyzer_testlib.py:67:BuildFullModel:return sgd_op.min... (0/0 params, \n',
+        self.assertEqual('node name | # parameters | # float_ops\n_TFProfRoot (--/2.84k params, --/91.04k flops)\n  model_analyzer_testlib.py:58:BuildFullModel (0/1.80k params, 0/41.76k flops)\n    model_analyzer_testlib.py:35:BuildSmallModel (0/0 params, 0/0 flops)\n    model_analyzer_testlib.py:39:BuildSmallModel (0/4 params, 0/0 flops)\n    model_analyzer_testlib.py:43:BuildSmallModel (0/648 params, 0/0 flops)\n    model_analyzer_testlib.py:44:BuildSmallModel (0/0 params, 0/23.33k flops)\n    model_analyzer_testlib.py:48:BuildSmallModel (0/1.15k params, 0/0 flops)\n    model_analyzer_testlib.py:49:BuildSmallModel (0/0 params, 0/18.43k flops)\n  model_analyzer_testlib.py:58:BuildFullModel (gradient) (0/0 params, 0/0 flops)\n    model_analyzer_testlib.py:44:BuildSmallModel (gradient) (0/0 params, 0/0 flo\n    model_analyzer_testlib.py:49:BuildSmallModel (gradient) (0/0 params, 0/0 flo\n  model_analyzer_testlib.py:62:BuildFullModel (0/1.04k params, 0/16.51k flops)\n  model_analyzer_testlib.py:62:BuildFullModel (gradient) (0/0 params, 0/32.77k f\n  model_analyzer_testlib.py:64:BuildFullModel (0/0 params, 0/0 flops)\n  model_analyzer_testlib.py:65:BuildFullModel (0/0 params, 0/0 flops)\n  model_analyzer_testlib.py:65:BuildFullModel (gradient) (0/0 params, 0/0 flops)\n  model_analyzer_testlib.py:67:BuildFullModel (0/0 params, 0/0 flops)\n',
                          result)
 
       self.assertLess(0, tfprof_node.total_exec_micros)
@@ -224,28 +224,28 @@ class PrintModelAnalysisTest(test.TestCase):
       self.assertEqual(8, len(tfprof_node.children))
       self.assertEqual('_TFProfRoot', tfprof_node.name)
       self.assertEqual(
-          'model_analyzer_testlib.py:58:BuildFullModel:seq.append(array_...',
+          'model_analyzer_testlib.py:58:BuildFullModel',
           tfprof_node.children[0].name)
       self.assertEqual(
-          'model_analyzer_testlib.py:58:BuildFullModel:seq.append(array_... (gradient)',
+          'model_analyzer_testlib.py:58:BuildFullModel (gradient)',
           tfprof_node.children[1].name)
       self.assertEqual(
-          'model_analyzer_testlib.py:62:BuildFullModel:cell, array_ops.c...',
+          'model_analyzer_testlib.py:62:BuildFullModel',
           tfprof_node.children[2].name)
       self.assertEqual(
-          'model_analyzer_testlib.py:62:BuildFullModel:cell, array_ops.c... (gradient)',
+          'model_analyzer_testlib.py:62:BuildFullModel (gradient)',
           tfprof_node.children[3].name)
       self.assertEqual(
-          'model_analyzer_testlib.py:64:BuildFullModel:target = array_op...',
+          'model_analyzer_testlib.py:64:BuildFullModel',
           tfprof_node.children[4].name)
       self.assertEqual(
-          'model_analyzer_testlib.py:65:BuildFullModel:loss = nn_ops.l2_...',
+          'model_analyzer_testlib.py:65:BuildFullModel',
           tfprof_node.children[5].name)
       self.assertEqual(
-          'model_analyzer_testlib.py:65:BuildFullModel:loss = nn_ops.l2_... (gradient)',
+          'model_analyzer_testlib.py:65:BuildFullModel (gradient)',
           tfprof_node.children[6].name)
       self.assertEqual(
-          'model_analyzer_testlib.py:67:BuildFullModel:return sgd_op.min...',
+          'model_analyzer_testlib.py:67:BuildFullModel',
           tfprof_node.children[7].name)
       # pylint: enable=line-too-long
 
diff --git a/tensorflow/python/summary/text_summary.py b/tensorflow/python/summary/text_summary.py
index f0788399ffc5d98af28183a1beabf16de634cad2..4031355b03d3831453f52848bd092c8f45e1ef69 100644
--- a/tensorflow/python/summary/text_summary.py
+++ b/tensorflow/python/summary/text_summary.py
@@ -23,7 +23,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from collections import namedtuple
 import json
 
 from tensorflow.core.framework import summary_pb2
@@ -33,9 +32,6 @@ from tensorflow.python.summary import plugin_asset
 
 PLUGIN_NAME = "text"
 
-# Contains event-related data specific to the text plugin.
-_TextPluginData = namedtuple("_TextPluginData", [])
-
 
 def text_summary(name, tensor, collections=None):
   """Summarizes textual data.
@@ -67,11 +63,9 @@ def text_summary(name, tensor, collections=None):
     raise ValueError("Expected tensor %s to have dtype string, got %s" %
                      (tensor.name, tensor.dtype))
 
-  summary_metadata = summary_pb2.SummaryMetadata()
-  text_plugin_data = _TextPluginData()
-  data_dict = text_plugin_data._asdict()  # pylint: disable=protected-access
-  summary_metadata.plugin_data.plugin_name = PLUGIN_NAME
-  summary_metadata.plugin_data.content = json.dumps(data_dict)
+  summary_metadata = summary_pb2.SummaryMetadata(
+      plugin_data=summary_pb2.SummaryMetadata.PluginData(
+          plugin_name=PLUGIN_NAME))
   t_summary = tensor_summary(
       name=name,
       tensor=tensor,
diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py
index 9d3e20e408a56e7a50d340c34aa50fe5f7153d02..88ade0aac33f1cd8f9d8cb30344aabca76a13511 100644
--- a/tensorflow/python/summary/writer/writer_test.py
+++ b/tensorflow/python/summary/writer/writer_test.py
@@ -39,6 +39,7 @@ from tensorflow.python.summary import plugin_asset
 from tensorflow.python.summary import summary_iterator
 from tensorflow.python.summary.writer import writer
 from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.util import compat
 
 
 class SummaryWriterTestCase(test.TestCase):
@@ -334,11 +335,11 @@ class SummaryWriterTestCase(test.TestCase):
     # should strip the metadata from the second one.
     value = summary_pb2.Summary.Value(tag="foo", simple_value=10.0)
     value.metadata.plugin_data.plugin_name = "bar"
-    value.metadata.plugin_data.content = "... content ..."
+    value.metadata.plugin_data.content = compat.as_bytes("... content ...")
     sw.add_summary(summary_pb2.Summary(value=[value]), 10)
     value = summary_pb2.Summary.Value(tag="foo", simple_value=10.0)
     value.metadata.plugin_data.plugin_name = "bar"
-    value.metadata.plugin_data.content = "... content ..."
+    value.metadata.plugin_data.content = compat.as_bytes("... content ...")
     sw.add_summary(summary_pb2.Summary(value=[value]), 10)
 
     sw.close()
diff --git a/tensorflow/python/tools/import_pb_to_tensorboard.py b/tensorflow/python/tools/import_pb_to_tensorboard.py
index a8712fc37e631cd7c3ddb76b9ca21f78599d668c..00de044505f7f18e6af8237be57c4d8b346caa42 100644
--- a/tensorflow/python/tools/import_pb_to_tensorboard.py
+++ b/tensorflow/python/tools/import_pb_to_tensorboard.py
@@ -51,7 +51,7 @@ def import_to_tensorboard(model_dir, log_dir):
     pb_visual_writer = summary.FileWriter(log_dir)
     pb_visual_writer.add_graph(sess.graph)
     print("Model Imported. Visualize by running: "
-          "> tensorboard --logdir={}".format(log_dir))
+          "tensorboard --logdir={}".format(log_dir))
 
 
 def main(unused_args):
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
index 796402425a123d0063084f3f9886855789a40e10..cdc532a38e8e683b18619b0f1f795f3cb0d748f3 100644
--- a/tensorflow/python/training/adam.py
+++ b/tensorflow/python/training/adam.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
@@ -118,8 +119,11 @@ class AdamOptimizer(optimizer.Optimizer):
     # silently ignored).
     first_var = min(var_list, key=lambda x: x.name)
 
-    if (self._beta1_power is None or
-        self._beta1_power.graph is not first_var.graph):
+    create_new = self._beta1_power is None
+    if not create_new and context.in_graph_mode():
+      create_new = (self._beta1_power.graph is not first_var.graph)
+
+    if create_new:
       with ops.colocate_with(first_var):
         self._beta1_power = variable_scope.variable(self._beta1,
                                                     name="beta1_power",
diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py
index 62b171e234eebcb3e12508d8e000f565e5e89903..defcf3371410ec3ba73f8796ebac2708da3b34e8 100644
--- a/tensorflow/python/training/adam_test.py
+++ b/tensorflow/python/training/adam_test.py
@@ -21,9 +21,11 @@ from __future__ import print_function
 import numpy as np
 
 from tensorflow.python.client import session
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import resource_variable_ops
@@ -149,49 +151,60 @@ class AdamOptimizerTest(test.TestCase):
                               repeated_index_update_var.eval())
 
   def doTestBasic(self, use_resource=False):
-    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
-        # Initialize variables for numpy implementation.
-        m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
-        var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
-        grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
-        var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
-        grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+      # Initialize variables for numpy implementation.
+      m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+      var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+      grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+      var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+      grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+      if use_resource:
+        var0 = resource_variable_ops.ResourceVariable(
+            var0_np, name="var0_%d" % i)
+        var1 = resource_variable_ops.ResourceVariable(
+            var1_np, name="var1_%d" % i)
+      else:
+        var0 = variables.Variable(var0_np)
+        var1 = variables.Variable(var1_np)
+      grads0 = constant_op.constant(grads0_np)
+      grads1 = constant_op.constant(grads1_np)
 
-        if use_resource:
-          var0 = resource_variable_ops.ResourceVariable(var0_np)
-          var1 = resource_variable_ops.ResourceVariable(var1_np)
-        else:
-          var0 = variables.Variable(var0_np)
-          var1 = variables.Variable(var1_np)
-        grads0 = constant_op.constant(grads0_np)
-        grads1 = constant_op.constant(grads1_np)
-        opt = adam.AdamOptimizer()
-        update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
-        variables.global_variables_initializer().run()
+      opt = adam.AdamOptimizer()
+      update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
 
+      if context.in_graph_mode():
+        self.evaluate(variables.global_variables_initializer())
         # Fetch params to validate initial values
-        self.assertAllClose([1.0, 2.0], var0.eval())
-        self.assertAllClose([3.0, 4.0], var1.eval())
+        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
 
-        beta1_power, beta2_power = opt._get_beta_accumulators()
+      beta1_power, beta2_power = opt._get_beta_accumulators()
 
-        # Run 3 steps of Adam
-        for t in range(1, 4):
-          self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
-          self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
-          update.run()
+      # Run 3 steps of Adam
+      for t in range(1, 4):
+        if context.in_graph_mode():
+          self.evaluate(update)
+        elif t > 1:
+          opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
 
-          var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
-          var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+        self.assertAllCloseAccordingToType(0.9**(t + 1),
+                                           self.evaluate(beta1_power))
+        self.assertAllCloseAccordingToType(0.999**(t + 1),
+                                           self.evaluate(beta2_power))
 
-          # Validate updated params
-          self.assertAllCloseAccordingToType(var0_np, var0.eval())
-          self.assertAllCloseAccordingToType(var1_np, var1.eval())
+        var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+        var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+        # Validate updated params
+        self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+        self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
 
   def testBasic(self):
-    self.doTestBasic(use_resource=False)
+    with self.test_session():
+      self.doTestBasic(use_resource=False)
 
+  @test_util.run_in_graph_and_eager_modes(reset_test=True)
   def testResourceBasic(self):
     self.doTestBasic(use_resource=True)
 
diff --git a/tensorflow/python/training/checkpoint_ops.py b/tensorflow/python/training/checkpoint_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..70460ceb4802f3f30eaab4b3ae10a6e59589d83d
--- /dev/null
+++ b/tensorflow/python/training/checkpoint_ops.py
@@ -0,0 +1,453 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operations for generating and loading vocab remappings."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_checkpoint_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+
+ops.NotDifferentiable("GenerateVocabRemapping")
+ops.NotDifferentiable("LoadAndRemapMatrix")
+
+
+def _load_and_remap_matrix(ckpt_path,
+                           old_tensor_name,
+                           new_row_vocab_offset,
+                           num_rows_to_load,
+                           new_col_vocab_size,
+                           initializer,
+                           old_row_vocab_file=None,
+                           new_row_vocab_file=None,
+                           old_col_vocab_file=None,
+                           new_col_vocab_file=None,
+                           num_row_oov_buckets=0,
+                           num_col_oov_buckets=0,
+                           max_rows_in_memory=-1):
+  """Loads a 2-D (matrix) `Tensor` from checkpoint.
+
+  Generates 1D-remappings for rows and columns using the
+  `GenerateVocabRemapping` op, and initializes any anticipated values with the
+  provided initializer. Then, uses the `LoadAndRemapMatrix` op to create a
+  matrix that loads existing values from the checkpoint, while filling out
+  "missing" values with the newly initialized values. See
+  contrib/framework/ops/checkpoint_ops.cc for more information on the wrapped
+  functionality (LoadAndRemapMatrix). This wrapper can be used to perform only
+  row remapping or only col remapping. If only row remapping is desired,
+  {new,old}_col_vocab_file should be `None`, and vice versa for column
+  remapping.
+
+  NOTE: This only supports div-partitioning the vocabulary on the 1st dimension
+  (row axis) via `new_row_vocab_offset`.
+
+  Args:
+    ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
+      from which the old matrix `Tensor` will be loaded.
+    old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
+    new_row_vocab_offset: A 0-indexed integer representing what line to
+      start reading at in the new row vocabulary. Used for partitioned
+      variables.
+    num_rows_to_load: Number of rows to load for the new vocabulary (note: to
+      support variable partitioning and partial loading, this does not need to
+      be the same as the number of entries in `new_row_vocab_file`).
+    new_col_vocab_size: Number of columns to load - should be the same as the
+      number of entries in `new_col_vocab_file`, since we don't support
+      partitioning along the column axis.
+    initializer: Callable initializer function that accepts a 1-D tensor as the
+      arg to specify the shape of the returned tensor. Used to initialize
+      missing values.
+    old_row_vocab_file: A scalar `Tensor` of type `string` containing the
+      path to the old row vocabulary file. Can be None, which represents no
+      remapping on the row axis.
+    new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
+      to the new row vocabulary file. Can be None, which represents no remapping
+      on the row axis - in which case, `new_row_vocab_offset` and
+      `num_rows_to_load` work under the assumption that the new row vocab is the
+      same as the old row vocab.
+    old_col_vocab_file: A scalar `Tensor` of type `string` containing the
+      path to the old column vocabulary file. Can be None, which represents no
+      remapping on the column axis.
+    new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
+      to the new column vocabulary file. Can be None, which represents no
+      remapping on the column axis - in which case, `new_col_vocab_size` works
+      under the assumption that the new col vocab is the same as the old col
+      vocab.
+    num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
+      to append. Must be >= 0.
+    num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
+      columns to append. Must be >= 0.
+    max_rows_in_memory: `int` specifying the maximum number of rows to load from
+      the checkpoint at once. If less than or equal to 0, the entire matrix will
+      be loaded into memory. Setting this arg trades increased disk reads for
+      lower memory usage.
+
+  Returns:
+    A Tensor of shape `[num_rows_to_load + num_row_oov_buckets,
+    new_col_vocab_size + num_col_oov_buckets]`, with values loaded from the
+    specified tensor in the checkpoint, and any missing or OOV values
+    initialized with the given `initializer`.
+
+  Raises:
+    ValueError: If `num_row_oov_buckets` or `num_col_oov_buckets` < 0.
+    ValueError: If either `old_row_vocab_file` or `new_row_vocab_file` is
+      provided, while the other is not. Same for `old_col_vocab_file` and
+      `new_col_vocab_file`.
+    ValueError: If neither row vocabs or col vocabs are provided.
+  """
+  if num_row_oov_buckets < 0:
+    raise ValueError("num_row_oov_buckets must be >= 0, but received %d" %
+                     num_row_oov_buckets)
+  if num_col_oov_buckets < 0:
+    raise ValueError("num_col_oov_buckets must be >= 0, but received %d" %
+                     num_col_oov_buckets)
+
+  if bool(old_row_vocab_file) != bool(new_row_vocab_file):
+    raise ValueError(
+        "old_row_vocab_file and new_row_vocab_file must both be specified or "
+        "left unspecified. old_row_vocab_file='{}', new_row_vocab_file='{}'".
+        format(old_row_vocab_file, new_row_vocab_file))
+  if bool(old_col_vocab_file) != bool(new_col_vocab_file):
+    raise ValueError(
+        "old_col_vocab_file and new_col_vocab_file must both be specified or "
+        "left unspecified. old_col_vocab_file='{}', new_col_vocab_file='{}'".
+        format(old_col_vocab_file, new_col_vocab_file))
+
+  remap_rows = new_row_vocab_file and old_row_vocab_file
+  remap_cols = new_col_vocab_file and old_col_vocab_file
+  if not (remap_rows or remap_cols):
+    raise ValueError(
+        "Must provide either row or column vocab files. If no remapping is "
+        "necessary, consider using `tf.contrib.framework.init_from_checkpoint` "
+        "instead.")
+
+  num_rows_present = num_rows_to_load
+  if remap_rows:
+    row_remapping, num_rows_present = (
+        gen_checkpoint_ops._generate_vocab_remapping(  # pylint: disable=protected-access
+            new_vocab_file=new_row_vocab_file,
+            old_vocab_file=old_row_vocab_file,
+            new_vocab_offset=new_row_vocab_offset,
+            num_new_vocab=num_rows_to_load))
+  else:
+    # Even when the rows are not being reordered, we still need to generate a
+    # remapping to account for initializing partitioned Variables (when
+    # new_row_vocab_offset is non-zero).
+    row_remapping = math_ops.range(
+        new_row_vocab_offset,
+        new_row_vocab_offset + num_rows_to_load,
+        dtype=dtypes.int64)
+
+  col_remapping = []
+  num_cols_present = new_col_vocab_size
+  if remap_cols:
+    col_remapping, num_cols_present = (
+        gen_checkpoint_ops._generate_vocab_remapping(  # pylint: disable=protected-access
+            new_vocab_file=new_col_vocab_file,
+            old_vocab_file=old_col_vocab_file,
+            new_vocab_offset=0,  # Offset is unused for cols (no partitioning).
+            num_new_vocab=new_col_vocab_size))
+
+  init_vals = initializer([
+      num_rows_to_load * new_col_vocab_size -
+      num_rows_present * num_cols_present, 1
+  ])
+  return_tensor = gen_checkpoint_ops._load_and_remap_matrix(  # pylint: disable=protected-access
+      ckpt_path=ckpt_path,
+      old_tensor_name=old_tensor_name,
+      row_remapping=row_remapping,
+      col_remapping=col_remapping,
+      initializing_values=init_vals,
+      num_rows=num_rows_to_load,
+      num_cols=new_col_vocab_size,
+      max_rows_in_memory=max_rows_in_memory)
+
+  # Add OOV row(s) and column(s).
+  if num_row_oov_buckets > 0:
+    init_row_oov_val = initializer([num_row_oov_buckets, new_col_vocab_size])
+    init_row_oov_val = ops.convert_to_tensor(init_row_oov_val)
+    return_tensor = array_ops.concat([return_tensor, init_row_oov_val], 0)
+  if num_col_oov_buckets > 0:
+    # We need to add any row OOV to the new column shape.
+    init_col_oov_val = initializer(
+        [num_rows_to_load + num_row_oov_buckets, num_col_oov_buckets])
+    init_col_oov_val = ops.convert_to_tensor(init_col_oov_val)
+    return_tensor = array_ops.concat([return_tensor, init_col_oov_val], 1)
+
+  return return_tensor
+
+
+def _load_and_remap_matrix_initializer(ckpt_path,
+                                       old_tensor_name,
+                                       new_row_vocab_size,
+                                       new_col_vocab_size,
+                                       old_row_vocab_file=None,
+                                       new_row_vocab_file=None,
+                                       old_col_vocab_file=None,
+                                       new_col_vocab_file=None,
+                                       num_row_oov_buckets=0,
+                                       num_col_oov_buckets=0,
+                                       initializer=None,
+                                       max_rows_in_memory=-1):
+  r"""Returns a var initializer for loading and remapping a 2-D (matrix) tensor.
+
+  The returned initializer loads a 2-D (matrix) `Tensor` with name
+  `old_tensor_name` from the checkpoint at `ckpt_path`. It will reorder the
+  rows/columns according to the specified vocab files and append additional
+  out-of-vocabulary rows/columns according to the number of OOV buckets.
+
+  The format of the file at the `{old,new}_{row,col}_vocab_file` path should be
+  a text file, with each line containing a single entity within the vocabulary.
+  Let the function `line_of(f, "x")` return the 0-indexed line number of the
+  entity "x" in file f, and the function `entity_at(f, i)` return the entity at
+  line i of file f. Then, row i of the new output matrix will be taken from row
+  `line_of(old_row_vocab_file, entity_at(new_row_vocab_file, i))` of the old
+  matrix. If any entity in `new_row_vocab_file` is not found in
+  `old_row_vocab_file`, that row is considered a "missing" row, and its values
+  will be initialized using the `initializer` arg. The same logic also applies
+  for the columns.
+
+  For example, assuming that:
+
+  * `old_row_vocab_file` contains "mercury\nvenus\nmars"
+  * `new_row_vocab_file` contains "venus\njupiter\nmercury"
+  * `old_col_vocab_file` contains "good\nbetter\nbest"
+  * `new_col_vocab_file` contains "good\nbest\nfantastic"
+  * `initializer` returns the natural numbers `[1, 2, 3, 4, ...]`
+  * `w(i, j)` represents the value from row i, column j of the old matrix
+
+  Then the new output matrix will look like:
+
+  `[[w(1, 0), w(1, 2), 1],
+    [2,       3,       4],
+    [w(0, 0), w(0, 2), 5]]`
+
+  If we further specify that:
+
+  * `num_row_oov_buckets` == 2
+  * `num_col_oov_buckets` == 1
+
+  Then the new output matrix will look like:
+
+  `[[w(1, 0), w(1, 2), 1,  12],
+    [2,       3,       4,  13],
+    [w(0, 0), w(0, 2), 5,  14],
+    [6,       7,       8,  15],
+    [9,       10,      11, 16]]`
+
+  If `{old,new}_row_vocab_file` are None, we assume that the old and new row
+  vocab files are the same, and no row remapping is done. If
+  `{old,new}_col_vocab_file` are None, we assume that the old and new column
+  vocab files are the same, and no column remapping is done.
+
+  The returned initializer only supports div-partitioning along the row axis. It
+  does not support partitioning along the column axis or mod-partitioning.
+
+  NOTE: When this is used to warm-start variables, client code should use
+  `tf.lookup.index_table_from_tensor()` like
+  contrib/layers/python/layers/feature_column.py does, as opposed to
+  `tf.feature_to_id()` - in order to ensure the underlying lookup tables are the
+  same.
+
+  Args:
+    ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
+      from which the old matrix `Tensor` will be loaded.
+    old_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
+    new_row_vocab_size: `int` specifying the number of entries in
+      `new_row_vocab_file`. If no row remapping is needed (no row vocab
+      provided), this should be equal to the number of rows to load from the old
+      matrix (which can theoretically be smaller than the number of rows in the
+      old matrix).
+    new_col_vocab_size: `int` specifying the number of entries in
+      `new_col_vocab_file`. If no column remapping is needed (no column vocab
+      provided), this should be equal to the number of columns in the old
+      matrix.
+    old_row_vocab_file: A scalar `Tensor` of type `string` containing the
+      path to the old row vocabulary file. Can be None, which represents no
+      remapping on the row axis.
+    new_row_vocab_file: A scalar `Tensor` of type `string` containing the path
+      to the new row vocabulary file. Can be None, which represents no remapping
+      on the row axis.
+    old_col_vocab_file: A scalar `Tensor` of type `string` containing the
+      path to the old column vocabulary file. Can be None, which represents no
+      remapping on the column axis.
+    new_col_vocab_file: A scalar `Tensor` of type `string` containing the path
+      to the new column vocabulary file. Can be None, which represents no
+      remapping on the column axis.
+    num_row_oov_buckets: `int` specifying the number of out-of-vocabulary rows
+      to append. Must be >= 0.
+    num_col_oov_buckets: `int` specifying the number of out-of-vocabulary
+      columns to append. Must be >= 0.
+    initializer: Initializer function to initialize missing values. Accepts a
+      1-D tensor as the arg to specify the shape of the returned tensor. If
+      `None`, defaults to using `zeros_initializer()`.
+    max_rows_in_memory: `int` specifying the maximum number of rows to load from
+      the checkpoint at once. If less than or equal to 0, the entire matrix will
+      be loaded into memory. Setting this arg trades increased disk reads for
+      lower memory usage.
+
+  Returns:
+    A variable initializer function that should be used to initialize a
+    (potentially partitioned) `Variable` whose complete shape is
+    `[new_row_vocab_size + num_row_oov_buckets, new_col_vocab_size +
+    num_col_oov_buckets]`.
+
+  Raises:
+    TypeError: If `initializer` is specified but not callable.
+  """
+  if initializer is None:
+    # TODO(b/25671353): Consider using sqrt(6/(fan_in + fan_out)) instead, from
+    # Glorot and Bengio, 2010.
+    initializer = init_ops.zeros_initializer()
+
+  if not callable(initializer):
+    raise TypeError(
+        "initializer must be callable, instead of being {} of type {}.".format(
+            initializer, type(initializer)))
+
+  def _initializer(shape, dtype=dtypes.float32, partition_info=None):
+    """Variable initializer.
+
+    Args:
+      shape: Shape of `Tensor` to return. Should include OOV on both axes.
+      dtype: Must be float32.
+      partition_info: variable_scope._PartitionInfo.
+
+    Returns:
+      `Tensor` of shape `shape`.
+
+    Raises:
+      TypeError: If `dtype` is anything other than float32.
+      ValueError: For shape mismatch upon invocation.
+    """
+    # Sanity checks.
+    if dtype != dtypes.float32:
+      raise TypeError(
+          "Currently, only float32 is supported. Received dtype: {}".format(
+              dtype))
+    if len(shape) != 2:
+      raise ValueError("Expected 2-dim shape, but received: {}".format(shape))
+    if shape[0] <= 0:
+      raise ValueError(
+          "Expected 1st dim of shape to be > 0, but received shape: {}".format(
+              shape))
+    if shape[1] != (new_col_vocab_size + num_col_oov_buckets):
+      raise ValueError(
+          "Expected 2nd dim of shape to be new_col_vocab_size ({}) + "
+          "num_col_oov_buckets ({}) = {}, but received shape: {}".format(
+              new_col_vocab_size, num_col_oov_buckets,
+              new_col_vocab_size + num_col_oov_buckets, shape))
+
+    offset = 0
+    if partition_info is not None:
+      offset = partition_info.single_offset(shape)
+
+    if offset + shape[0] > new_row_vocab_size + num_row_oov_buckets:
+      raise ValueError(
+          "Trying to initialize {} additional rows after {} rows have already "
+          "been initialized, which would exceed expected total row count of "
+          "new_row_vocab_size ({}) + num_row_oov_buckets ({}) = {}.".format(
+              shape[0], offset, new_row_vocab_size, num_row_oov_buckets,
+              new_row_vocab_size + num_row_oov_buckets))
+
+    row_oov_buckets_to_use = min(shape[0],
+                                 max(0, offset + shape[0] - new_row_vocab_size))
+    num_rows_to_load = shape[0] - row_oov_buckets_to_use
+
+    return _load_and_remap_matrix(
+        ckpt_path=ckpt_path,
+        old_tensor_name=old_tensor_name,
+        new_row_vocab_offset=offset,
+        num_rows_to_load=num_rows_to_load,
+        new_col_vocab_size=new_col_vocab_size,
+        initializer=initializer,
+        old_row_vocab_file=old_row_vocab_file,
+        new_row_vocab_file=new_row_vocab_file,
+        old_col_vocab_file=old_col_vocab_file,
+        new_col_vocab_file=new_col_vocab_file,
+        num_row_oov_buckets=row_oov_buckets_to_use,
+        num_col_oov_buckets=num_col_oov_buckets,
+        max_rows_in_memory=max_rows_in_memory)
+
+  return _initializer
+
+
+def _load_embedding_initializer(ckpt_path,
+                                embedding_tensor_name,
+                                new_vocab_size,
+                                embedding_dim,
+                                old_vocab_file,
+                                new_vocab_file,
+                                num_oov_buckets=0,
+                                initializer=None,
+                                max_rows_in_memory=-1):
+  """Returns a variable initializer for loading pre-trained embeddings.
+
+  Wrapper around `load_and_remap_matrix_initializer()` specialized for loading
+  embedding weights and remapping according to the provided vocab files. See
+  docs for `load_and_remap_matrix_initializer()` for more details.
+
+  NOTE: Only for use with div-partitioned variables / vocabularies.
+
+  Args:
+    ckpt_path: Path to the TensorFlow checkpoint (version 2, `TensorBundle`)
+      from which the old matrix `Tensor` will be loaded.
+    embedding_tensor_name: Name of the 2-D `Tensor` to load from checkpoint.
+    new_vocab_size: Number of entries in the new vocab.
+    embedding_dim: `int` specifying the dimension of the embedding vectors from
+      the checkpoint. Must match the number of columns in the old embedding
+      matrix.
+    old_vocab_file: A scalar `Tensor` of type `string` containing the
+      path to the old vocabulary file.
+    new_vocab_file: A scalar `Tensor` of type `string` containing the
+      path to the new vocabulary file.
+    num_oov_buckets: `int` specifying the number of out-of-vocabulary
+      buckets to use. Must be >= 0.
+    initializer: Initializer function that accepts a 1-D tensor as the arg to
+      specify the shape of the returned tensor. If `None`, defaults to using
+      `truncated_normal_initializer()`.
+    max_rows_in_memory: `int` specifying the maximum number of rows to load from
+      the checkpoint at once. If less than or equal to 0, the entire matrix will
+      be loaded into memory. Setting this arg trades increased disk reads for
+      lower memory usage.
+
+  Returns:
+    A variable initializer function.
+  """
+  if initializer is None:
+    # TODO(b/25671353): This should be kept in sync with the stddev used by
+    # feature_column.py's _EmbeddingColumn.
+    initializer = init_ops.truncated_normal_initializer(
+        stddev=1.0 / math.sqrt(embedding_dim))
+
+  return _load_and_remap_matrix_initializer(
+      ckpt_path=ckpt_path,
+      old_tensor_name=embedding_tensor_name,
+      new_row_vocab_size=new_vocab_size,
+      new_col_vocab_size=embedding_dim,
+      old_row_vocab_file=old_vocab_file,
+      new_row_vocab_file=new_vocab_file,
+      old_col_vocab_file=None,
+      new_col_vocab_file=None,
+      num_row_oov_buckets=num_oov_buckets,
+      num_col_oov_buckets=0,
+      initializer=initializer,
+      max_rows_in_memory=max_rows_in_memory)
diff --git a/tensorflow/python/training/checkpoint_ops_test.py b/tensorflow/python/training/checkpoint_ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..39c4d2911f2d279b8817e70fa23596ab195dbcd8
--- /dev/null
+++ b/tensorflow/python/training/checkpoint_ops_test.py
@@ -0,0 +1,305 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for Python wrappers around warm-starting."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_ops
+from tensorflow.python.training import saver as saver_lib
+
+
+class LoadAndRemapWrappersTest(test.TestCase):
+  """Tests for the functionality of the Python wrappers."""
+
+  def setUp(self):
+    ops.reset_default_graph()
+    # Create the checkpoint file in a temporary directory.
+    checkpoint_prefix = os.path.join(self.get_temp_dir(), 'model')
+    # 0., 1., ..., 79. reshaped into [5, 16].
+    initializer = init_ops.constant_initializer(
+        np.reshape(np.linspace(0.0, 79, 5 * 16), (5, 16)))
+    with self.test_session() as sess:
+      with variable_scope.variable_scope('some_scope'):
+        variable_scope.get_variable(name='embeddings', shape=[5, 16],
+                                    initializer=initializer)
+      sess.run(variables.global_variables_initializer())
+      saver = saver_lib.Saver()
+      saver.save(sess, checkpoint_prefix, global_step=5)
+    self.checkpoint_file = '{}-5'.format(checkpoint_prefix)
+
+    # Create the vocabulary files.
+    self.new_feature_vocab_file = os.path.join(
+        self.get_temp_dir(), 'new_feature_vocab.txt')
+    with open(self.new_feature_vocab_file, 'w') as f:
+      f.write('\n'.join(['zero', 'one', 'two', 'three', 'four']) + '\n')
+
+    self.old_feature_vocab_file = os.path.join(
+        self.get_temp_dir(), 'old_feature_vocab.txt')
+    with open(self.old_feature_vocab_file, 'w') as f:
+      f.write('\n'.join(['zero', 'one', 'two', 'three']) + '\n')
+
+    self.new_class_vocab_file = os.path.join(
+        self.get_temp_dir(), 'new_class_vocab.txt')
+    with open(self.new_class_vocab_file, 'w') as f:
+      f.write('\n'.join(['MISSING', 'knitting', 'flask', 'eminem']) + '\n')
+
+    self.old_class_vocab_file = os.path.join(
+        self.get_temp_dir(), 'old_class_vocab.txt')
+    with open(self.old_class_vocab_file, 'w') as f:
+      f.write('\n'.join(['knitting', 'eminem', 'MISSING']) + '\n')
+
+    self.init_val = 42
+
+    def _init_val_initializer(shape, dtype=None, partition_info=None):
+      del dtype, partition_info  # Unused by this unit-testing initializer.
+      return array_ops.tile(
+          constant_op.constant([[self.init_val]], dtype=dtypes.float32), shape)
+
+    self.initializer = _init_val_initializer
+
+  def test_load_and_remap_matrix(self):
+    """Tests the end-to-end loading / remapping of weights."""
+    # _load_and_remap_matrix() is the generalized wrapper that takes in row and
+    # column vocabulary files, calls the relevant remappings, and returns the
+    # weight matrix.  Take this example to be linear multi-class by providing
+    # both row and column vocabularies.
+    remapped_matrix = checkpoint_ops._load_and_remap_matrix(
+        new_row_vocab_file=self.new_feature_vocab_file,
+        old_row_vocab_file=self.old_feature_vocab_file,
+        num_rows_to_load=4,
+        new_col_vocab_file=self.new_class_vocab_file,
+        old_col_vocab_file=self.old_class_vocab_file,
+        new_col_vocab_size=4,
+        old_tensor_name='some_scope/embeddings',
+        ckpt_path=[self.checkpoint_file],
+        new_row_vocab_offset=1,
+        initializer=self.initializer,
+        num_row_oov_buckets=1,
+        num_col_oov_buckets=1)
+
+    # [4 in vocab + 1 oov features, 4 in vocab + 1 oov classes].  The offset
+    # means we read
+    expected_remapped_matrix = np.concatenate(
+        [
+            np.reshape([18, 34, 50, self.init_val, self.init_val], [5, 1]),
+            np.reshape([16, 32, 48, self.init_val, self.init_val], [5, 1]),
+            np.reshape([self.init_val] * 5, [5, 1]),
+            np.reshape([17, 33, 49, self.init_val, self.init_val], [5, 1]),
+            np.reshape([self.init_val] * 5, [5, 1])
+        ],
+        axis=1)
+
+    with self.test_session():
+      self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
+
+  def test_load_and_remap_output_layer_weight_initializer_linear(self):
+    """Tests for the output layer initializer in the linear multi-class case."""
+    loading_initializer = (checkpoint_ops._load_and_remap_matrix_initializer(
+        new_row_vocab_size=5,
+        new_col_vocab_file=self.new_class_vocab_file,
+        old_col_vocab_file=self.old_class_vocab_file,
+        new_col_vocab_size=4,
+        old_tensor_name='some_scope/embeddings',
+        ckpt_path=[self.checkpoint_file],
+        new_row_vocab_file=self.new_feature_vocab_file,
+        old_row_vocab_file=self.old_feature_vocab_file,
+        num_row_oov_buckets=1,
+        num_col_oov_buckets=1,
+        initializer=self.initializer))
+
+    expected_remapped_matrix = np.concatenate(
+        [
+            np.reshape([2, 18, 34, 50, self.init_val, self.init_val], [6, 1]),
+            np.reshape([0, 16, 32, 48, self.init_val, self.init_val], [6, 1]),
+            np.reshape([self.init_val] * 6, [6, 1]),
+            np.reshape([1, 17, 33, 49, self.init_val, self.init_val], [6, 1]),
+            np.reshape([self.init_val] * 6, [6, 1])
+        ],
+        axis=1)
+
+    # The new weight matrix is of size
+    # [5 feature vocab + 1 feature OOV, 4 class vocab + 1 class OOV].  Use a
+    # partitioned variable to confirm that the offset logic works.
+    remapped_matrix = variable_scope.get_variable(
+        name='linear/obtained_weight_matrix',
+        shape=[6, 5],
+        initializer=loading_initializer,
+        partitioner=partitioned_variables.fixed_size_partitioner(2))
+
+    with self.test_session():
+      variables.global_variables_initializer().run()
+      self.assertAllClose(expected_remapped_matrix,
+                          remapped_matrix.as_tensor().eval())
+
+  def test_load_and_remap_output_layer_weight_initializer_dnn_output(self):
+    """Tests for the output layer initializer in the DNN output case."""
+    loading_initializer = (checkpoint_ops._load_and_remap_matrix_initializer(
+        new_row_vocab_size=5,
+        new_col_vocab_file=self.new_class_vocab_file,
+        old_col_vocab_file=self.old_class_vocab_file,
+        new_col_vocab_size=4,
+        old_tensor_name='some_scope/embeddings',
+        ckpt_path=[self.checkpoint_file],
+        num_col_oov_buckets=1,
+        initializer=self.initializer))
+
+    expected_remapped_matrix = np.concatenate(
+        [
+            np.reshape([2, 18, 34, 50, 66], [5, 1]),
+            np.reshape([0, 16, 32, 48, 64], [5, 1]),
+            np.reshape([self.init_val] * 5, [5, 1]),
+            np.reshape([1, 17, 33, 49, 65], [5, 1]),
+            np.reshape([self.init_val] * 5, [5, 1])
+        ],
+        axis=1)
+
+    # The new weight matrix is of size
+    # [5-sized input layer, 4 class vocab + 1 class OOV].
+    remapped_matrix = variable_scope.get_variable(
+        name='dnn_output/obtained_weight_matrix',
+        shape=[5, 5],
+        initializer=loading_initializer,
+        partitioner=partitioned_variables.fixed_size_partitioner(2))
+
+    with self.test_session():
+      variables.global_variables_initializer().run()
+      self.assertAllClose(expected_remapped_matrix,
+                          remapped_matrix.as_tensor().eval())
+
+  def test_initializer_with_oov_only_partition(self):
+    """Tests for the output layer initializer where one partition is all OOV."""
+    loading_initializer = (checkpoint_ops._load_and_remap_matrix_initializer(
+        new_row_vocab_size=5,
+        new_col_vocab_file=self.new_class_vocab_file,
+        old_col_vocab_file=self.old_class_vocab_file,
+        new_col_vocab_size=4,
+        old_tensor_name='some_scope/embeddings',
+        ckpt_path=[self.checkpoint_file],
+        new_row_vocab_file=self.new_feature_vocab_file,
+        old_row_vocab_file=self.old_feature_vocab_file,
+        num_row_oov_buckets=5,
+        num_col_oov_buckets=1,
+        initializer=self.initializer))
+
+    expected_remapped_matrix = np.concatenate(
+        [
+            np.reshape([2, 18, 34, 50] + [self.init_val] * 6, [10, 1]),
+            np.reshape([0, 16, 32, 48] + [self.init_val] * 6, [10, 1]),
+            np.reshape([self.init_val] * 10, [10, 1]),
+            np.reshape([1, 17, 33, 49] + [self.init_val] * 6, [10, 1]),
+            np.reshape([self.init_val] * 10, [10, 1]),
+        ],
+        axis=1)
+
+    # The new weight matrix is of size
+    # [5 feature vocab + 5 feature OOV, 4 class vocab + 1 class OOV].  The
+    # second partition has only OOV.
+    remapped_matrix = variable_scope.get_variable(
+        name='linear_all_oov/obtained_weight_matrix',
+        shape=[10, 5],
+        initializer=loading_initializer,
+        partitioner=partitioned_variables.fixed_size_partitioner(2))
+
+    with self.test_session():
+      variables.global_variables_initializer().run()
+      self.assertAllClose(expected_remapped_matrix,
+                          remapped_matrix.as_tensor().eval())
+
+  def test_load_and_remap_linear_multiclass_initializer_default_init(self):
+    """Tests where the zeros_initializer default is used for linear."""
+    loading_initializer = (checkpoint_ops._load_and_remap_matrix_initializer(
+        new_row_vocab_size=5,
+        new_col_vocab_file=self.new_class_vocab_file,
+        old_col_vocab_file=self.old_class_vocab_file,
+        new_col_vocab_size=4,
+        old_tensor_name='some_scope/embeddings',
+        ckpt_path=[self.checkpoint_file],
+        new_row_vocab_file=self.new_feature_vocab_file,
+        old_row_vocab_file=self.old_feature_vocab_file,
+        num_row_oov_buckets=1,
+        num_col_oov_buckets=1))
+
+    expected_remapped_matrix = np.concatenate(
+        [
+            np.reshape([2, 18, 34, 50, 0, 0], [6, 1]),
+            np.reshape([0, 16, 32, 48, 0, 0], [6, 1]),
+            np.reshape([0] * 6, [6, 1]),
+            np.reshape([1, 17, 33, 49, 0, 0], [6, 1]),
+            np.reshape([0] * 6, [6, 1])
+        ],
+        axis=1)
+
+    remapped_matrix = variable_scope.get_variable(
+        name='linear_init_fallback/obtained_weight_matrix',
+        shape=[6, 5],
+        initializer=loading_initializer,
+        partitioner=partitioned_variables.fixed_size_partitioner(2))
+
+    with self.test_session():
+      variables.global_variables_initializer().run()
+      self.assertAllClose(expected_remapped_matrix,
+                          remapped_matrix.as_tensor().eval())
+
+  def test_load_embedding_initializer(self):
+    """Tests for the load_embedding_initializer wrapper."""
+    embedding_loading_initializer = (checkpoint_ops._load_embedding_initializer(
+        new_vocab_file=self.new_feature_vocab_file,
+        old_vocab_file=self.old_feature_vocab_file,
+        new_vocab_size=5,
+        embedding_dim=16,
+        embedding_tensor_name='some_scope/embeddings',
+        ckpt_path=[self.checkpoint_file],
+        num_oov_buckets=1,
+        initializer=self.initializer))
+
+    expected_remapped_embeddings = np.concatenate(
+        [
+            np.reshape(range(64), [4, 16]),
+            np.reshape([self.init_val] * 32, [2, 16]),
+        ],
+        axis=0)
+
+    # The new weight matrix is of size
+    # [5 feature vocab + 1 feature OOV, 16 (embedding dimension)], where the
+    # last vocab row (2nd last row) is newly initialized (wasn't found in
+    # previous vocab) and the actual last row is OOV and also newly initialized.
+    # Use a partitioned variable to confirm that the offset logic works.
+    remapped_embeddings = variable_scope.get_variable(
+        name='embedding/obtained_embedding_matrix',
+        shape=[6, 16],
+        initializer=embedding_loading_initializer,
+        partitioner=partitioned_variables.fixed_size_partitioner(2))
+
+    with self.test_session():
+      variables.global_variables_initializer().run()
+      self.assertAllClose(expected_remapped_embeddings,
+                          remapped_embeddings.as_tensor().eval())
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
index 9d6221b5606920e494785573e2b3e2de8a749303..ba9f7638311284548c1900c3ba1d510de2d03210 100644
--- a/tensorflow/python/training/momentum_test.py
+++ b/tensorflow/python/training/momentum_test.py
@@ -21,9 +21,11 @@ from __future__ import print_function
 import numpy as np
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import embedding_ops
 from tensorflow.python.ops import math_ops
@@ -43,66 +45,82 @@ class MomentumOptimizerTest(test.TestCase):
     return var, accum
 
   def doTestBasic(self, use_resource=False):
-    for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
-      with self.test_session():
-        if use_resource:
-          var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
-          var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
-        else:
-          var0 = variables.Variable([1.0, 2.0], dtype=dtype)
-          var1 = variables.Variable([3.0, 4.0], dtype=dtype)
-        grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
-        grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
-        mom_opt = momentum_lib.MomentumOptimizer(
-            learning_rate=2.0, momentum=0.9)
-        mom_update = mom_opt.apply_gradients(
-            zip([grads0, grads1], [var0, var1]))
-        variables.global_variables_initializer().run()
-        # Check we have slots
-        self.assertEqual(["momentum"], mom_opt.get_slot_names())
-        slot0 = mom_opt.get_slot(var0, "momentum")
-        self.assertEquals(slot0.get_shape(), var0.get_shape())
-        self.assertFalse(slot0 in variables.trainable_variables())
-        slot1 = mom_opt.get_slot(var1, "momentum")
-        self.assertEquals(slot1.get_shape(), var1.get_shape())
-        self.assertFalse(slot1 in variables.trainable_variables())
+    for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+      if use_resource:
+        var0 = resource_variable_ops.ResourceVariable(
+            [1.0, 2.0], dtype=dtype, name="var0_%d" % i)
+        var1 = resource_variable_ops.ResourceVariable(
+            [3.0, 4.0], dtype=dtype, name="var1_%d" % i)
+      else:
+        var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+        var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+      grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+      grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+      mom_opt = momentum_lib.MomentumOptimizer(
+          learning_rate=2.0, momentum=0.9)
+      mom_update = mom_opt.apply_gradients(
+          zip([grads0, grads1], [var0, var1]))
 
+      if context.in_graph_mode():
+        self.evaluate(variables.global_variables_initializer())
         # Fetch params to validate initial values
-        self.assertAllClose([1.0, 2.0], var0.eval())
-        self.assertAllClose([3.0, 4.0], var1.eval())
-        # Step 1: the momentum accumulators where 0. So we should see a normal
-        # update: v -= grad * learning_rate
-        mom_update.run()
-        # Check that the momentum accumulators have been updated.
-        self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval())
-        self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval())
-        # Check that the parameters have been updated.
-        self.assertAllCloseAccordingToType(
-            np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval())
-        self.assertAllCloseAccordingToType(
-            np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval())
-        # Step 2: the momentum accumulators contain the previous update.
-        mom_update.run()
-        # Check that the momentum accumulators have been updated.
-        self.assertAllCloseAccordingToType(
-            np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval())
-        self.assertAllCloseAccordingToType(
-            np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval())
-        # Check that the parameters have been updated.
-        self.assertAllCloseAccordingToType(
-            np.array([
-                1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
-                2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
-            ]), var0.eval())
-        self.assertAllCloseAccordingToType(
-            np.array([
-                2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
-                    (0.9 * 0.01 + 0.01) * 2.0)
-            ]), var1.eval())
+        self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+        self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+      # Check we have slots
+      self.assertEqual(["momentum"], mom_opt.get_slot_names())
+      slot0 = mom_opt.get_slot(var0, "momentum")
+      self.assertEquals(slot0.get_shape(), var0.get_shape())
+      self.assertFalse(slot0 in variables.trainable_variables())
+      slot1 = mom_opt.get_slot(var1, "momentum")
+      self.assertEquals(slot1.get_shape(), var1.get_shape())
+      self.assertFalse(slot1 in variables.trainable_variables())
+
+      # Step 1: the momentum accumulators where 0. So we should see a normal
+      # update: v -= grad * learning_rate
+      if context.in_graph_mode():
+        self.evaluate(mom_update)
+      # Check that the momentum accumulators have been updated.
+      self.assertAllCloseAccordingToType(np.array([0.1, 0.1]),
+                                         self.evaluate(slot0))
+      self.assertAllCloseAccordingToType(np.array([0.01, 0.01]),
+                                         self.evaluate(slot1))
+      # Check that the parameters have been updated.
+      self.assertAllCloseAccordingToType(
+          np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
+          self.evaluate(var0))
+      self.assertAllCloseAccordingToType(
+          np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
+          self.evaluate(var1))
+      # Step 2: the momentum accumulators contain the previous update.
+      if context.in_graph_mode():
+        self.evaluate(mom_update)
+      else:
+        mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+      # Check that the momentum accumulators have been updated.
+      self.assertAllCloseAccordingToType(
+          np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
+          self.evaluate(slot0))
+      self.assertAllCloseAccordingToType(
+          np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+          self.evaluate(slot1))
+      # Check that the parameters have been updated.
+      self.assertAllCloseAccordingToType(
+          np.array([
+              1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+              2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
+          ]), self.evaluate(var0))
+      self.assertAllCloseAccordingToType(
+          np.array([
+              2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
+                  (0.9 * 0.01 + 0.01) * 2.0)
+          ]), self.evaluate(var1))
 
   def testBasic(self):
-    self.doTestBasic(use_resource=False)
+    with self.test_session():
+      self.doTestBasic(use_resource=False)
 
+  @test_util.run_in_graph_and_eager_modes(reset_test=True)
   def testResourceBasic(self):
     self.doTestBasic(use_resource=True)
 
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 1562f656753bbe1ec40e68d9c951913af5deda80..e6162dd34b42e874bd896e04408d73ba3206ac69 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -20,6 +20,9 @@ from __future__ import division
 from __future__ import print_function
 
 import abc
+import sys
+
+import six
 
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.framework import errors
@@ -947,20 +950,21 @@ class _CoordinatedSession(_WrappedSession):
   def run(self, *args, **kwargs):
     try:
       return self._sess.run(*args, **kwargs)
-    except _PREEMPTION_ERRORS as original_exception:
-      raise original_exception
-    except Exception as original_exception:  # pylint: disable=broad-except
+    except _PREEMPTION_ERRORS:
+      raise
+    except Exception:  # pylint: disable=broad-except
       # A non-preemption error could have been caused by a preemption error
       # in the coordinator. If this is the case, raise that exception instead,
-      # since it's the root cause. Otherwise, stick to the `original_exception`.
+      # since it's the root cause. Otherwise, stick to the `original_exc_info`.
+      original_exc_info = sys.exc_info()
       try:
         self._coord.raise_requested_exception()
-      except _PREEMPTION_ERRORS as preemption_in_coordinator:
-        raise preemption_in_coordinator
+      except _PREEMPTION_ERRORS:
+        raise
       except Exception:  # pylint: disable=broad-except
-        raise original_exception
+        raise six.reraise(*original_exc_info)
       else:
-        raise original_exception
+        raise six.reraise(*original_exc_info)
 
 
 class _HookedSession(_WrappedSession):
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index a7c34cdd1b8cdffa43406825abf8f11eaac84393..d88b187fde58fc41d8aa19141df87bb518f5c681 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -22,8 +22,10 @@ from __future__ import print_function
 import collections
 import glob
 import os
+import sys
 import threading
 import time
+import traceback
 
 from tensorflow.contrib.framework.python.ops import variables as variables_lib
 from tensorflow.contrib.testing.python.framework import util_test
@@ -34,6 +36,7 @@ from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
@@ -506,6 +509,31 @@ class CoordinatedSessionTest(test.TestCase):
       self.assertTrue(coord.should_stop())
       self.assertTrue(coord_sess.should_stop())
 
+  def test_propagates_exception_trace(self):
+    assertion = control_flow_ops.Assert(False, ['This should fail.'])
+    with self.test_session() as sess:
+      coord = coordinator.Coordinator(clean_stop_exception_types=())
+      coord_sess = monitored_session._CoordinatedSession(sess, coord)
+      try:
+        coord_sess.run([assertion])
+        self.fail('No exception was raised by assertion.')
+      except errors_impl.InvalidArgumentError:
+        # Extract the name of the file where the exception was first raised.
+        _, _, exc_traceback = sys.exc_info()
+        tb = traceback.extract_tb(exc_traceback)
+        exc_source_file = tb[-1][0]
+        exc_source_basename = os.path.basename(exc_source_file)
+        # If it's monitored_session.py then the original stack trace was not
+        # correctly propagated.
+        self.assertIn(
+            exc_source_basename, ['session.py', 'monitored_session.py'],
+            'The exception was raised from an unrecognized file. This unit '
+            'test probably needs to be updated. Traceback:\n%s\n' % tb)
+        self.assertEqual(
+            exc_source_basename, 'session.py',
+            'Original stack trace was not propagated by MonitoredSession. '
+            'Traceback:\n%s' % tb)
+
 
 class AbortAtNSession(object):
   """A mock session that aborts at the N-th run call."""
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 250e22f91eee081c6ce9dd526b3890214b61b946..86ba8e2c8e471f453f9155778943aeacbad9941f 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -69,6 +69,8 @@ def _deduplicate_indexed_slices(values, indices):
 
 
 def _var_key(var):
+  if context.in_eager_mode():
+    return var._shared_name  # pylint: disable=protected-access
   return (var.op.graph, var.op.name)
 
 
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index fb1ac0029ab0e20a5831257012a054a629018e08..dc9bd3a8c2367e4b9e72d6963c103377fea59b04 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -33,6 +33,7 @@ from google.protobuf import text_format
 from tensorflow.core.protobuf import meta_graph_pb2
 from tensorflow.core.protobuf import saver_pb2
 from tensorflow.python.client import session
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import device as pydev
 from tensorflow.python.framework import errors
@@ -92,14 +93,18 @@ class BaseSaverBuilder(object):
       """Creates a `SaveSpec` object.
 
       Args:
-        tensor: the tensor to save.
+        tensor: the tensor to save or callable that produces a tensor to save.
         slice_spec: the slice to be saved. See `Variable.SaveSliceInfo`.
         name: the name to save the tensor under.
       """
-      self.tensor = tensor
+      self._tensor = tensor
       self.slice_spec = slice_spec
       self.name = name
 
+    @property
+    def tensor(self):
+      return self._tensor() if callable(self._tensor) else self._tensor
+
   class SaveableObject(object):
     """Base class for saving and restoring saveable objects."""
 
@@ -160,13 +165,15 @@ class BaseSaverBuilder(object):
     def __init__(self, var, slice_spec, name):
       if isinstance(var, ops.Tensor):
         self.handle_op = var.op.inputs[0]
+        tensor = var
       elif isinstance(var, resource_variable_ops.ResourceVariable):
         self.handle_op = var.handle
+        tensor = var.read_value
       else:
         raise ValueError(
             "Saveable is neither a resource variable nor a read operation."
             " Got: %s" % repr(var))
-      spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name)
+      spec = BaseSaverBuilder.SaveSpec(tensor, slice_spec, name)
       super(BaseSaverBuilder.ResourceVariableSaveable, self).__init__(
           var, [spec], name)
 
@@ -521,17 +528,24 @@ class BaseSaverBuilder(object):
         else:
           names_to_saveables[name] = [var]
       else:
-        var = ops.internal_convert_to_tensor(var, as_ref=True)
-        if not BaseSaverBuilder._IsVariable(var):
-          raise TypeError("Variable to save is not a Variable: %s" % var)
-        if var.op.type == "ReadVariableOp":
-          name = var.op.inputs[0].op.name
+        if context.in_graph_mode():
+          var = ops.internal_convert_to_tensor(var, as_ref=True)
+          if not BaseSaverBuilder._IsVariable(var):
+            raise TypeError("Variable to save is not a Variable: %s" % var)
+          if var.op.type == "ReadVariableOp":
+            name = var.op.inputs[0].op.name
+          else:
+            name = var.op.name
+          if name in names_to_saveables:
+            raise ValueError("At least two variables have the same name: %s" %
+                             name)
+          names_to_saveables[name] = var
         else:
-          name = var.op.name
-        if name in names_to_saveables:
-          raise ValueError("At least two variables have the same name: %s" %
-                           name)
-        names_to_saveables[name] = var
+          if not isinstance(var, resource_variable_ops.ResourceVariable):
+            raise ValueError("Can only save/restore ResourceVariable eager "
+                             "mode is enabled, type: %s." % type(var))
+          names_to_saveables[var._shared_name] = var
+
       # pylint: enable=protected-access
     return names_to_saveables
 
@@ -592,16 +606,23 @@ class BaseSaverBuilder(object):
         # pylint: enable=protected-access
       else:
         # A variable or tensor.
-        variable = ops.internal_convert_to_tensor(op, as_ref=True)
-        if not BaseSaverBuilder._IsVariable(variable):
-          raise TypeError("names_to_saveables must be a dict mapping string "
-                          "names to Tensors/Variables. Not a variable: %s" %
-                          variable)
-        if variable.op.type in ["Variable", "VariableV2", "AutoReloadVariable"]:
-          saveable = BaseSaverBuilder.VariableSaveable(variable, "", name)
+        if context.in_eager_mode():
+          if not isinstance(op, resource_variable_ops.ResourceVariable):
+            raise ValueError("Can only save/restore ResourceVariable eager "
+                             "mode is enabled, type: %s." % type(op))
+          saveable = BaseSaverBuilder.ResourceVariableSaveable(op, "", name)
         else:
-          saveable = BaseSaverBuilder.ResourceVariableSaveable(
-              variable, "", name)
+          variable = ops.internal_convert_to_tensor(op, as_ref=True)
+          if not BaseSaverBuilder._IsVariable(variable):
+            raise TypeError("names_to_saveables must be a dict mapping string "
+                            "names to Tensors/Variables. Not a variable: %s" %
+                            variable)
+          if variable.op.type in ["Variable", "VariableV2",
+                                  "AutoReloadVariable"]:
+            saveable = BaseSaverBuilder.VariableSaveable(variable, "", name)
+          else:
+            saveable = BaseSaverBuilder.ResourceVariableSaveable(
+                variable, "", name)
         self._AddSaveable(saveables, seen_ops, saveable)
     return saveables
 
@@ -632,7 +653,7 @@ class BaseSaverBuilder(object):
             name=None,
             restore_sequentially=False,
             filename="model"):
-    """Adds save/restore nodes to the graph and creates a SaverDef proto.
+    """Builds save/restore graph nodes or runs save/restore in eager mode.
 
     Args:
       names_to_saveables: A dictionary mapping name to a Variable or
@@ -667,6 +688,31 @@ class BaseSaverBuilder(object):
       ValueError: If any of the keys or values in 'names_to_saveables' is not
         unique.
     """
+    return self._build_internal(
+        names_to_saveables=names_to_saveables,
+        reshape=reshape,
+        sharded=sharded,
+        max_to_keep=max_to_keep,
+        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
+        name=name,
+        restore_sequentially=restore_sequentially,
+        filename=filename)
+
+  def _build_internal(self,
+                      names_to_saveables,
+                      reshape=False,
+                      sharded=False,
+                      max_to_keep=5,
+                      keep_checkpoint_every_n_hours=10000.0,
+                      name=None,
+                      restore_sequentially=False,
+                      filename="model",
+                      build_save=True,
+                      build_restore=True):
+    """build() with option to only perform save and restore."""
+    if context.in_graph_mode() and (not build_save or not build_restore):
+      raise ValueError("Graph mode needs to build save and restore together.")
+
     saveables = self._ValidateAndSliceInputs(names_to_saveables)
     if max_to_keep is None:
       max_to_keep = 0
@@ -679,13 +725,17 @@ class BaseSaverBuilder(object):
       # Add the save ops.
       if sharded:
         per_device = self._GroupByDevices(saveables)
-        save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
-        restore_op = self._AddShardedRestoreOps(filename_tensor, per_device,
-                                                restore_sequentially, reshape)
+        if build_save:
+          save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
+        if build_restore:
+          restore_op = self._AddShardedRestoreOps(filename_tensor, per_device,
+                                                  restore_sequentially, reshape)
       else:
-        save_tensor = self._AddSaveOps(filename_tensor, saveables)
-        restore_op = self._AddRestoreOps(filename_tensor, saveables,
-                                         restore_sequentially, reshape)
+        if build_save:
+          save_tensor = self._AddSaveOps(filename_tensor, saveables)
+        if build_restore:
+          restore_op = self._AddRestoreOps(filename_tensor, saveables,
+                                           restore_sequentially, reshape)
 
     # In the following use case, it's possible to have restore_ops be called
     # something else:
@@ -698,15 +748,26 @@ class BaseSaverBuilder(object):
     # such usage model makes sense.
     #
     # assert restore_op.name.endswith("restore_all"), restore_op.name
-
-    return saver_pb2.SaverDef(
-        filename_tensor_name=filename_tensor.name,
-        save_tensor_name=save_tensor.name,
-        restore_op_name=restore_op.name,
-        max_to_keep=max_to_keep,
-        sharded=sharded,
-        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
-        version=self._write_version)
+    if context.in_graph_mode():
+      return saver_pb2.SaverDef(
+          filename_tensor_name=filename_tensor.name,
+          save_tensor_name=save_tensor.name,
+          restore_op_name=restore_op.name,
+          max_to_keep=max_to_keep,
+          sharded=sharded,
+          keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
+          version=self._write_version)
+    else:
+      # Store the tensor values to the tensor_names.
+      save_tensor_name = save_tensor.numpy() if build_save else ""
+      return saver_pb2.SaverDef(
+          filename_tensor_name=filename_tensor.numpy(),
+          save_tensor_name=save_tensor_name,
+          restore_op_name="",
+          max_to_keep=max_to_keep,
+          sharded=sharded,
+          keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
+          version=self._write_version)
 
 
 def _get_saver_or_default():
@@ -1136,7 +1197,7 @@ class Saver(object):
     self._write_version = write_version
     self._pad_step_number = pad_step_number
     self._filename = filename
-    if not defer_build:
+    if not defer_build and context.in_graph_mode():
       self.build()
     if self.saver_def:
       self._check_saver_def()
@@ -1144,11 +1205,22 @@ class Saver(object):
     self._save_relative_paths = save_relative_paths
 
   def build(self):
+    if context.in_eager_mode():
+      raise ValueError("Use save/restore instead of build in eager mode.")
+    self._build(self._filename, build_save=True, build_restore=True)
+
+  def _build_eager(self, checkpoint_path, build_save, build_restore):
+    self._build(
+        checkpoint_path, build_save=build_save, build_restore=build_restore)
+
+  def _build(self, checkpoint_path, build_save, build_restore):
     """Builds saver_def."""
-    if self._is_built:
-      return
-    self._is_built = True
-    if not self.saver_def:
+    if context.in_graph_mode():
+      if self._is_built:
+        return
+      self._is_built = True
+
+    if not self.saver_def or context.in_eager_mode():
       if self._builder is None:
         self._builder = BaseSaverBuilder(self._write_version)
       if self._var_list is None:
@@ -1161,7 +1233,8 @@ class Saver(object):
         else:
           raise ValueError("No variables to save")
       self._is_empty = False
-      self.saver_def = self._builder.build(
+
+      self.saver_def = self._builder._build_internal(  # pylint: disable=protected-access
           self._var_list,
           reshape=self._reshape,
           sharded=self._sharded,
@@ -1169,7 +1242,8 @@ class Saver(object):
           keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
           name=self._name,
           restore_sequentially=self._restore_sequentially,
-          filename=self._filename)
+          filename=checkpoint_path,
+          build_save=build_save, build_restore=build_restore)
     elif self.saver_def and self._name:
       # Since self._name is used as a name_scope by builder(), we are
       # overloading the use of this field to represent the "import_scope" as
@@ -1191,12 +1265,13 @@ class Saver(object):
     if not isinstance(self.saver_def, saver_pb2.SaverDef):
       raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" %
                        self.saver_def)
-    if not self.saver_def.save_tensor_name:
-      raise ValueError("saver_def must specify the save_tensor_name: %s" %
-                       str(self.saver_def))
-    if not self.saver_def.restore_op_name:
-      raise ValueError("saver_def must specify the restore_op_name: %s" %
-                       str(self.saver_def))
+    if context.in_graph_mode():
+      if not self.saver_def.save_tensor_name:
+        raise ValueError("saver_def must specify the save_tensor_name: %s" %
+                         str(self.saver_def))
+      if not self.saver_def.restore_op_name:
+        raise ValueError("saver_def must specify the restore_op_name: %s" %
+                         str(self.saver_def))
 
   def _CheckpointFilename(self, p):
     """Returns the checkpoint filename given a `(filename, time)` pair.
@@ -1402,7 +1477,7 @@ class Saver(object):
     path can be passed directly to a call to `restore()`.
 
     Args:
-      sess: A Session to use to save the variables.
+      sess: A Session to use to save the variables. None in eager mode.
       save_path: String.  Path to the checkpoint filename.  If the saver is
         `sharded`, this is the prefix of the sharded checkpoint filename.
       global_step: If provided the global step number is appended to
@@ -1431,7 +1506,7 @@ class Saver(object):
         collides with `save_path`.
       RuntimeError: If save and restore ops weren't built.
     """
-    if not self._is_built:
+    if not self._is_built and context.in_graph_mode():
       raise RuntimeError(
           "`build()` should be called before save if defer_build==True")
     if latest_filename is None:
@@ -1457,21 +1532,28 @@ class Saver(object):
     else:
       checkpoint_file = save_path
       if os.path.basename(
-          save_path) == latest_filename and not self.saver_def.sharded:
+          save_path) == latest_filename and not self._sharded:
         # Guard against collision between data file and checkpoint state file.
         raise ValueError(
             "'latest_filename' collides with 'save_path': '%s' and '%s'" %
             (latest_filename, save_path))
 
-    if not isinstance(sess, session.SessionInterface):
+    if (context.in_graph_mode() and
+        not isinstance(sess, session.SessionInterface)):
       raise TypeError("'sess' must be a Session; %s" % sess)
 
     save_path_parent = os.path.dirname(save_path)
     if not self._is_empty:
       try:
-        model_checkpoint_path = sess.run(
-            self.saver_def.save_tensor_name,
-            {self.saver_def.filename_tensor_name: checkpoint_file})
+        if context.in_graph_mode():
+          model_checkpoint_path = sess.run(
+              self.saver_def.save_tensor_name,
+              {self.saver_def.filename_tensor_name: checkpoint_file})
+        else:
+          self._build_eager(
+              checkpoint_file, build_save=True, build_restore=False)
+          model_checkpoint_path = self.saver_def.save_tensor_name
+
         model_checkpoint_path = compat.as_str(model_checkpoint_path)
         if write_state:
           self._MaybeDeleteOldCheckpoints(
@@ -1492,8 +1574,11 @@ class Saver(object):
     if write_meta_graph:
       meta_graph_filename = self._MetaGraphFilename(
           checkpoint_file, meta_graph_suffix=meta_graph_suffix)
-      with sess.graph.as_default():
+      if context.in_eager_mode():
         self.export_meta_graph(meta_graph_filename)
+      else:
+        with sess.graph.as_default():
+          self.export_meta_graph(meta_graph_filename)
 
     if self._is_empty:
       return None
@@ -1545,7 +1630,7 @@ class Saver(object):
     `save()` call, or a call to `latest_checkpoint()`.
 
     Args:
-      sess: A `Session` to use to restore the parameters.
+      sess: A `Session` to use to restore the parameters. None in eager mode.
       save_path: Path where parameters were previously saved.
 
     Raises:
@@ -1556,8 +1641,11 @@ class Saver(object):
     if save_path is None:
       raise ValueError("Can't load save_path when it is None.")
     logging.info("Restoring parameters from %s", save_path)
-    sess.run(self.saver_def.restore_op_name,
-             {self.saver_def.filename_tensor_name: save_path})
+    if context.in_graph_mode():
+      sess.run(self.saver_def.restore_op_name,
+               {self.saver_def.filename_tensor_name: save_path})
+    else:
+      self._build_eager(save_path, build_save=False, build_restore=True)
 
   @staticmethod
   def _add_collection_def(meta_graph_def, key, export_scope=None):
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 2d878e3057fa411a93c974a1a194aa1aac601125..e66993f50b970f637ef3967a09e005c9ba7a3ba7 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -38,6 +38,7 @@ from tensorflow.core.protobuf import queue_runner_pb2
 from tensorflow.core.protobuf import saver_pb2
 from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.client import session
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -77,97 +78,170 @@ class SaverTest(test.TestCase):
   def basicSaveRestore(self, variable_op):
     save_path = os.path.join(self.get_temp_dir(), "basic_save_restore")
 
-    # Build a graph with 2 parameter nodes, and Save and
-    # Restore nodes for them.
-    v0 = variable_op(10.0, name="v0")
-    v1 = variable_op(20.0, name="v1")
-    v2 = saver_test_utils.CheckpointedOp(name="v2")
-    v2_init = v2.insert("k1", 30.0)
-    save = saver_module.Saver(
-        {
-            "v0": v0,
-            "v1": v1,
-            "v2": v2.saveable
-        }, restore_sequentially=True)
-    init_all_op = [variables.global_variables_initializer(), v2_init]
+    with self.test_session(graph=ops_lib.Graph()) as sess:
+      # Build a graph with 2 parameter nodes, and Save and
+      # Restore nodes for them.
+      v0 = variable_op(10.0, name="v0")
+      v1 = variable_op(20.0, name="v1")
+      v2 = saver_test_utils.CheckpointedOp(name="v2")
+      v2_init = v2.insert("k1", 30.0)
 
-    with self.test_session() as sess:
       # Initialize all variables
-      sess.run(init_all_op)
+      if context.in_graph_mode():
+        self.evaluate([variables.global_variables_initializer(), v2_init])
 
-      # Check that the parameter nodes have been initialized.
-      self.assertEqual(10.0, v0.eval())
-      self.assertEqual(20.0, v1.eval())
-      self.assertEqual(b"k1", v2.keys().eval())
-      self.assertEqual(30.0, v2.values().eval())
+        # Check that the parameter nodes have been initialized.
+      self.assertEqual(10.0, self.evaluate(v0))
+      self.assertEqual(20.0, self.evaluate(v1))
+      self.assertEqual(b"k1", self.evaluate(v2.keys()))
+      self.assertEqual(30.0, self.evaluate(v2.values()))
 
       # Save the initialized values in the file at "save_path"
+      save = saver_module.Saver(
+          {
+              "v0": v0,
+              "v1": v1,
+              "v2": v2.saveable
+          }, restore_sequentially=True)
       val = save.save(sess, save_path)
       self.assertTrue(isinstance(val, six.string_types))
       self.assertEqual(save_path, val)
 
     # Start a second session.  In that session the parameter nodes
     # have not been initialized either.
-    with self.test_session() as sess:
+    with self.test_session(graph=ops_lib.Graph()) as sess:
       v0 = variable_op(-1.0, name="v0")
       v1 = variable_op(-1.0, name="v1")
       v2 = saver_test_utils.CheckpointedOp(name="v2")
-      save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
 
       # Assert that the variables are not initialized.
-      self.assertEqual(
-          len(variables.report_uninitialized_variables().eval()), 2)
-      self.assertEqual(0, len(v2.keys().eval()))
-      self.assertEqual(0, len(v2.values().eval()))
-
+      if context.in_graph_mode():
+        self.assertEqual(
+            len(variables.report_uninitialized_variables().eval()), 2)
+        self.assertEqual(0, len(v2.keys().eval()))
+        self.assertEqual(0, len(v2.values().eval()))
       # Restore the saved values in the parameter nodes.
+      save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
       save.restore(sess, save_path)
       # Check that the parameter nodes have been restored.
-      self.assertEqual(10.0, v0.eval())
-      self.assertEqual(20.0, v1.eval())
-      self.assertEqual(b"k1", v2.keys().eval())
-      self.assertEqual(30.0, v2.values().eval())
+      self.assertEqual(10.0, self.evaluate(v0))
+      self.assertEqual(20.0, self.evaluate(v1))
+      self.assertEqual(b"k1", self.evaluate(v2.keys()))
+      self.assertEqual(30.0, self.evaluate(v2.values()))
 
     # Build another graph with 2 nodes, initialized
     # differently, and a Restore node for them.
-    with self.test_session() as sess:
+    with self.test_session(graph=ops_lib.Graph()) as sess:
       v0_2 = variable_op(1000.0, name="v0")
       v1_2 = variable_op(2000.0, name="v1")
       v2_2 = saver_test_utils.CheckpointedOp(name="v2")
-      save2 = saver_module.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable})
-      v2_2.insert("k1000", 3000.0).run()
-      variables.global_variables_initializer().run()
+      v2_init = v2_2.insert("k1000", 3000.0)
 
       # Check that the parameter nodes have been initialized.
-      self.assertEqual(1000.0, v0_2.eval())
-      self.assertEqual(2000.0, v1_2.eval())
-      self.assertEqual(b"k1000", v2_2.keys().eval())
-      self.assertEqual(3000.0, v2_2.values().eval())
+      if context.in_graph_mode():
+        init_all_op = [variables.global_variables_initializer(), v2_init]
+        self.evaluate(init_all_op)
+        # TODO(xpan): Why _mutable_hash_table_v2 doesn't create empty
+        # table as it claims in eager mode?
+        self.assertEqual(b"k1000", self.evaluate(v2_2.keys()))
+        self.assertEqual(3000.0, self.evaluate(v2_2.values()))
+      self.assertEqual(1000.0, self.evaluate(v0_2))
+      self.assertEqual(2000.0, self.evaluate(v1_2))
+
       # Restore the values saved earlier in the parameter nodes.
+      save2 = saver_module.Saver({"v0": v0_2, "v1": v1_2, "v2": v2_2.saveable})
       save2.restore(sess, save_path)
       # Check that the parameter nodes have been restored.
-      self.assertEqual(10.0, v0_2.eval())
-      self.assertEqual(20.0, v1_2.eval())
-      self.assertEqual(b"k1", v2_2.keys().eval())
-      self.assertEqual(30.0, v2_2.values().eval())
+      self.assertEqual(10.0, self.evaluate(v0_2))
+      self.assertEqual(20.0, self.evaluate(v1_2))
+      self.assertEqual(b"k1", self.evaluate(v2_2.keys()))
+      self.assertEqual(30.0, self.evaluate(v2_2.values()))
 
   def testBasic(self):
     self.basicSaveRestore(variables.Variable)
 
+  @test_util.run_in_graph_and_eager_modes()
   def testResourceBasic(self):
     self.basicSaveRestore(resource_variable_ops.ResourceVariable)
 
+  def testEagerBasic(self):
+    with context.eager_mode():
+      ckpt_prefix = os.path.join(self.get_temp_dir(), "ckpt")
+
+      v1 = resource_variable_ops.ResourceVariable(3.14, name="v1")
+      v2 = resource_variable_ops.ResourceVariable([1, 2], name="v2")
+      save = saver_module.Saver([v1, v2])
+      save.save(None, ckpt_prefix)
+
+      v1.assign(0.0)
+      v2.assign([0, 0])
+      self.assertNear(0.0, self.evaluate(v1), 1e-5)
+      self.assertAllEqual([0, 0], self.evaluate(v2))
+
+      save.restore(None, ckpt_prefix)
+      self.assertNear(3.14, self.evaluate(v1), 1e-5)
+      self.assertAllEqual([1, 2], self.evaluate(v2))
+
+  def testEagerGraphCompatibility(self):
+    # Save from graph mode and restore from eager mode.
+    graph_ckpt_prefix = os.path.join(self.get_temp_dir(), "graph_ckpt")
+    with context.graph_mode():
+      with self.test_session(graph=ops_lib.Graph()) as sess:
+        # Create a graph model and save the checkpoint.
+        w1 = resource_variable_ops.ResourceVariable(1.0, name="w1")
+        w2 = resource_variable_ops.ResourceVariable(2.0, name="w2")
+        graph_saver = saver_module.Saver([w1, w2])
+        sess.run(variables.global_variables_initializer())
+        graph_saver.save(sess, graph_ckpt_prefix)
+
+    with context.eager_mode():
+      ops_lib._default_graph_stack.reset()  # pylint: disable=protected-access
+      ops_lib.reset_default_graph()
+
+      w1 = resource_variable_ops.ResourceVariable(0.0, name="w1")
+      w2 = resource_variable_ops.ResourceVariable(0.0, name="w2")
+
+      graph_saver = saver_module.Saver()
+      graph_saver.restore(None, graph_ckpt_prefix)
+
+      self.assertAllEqual(self.evaluate(w1), 1.0)
+      self.assertAllEqual(self.evaluate(w2), 2.0)
+
+    # Save from eager mode and restore from graph mode.
+    eager_ckpt_prefix = os.path.join(self.get_temp_dir(), "eager_ckpt")
+    with context.eager_mode():
+      ops_lib._default_graph_stack.reset()  # pylint: disable=protected-access
+      ops_lib.reset_default_graph()
+
+      w3 = resource_variable_ops.ResourceVariable(3.0, name="w3")
+      w4 = resource_variable_ops.ResourceVariable(4.0, name="w4")
+
+      graph_saver = saver_module.Saver()
+      graph_saver.save(None, eager_ckpt_prefix)
+
+    with context.graph_mode():
+      with self.test_session(graph=ops_lib.Graph()) as sess:
+        w3 = resource_variable_ops.ResourceVariable(0.0, name="w3")
+        w4 = resource_variable_ops.ResourceVariable(0.0, name="w4")
+        graph_saver = saver_module.Saver([w3, w4])
+        sess.run(variables.global_variables_initializer())
+        graph_saver.restore(sess, eager_ckpt_prefix)
+        self.assertAllEqual(w3.eval(), 3.0)
+        self.assertAllEqual(w4.eval(), 4.0)
+
+  @test_util.run_in_graph_and_eager_modes()
   def testResourceSaveRestoreCachingDevice(self):
     save_path = os.path.join(self.get_temp_dir(), "resource_cache")
-    v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0")
-    with self.test_session() as sess:
-      variables.global_variables_initializer().run()
+    with self.test_session(graph=ops_lib.Graph()) as sess:
+      v = resource_variable_ops.ResourceVariable([1], caching_device="/cpu:0")
+      if context.in_graph_mode():
+        self.evaluate(variables.global_variables_initializer())
       save = saver_module.Saver()
       save.save(sess, save_path)
-    with self.test_session() as sess:
+
       save2 = saver_module.Saver()
       save2.restore(sess, save_path)
-      self.assertEquals(v.eval(), [1])
+      self.assertEquals(self.evaluate(v), [1])
 
   def testSaveCopyRestoreWithSaveRelativePaths(self):
     """Save, copy checkpoint dir and restore from copied dir.
@@ -404,16 +478,17 @@ class SaverTest(test.TestCase):
 
   def _SaveAndLoad(self, var_name, var_value, other_value, save_path):
     with self.test_session() as sess:
-      var = variables.Variable(var_value, name=var_name)
+      var = resource_variable_ops.ResourceVariable(var_value, name=var_name)
       save = saver_module.Saver({var_name: var})
-      var.initializer.run()
+      if context.in_graph_mode():
+        self.evaluate(var.initializer)
       val = save.save(sess, save_path)
       self.assertEqual(save_path, val)
     with self.test_session() as sess:
-      var = variables.Variable(other_value, name=var_name)
+      var = resource_variable_ops.ResourceVariable(other_value, name=var_name)
       save = saver_module.Saver({var_name: var})
       save.restore(sess, save_path)
-      self.assertAllClose(var_value, var.eval())
+      self.assertAllClose(var_value, self.evaluate(var))
 
   def testCacheRereadsFile(self):
     save_path = os.path.join(self.get_temp_dir(), "cache_rereads")
@@ -535,30 +610,32 @@ class SaverTest(test.TestCase):
       save.restore(sess, save_path)
       self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], var.eval())
 
+  @test_util.run_in_graph_and_eager_modes()
   def testSaveWithGlobalStep(self, pad_step_number=False):
     save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step")
     global_step_int = 5
     # Save and reload one Variable named "var0".
     self._SaveAndLoad("var0", 0.0, 1.0, save_path)
     for use_tensor in [True, False]:
-      with self.test_session() as sess:
-        var = variables.Variable(1.0, name="var0")
-        save = saver_module.Saver(
-            {
-                var.op.name: var
-            }, pad_step_number=pad_step_number)
-        var.initializer.run()
-        if use_tensor:
-          global_step = constant_op.constant(global_step_int)
-          val = save.save(sess, save_path, global_step=global_step)
-        else:
-          val = save.save(sess, save_path, global_step=global_step_int)
-        if pad_step_number:
-          expected_save_path = "%s-%s" % (save_path,
-                                          "{:08d}".format(global_step_int))
-        else:
-          expected_save_path = "%s-%d" % (save_path, global_step_int)
-        self.assertEqual(expected_save_path, val)
+      var = resource_variable_ops.ResourceVariable(1.0, name="var0")
+      save = saver_module.Saver(
+          {
+              var._shared_name: var
+          }, pad_step_number=pad_step_number)
+      if context.in_graph_mode():
+        self.evaluate(var.initializer)
+      sess = ops_lib.get_default_session() if context.in_graph_mode() else None
+      if use_tensor:
+        global_step = constant_op.constant(global_step_int)
+        val = save.save(sess, save_path, global_step=global_step)
+      else:
+        val = save.save(sess, save_path, global_step=global_step_int)
+      if pad_step_number:
+        expected_save_path = "%s-%s" % (save_path,
+                                        "{:08d}".format(global_step_int))
+      else:
+        expected_save_path = "%s-%d" % (save_path, global_step_int)
+      self.assertEqual(expected_save_path, val)
 
   def testSaveWithGlobalStepWithPadding(self):
     self.testSaveWithGlobalStep(pad_step_number=True)
diff --git a/tensorflow/python/training/saver_test_utils.py b/tensorflow/python/training/saver_test_utils.py
index bcabb4130409891aaabdad65a6ff238424f3e8bd..44b06b357ecbe4c8e330a2ccc49e83ddd4bf8c7d 100644
--- a/tensorflow/python/training/saver_test_utils.py
+++ b/tensorflow/python/training/saver_test_utils.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops as ops_lib
 from tensorflow.python.ops import gen_lookup_ops
@@ -39,9 +40,10 @@ class CheckpointedOp(object):
     else:
       self.table_ref = table_ref
     self._name = name
-    self._saveable = CheckpointedOp.CustomSaveable(self, name)
-    ops_lib.add_to_collection(ops_lib.GraphKeys.SAVEABLE_OBJECTS,
-                              self._saveable)
+    if context.in_graph_mode():
+      self._saveable = CheckpointedOp.CustomSaveable(self, name)
+      ops_lib.add_to_collection(ops_lib.GraphKeys.SAVEABLE_OBJECTS,
+                                self._saveable)
 
   @property
   def name(self):
@@ -49,7 +51,10 @@ class CheckpointedOp(object):
 
   @property
   def saveable(self):
-    return self._saveable
+    if context.in_graph_mode():
+      return self._saveable
+    else:
+      return CheckpointedOp.CustomSaveable(self, self.name)
 
   def insert(self, keys, values):
     return gen_lookup_ops._lookup_table_insert_v2(self.table_ref, keys, values)
diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py
index 4371e92bd300984836fcb4efb879b7d8720e4121..ea28b5ddfc2dbbf65ec60e86d29ff2a9988d2b97 100644
--- a/tensorflow/python/training/slot_creator.py
+++ b/tensorflow/python/training/slot_creator.py
@@ -39,6 +39,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
@@ -139,7 +140,8 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name,
   # and the same name has been previously used, the scope name will add '_N'
   # as suffix for unique identifications.
   validate_shape = shape.is_fully_defined()
-  with variable_scope.variable_scope(None, primary.op.name + "/" + name):
+  prefix = primary.op.name if context.in_graph_mode() else primary._shared_name  # pylint: disable=protected-access
+  with variable_scope.variable_scope(None, prefix + "/" + name):
     if colocate_with_primary:
       with ops.colocate_with(primary):
         return _create_slot_var(primary, initializer, "", validate_shape, shape,
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py
index a40f63ccb329ad6c1c7cc97157c64b8a1f563377..9f2f9b7479efd354cc4c5ae675458a652a64c86f 100644
--- a/tensorflow/python/training/training_util.py
+++ b/tensorflow/python/training/training_util.py
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import graph_io
 from tensorflow.python.framework import ops
@@ -55,6 +56,8 @@ def global_step(sess, global_step_tensor):
   Returns:
     The global step value.
   """
+  if context.in_eager_mode():
+    return int(global_step_tensor.numpy())
   return int(sess.run(global_step_tensor))
 
 
@@ -154,6 +157,7 @@ def assert_global_step(global_step_tensor):
     raise TypeError('Existing "global_step" does not have integer type: %s' %
                     global_step_tensor.dtype)
 
-  if global_step_tensor.get_shape().ndims != 0:
+  if (global_step_tensor.get_shape().ndims != 0 and
+      global_step_tensor.get_shape().is_fully_defined()):
     raise TypeError('Existing "global_step" is not scalar: %s' %
                     global_step_tensor.get_shape())
diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..47396997d34586543a4bec183f832786261c0678
--- /dev/null
+++ b/tensorflow/python/util/tf_export.py
@@ -0,0 +1,128 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for exporting TensorFlow symbols to the API.
+
+Exporting a function or a class:
+
+To export a function or a class use tf_export decorator. For e.g.:
+```python
+@tf_export('foo', 'bar.foo')
+def foo(...):
+  ...
+```
+
+If a function is assigned to a variable, you can export it by calling
+tf_export explicitly. For e.g.:
+```python
+foo = get_foo(...)
+tf_export('foo', 'bar.foo')(foo)
+```
+
+
+Exporting a constant
+```python
+foo = 1
+tf_export("consts.foo").export_constant(__name__, foo)
+```
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+from tensorflow.python.util import tf_decorator
+
+
+class SymbolAlreadyExposedError(Exception):
+  """Raised when adding API names to symbol that already has API names."""
+  pass
+
+
+class tf_export(object):  # pylint: disable=invalid-name
+  """Provides ways to export symbols to the TensorFlow API."""
+
+  def __init__(self, *args, **kwargs):
+    """Export under the names *args (first one is considered canonical).
+
+    Args:
+      *args: API names in dot delimited format.
+      **kwargs: Optional keyed arguments. Currently only supports 'overrides'
+        argument. overrides: List of symbols that this is overriding
+        (those overrided api exports will be removed). Note: passing overrides
+        has no effect on exporting a constant.
+    """
+    self._names = args
+    self._overrides = kwargs.get('overrides', [])
+
+  def __call__(self, func):
+    """Calls this decorator.
+
+    Args:
+      func: decorated symbol (function or class).
+
+    Returns:
+      The input function with _tf_api_names attribute set.
+
+    Raises:
+      SymbolAlreadyExposedError: Raised when a symbol already has API names.
+    """
+    # Undecorate overridden names
+    for f in self._overrides:
+      _, undecorated_f = tf_decorator.unwrap(f)
+      del undecorated_f._tf_api_names  # pylint: disable=protected-access
+
+    _, undecorated_func = tf_decorator.unwrap(func)
+
+    # Check for an existing api. We check if attribute name is in
+    # __dict__ instead of using hasattr to verify that subclasses have
+    # their own _tf_api_names as opposed to just inheriting it.
+    if '_tf_api_names' in undecorated_func.__dict__:
+      # pylint: disable=protected-access
+      raise SymbolAlreadyExposedError(
+          'Symbol %s is already exposed as %s.' %
+          (undecorated_func.__name__, undecorated_func._tf_api_names))
+      # pylint: enable=protected-access
+
+    # Complete the export by creating/overriding attribute
+    # pylint: disable=protected-access
+    undecorated_func._tf_api_names = self._names
+    # pylint: enable=protected-access
+    return func
+
+  def export_constant(self, module_name, value):
+    """Store export information for constants/string literals.
+
+    Export information is stored in the module where constants/string literals
+    are defined.
+
+    e.g.
+    ```python
+    foo = 1
+    bar = 2
+    tf_export("consts.foo").export_constant(__name__, foo)
+    tf_export("consts.bar").export_constant(__name__, bar)
+    ```
+
+    Args:
+      module_name: (string) Name of the module to store constant at.
+      value: Value of the constant.
+    """
+    module = sys.modules[module_name]
+    if not hasattr(module, '_tf_api_constants'):
+      module._tf_api_constants = []  # pylint: disable=protected-access
+    # pylint: disable=protected-access
+    module._tf_api_constants.append((self._names, value))
+
diff --git a/tensorflow/python/util/tf_export_test.py b/tensorflow/python/util/tf_export_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b7636c34e4e9dc9e772faa664ec7c48402f9233
--- /dev/null
+++ b/tensorflow/python/util/tf_export_test.py
@@ -0,0 +1,157 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""tf_export tests."""
+
+# pylint: disable=unused-import
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+from tensorflow.python.platform import test
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_export
+
+
+def _test_function(unused_arg=0):
+  pass
+
+
+def _test_function2(unused_arg=0):
+  pass
+
+
+class TestClassA(object):
+  pass
+
+
+class TestClassB(TestClassA):
+  pass
+
+
+class ValidateExportTest(test.TestCase):
+  """Tests for tf_export class."""
+
+  class MockModule(object):
+
+    def __init__(self, name):
+      self.__name__ = name
+
+  def setUp(self):
+    self._modules = []
+
+  def tearDown(self):
+    for name in self._modules:
+      del sys.modules[name]
+    self._modules = []
+    for symbol in [_test_function, _test_function, TestClassA, TestClassB]:
+      if hasattr(symbol, '_tf_api_names'):
+        del symbol._tf_api_names
+
+  def _CreateMockModule(self, name):
+    mock_module = self.MockModule(name)
+    sys.modules[name] = mock_module
+    self._modules.append(name)
+    return mock_module
+
+  def testExportSingleFunction(self):
+    export_decorator = tf_export.tf_export('nameA', 'nameB')
+    decorated_function = export_decorator(_test_function)
+    self.assertEquals(decorated_function, _test_function)
+    self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names)
+
+  def testExportMultipleFunctions(self):
+    export_decorator1 = tf_export.tf_export('nameA', 'nameB')
+    export_decorator2 = tf_export.tf_export('nameC', 'nameD')
+    decorated_function1 = export_decorator1(_test_function)
+    decorated_function2 = export_decorator2(_test_function2)
+    self.assertEquals(decorated_function1, _test_function)
+    self.assertEquals(decorated_function2, _test_function2)
+    self.assertEquals(('nameA', 'nameB'), decorated_function1._tf_api_names)
+    self.assertEquals(('nameC', 'nameD'), decorated_function2._tf_api_names)
+
+  def testExportClasses(self):
+    export_decorator_a = tf_export.tf_export('TestClassA1')
+    export_decorator_a(TestClassA)
+    self.assertEquals(('TestClassA1',), TestClassA._tf_api_names)
+    self.assertTrue('_tf_api_names' not in TestClassB.__dict__)
+
+    export_decorator_b = tf_export.tf_export('TestClassB1')
+    export_decorator_b(TestClassB)
+    self.assertEquals(('TestClassA1',), TestClassA._tf_api_names)
+    self.assertEquals(('TestClassB1',), TestClassB._tf_api_names)
+
+  def testExportSingleConstant(self):
+    module1 = self._CreateMockModule('module1')
+
+    test_constant = 123
+    export_decorator = tf_export.tf_export('NAME_A', 'NAME_B')
+    export_decorator.export_constant('module1', test_constant)
+    self.assertEquals([(('NAME_A', 'NAME_B'), 123)],
+                      module1._tf_api_constants)
+
+  def testExportMultipleConstants(self):
+    module1 = self._CreateMockModule('module1')
+    module2 = self._CreateMockModule('module2')
+
+    test_constant1 = 123
+    test_constant2 = 'abc'
+    test_constant3 = 0.5
+
+    export_decorator1 = tf_export.tf_export('NAME_A', 'NAME_B')
+    export_decorator2 = tf_export.tf_export('NAME_C', 'NAME_D')
+    export_decorator3 = tf_export.tf_export('NAME_E', 'NAME_F')
+    export_decorator1.export_constant('module1', test_constant1)
+    export_decorator2.export_constant('module2', test_constant2)
+    export_decorator3.export_constant('module2', test_constant3)
+    self.assertEquals([(('NAME_A', 'NAME_B'), 123)],
+                      module1._tf_api_constants)
+    self.assertEquals([(('NAME_C', 'NAME_D'), 'abc'),
+                       (('NAME_E', 'NAME_F'), 0.5)],
+                      module2._tf_api_constants)
+
+  def testRaisesExceptionIfAlreadyHasAPINames(self):
+    _test_function._tf_api_names = ['abc']
+    export_decorator = tf_export.tf_export('nameA', 'nameB')
+    with self.assertRaises(tf_export.SymbolAlreadyExposedError):
+      export_decorator(_test_function)
+
+  def testOverridesFunction(self):
+    _test_function2._tf_api_names = ['abc']
+
+    export_decorator = tf_export.tf_export(
+        'nameA', 'nameB', overrides=[_test_function2])
+    export_decorator(_test_function)
+
+    # _test_function overrides _test_function2. So, _tf_api_names
+    # should be removed from _test_function2.
+    self.assertFalse(hasattr(_test_function2, '_tf_api_names'))
+
+  def testMultipleDecorators(self):
+    def get_wrapper(func):
+      def wrapper(*unused_args, **unused_kwargs):
+        pass
+      return tf_decorator.make_decorator(func, wrapper)
+    decorated_function = get_wrapper(_test_function)
+
+    export_decorator = tf_export.tf_export('nameA', 'nameB')
+    exported_function = export_decorator(decorated_function)
+    self.assertEquals(decorated_function, exported_function)
+    self.assertEquals(('nameA', 'nameB'), _test_function._tf_api_names)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py
index ab9e82a3cce660e60b0aa69c37228ff75f919f23..d9b2e6fcd799db019adf40c717efc09845aa216f 100644
--- a/tensorflow/python/util/tf_should_use.py
+++ b/tensorflow/python/util/tf_should_use.py
@@ -17,57 +17,17 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import functools
-import itertools
-import traceback
 import types
 
 import six  # pylint: disable=unused-import
 
-# pylint: disable=g-bad-import-order,g-import-not-at-top
-try:
-  from weakref import finalize
-except ImportError:
-  from backports.weakref import finalize
-
-from tensorflow.python.platform import tf_logging
 from tensorflow.python.util import tf_decorator
 # pylint: enable=g-bad-import-order,g-import-not-at-top
 
 
-class _RefInfoField(
-    collections.namedtuple(
-        '_RefInfoField', ('type_', 'repr_', 'creation_stack', 'object_used'))):
-  pass
-
-
-# Thread-safe up to int32max/2 thanks to python's GIL; and may be safe even for
-# higher values in Python 3.4+.  We don't expect to ever count higher than this.
-# https://mail.python.org/pipermail/python-list/2005-April/342279.html
-_REF_ITER = itertools.count()
-
-# Dictionary mapping id(obj) => _RefInfoField.
-_REF_INFO = {}
-
-
-def _deleted(obj_id, fatal_error):
-  obj = _REF_INFO[obj_id]
-  del _REF_INFO[obj_id]
-  if not obj.object_used:
-    if fatal_error:
-      logger = tf_logging.fatal
-    else:
-      logger = tf_logging.error
-    logger(
-        '==================================\n'
-        'Object was never used (type %s):\n%s\nIf you want to mark it as '
-        'used call its "mark_used()" method.\nIt was originally created '
-        'here:\n%s\n'
-        '==================================' %
-        (obj.type_, obj.repr_, obj.creation_stack))
-
-
+# TODO(b/65412899): Re-implement to avoid leaking python objects.
+# This function / class remains since the API is public (mark_used()).
 def _add_should_use_warning(x, fatal_error=False):
   """Wraps object x so that if it is never used, a warning is logged.
 
@@ -80,16 +40,12 @@ def _add_should_use_warning(x, fatal_error=False):
     An instance of `TFShouldUseWarningWrapper` which subclasses `type(x)`
     and is a very shallow wrapper for `x` which logs access into `x`.
   """
+  del fatal_error
   if x is None:  # special corner case where x is None
     return x
-  if hasattr(x, '_tf_ref_id'):  # this is already a TFShouldUseWarningWrapper
-    return x
 
   def override_method(method):
     def fn(self, *args, **kwargs):
-      # pylint: disable=protected-access
-      _REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
-          object_used=True)
       return method(self, *args, **kwargs)
     return fn
 
@@ -98,38 +54,16 @@ def _add_should_use_warning(x, fatal_error=False):
 
     def __init__(self, true_self):
       self.__dict__ = true_self.__dict__
-      stack = [s.strip() for s in traceback.format_stack()]
-      # Remove top three stack entries from adding the wrapper
-      self.creation_stack = '\n'.join(stack[:-3])
-      self._tf_ref_id = next(_REF_ITER)
-      _REF_INFO[self._tf_ref_id] = _RefInfoField(
-          type_=type(x),
-          repr_=repr(x),
-          creation_stack=stack,
-          object_used=False)
-
-      # Create a finalizer for self, which will be called when self is
-      # garbage collected.  Can't add self as the args because the
-      # loop will break garbage collection.  We keep track of
-      # ourselves via python ids.
-      finalize(self, _deleted, self._tf_ref_id, fatal_error)
 
     # Not sure why this pylint warning is being used; this is not an
     # old class form.
     # pylint: disable=super-on-old-class
     def __getattribute__(self, name):
-      if name == '_tf_ref_id':
-        return super(TFShouldUseWarningWrapper, self).__getattribute__(name)
-      if self._tf_ref_id in _REF_INFO:
-        _REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
-            object_used=True)
       return super(TFShouldUseWarningWrapper, self).__getattribute__(name)
 
     def mark_used(self, *args, **kwargs):
-      _REF_INFO[self._tf_ref_id] = _REF_INFO[self._tf_ref_id]._replace(
-          object_used=True)
-      if hasattr(super(TFShouldUseWarningWrapper, self), 'mark_used'):
-        return super(TFShouldUseWarningWrapper, self).mark_used(*args, **kwargs)
+      return
+
     # pylint: enable=super-on-old-class
 
   for name in dir(TFShouldUseWarningWrapper):
@@ -143,8 +77,6 @@ def _add_should_use_warning(x, fatal_error=False):
 
   wrapped = TFShouldUseWarningWrapper(x)
   wrapped.__doc__ = x.__doc__  # functools.wraps fails on some objects.
-  ref_id = wrapped._tf_ref_id  # pylint: disable=protected-access
-  _REF_INFO[ref_id] = _REF_INFO[ref_id]._replace(object_used=False)
   return wrapped
 
 
diff --git a/tensorflow/python/util/tf_should_use_test.py b/tensorflow/python/util/tf_should_use_test.py
index c8268744004cb5826b605f0eb355ef17e7a964eb..4c6e48b11c1d013d1e4c6cdfc376973baa7bb9a2 100644
--- a/tensorflow/python/util/tf_should_use_test.py
+++ b/tensorflow/python/util/tf_should_use_test.py
@@ -46,6 +46,7 @@ def reroute_error(captured):
 class TfShouldUseTest(test.TestCase):
 
   def testAddShouldUseWarningWhenNotUsed(self):
+    self.skipTest('b/65412899')
     c = constant_op.constant(0, name='blah0')
     captured = []
     with reroute_error(captured):
@@ -70,6 +71,7 @@ class TfShouldUseTest(test.TestCase):
     self.assertNotIn('%s:0' % name, '\n'.join(captured))
 
   def testAddShouldUseWarningWhenUsedWithAdd(self):
+    self.skipTest('b/65412899')
     def add(h):
       _ = h + 1
     self._testAddShouldUseWarningWhenUsed(add, name='blah_add')
@@ -77,6 +79,7 @@ class TfShouldUseTest(test.TestCase):
     self.assertFalse(gc.garbage)
 
   def testAddShouldUseWarningWhenUsedWithGetName(self):
+    self.skipTest('b/65412899')
     def get_name(h):
       _ = h.name
     self._testAddShouldUseWarningWhenUsed(get_name, name='blah_get_name')
@@ -84,6 +87,7 @@ class TfShouldUseTest(test.TestCase):
     self.assertFalse(gc.garbage)
 
   def testShouldUseResult(self):
+    self.skipTest('b/65412899')
     @tf_should_use.should_use_result
     def return_const(value):
       return constant_op.constant(value, name='blah2')
@@ -97,6 +101,7 @@ class TfShouldUseTest(test.TestCase):
     self.assertFalse(gc.garbage)
 
   def testShouldUseResultWhenNotReallyUsed(self):
+    self.skipTest('b/65412899')
     @tf_should_use.should_use_result
     def return_const(value):
       return constant_op.constant(value, name='blah3')
@@ -114,6 +119,13 @@ class TfShouldUseTest(test.TestCase):
     gc.collect()
     self.assertFalse(gc.garbage)
 
+  # Tests that mark_used is available in the API.
+  def testMarkUsed(self):
+    @tf_should_use.should_use_result
+    def return_const(value):
+      return constant_op.constant(value, name='blah3')
+    with self.test_session():
+      return_const(0.0).mark_used()
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 904c8c7818e678f1036175e602c446b2225c48ce..6b5ad1b5fb97b0a603bc86cedc8d42126eb64db0 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -1913,6 +1913,106 @@ bool CudnnSupport::DoRnnBackward(
 #endif  // CUDNN_VERSION
 }
 
+namespace {
+
+inline cudnnConvolutionFwdAlgo_t GetCudnnConvolutionForwardAlgo(
+    Stream* stream, CUDAExecutor* parent, void* dnn_handle,
+    const ScopedTensorDescriptor& input_nd,
+    const ScopedFilterDescriptor& filter,
+    const ScopedConvolutionDescriptor& conv,
+    const ScopedTensorDescriptor& output_nd, bool specify_workspace_limit,
+    ScratchAllocator* scratch_allocator) {
+  cudnnConvolutionFwdPreference_t preference =
+      specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
+                              : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
+  auto memory_limit_bytes =
+      scratch_allocator == nullptr
+          ? 0
+          : scratch_allocator->GetMemoryLimitInBytes(stream);
+  if (memory_limit_bytes < 0) {
+    memory_limit_bytes = 0;
+  }
+
+  cudnnConvolutionFwdAlgo_t algo_to_use;
+  auto status = wrap::cudnnGetConvolutionForwardAlgorithm(
+      parent, ToHandle(dnn_handle), input_nd.handle(), filter.handle(),
+      conv.handle(), output_nd.handle(), preference, memory_limit_bytes,
+      &algo_to_use);
+  CHECK_EQ(status, CUDNN_STATUS_SUCCESS)
+      << "Unable to find a suitable algorithm for doing forward convolution";
+  return algo_to_use;
+}
+
+dnn::AlgorithmType GetCudnnConvolutionForwardAlgorithm(
+    Stream* stream, CUDAExecutor* parent, void* dnn_handle,
+    int cudnn_type,  // Actually cudnnDataType_t.
+    const dnn::AlgorithmConfig& algorithm_config, bool is_profiling,
+    const ScopedTensorDescriptor& input_nd,
+    const ScopedFilterDescriptor& filter,
+    const ScopedConvolutionDescriptor& conv,
+    const ScopedTensorDescriptor& output_nd,
+    ScratchAllocator* scratch_allocator, DeviceMemory* scratch) {
+  cudnnConvolutionFwdAlgo_t algo =
+      (algorithm_config.algorithm() == dnn::kDefaultAlgorithm)
+          ? GetCudnnConvolutionForwardAlgo(
+                stream, parent, dnn_handle, input_nd, filter, conv, output_nd,
+                /*specify_workspace_limit=*/scratch_allocator != nullptr,
+                scratch_allocator)
+          : ToConvForwardAlgo(algorithm_config.algorithm());
+  size_t size_in_bytes;
+  auto status = wrap::cudnnGetConvolutionForwardWorkspaceSize(
+      parent, ToHandle(dnn_handle), /*srcDesc=*/input_nd.handle(),
+      /*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
+      /*destDesc=*/output_nd.handle(), /*algo=*/algo,
+      /*sizeInBytes=*/&size_in_bytes);
+  int64 size_in_bytes_int64 = size_in_bytes;
+  if (TF_PREDICT_FALSE(status != CUDNN_STATUS_SUCCESS)) {
+    CHECK(is_profiling) << "Cannot query the size of workspace needed "
+                           "for the specified algorithm: "
+                        << algorithm_config.algorithm() << " "
+                        << ToString(status);
+    // Silently return when we are profiling.
+    return dnn::kNoSuitableAlgorithmFound;
+  }
+  if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
+    LOG(WARNING) << "cudnnGetConvolutionForwardWorkspaceSize() returned "
+                    "negative sizeInBytes value. This could be a cudnn bug.";
+    if (TF_PREDICT_TRUE(is_profiling)) {
+      return dnn::kNoSuitableAlgorithmFound;
+    }
+  } else if (size_in_bytes_int64 > 0) {
+    port::StatusOr> allocated;
+    if (TF_PREDICT_TRUE(scratch_allocator)) {
+      allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
+      if (TF_PREDICT_TRUE(allocated.ok())) {
+        *scratch = allocated.ValueOrDie();
+      } else {
+        if (TF_PREDICT_TRUE(is_profiling)) {
+          // Silently return when we are profiling.
+          return dnn::kNoSuitableAlgorithmFound;
+        }
+        LOG(WARNING) << allocated.status().error_message();
+        // For the int8 case, we fail at this point since the no_scratch
+        // algorithm should be set to dnn::kDefaultAlgorithm.
+        CHECK(algorithm_config.algorithm_no_scratch() != dnn::kDefaultAlgorithm)
+            << "The primary convolution algorithm failed memory allocation, "
+               "while a secondary algorithm is not provided.";
+      }
+    }
+    if (TF_PREDICT_FALSE(!allocated.ok())) {
+      algo = (algorithm_config.algorithm_no_scratch() == dnn::kDefaultAlgorithm)
+                 ? GetCudnnConvolutionForwardAlgo(
+                       stream, parent, dnn_handle, input_nd, filter, conv,
+                       output_nd, /*specify_workspace_limit=*/false, nullptr)
+                 : ToConvForwardAlgo(algorithm_config.algorithm_no_scratch());
+    }
+  }
+
+  return algo;
+}
+
+}  // namespace
+
 template 
 bool CudnnSupport::DoConvolveImpl(
     Stream* stream, int cudnn_type,  // Actually cudnnDataType_t.
@@ -1920,7 +2020,6 @@ bool CudnnSupport::DoConvolveImpl(
     const FilterDescriptor& filter_descriptor,
     const DeviceMemory& filter_data,
     const ConvolutionDescriptor& convolution_descriptor,
-    const DeviceMemory& biases, dnn::ActivationMode activation_mode,
     const BatchDescriptor& output_descriptor, DeviceMemory* output_data,
     ScratchAllocator* scratch_allocator,
     const dnn::AlgorithmConfig& algorithm_config,
@@ -1953,6 +2052,8 @@ bool CudnnSupport::DoConvolveImpl(
   cudnnConvolutionFwdAlgo_t algo;
   DeviceMemory scratch;
 
+  // TODO(pauldonnelly): Replace the following code with a call to
+  //   GetCudnnConvolutionForwardAlgorithm().
   if (algorithm_config.algorithm() == dnn::kDefaultAlgorithm) {
     // With the default algorithm, use Cudnn's heuristics.
     auto get_algorithm =
@@ -2059,27 +2160,117 @@ bool CudnnSupport::DoConvolveImpl(
                       "negative sizeInBytes value. This could be a cudnn bug.";
     }
   }
-  const bool has_biases = (biases != nullptr);
-  const bool supported_activation_mode =
-      (activation_mode == dnn::ActivationMode::kRelu);
+  std::unique_ptr timer;
+  if (is_profiling) {
+    timer.reset(new CUDATimer(parent_));  // NOLINT
+    if (!timer->Init()) {
+      return false;
+    }
+    // The start and stop of the timer should be as close to the Cudnn call as
+    // possible. It is still possible for other threads to issue workload on
+    // to this stream. So it could take multiple profiling measurements.
+    if (!timer->Start(AsCUDAStream(stream))) {
+      timer->Destroy();
+      return false;
+    }
+  }
+  status = wrap::cudnnConvolutionForward(
+      parent_, ToHandle(dnn_handle_),
+      /*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(),
+      /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
+      /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
+      /*algo=*/algo, /*workSpace=*/scratch.opaque(),
+      /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/&beta,
+      /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
+
+  if (is_profiling) {
+    if (!timer->Stop(AsCUDAStream(stream))) {
+      timer->Destroy();
+      return false;
+    }
+    if (status == CUDNN_STATUS_SUCCESS) {
+      output_profile_result->set_algorithm(algo);
+      output_profile_result->set_elapsed_time_in_ms(
+          timer->GetElapsedMilliseconds());
+    }
+    timer->Destroy();
+  }
+
+  if (status != CUDNN_STATUS_SUCCESS) {
+    // Silently return when we are profiling.
+    if (!is_profiling) {
+      LOG(ERROR) << "failed to enqueue convolution on stream: "
+                 << ToString(status);
+    }
+    return false;
+  }
+
+  return true;
+}
+
+template 
+bool CudnnSupport::DoFusedConvolveImpl(
+    Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
+    const DeviceMemory& conv_input_data, ScaleType conv_input_scale,
+    const dnn::FilterDescriptor& filter_descriptor,
+    const DeviceMemory& filter_data,
+    const dnn::ConvolutionDescriptor& convolution_descriptor,
+    const DeviceMemory& side_input_data, ScaleType side_input_scale,
+    const dnn::BatchDescriptor& bias_descriptor,
+    const DeviceMemory& biases, dnn::ActivationMode activation_mode,
+    const dnn::BatchDescriptor& output_descriptor,
+    DeviceMemory* output_data, ScratchAllocator* scratch_allocator,
+    const dnn::AlgorithmConfig& algorithm_config,
+    dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION < 6000
+  LOG(ERROR) << "cudnnConvolutionBiasActivationForward() is only "
+                "supported for cuDNN version >= 6";
+  return false;
+#else
+  ScopedTensorDescriptor conv_input_nd{
+      parent_, conv_input_descriptor,
+      static_cast(cudnn_data_type)};
+  ScopedTensorDescriptor output_nd{
+      parent_, output_descriptor,
+      static_cast(cudnn_data_type)};
+  ScopedFilterDescriptor filter{parent_, filter_descriptor,
+                                conv_input_descriptor,
+                                static_cast(cudnn_data_type)};
+  ScopedTensorDescriptor bias_nd{parent_, bias_descriptor, CUDNN_DATA_FLOAT};
+  ScopedConvolutionDescriptor conv{
+      parent_, convolution_descriptor,
+      static_cast(cudnn_compute_type)};
 
-  if (has_biases && !supported_activation_mode) {
-    LOG(ERROR) << "cudnnConvolutionBiasActivationForward() only "
-                  "support relu activation.";
+  mutex_lock lock{dnn_handle_mutex_};
+  auto status = wrap::cudnnSetStream(parent_, ToHandle(dnn_handle_),
+                                     AsCUDAStreamValue(stream));
+  CHECK(status == CUDNN_STATUS_SUCCESS)
+      << "failed to set stream for cudnn handle: " << ToString(status);
+
+  const bool is_profiling = output_profile_result != nullptr;
+  DeviceMemory scratch;
+  dnn::AlgorithmType algorithm_type = GetCudnnConvolutionForwardAlgorithm(
+      stream, parent_, dnn_handle_, cudnn_data_type, algorithm_config,
+      is_profiling, conv_input_nd, filter, conv, output_nd, scratch_allocator,
+      &scratch);
+  if (algorithm_type == dnn::kNoSuitableAlgorithmFound) {
+    if (!is_profiling) {
+      LOG(ERROR) << "No suitable algorithm found";
+    }
     return false;
   }
+  auto algo = static_cast(algorithm_type);
 
-  if (has_biases && activation_mode == dnn::ActivationMode::kNone) {
-    LOG(ERROR) << "To use cudnnConvolutionBiasActivationForward() "
-                  "with a valid biases tensor, need to also provide "
-                  "a valid activation mode (currently only supports "
-                  "kRelu).";
+  if (activation_mode != dnn::ActivationMode::kRelu) {
+    LOG(ERROR) << "cudnnConvolutionBiasActivationForward() only supports Relu "
+                  "activation.";
     return false;
   }
 
   std::unique_ptr timer;
   if (is_profiling) {
-    timer.reset(new CUDATimer(parent_));
+    timer.reset(new CUDATimer(parent_));  // NOLINT
     if (!timer->Init()) {
       return false;
     }
@@ -2091,50 +2282,44 @@ bool CudnnSupport::DoConvolveImpl(
       return false;
     }
   }
-  if (has_biases) {
-    CHECK(supported_activation_mode);
-#if CUDNN_VERSION < 6000
-    LOG(ERROR) << "cudnnConvolutionBiasActivationForward() is only "
-                  "supported for cuDNN version >= 6.";
-    return false;
-#else
-    BatchDescriptor bias_dimensions;
-    bias_dimensions.set_count(1)
-        .set_feature_map_count(output_descriptor.feature_map_count())
-        .set_height(1)
-        .set_width(1)
-        .set_layout(dnn::DataLayout::kBatchYXDepth);
-    ScopedTensorDescriptor bias_descriptor{
-        parent_, bias_dimensions, static_cast(cudnn_type)};
-    // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
-    // activation descriptor. Note that this will change the nan propagation
-    // behavior from separate conv, bias, and relu (which by default is
-    // CUDNN_PROPAGATE_NAN.
-    ScopedActivationDescriptor activation_desc{parent_, activation_mode,
-                                               CUDNN_NOT_PROPAGATE_NAN,
-                                               output_descriptor.value_max()};
-    status = wrap::cudnnConvolutionBiasActivationForward(
-        parent_, ToHandle(dnn_handle_),
-        /*alpha1=*/&alpha, /*srcDesc=*/input_nd.handle(),
-        /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
-        /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
-        /*algo=*/algo, /*workSpace=*/scratch.opaque(),
-        /*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&beta,
-        /*zDesc=*/output_nd.handle(), /*z=*/input_data.opaque(),
-        /*biasDesc=*/bias_descriptor.handle(),
-        /*bias=*/biases.opaque(), /*activationDesc=*/activation_desc.handle(),
-        /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
-#endif  // CUDNN_VERSION < 6000
-  } else {
-    status = wrap::cudnnConvolutionForward(
-        parent_, ToHandle(dnn_handle_),
-        /*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(),
-        /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
-        /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
-        /*algo=*/algo, /*workSpace=*/scratch.opaque(),
-        /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/&beta,
-        /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
-  }
+  // CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
+  // activation descriptor. Note that this will change the nan propagation
+  // behavior from separate conv, bias, and relu (which by default is
+  // CUDNN_PROPAGATE_NAN.
+  ScopedActivationDescriptor activation_desc{parent_, activation_mode,
+                                             CUDNN_NOT_PROPAGATE_NAN,
+                                             output_descriptor.value_max()};
+  auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque()
+                                                     : side_input_data.opaque();
+
+  VLOG(2) << "\nconv_input_scale = " << conv_input_scale
+          << "\nconv_input_nd.handle() = " << conv_input_nd.handle()
+          << "\nconv_input_data.opaque() = " << conv_input_data.opaque()
+          << "\nfilter.handle() = " << filter.handle()
+          << "\nfilter_data.opaque() = " << filter_data.opaque()
+          << "\nconv.handle() = " << conv.handle() << "\nalgo = " << algo
+          << "\nscratch.opaque() = " << scratch.opaque()
+          << "\nscratch.size() = " << scratch.size()
+          << "\nside_input_scale = " << side_input_scale
+          << "\noutput_nd.handle() = " << output_nd.handle()
+          << "\nside_input_data_ptr = " << side_input_data_ptr
+          << "\nbias_nd.handle() = " << bias_nd.handle()
+          << "\nbiases.opaque() = " << biases.opaque()
+          << "\nactivation_desc.handle() = " << activation_desc.handle()
+          << "\noutput_nd.handle() = " << output_nd.handle()
+          << "\noutput_data->opaque() = " << output_data->opaque();
+
+  status = wrap::cudnnConvolutionBiasActivationForward(
+      parent_, ToHandle(dnn_handle_), /*alpha1=*/&conv_input_scale,
+      /*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(),
+      /*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(),
+      /*convDesc=*/conv.handle(), algo, /*workSpace=*/scratch.opaque(),
+      /*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&side_input_scale,
+      /*zDesc=*/output_nd.handle(), /*z=*/side_input_data_ptr,
+      /*biasDesc=*/bias_nd.handle(), /*bias=*/biases.opaque(),
+      /*activationDesc=*/activation_desc.handle(),
+      /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
+
   if (is_profiling) {
     if (!timer->Stop(AsCUDAStream(stream))) {
       timer->Destroy();
@@ -2158,6 +2343,7 @@ bool CudnnSupport::DoConvolveImpl(
   }
 
   return true;
+#endif  // CUDNN_VERSION < 6000
 }
 
 // A helper class to decide whether to enable the WINOGRAD_NONFUSED algorithms.
@@ -2401,24 +2587,6 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl(
   return true;
 }
 
-bool CudnnSupport::DoConvolve(
-    Stream* stream, const BatchDescriptor& batch_descriptor,
-    const DeviceMemory& input_data,
-    const FilterDescriptor& filter_descriptor,
-    const DeviceMemory& filter_data,
-    const ConvolutionDescriptor& convolution_descriptor,
-    const DeviceMemory& biases, dnn::ActivationMode activation_mode,
-    const BatchDescriptor& output_descriptor, DeviceMemory* output_data,
-    ScratchAllocator* scratch_allocator,
-    const dnn::AlgorithmConfig& algorithm_config,
-    dnn::ProfileResult* output_profile_result) {
-  return DoConvolveImpl(
-      stream, CUDNN_DATA_FLOAT, batch_descriptor, input_data, filter_descriptor,
-      filter_data, convolution_descriptor, biases, activation_mode,
-      output_descriptor, output_data, scratch_allocator, algorithm_config,
-      output_profile_result);
-}
-
 bool CudnnSupport::DoConvolve(
     Stream* stream, const BatchDescriptor& batch_descriptor,
     const DeviceMemory& input_data,
@@ -2431,24 +2599,10 @@ bool CudnnSupport::DoConvolve(
     dnn::ProfileResult* output_profile_result) {
   return DoConvolveImpl(
       stream, CUDNN_DATA_FLOAT, batch_descriptor, input_data, filter_descriptor,
-      filter_data, convolution_descriptor, /*biases=*/nullptr,
-      dnn::ActivationMode::kNone, output_descriptor, output_data,
+      filter_data, convolution_descriptor, output_descriptor, output_data,
       scratch_allocator, algorithm_config, output_profile_result);
 }
 
-bool CudnnSupport::DoConvolve(
-    Stream* stream, const BatchDescriptor& batch_descriptor,
-    const DeviceMemory& input_data,
-    const FilterDescriptor& filter_descriptor,
-    const DeviceMemory& filter_data,
-    const ConvolutionDescriptor& convolution_descriptor,
-    const DeviceMemory& biases, dnn::ActivationMode activation_mode,
-    const BatchDescriptor& output_descriptor,
-    DeviceMemory* output_data) {
-  LOG(ERROR) << "double-based DNN not yet implemented";
-  return false;
-}
-
 bool CudnnSupport::DoConvolve(
     Stream* stream, const BatchDescriptor& batch_descriptor,
     const DeviceMemory& input_data,
@@ -2467,34 +2621,113 @@ bool CudnnSupport::DoConvolve(
     const FilterDescriptor& filter_descriptor,
     const DeviceMemory& filter_data,
     const ConvolutionDescriptor& convolution_descriptor,
-    const DeviceMemory& biases,
-    dnn::ActivationMode activation_mode,
     const BatchDescriptor& output_descriptor,
     DeviceMemory* output_data, ScratchAllocator* scratch_allocator,
     const dnn::AlgorithmConfig& algorithm_config,
     dnn::ProfileResult* output_profile_result) {
   return DoConvolveImpl(
       stream, CUDNN_DATA_HALF, batch_descriptor, input_data, filter_descriptor,
-      filter_data, convolution_descriptor, biases, activation_mode,
+      filter_data, convolution_descriptor, output_descriptor, output_data,
+      scratch_allocator, algorithm_config, output_profile_result);
+}
+
+bool CudnnSupport::DoFusedConvolve(
+    Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
+    const DeviceMemory& conv_input_data, double conv_input_scale,
+    const dnn::FilterDescriptor& filter_descriptor,
+    const DeviceMemory& filter_data,
+    const dnn::ConvolutionDescriptor& convolution_descriptor,
+    const DeviceMemory& side_input_data, double side_input_scale,
+    const dnn::BatchDescriptor& bias_descriptor,
+    const DeviceMemory& biases, dnn::ActivationMode activation_mode,
+    const dnn::BatchDescriptor& output_descriptor,
+    DeviceMemory* output_data, ScratchAllocator* scratch_allocator,
+    const dnn::AlgorithmConfig& algorithm_config,
+    dnn::ProfileResult* output_profile_result) {
+  return DoFusedConvolveImpl(
+      stream, conv_input_descriptor, conv_input_data, conv_input_scale,
+      filter_descriptor, filter_data, convolution_descriptor, side_input_data,
+      side_input_scale, bias_descriptor, biases, activation_mode,
+      output_descriptor, output_data, scratch_allocator, algorithm_config,
+      output_profile_result);
+  return true;
+}
+
+bool CudnnSupport::DoFusedConvolve(
+    Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
+    const DeviceMemory& conv_input_data, float conv_input_scale,
+    const dnn::FilterDescriptor& filter_descriptor,
+    const DeviceMemory& filter_data,
+    const dnn::ConvolutionDescriptor& convolution_descriptor,
+    const DeviceMemory& side_input_data, float side_input_scale,
+    const dnn::BatchDescriptor& bias_descriptor,
+    const DeviceMemory& biases, dnn::ActivationMode activation_mode,
+    const dnn::BatchDescriptor& output_descriptor,
+    DeviceMemory* output_data, ScratchAllocator* scratch_allocator,
+    const dnn::AlgorithmConfig& algorithm_config,
+    dnn::ProfileResult* output_profile_result) {
+  return DoFusedConvolveImpl(
+      stream, conv_input_descriptor, conv_input_data, conv_input_scale,
+      filter_descriptor, filter_data, convolution_descriptor, side_input_data,
+      side_input_scale, bias_descriptor, biases, activation_mode,
       output_descriptor, output_data, scratch_allocator, algorithm_config,
       output_profile_result);
+  return true;
 }
 
-bool CudnnSupport::DoConvolve(
-    Stream* stream, const BatchDescriptor& batch_descriptor,
-    const DeviceMemory& input_data,
-    const FilterDescriptor& filter_descriptor,
+bool CudnnSupport::DoFusedConvolve(
+    Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
+    const DeviceMemory& conv_input_data, float conv_input_scale,
+    const dnn::FilterDescriptor& filter_descriptor,
     const DeviceMemory& filter_data,
-    const ConvolutionDescriptor& convolution_descriptor,
-    const BatchDescriptor& output_descriptor,
+    const dnn::ConvolutionDescriptor& convolution_descriptor,
+    const DeviceMemory& side_input_data, float side_input_scale,
+    const dnn::BatchDescriptor& bias_descriptor,
+    const DeviceMemory& biases,
+    dnn::ActivationMode activation_mode,
+    const dnn::BatchDescriptor& output_descriptor,
     DeviceMemory* output_data, ScratchAllocator* scratch_allocator,
     const dnn::AlgorithmConfig& algorithm_config,
     dnn::ProfileResult* output_profile_result) {
-  return DoConvolveImpl(
-      stream, CUDNN_DATA_HALF, batch_descriptor, input_data, filter_descriptor,
-      filter_data, convolution_descriptor, /*biases=*/nullptr,
-      dnn::ActivationMode::kNone, output_descriptor, output_data,
-      scratch_allocator, algorithm_config, output_profile_result);
+  return DoFusedConvolveImpl(
+      stream, conv_input_descriptor, conv_input_data, conv_input_scale,
+      filter_descriptor, filter_data, convolution_descriptor, side_input_data,
+      side_input_scale, bias_descriptor, biases, activation_mode,
+      output_descriptor, output_data, scratch_allocator, algorithm_config,
+      output_profile_result);
+  return true;
+}
+
+bool CudnnSupport::DoFusedConvolve(
+    Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
+    const DeviceMemory& conv_input_data, float conv_input_scale,
+    const dnn::FilterDescriptor& filter_descriptor,
+    const DeviceMemory& filter_data,
+    const dnn::ConvolutionDescriptor& convolution_descriptor,
+    const DeviceMemory& side_input_data, float side_input_scale,
+    const dnn::BatchDescriptor& bias_descriptor,
+    const DeviceMemory& biases, dnn::ActivationMode activation_mode,
+    const dnn::BatchDescriptor& output_descriptor,
+    DeviceMemory* output_data, ScratchAllocator* scratch_allocator,
+    const dnn::AlgorithmConfig& algorithm_config,
+    dnn::ProfileResult* output_profile_result) {
+#if CUDNN_VERSION < 6000
+  LOG(ERROR) << "cudnnConvolutionBiasActivationForward() is only "
+                "supported for cuDNN version >= 6";
+  return false;
+#else
+  return DoFusedConvolveImpl(
+      stream, conv_input_descriptor, conv_input_data, conv_input_scale,
+      filter_descriptor, filter_data, convolution_descriptor, side_input_data,
+      side_input_scale, bias_descriptor, biases, activation_mode,
+      output_descriptor, output_data, scratch_allocator, algorithm_config,
+      output_profile_result);
+  return true;
+#endif
 }
 
 template
@@ -2730,7 +2963,7 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
 
   std::unique_ptr timer;
   if (is_profiling) {
-    timer.reset(new CUDATimer(parent_));
+    timer.reset(new CUDATimer(parent_));  // NOLINT
     timer->Init();
     // The start and stop of the timer should be as close to the Cudnn call as
     // possible. It is still possible for other threads to issue workload on
@@ -2981,7 +3214,7 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
 
   std::unique_ptr timer;
   if (is_profiling) {
-    timer.reset(new CUDATimer(parent_));
+    timer.reset(new CUDATimer(parent_));  // NOLINT
     timer->Init();
     // The start and stop of the timer should be as close to the Cudnn call as
     // possible. It is still possible for other threads to issue workload on
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index b094cf76e94bfdbd2fbe89b8e7ff917145cd0fd5..db376e2a66967edcae578e9181bfb952d58feae4 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -183,8 +183,6 @@ class CudnnSupport : public dnn::DnnSupport {
                   const dnn::FilterDescriptor& filter_descriptor,
                   const DeviceMemory& filter_data,
                   const dnn::ConvolutionDescriptor& convolution_descriptor,
-                  const DeviceMemory& biases,
-                  dnn::ActivationMode activation_mode,
                   const dnn::BatchDescriptor& output_descriptor,
                   DeviceMemory* output_data,
                   ScratchAllocator* scratch_allocator,
@@ -196,8 +194,6 @@ class CudnnSupport : public dnn::DnnSupport {
                   const dnn::FilterDescriptor& filter_descriptor,
                   const DeviceMemory& filter_data,
                   const dnn::ConvolutionDescriptor& convolution_descriptor,
-                  const DeviceMemory& biases,
-                  dnn::ActivationMode activation_mode,
                   const dnn::BatchDescriptor& output_descriptor,
                   DeviceMemory* output_data) override;
 
@@ -206,43 +202,71 @@ class CudnnSupport : public dnn::DnnSupport {
                   const dnn::FilterDescriptor& filter_descriptor,
                   const DeviceMemory& filter_data,
                   const dnn::ConvolutionDescriptor& convolution_descriptor,
-                  const DeviceMemory& biases,
-                  dnn::ActivationMode activation_mode,
                   const dnn::BatchDescriptor& output_descriptor,
                   DeviceMemory* output_data,
                   ScratchAllocator* scratch_allocator,
                   const dnn::AlgorithmConfig& algorithm_config,
                   dnn::ProfileResult* output_profile_result) override;
 
-  bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
-                  const DeviceMemory& input_data,
-                  const dnn::FilterDescriptor& filter_descriptor,
-                  const DeviceMemory& filter_data,
-                  const dnn::ConvolutionDescriptor& convolution_descriptor,
-                  const dnn::BatchDescriptor& output_descriptor,
-                  DeviceMemory* output_data,
-                  ScratchAllocator* scratch_allocator,
-                  const dnn::AlgorithmConfig& algorithm_config,
-                  dnn::ProfileResult* output_profile_result) override;
+  bool DoFusedConvolve(
+      Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
+      const DeviceMemory& conv_input_data, double conv_input_scale,
+      const dnn::FilterDescriptor& filter_descriptor,
+      const DeviceMemory& filter_data,
+      const dnn::ConvolutionDescriptor& convolution_descriptor,
+      const DeviceMemory& side_input_data, double side_input_scale,
+      const dnn::BatchDescriptor& bias_descriptor,
+      const DeviceMemory& biases, dnn::ActivationMode activation_mode,
+      const dnn::BatchDescriptor& output_descriptor,
+      DeviceMemory* output_data, ScratchAllocator* scratch_allocator,
+      const dnn::AlgorithmConfig& algorithm_config,
+      dnn::ProfileResult* output_profile_result) override;
 
-  bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
-                  const DeviceMemory& input_data,
-                  const dnn::FilterDescriptor& filter_descriptor,
-                  const DeviceMemory& filter_data,
-                  const dnn::ConvolutionDescriptor& convolution_descriptor,
-                  const dnn::BatchDescriptor& output_descriptor,
-                  DeviceMemory* output_data) override;
+  bool DoFusedConvolve(
+      Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
+      const DeviceMemory& conv_input_data, float conv_input_scale,
+      const dnn::FilterDescriptor& filter_descriptor,
+      const DeviceMemory& filter_data,
+      const dnn::ConvolutionDescriptor& convolution_descriptor,
+      const DeviceMemory& side_input_data, float side_input_scale,
+      const dnn::BatchDescriptor& bias_descriptor,
+      const DeviceMemory& biases, dnn::ActivationMode activation_mode,
+      const dnn::BatchDescriptor& output_descriptor,
+      DeviceMemory* output_data, ScratchAllocator* scratch_allocator,
+      const dnn::AlgorithmConfig& algorithm_config,
+      dnn::ProfileResult* output_profile_result) override;
 
-  bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
-                  const DeviceMemory& input_data,
-                  const dnn::FilterDescriptor& filter_descriptor,
-                  const DeviceMemory& filter_data,
-                  const dnn::ConvolutionDescriptor& convolution_descriptor,
-                  const dnn::BatchDescriptor& output_descriptor,
-                  DeviceMemory* output_data,
-                  ScratchAllocator* scratch_allocator,
-                  const dnn::AlgorithmConfig& algorithm_config,
-                  dnn::ProfileResult* output_profile_result) override;
+  bool DoFusedConvolve(Stream* stream,
+                       const dnn::BatchDescriptor& conv_input_descriptor,
+                       const DeviceMemory& conv_input_data,
+                       float conv_input_scale,
+                       const dnn::FilterDescriptor& filter_descriptor,
+                       const DeviceMemory& filter_data,
+                       const dnn::ConvolutionDescriptor& convolution_descriptor,
+                       const DeviceMemory& side_input_data,
+                       float side_input_scale,
+                       const dnn::BatchDescriptor& bias_descriptor,
+                       const DeviceMemory& biases,
+                       dnn::ActivationMode activation_mode,
+                       const dnn::BatchDescriptor& output_descriptor,
+                       DeviceMemory* output_data,
+                       ScratchAllocator* scratch_allocator,
+                       const dnn::AlgorithmConfig& algorithm_config,
+                       dnn::ProfileResult* output_profile_result) override;
+
+  bool DoFusedConvolve(
+      Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
+      const DeviceMemory& conv_input_data, float conv_input_scale,
+      const dnn::FilterDescriptor& filter_descriptor,
+      const DeviceMemory& filter_data,
+      const dnn::ConvolutionDescriptor& convolution_descriptor,
+      const DeviceMemory& side_input_data, float side_input_scale,
+      const dnn::BatchDescriptor& bias_descriptor,
+      const DeviceMemory& biases, dnn::ActivationMode activation_mode,
+      const dnn::BatchDescriptor& output_descriptor,
+      DeviceMemory* output_data, ScratchAllocator* scratch_allocator,
+      const dnn::AlgorithmConfig& algorithm_config,
+      dnn::ProfileResult* output_profile_result) override;
 
   bool DoConvolveQuantized(
       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
@@ -561,14 +585,28 @@ class CudnnSupport : public dnn::DnnSupport {
                       const dnn::FilterDescriptor& filter_descriptor,
                       const DeviceMemory& filter_data,
                       const dnn::ConvolutionDescriptor& convolution_descriptor,
-                      const DeviceMemory& biases,
-                      dnn::ActivationMode activation_mode,
                       const dnn::BatchDescriptor& output_descriptor,
                       DeviceMemory* output_data,
                       ScratchAllocator* scratch_allocator,
                       const dnn::AlgorithmConfig& algorithm_config,
                       dnn::ProfileResult* output_profile_result);
 
+  template 
+  bool DoFusedConvolveImpl(
+      Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
+      const DeviceMemory& conv_input_data, ScaleType conv_input_scale,
+      const dnn::FilterDescriptor& filter_descriptor,
+      const DeviceMemory& filter_data,
+      const dnn::ConvolutionDescriptor& convolution_descriptor,
+      const DeviceMemory& side_input_data, ScaleType side_input_scale,
+      const dnn::BatchDescriptor& bias_descriptor,
+      const DeviceMemory& biases, dnn::ActivationMode activation_mode,
+      const dnn::BatchDescriptor& output_descriptor,
+      DeviceMemory* output_data, ScratchAllocator* scratch_allocator,
+      const dnn::AlgorithmConfig& algorithm_config,
+      dnn::ProfileResult* output_profile_result);
+
   template 
   bool DoConvolveBackwardDataImpl(
       Stream* stream,
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 0a0ad7d9fb94c9b3c387a98efea7cf037ddfe51a..0a4525c1b7c77ee2c6de63a21c6d134be3f38612 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -669,6 +669,7 @@ class PoolingDescriptor {
 
 typedef int64 AlgorithmType;
 constexpr AlgorithmType kDefaultAlgorithm = -1;
+constexpr AlgorithmType kNoSuitableAlgorithmFound = -2;
 
 // Describes the result from a perf experiment.
 //
@@ -912,20 +913,32 @@ class DnnSupport {
     return false;
   }
 
-  // Enqueues a single-precision convolution operation onto the stream.
+  // Enqueues a fused convolution operation onto the stream.
+  // We provide several variants with different types for inputs, biases and
+  // scaling parameters.
   //
   // Arguments (all borrowed):
   //  stream: borrowed pointer to the stream that the 'convolve' operation
   //    should be enqueued onto.
-  //  input_descriptor: dimensions of the input layer.
-  //  input_data: un-owned device memory region which contains the
+  //  conv_input_descriptor: dimensions of the convolution input layer.
+  //  conv_input_data: un-owned device memory region which contains the
   //    convolution input.
+  //  conv_input_scale: a floating point scale to multiply with each element
+  //    of conv_input_data.
   //  filter_descriptor: dimensions of the convolution filter.
+  //  filter_data: un-owned device memory region which contains the
+  //    convolution filter weights.
   //  convolution_descriptor: stride of the convolution filter.
   //  biases: un-owned device memory region containing biases to add to the
-  //  input. This can be DeviceMemory pointing to NULL only when activation_mode
-  //  is kNone.
+  //    input.
   //  activation_mode: Type of activation to perform.
+  //  side_input_data: un-owned device memory region which contains optional
+  //    side input data. If 'side_input_scale' is non-zero, then this must
+  //    point to data in the tensor shape specified by output_shape.
+  //    It will be scaled by 'side_input_scale' and added to the convolution
+  //    result and bias prior to applying the activation function.
+  //  side_input_scale: a floating point scale to multiply with each element
+  //    of side_input_data.
   //  output_descriptor: dimensions of the output layer.
   //  output_data: un-owned device memory region in which to place the
   //    convolution result.
@@ -938,7 +951,7 @@ class DnnSupport {
   //  output_profile_result: the output profile result for this call. The
   //    profiling is only enabled when this is not nullptr.
   //
-  // input_descriptor, filter_descriptor, convolution_descriptor and
+  // conv_input_descriptor, filter_descriptor, convolution_descriptor and
   // output_descriptor together specify exactly how the convolution is aligned
   // with the input data:
   //
@@ -952,55 +965,115 @@ class DnnSupport {
   //   that if the inverse of the filter is applied to the output in VALID mode
   //   the result is the same size as the input - this requires even more
   //   padding of the input.
-  virtual bool DoConvolve(
-      Stream* stream, const dnn::BatchDescriptor& input_descriptor,
-      const DeviceMemory& input_data,
+  virtual bool DoFusedConvolve(
+      Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
+      const DeviceMemory& conv_input_data, double conv_input_scale,
       const dnn::FilterDescriptor& filter_descriptor,
-      const DeviceMemory& filter_data,
+      const DeviceMemory& filter_data,
       const dnn::ConvolutionDescriptor& convolution_descriptor,
-      const DeviceMemory& biases, dnn::ActivationMode activation_mode,
+      const DeviceMemory& side_input_data, double side_input_scale,
+      const dnn::BatchDescriptor& bias_descriptor,
+      const DeviceMemory& biases, dnn::ActivationMode activation_mode,
       const dnn::BatchDescriptor& output_descriptor,
-      DeviceMemory* output_data, ScratchAllocator* scratch_allocator,
+      DeviceMemory* output_data, ScratchAllocator* scratch_allocator,
       const dnn::AlgorithmConfig& algorithm_config,
-      ProfileResult* output_profile_result) {
+      dnn::ProfileResult* output_profile_result) {
     return false;
   }
 
-  // Enqueues a double-precision fused convolution, bias add, and activation
-  // operation onto the stream. See DoConvolve above for argument details.
-  virtual bool DoConvolve(
-      Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
-      const DeviceMemory& input_data,
+  // This is the float version of DoFusedConvolve.
+  virtual bool DoFusedConvolve(
+      Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
+      const DeviceMemory& conv_input_data, float conv_input_scale,
       const dnn::FilterDescriptor& filter_descriptor,
-      const DeviceMemory& filter_data,
+      const DeviceMemory& filter_data,
       const dnn::ConvolutionDescriptor& convolution_descriptor,
-      const DeviceMemory& biases, dnn::ActivationMode activation_mode,
+      const DeviceMemory& side_input_data, float side_input_scale,
+      const dnn::BatchDescriptor& bias_descriptor,
+      const DeviceMemory& biases, dnn::ActivationMode activation_mode,
       const dnn::BatchDescriptor& output_descriptor,
-      DeviceMemory* output_data) {
+      DeviceMemory* output_data, ScratchAllocator* scratch_allocator,
+      const dnn::AlgorithmConfig& algorithm_config,
+      dnn::ProfileResult* output_profile_result) {
     return false;
   }
 
-  // Enqueues a half-precision fused convolution, bias add, and activation
-  // operation onto the stream. See DoConvolve above for argument details.
-  virtual bool DoConvolve(
-      Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
-      const DeviceMemory& input_data,
+  // This is the Eigen::half version of DoFusedConvolve.
+  // The scaling parameters are still floats.
+  virtual bool DoFusedConvolve(
+      Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
+      const DeviceMemory& conv_input_data, float conv_input_scale,
       const dnn::FilterDescriptor& filter_descriptor,
       const DeviceMemory& filter_data,
       const dnn::ConvolutionDescriptor& convolution_descriptor,
+      const DeviceMemory& side_input_data, float side_input_scale,
+      const dnn::BatchDescriptor& bias_descriptor,
       const DeviceMemory& biases,
       dnn::ActivationMode activation_mode,
       const dnn::BatchDescriptor& output_descriptor,
       DeviceMemory* output_data,
       ScratchAllocator* scratch_allocator,
       const dnn::AlgorithmConfig& algorithm_config,
-      ProfileResult* output_profile_result) {
+      dnn::ProfileResult* output_profile_result) {
     return false;
   }
 
-  // Enqueues a single-precision convolution operation (without bias add
-  // or activation) onto the stream.
-  // See DoConvolve above for argument details.
+  // This is the int8 version of DoFusedConvolve.
+  // The bias input and scaling parameters are floats.
+  virtual bool DoFusedConvolve(
+      Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
+      const DeviceMemory& conv_input_data, float conv_input_scale,
+      const dnn::FilterDescriptor& filter_descriptor,
+      const DeviceMemory& filter_data,
+      const dnn::ConvolutionDescriptor& convolution_descriptor,
+      const DeviceMemory& side_input_data, float side_input_scale,
+      const dnn::BatchDescriptor& bias_descriptor,
+      const DeviceMemory& biases, dnn::ActivationMode activation_mode,
+      const dnn::BatchDescriptor& output_descriptor,
+      DeviceMemory* output_data, ScratchAllocator* scratch_allocator,
+      const dnn::AlgorithmConfig& algorithm_config,
+      dnn::ProfileResult* output_profile_result) {
+    return false;
+  }
+
+  // Enqueues a single-precision convolution operation onto the stream.
+  //
+  // Arguments (all borrowed):
+  //  stream: borrowed pointer to the stream that the 'convolve' operation
+  //    should be enqueued onto.
+  //  input_descriptor: dimensions of the input layer.
+  //  input_data: un-owned device memory region which contains the
+  //    convolution input.
+  //  filter_descriptor: dimensions of the convolution filter.
+  //  convolution_descriptor: stride of the convolution filter.
+  //  input. This can be DeviceMemory pointing to NULL only when activation_mode
+  //  is kNone.
+  //  output_descriptor: dimensions of the output layer.
+  //  output_data: un-owned device memory region in which to place the
+  //    convolution result.
+  //  scratch_allocator: un-owned, may-be-null object that may allocate scratch
+  //    space in order to speed up the convolution operation.
+  //  algorithm: an integer to specify which algorithm should be used for the
+  //    operation. kDefaultAlgorithm means the system will pick an algorithm
+  //    by default. The coding of the algorithm is be interpretted by the
+  //    underlying implementation.
+  //  output_profile_result: the output profile result for this call. The
+  //    profiling is only enabled when this is not nullptr.
+  //
+  // input_descriptor, filter_descriptor, convolution_descriptor and
+  // output_descriptor together specify exactly how the convolution is aligned
+  // with the input data:
+  //
+  // * (input dimensions - filter size + 1) / filter stride == output dimensions
+  //   corresponds to dist_belief padding = VALID, i.e. the input is not padded.
+  // * input dimensions / filter stride == output dimensions
+  //   corresponds to dist_belief padding = SAME, i.e. input and output are the
+  //   same size - this requires padding the input.
+  // * (input dimensions + filter size - 1) / filter stride == output dimensions
+  //   corresponds to dist_belief padding = FULL, i.e. the output is sized so
+  //   that if the inverse of the filter is applied to the output in VALID mode
+  //   the result is the same size as the input - this requires even more
+  //   padding of the input.
   virtual bool DoConvolve(
       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
       const DeviceMemory& input_data,
@@ -1012,8 +1085,7 @@ class DnnSupport {
       const dnn::AlgorithmConfig& algorithm_config,
       ProfileResult* output_profile_result) = 0;
 
-  // Enqueues a double-precision convolution operation (without bias add
-  // or activation) onto the stream.
+  // Enqueues a double-precision convolution operation onto the stream.
   // See DoConvolve above for argument details.
   virtual bool DoConvolve(
       Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
@@ -1024,8 +1096,7 @@ class DnnSupport {
       const dnn::BatchDescriptor& output_descriptor,
       DeviceMemory* output_data) = 0;
 
-  // Enqueues a half-precision convolution operation (without bias add
-  // or activation) onto the stream.
+  // Enqueues a half-precision convolution operation onto the stream.
   // See DoConvolve above for argument details.
   virtual bool DoConvolve(
       Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index c9b36ba7ab35df1229d04b2ef5f73edeaa2e3c1f..dc768e02730ecba64cfa67fb314bc5a28f5e212e 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -361,28 +361,66 @@ Stream &Stream::ThenBatchNormalizationBackward(
   return *this;
 }
 
-Stream &Stream::ThenConvolveWithScratch(
-    const dnn::BatchDescriptor &input_descriptor,
-    const DeviceMemory &input_data,
+Stream &Stream::ThenFusedConvolveWithScratch(
+    const dnn::BatchDescriptor &conv_input_descriptor,
+    const DeviceMemory &conv_input_data, float conv_input_scale,
+    const dnn::FilterDescriptor &filter_descriptor,
+    const DeviceMemory &filter_data,
+    const dnn::ConvolutionDescriptor &convolution_descriptor,
+    const DeviceMemory &side_input_data, float side_input_scale,
+    const dnn::BatchDescriptor &bias_descriptor,
+    const DeviceMemory &biases, dnn::ActivationMode activation_mode,
+    const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output,
+    ScratchAllocator *scratch_allocator) {
+  VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
+            PARAM(conv_input_scale), PARAM(filter_descriptor),
+            PARAM(filter_data), PARAM(convolution_descriptor),
+            PARAM(side_input_data), PARAM(side_input_scale),
+            PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
+            PARAM(output_descriptor), PARAM(output));
+
+  if (ok()) {
+    if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+      CheckError(dnn->DoFusedConvolve(
+          this, conv_input_descriptor, conv_input_data, conv_input_scale,
+          filter_descriptor, filter_data, convolution_descriptor,
+          side_input_data, side_input_scale, bias_descriptor, biases,
+          activation_mode, output_descriptor, output, scratch_allocator,
+          dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
+    } else {
+      SetErrorAndLogNoDnnSupport();
+    }
+  }
+  return *this;
+}
+
+Stream &Stream::ThenFusedConvolveWithScratch(
+    const dnn::BatchDescriptor &conv_input_descriptor,
+    const DeviceMemory &conv_input_data, float conv_input_scale,
     const dnn::FilterDescriptor &filter_descriptor,
     const DeviceMemory &filter_data,
     const dnn::ConvolutionDescriptor &convolution_descriptor,
+    const DeviceMemory &side_input_data, float side_input_scale,
+    const dnn::BatchDescriptor &bias_descriptor,
     const DeviceMemory &biases,
     dnn::ActivationMode activation_mode,
     const dnn::BatchDescriptor &output_descriptor,
     DeviceMemory *output, ScratchAllocator *scratch_allocator) {
-  VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
-            PARAM(filter_descriptor), PARAM(filter_data),
-            PARAM(convolution_descriptor), PARAM(biases),
-            PARAM(activation_mode), PARAM(output_descriptor), PARAM(output));
+  VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
+            PARAM(conv_input_scale), PARAM(filter_descriptor),
+            PARAM(filter_data), PARAM(convolution_descriptor),
+            PARAM(side_input_data), PARAM(side_input_scale),
+            PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
+            PARAM(output_descriptor), PARAM(output));
 
   if (ok()) {
     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
-      CheckError(dnn->DoConvolve(
-          this, input_descriptor, input_data, filter_descriptor, filter_data,
-          convolution_descriptor, biases, activation_mode, output_descriptor,
-          output, scratch_allocator, dnn::AlgorithmConfig(),
-          /*output_profile_result=*/nullptr));
+      CheckError(dnn->DoFusedConvolve(
+          this, conv_input_descriptor, conv_input_data, conv_input_scale,
+          filter_descriptor, filter_data, convolution_descriptor,
+          side_input_data, side_input_scale, bias_descriptor, biases,
+          activation_mode, output_descriptor, output, scratch_allocator,
+          dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
     } else {
       SetErrorAndLogNoDnnSupport();
     }
@@ -390,27 +428,32 @@ Stream &Stream::ThenConvolveWithScratch(
   return *this;
 }
 
-Stream &Stream::ThenConvolveWithScratch(
-    const dnn::BatchDescriptor &input_descriptor,
-    const DeviceMemory &input_data,
+Stream &Stream::ThenFusedConvolveWithScratch(
+    const dnn::BatchDescriptor &conv_input_descriptor,
+    const DeviceMemory &conv_input_data, float conv_input_scale,
     const dnn::FilterDescriptor &filter_descriptor,
     const DeviceMemory &filter_data,
     const dnn::ConvolutionDescriptor &convolution_descriptor,
+    const DeviceMemory &side_input_data, float side_input_scale,
+    const dnn::BatchDescriptor &bias_descriptor,
     const DeviceMemory &biases, dnn::ActivationMode activation_mode,
     const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output,
     ScratchAllocator *scratch_allocator) {
-  VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
-            PARAM(filter_descriptor), PARAM(filter_data),
-            PARAM(convolution_descriptor), PARAM(biases),
-            PARAM(activation_mode), PARAM(output_descriptor), PARAM(output));
+  VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
+            PARAM(conv_input_scale), PARAM(filter_descriptor),
+            PARAM(filter_data), PARAM(convolution_descriptor),
+            PARAM(side_input_data), PARAM(side_input_scale),
+            PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
+            PARAM(output_descriptor), PARAM(output));
 
   if (ok()) {
     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
-      CheckError(dnn->DoConvolve(
-          this, input_descriptor, input_data, filter_descriptor, filter_data,
-          convolution_descriptor, biases, activation_mode, output_descriptor,
-          output, scratch_allocator, dnn::AlgorithmConfig(),
-          /*output_profile_result=*/nullptr));
+      CheckError(dnn->DoFusedConvolve(
+          this, conv_input_descriptor, conv_input_data, conv_input_scale,
+          filter_descriptor, filter_data, convolution_descriptor,
+          side_input_data, side_input_scale, bias_descriptor, biases,
+          activation_mode, output_descriptor, output, scratch_allocator,
+          dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
     } else {
       SetErrorAndLogNoDnnSupport();
     }
@@ -472,29 +515,34 @@ Stream &Stream::ThenConvolveWithScratch(
   return *this;
 }
 
-Stream &Stream::ThenConvolveWithAlgorithm(
-    const dnn::BatchDescriptor &input_descriptor,
-    const DeviceMemory &input_data,
+Stream &Stream::ThenFusedConvolveWithAlgorithm(
+    const dnn::BatchDescriptor &conv_input_descriptor,
+    const DeviceMemory &conv_input_data, float conv_input_scale,
     const dnn::FilterDescriptor &filter_descriptor,
     const DeviceMemory &filter_data,
     const dnn::ConvolutionDescriptor &convolution_descriptor,
+    const DeviceMemory &side_input_data, float side_input_scale,
+    const dnn::BatchDescriptor &bias_descriptor,
     const DeviceMemory &biases, dnn::ActivationMode activation_mode,
     const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output,
     ScratchAllocator *scratch_allocator,
     const dnn::AlgorithmConfig &algorithm_config,
     dnn::ProfileResult *output_profile_result) {
-  VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
-            PARAM(filter_descriptor), PARAM(filter_data),
-            PARAM(convolution_descriptor), PARAM(biases),
+  VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
+            PARAM(conv_input_scale), PARAM(filter_descriptor),
+            PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
+            PARAM(side_input_data), PARAM(side_input_scale),
             PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
             PARAM(algorithm_config));
 
   if (ok()) {
     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
-      auto status = dnn->DoConvolve(
-          this, input_descriptor, input_data, filter_descriptor, filter_data,
-          convolution_descriptor, biases, activation_mode, output_descriptor,
-          output, scratch_allocator, algorithm_config, output_profile_result);
+      auto status = dnn->DoFusedConvolve(
+          this, conv_input_descriptor, conv_input_data, conv_input_scale,
+          filter_descriptor, filter_data, convolution_descriptor,
+          side_input_data, side_input_scale, bias_descriptor, biases,
+          activation_mode, output_descriptor, output, scratch_allocator,
+          algorithm_config, output_profile_result);
       if (!status && !output_profile_result) {
         SetError();
       }
@@ -505,30 +553,73 @@ Stream &Stream::ThenConvolveWithAlgorithm(
   return *this;
 }
 
-Stream &Stream::ThenConvolveWithAlgorithm(
-    const dnn::BatchDescriptor &input_descriptor,
-    const DeviceMemory &input_data,
+Stream &Stream::ThenFusedConvolveWithAlgorithm(
+    const dnn::BatchDescriptor &conv_input_descriptor,
+    const DeviceMemory &conv_input_data, float conv_input_scale,
     const dnn::FilterDescriptor &filter_descriptor,
     const DeviceMemory &filter_data,
     const dnn::ConvolutionDescriptor &convolution_descriptor,
+    const DeviceMemory &side_input_data, float side_input_scale,
+    const dnn::BatchDescriptor &bias_descriptor,
     const DeviceMemory &biases,
     dnn::ActivationMode activation_mode,
     const dnn::BatchDescriptor &output_descriptor,
     DeviceMemory *output, ScratchAllocator *scratch_allocator,
     const dnn::AlgorithmConfig &algorithm_config,
     dnn::ProfileResult *output_profile_result) {
-  VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
-            PARAM(filter_descriptor), PARAM(filter_data),
-            PARAM(convolution_descriptor), PARAM(biases),
-            PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
-            PARAM(algorithm_config));
+  VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
+            PARAM(conv_input_scale), PARAM(filter_descriptor),
+            PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
+            PARAM(side_input_data), PARAM(side_input_scale),
+            PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
+            PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
 
   if (ok()) {
     if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
-      auto status = dnn->DoConvolve(
-          this, input_descriptor, input_data, filter_descriptor, filter_data,
-          convolution_descriptor, biases, activation_mode, output_descriptor,
-          output, scratch_allocator, algorithm_config, output_profile_result);
+      auto status = dnn->DoFusedConvolve(
+          this, conv_input_descriptor, conv_input_data, conv_input_scale,
+          filter_descriptor, filter_data, convolution_descriptor,
+          side_input_data, side_input_scale, bias_descriptor, biases,
+          activation_mode, output_descriptor, output, scratch_allocator,
+          algorithm_config, output_profile_result);
+      if (!status && !output_profile_result) {
+        SetError();
+      }
+    } else {
+      SetErrorAndLogNoDnnSupport();
+    }
+  }
+  return *this;
+}
+
+Stream &Stream::ThenFusedConvolveWithAlgorithm(
+    const dnn::BatchDescriptor &conv_input_descriptor,
+    const DeviceMemory &conv_input_data, float conv_input_scale,
+    const dnn::FilterDescriptor &filter_descriptor,
+    const DeviceMemory &filter_data,
+    const dnn::ConvolutionDescriptor &convolution_descriptor,
+    const DeviceMemory &side_input_data, float side_input_scale,
+    const dnn::BatchDescriptor &bias_descriptor,
+    const DeviceMemory &biases, dnn::ActivationMode activation_mode,
+    const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output,
+    ScratchAllocator *scratch_allocator,
+    const dnn::AlgorithmConfig &algorithm_config,
+    dnn::ProfileResult *output_profile_result) {
+  VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
+            PARAM(conv_input_scale), PARAM(filter_descriptor),
+            PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
+            PARAM(side_input_data), PARAM(side_input_scale),
+            PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
+            PARAM(output_descriptor), PARAM(output), PARAM(algorithm_config));
+
+  if (ok()) {
+    if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+      auto status = dnn->DoFusedConvolve(
+          this, conv_input_descriptor, conv_input_data, conv_input_scale,
+          filter_descriptor, filter_data, convolution_descriptor,
+          side_input_data, side_input_scale, bias_descriptor, biases,
+          activation_mode, output_descriptor, output, scratch_allocator,
+          algorithm_config, output_profile_result);
       if (!status && !output_profile_result) {
         SetError();
       }
@@ -601,19 +692,22 @@ Stream &Stream::ThenConvolveWithAlgorithm(
   return *this;
 }
 
-Stream &Stream::ThenConvolve(
-    const dnn::BatchDescriptor &input_descriptor,
-    const DeviceMemory &input_data,
+Stream &Stream::ThenFusedConvolve(
+    const dnn::BatchDescriptor &conv_input_descriptor,
+    const DeviceMemory &conv_input_data, float conv_input_scale,
     const dnn::FilterDescriptor &filter_descriptor,
-    const DeviceMemory &filter_data,
+    const DeviceMemory &filter_data,
     const dnn::ConvolutionDescriptor &convolution_descriptor,
+    const DeviceMemory &side_input_data, float side_input_scale,
+    const dnn::BatchDescriptor &bias_descriptor,
     const DeviceMemory &biases, dnn::ActivationMode activation_mode,
-    const dnn::BatchDescriptor &output_descriptor,
-    DeviceMemory *output) {
-  return ThenConvolveWithScratch(
-      input_descriptor, input_data, filter_descriptor, filter_data,
-      convolution_descriptor, biases, activation_mode, output_descriptor,
-      output, /*scratch_allocator=*/nullptr);
+    const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output) {
+  return ThenFusedConvolveWithScratch(
+      conv_input_descriptor, conv_input_data, conv_input_scale,
+      filter_descriptor, filter_data, convolution_descriptor, side_input_data,
+      side_input_scale, bias_descriptor, biases, activation_mode,
+      output_descriptor, output,
+      /*scratch_allocator=*/nullptr);
 }
 
 Stream &Stream::ThenConvolve(
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 9bd4c21a66e1ffc68cc32e01e166131847be338c..a418fe961c822c1991382038b30917d5a2a0af18 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -240,15 +240,17 @@ class Stream {
       DeviceMemory *offset_backprop);
 
   // TODO(leary) add double-precision version of this interface.
-  Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
-                       const DeviceMemory &input_data,
-                       const dnn::FilterDescriptor &filter_descriptor,
-                       const DeviceMemory &filter_data,
-                       const dnn::ConvolutionDescriptor &convolution_descriptor,
-                       const DeviceMemory &biases,
-                       dnn::ActivationMode activation_mode,
-                       const dnn::BatchDescriptor &output_descriptor,
-                       DeviceMemory *output);
+  Stream &ThenFusedConvolve(
+      const dnn::BatchDescriptor &conv_input_descriptor,
+      const DeviceMemory &conv_input_data, float conv_input_scale,
+      const dnn::FilterDescriptor &filter_descriptor,
+      const DeviceMemory &filter_data,
+      const dnn::ConvolutionDescriptor &convolution_descriptor,
+      const DeviceMemory &side_input_data, float side_input_scale,
+      const dnn::BatchDescriptor &bias_descriptor,
+      const DeviceMemory &biases, dnn::ActivationMode activation_mode,
+      const dnn::BatchDescriptor &output_descriptor,
+      DeviceMemory *output);
 
   Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
                        const DeviceMemory &input_data,
@@ -278,23 +280,39 @@ class Stream {
       const dnn::BatchDescriptor &output_descriptor,
       DeviceMemory *output_data);
 
-  Stream &ThenConvolveWithScratch(
-      const dnn::BatchDescriptor &input_descriptor,
-      const DeviceMemory &input_data,
+  Stream &ThenFusedConvolveWithScratch(
+      const dnn::BatchDescriptor &conv_input_descriptor,
+      const DeviceMemory &conv_input_data, float conv_input_scale,
+      const dnn::FilterDescriptor &filter_descriptor,
+      const DeviceMemory &filter_data,
+      const dnn::ConvolutionDescriptor &convolution_descriptor,
+      const DeviceMemory &side_input_data, float side_input_scale,
+      const dnn::BatchDescriptor &bias_descriptor,
+      const DeviceMemory &biases, dnn::ActivationMode activation_mode,
+      const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output,
+      ScratchAllocator *scratch_allocator);
+
+  Stream &ThenFusedConvolveWithScratch(
+      const dnn::BatchDescriptor &conv_input_descriptor,
+      const DeviceMemory &conv_input_data, float conv_input_scale,
       const dnn::FilterDescriptor &filter_descriptor,
       const DeviceMemory &filter_data,
       const dnn::ConvolutionDescriptor &convolution_descriptor,
+      const DeviceMemory &side_input_data, float side_input_scale,
+      const dnn::BatchDescriptor &bias_descriptor,
       const DeviceMemory &biases,
       dnn::ActivationMode activation_mode,
       const dnn::BatchDescriptor &output_descriptor,
       DeviceMemory *output, ScratchAllocator *scratch_allocator);
 
-  Stream &ThenConvolveWithScratch(
-      const dnn::BatchDescriptor &input_descriptor,
-      const DeviceMemory &input_data,
+  Stream &ThenFusedConvolveWithScratch(
+      const dnn::BatchDescriptor &conv_input_descriptor,
+      const DeviceMemory &conv_input_data, float conv_input_scale,
       const dnn::FilterDescriptor &filter_descriptor,
       const DeviceMemory &filter_data,
       const dnn::ConvolutionDescriptor &convolution_descriptor,
+      const DeviceMemory &side_input_data, float side_input_scale,
+      const dnn::BatchDescriptor &bias_descriptor,
       const DeviceMemory &biases, dnn::ActivationMode activation_mode,
       const dnn::BatchDescriptor &output_descriptor,
       DeviceMemory *output, ScratchAllocator *scratch_allocator);
@@ -323,7 +341,6 @@ class Stream {
       const dnn::FilterDescriptor &filter_descriptor,
       const DeviceMemory &filter_data,
       const dnn::ConvolutionDescriptor &convolution_descriptor,
-      const DeviceMemory &biases, dnn::ActivationMode activation_mode,
       const dnn::BatchDescriptor &output_descriptor,
       DeviceMemory *output, ScratchAllocator *scratch_allocator,
       const dnn::AlgorithmConfig &algorithm_config,
@@ -335,35 +352,68 @@ class Stream {
       const dnn::FilterDescriptor &filter_descriptor,
       const DeviceMemory &filter_data,
       const dnn::ConvolutionDescriptor &convolution_descriptor,
-      const DeviceMemory &biases,
-      dnn::ActivationMode activation_mode,
       const dnn::BatchDescriptor &output_descriptor,
       DeviceMemory *output, ScratchAllocator *scratch_allocator,
       const dnn::AlgorithmConfig &algorithm_config,
       dnn::ProfileResult *output_profile_result);
 
-  Stream &ThenConvolveWithAlgorithm(
-      const dnn::BatchDescriptor &input_descriptor,
-      const DeviceMemory &input_data,
+  Stream &ThenFusedConvolveWithAlgorithm(
+      const dnn::BatchDescriptor &conv_input_descriptor,
+      const DeviceMemory &conv_input_data, double conv_input_scale,
+      const dnn::FilterDescriptor &filter_descriptor,
+      const DeviceMemory &filter_data,
+      const dnn::ConvolutionDescriptor &convolution_descriptor,
+      const DeviceMemory &side_input_data, double side_input_scale,
+      const dnn::BatchDescriptor &bias_descriptor,
+      const DeviceMemory &biases, dnn::ActivationMode activation_mode,
+      const dnn::BatchDescriptor &output_descriptor,
+      DeviceMemory *output, ScratchAllocator *scratch_allocator,
+      const dnn::AlgorithmConfig &algorithm_config,
+      dnn::ProfileResult *output_profile_result);
+
+  Stream &ThenFusedConvolveWithAlgorithm(
+      const dnn::BatchDescriptor &conv_input_descriptor,
+      const DeviceMemory &conv_input_data, float conv_input_scale,
       const dnn::FilterDescriptor &filter_descriptor,
       const DeviceMemory &filter_data,
       const dnn::ConvolutionDescriptor &convolution_descriptor,
+      const DeviceMemory &side_input_data, float side_input_scale,
+      const dnn::BatchDescriptor &bias_descriptor,
+      const DeviceMemory &biases, dnn::ActivationMode activation_mode,
       const dnn::BatchDescriptor &output_descriptor,
       DeviceMemory *output, ScratchAllocator *scratch_allocator,
       const dnn::AlgorithmConfig &algorithm_config,
       dnn::ProfileResult *output_profile_result);
 
-  Stream &ThenConvolveWithAlgorithm(
-      const dnn::BatchDescriptor &input_descriptor,
-      const DeviceMemory &input_data,
+  Stream &ThenFusedConvolveWithAlgorithm(
+      const dnn::BatchDescriptor &conv_input_descriptor,
+      const DeviceMemory &conv_input_data, float conv_input_scale,
       const dnn::FilterDescriptor &filter_descriptor,
       const DeviceMemory &filter_data,
       const dnn::ConvolutionDescriptor &convolution_descriptor,
+      const DeviceMemory &side_input_data, float side_input_scale,
+      const dnn::BatchDescriptor &bias_descriptor,
+      const DeviceMemory &biases,
+      dnn::ActivationMode activation_mode,
       const dnn::BatchDescriptor &output_descriptor,
       DeviceMemory *output, ScratchAllocator *scratch_allocator,
       const dnn::AlgorithmConfig &algorithm_config,
       dnn::ProfileResult *output_profile_result);
 
+  Stream &ThenFusedConvolveWithAlgorithm(
+      const dnn::BatchDescriptor &conv_input_descriptor,
+      const DeviceMemory &conv_input_data, float conv_input_scale,
+      const dnn::FilterDescriptor &filter_descriptor,
+      const DeviceMemory &filter_data,
+      const dnn::ConvolutionDescriptor &convolution_descriptor,
+      const DeviceMemory &side_input_data, float side_input_scale,
+      const dnn::BatchDescriptor &bias_descriptor,
+      const DeviceMemory &biases, dnn::ActivationMode activation_mode,
+      const dnn::BatchDescriptor &output_descriptor, DeviceMemory *output,
+      ScratchAllocator *scratch_allocator,
+      const dnn::AlgorithmConfig &algorithm_config,
+      dnn::ProfileResult *output_profile_result);
+
   Stream &ThenSeparableConvolve(
       const dnn::BatchDescriptor &input_descriptor,
       const DeviceMemory &input_data,
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index f1e153e4c525f0b1989f5b47631f6e1ca33d1cd0..f0301937fbad224bf106859d0dd63cbedf8336d3 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -167,7 +167,7 @@ def tf_copts():
       "-fno-exceptions",
       "-ftemplate-depth=900",
   ]) + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1", "-fopenmp",]) + if_android_arm(
-      ["-mfpu=neon"]) + if_linux_x86_64(["-msse3"]) + select({
+      ["-mfpu=neon", "-fomit-frame-pointer"]) + if_linux_x86_64(["-msse3"]) + select({
           clean_dep("//tensorflow:android"): [
               "-std=c++11",
               "-DTF_LEAN_BINARY",
diff --git a/tensorflow/tf_exported_symbols.lds b/tensorflow/tf_exported_symbols.lds
index 4597d929a1d6c5b12de28333ff7da75040ff2811..bddb87f00cb5fd1ede2cb9d5cc4079d6e66f7896 100644
--- a/tensorflow/tf_exported_symbols.lds
+++ b/tensorflow/tf_exported_symbols.lds
@@ -1,5 +1,6 @@
 *tensorflow*
 *perftools*gputools*
 *tf_*
-TF_*
+*TF_*
+*TFE_*
 *nsync_*
diff --git a/tensorflow/tf_version_script.lds b/tensorflow/tf_version_script.lds
index 88b64eb1f09c1a26d5a20e406149b6fd64219f17..11f66c5c8b27f412b2023d6f3036c56d3d1e530c 100644
--- a/tensorflow/tf_version_script.lds
+++ b/tensorflow/tf_version_script.lds
@@ -2,7 +2,8 @@ tensorflow {
   global:
     *tensorflow*;
     *perftools*gputools*;
-    TF_*;
+    *TF_*;
+    *TFE_*;
     *nsync_*;
   local:
     *;
diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
index 661d6fc586271a18255b4d3127011c2d616ff4cc..764ffbb4b71e2c95e4c1b4f0c81662b7f5317616 100644
--- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
@@ -76,6 +76,10 @@ tf_module {
     name: "extract_glimpse"
     argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'None\'], "
   }
+  member_method {
+    name: "extract_jpeg_shape"
+    argspec: "args=[\'contents\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], "
+  }
   member_method {
     name: "flip_left_right"
     argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..2cd83baf65cf4114e58f52cdc40de7e4b6df7554
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt
@@ -0,0 +1,55 @@
+path: "tensorflow.keras.activations"
+tf_module {
+  member_method {
+    name: "deserialize"
+    argspec: "args=[\'name\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "elu"
+    argspec: "args=[\'x\', \'alpha\'], varargs=None, keywords=None, defaults=[\'1.0\'], "
+  }
+  member_method {
+    name: "get"
+    argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "hard_sigmoid"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "linear"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "relu"
+    argspec: "args=[\'x\', \'alpha\', \'max_value\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], "
+  }
+  member_method {
+    name: "selu"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "serialize"
+    argspec: "args=[\'activation\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "sigmoid"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "softmax"
+    argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], "
+  }
+  member_method {
+    name: "softplus"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "softsign"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "tanh"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.inception_v3.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.inception_v3.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..b67cee80ab04cdab617837efe42b6e7deb3c3b69
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.inception_v3.pbtxt
@@ -0,0 +1,15 @@
+path: "tensorflow.keras.applications.inception_v3"
+tf_module {
+  member_method {
+    name: "InceptionV3"
+    argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+  }
+  member_method {
+    name: "decode_predictions"
+    argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
+  }
+  member_method {
+    name: "preprocess_input"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.mobilenet.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.mobilenet.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..ef774e1dd742aca59aa642f15340e26869a5fa17
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.mobilenet.pbtxt
@@ -0,0 +1,15 @@
+path: "tensorflow.keras.applications.mobilenet"
+tf_module {
+  member_method {
+    name: "MobileNet"
+    argspec: "args=[\'input_shape\', \'alpha\', \'depth_multiplier\', \'dropout\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'1.0\', \'1\', \'0.001\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
+  }
+  member_method {
+    name: "decode_predictions"
+    argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
+  }
+  member_method {
+    name: "preprocess_input"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..f50dc7d7fe432d80e91c8bbfbd8cfc36b5682fb7
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.pbtxt
@@ -0,0 +1,51 @@
+path: "tensorflow.keras.applications"
+tf_module {
+  member {
+    name: "inception_v3"
+    mtype: ""
+  }
+  member {
+    name: "mobilenet"
+    mtype: ""
+  }
+  member {
+    name: "resnet50"
+    mtype: ""
+  }
+  member {
+    name: "vgg16"
+    mtype: ""
+  }
+  member {
+    name: "vgg19"
+    mtype: ""
+  }
+  member {
+    name: "xception"
+    mtype: ""
+  }
+  member_method {
+    name: "InceptionV3"
+    argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+  }
+  member_method {
+    name: "MobileNet"
+    argspec: "args=[\'input_shape\', \'alpha\', \'depth_multiplier\', \'dropout\', \'include_top\', \'weights\', \'input_tensor\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'None\', \'1.0\', \'1\', \'0.001\', \'True\', \'imagenet\', \'None\', \'None\', \'1000\'], "
+  }
+  member_method {
+    name: "ResNet50"
+    argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+  }
+  member_method {
+    name: "VGG16"
+    argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+  }
+  member_method {
+    name: "VGG19"
+    argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+  }
+  member_method {
+    name: "Xception"
+    argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.resnet50.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.resnet50.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..57c48df2e365528d8c3812ec502661eb9576e89e
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.resnet50.pbtxt
@@ -0,0 +1,15 @@
+path: "tensorflow.keras.applications.resnet50"
+tf_module {
+  member_method {
+    name: "ResNet50"
+    argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+  }
+  member_method {
+    name: "decode_predictions"
+    argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
+  }
+  member_method {
+    name: "preprocess_input"
+    argspec: "args=[\'x\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.vgg16.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.vgg16.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..29d45daea44cd51ef8bc4590218c3a30a7d9f39f
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.vgg16.pbtxt
@@ -0,0 +1,15 @@
+path: "tensorflow.keras.applications.vgg16"
+tf_module {
+  member_method {
+    name: "VGG16"
+    argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+  }
+  member_method {
+    name: "decode_predictions"
+    argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
+  }
+  member_method {
+    name: "preprocess_input"
+    argspec: "args=[\'x\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.vgg19.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.vgg19.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..124aa7e5e5dd6f9863790b86bf8c767f21304235
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.vgg19.pbtxt
@@ -0,0 +1,15 @@
+path: "tensorflow.keras.applications.vgg19"
+tf_module {
+  member_method {
+    name: "VGG19"
+    argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+  }
+  member_method {
+    name: "decode_predictions"
+    argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
+  }
+  member_method {
+    name: "preprocess_input"
+    argspec: "args=[\'x\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.applications.xception.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.applications.xception.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..59dd2108f2a3673d25f894795817e01a4311cc1c
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.applications.xception.pbtxt
@@ -0,0 +1,15 @@
+path: "tensorflow.keras.applications.xception"
+tf_module {
+  member_method {
+    name: "Xception"
+    argspec: "args=[\'include_top\', \'weights\', \'input_tensor\', \'input_shape\', \'pooling\', \'classes\'], varargs=None, keywords=None, defaults=[\'True\', \'imagenet\', \'None\', \'None\', \'None\', \'1000\'], "
+  }
+  member_method {
+    name: "decode_predictions"
+    argspec: "args=[\'preds\', \'top\'], varargs=None, keywords=None, defaults=[\'5\'], "
+  }
+  member_method {
+    name: "preprocess_input"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..6204ffa814680563f80a59640ff5757b6a7c7adf
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.backend.pbtxt
@@ -0,0 +1,555 @@
+path: "tensorflow.keras.backend"
+tf_module {
+  member_method {
+    name: "abs"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "all"
+    argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+  }
+  member_method {
+    name: "any"
+    argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+  }
+  member_method {
+    name: "arange"
+    argspec: "args=[\'start\', \'stop\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'int32\'], "
+  }
+  member_method {
+    name: "argmax"
+    argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], "
+  }
+  member_method {
+    name: "argmin"
+    argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], "
+  }
+  member_method {
+    name: "backend"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "batch_dot"
+    argspec: "args=[\'x\', \'y\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "batch_flatten"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "batch_get_value"
+    argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "batch_normalization"
+    argspec: "args=[\'x\', \'mean\', \'var\', \'beta\', \'gamma\', \'epsilon\'], varargs=None, keywords=None, defaults=[\'0.001\'], "
+  }
+  member_method {
+    name: "batch_set_value"
+    argspec: "args=[\'tuples\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "bias_add"
+    argspec: "args=[\'x\', \'bias\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "binary_crossentropy"
+    argspec: "args=[\'target\', \'output\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'False\'], "
+  }
+  member_method {
+    name: "cast"
+    argspec: "args=[\'x\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "cast_to_floatx"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "categorical_crossentropy"
+    argspec: "args=[\'target\', \'output\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'False\'], "
+  }
+  member_method {
+    name: "clear_session"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "clip"
+    argspec: "args=[\'x\', \'min_value\', \'max_value\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "concatenate"
+    argspec: "args=[\'tensors\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], "
+  }
+  member_method {
+    name: "constant"
+    argspec: "args=[\'value\', \'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "conv1d"
+    argspec: "args=[\'x\', \'kernel\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\'], varargs=None, keywords=None, defaults=[\'1\', \'valid\', \'None\', \'1\'], "
+  }
+  member_method {
+    name: "conv2d"
+    argspec: "args=[\'x\', \'kernel\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\'], "
+  }
+  member_method {
+    name: "conv2d_transpose"
+    argspec: "args=[\'x\', \'kernel\', \'output_shape\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\'], "
+  }
+  member_method {
+    name: "conv3d"
+    argspec: "args=[\'x\', \'kernel\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\'], varargs=None, keywords=None, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'(1, 1, 1)\'], "
+  }
+  member_method {
+    name: "cos"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "ctc_batch_cost"
+    argspec: "args=[\'y_true\', \'y_pred\', \'input_length\', \'label_length\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "ctc_decode"
+    argspec: "args=[\'y_pred\', \'input_length\', \'greedy\', \'beam_width\', \'top_paths\'], varargs=None, keywords=None, defaults=[\'True\', \'100\', \'1\'], "
+  }
+  member_method {
+    name: "ctc_label_dense_to_sparse"
+    argspec: "args=[\'labels\', \'label_lengths\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "dot"
+    argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "dropout"
+    argspec: "args=[\'x\', \'level\', \'noise_shape\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "dtype"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "elu"
+    argspec: "args=[\'x\', \'alpha\'], varargs=None, keywords=None, defaults=[\'1.0\'], "
+  }
+  member_method {
+    name: "epsilon"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "equal"
+    argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "eval"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "exp"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "expand_dims"
+    argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], "
+  }
+  member_method {
+    name: "eye"
+    argspec: "args=[\'size\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "flatten"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "floatx"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "foldl"
+    argspec: "args=[\'fn\', \'elems\', \'initializer\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "foldr"
+    argspec: "args=[\'fn\', \'elems\', \'initializer\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "function"
+    argspec: "args=[\'inputs\', \'outputs\', \'updates\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "gather"
+    argspec: "args=[\'reference\', \'indices\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_session"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_uid"
+    argspec: "args=[\'prefix\'], varargs=None, keywords=None, defaults=[\'\'], "
+  }
+  member_method {
+    name: "get_value"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "gradients"
+    argspec: "args=[\'loss\', \'variables\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "greater"
+    argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "greater_equal"
+    argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "hard_sigmoid"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "image_data_format"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "in_test_phase"
+    argspec: "args=[\'x\', \'alt\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "in_top_k"
+    argspec: "args=[\'predictions\', \'targets\', \'k\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "in_train_phase"
+    argspec: "args=[\'x\', \'alt\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "int_shape"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "is_sparse"
+    argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "l2_normalize"
+    argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "learning_phase"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "less"
+    argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "less_equal"
+    argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "log"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "manual_variable_initialization"
+    argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "map_fn"
+    argspec: "args=[\'fn\', \'elems\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "max"
+    argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+  }
+  member_method {
+    name: "maximum"
+    argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "mean"
+    argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+  }
+  member_method {
+    name: "min"
+    argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+  }
+  member_method {
+    name: "minimum"
+    argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "moving_average_update"
+    argspec: "args=[\'x\', \'value\', \'momentum\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "name_scope"
+    argspec: "args=[\'name\', \'default_name\', \'values\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "ndim"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "normalize_batch_in_training"
+    argspec: "args=[\'x\', \'gamma\', \'beta\', \'reduction_axes\', \'epsilon\'], varargs=None, keywords=None, defaults=[\'0.001\'], "
+  }
+  member_method {
+    name: "not_equal"
+    argspec: "args=[\'x\', \'y\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "one_hot"
+    argspec: "args=[\'indices\', \'num_classes\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "ones"
+    argspec: "args=[\'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "ones_like"
+    argspec: "args=[\'x\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "permute_dimensions"
+    argspec: "args=[\'x\', \'pattern\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "placeholder"
+    argspec: "args=[\'shape\', \'ndim\', \'dtype\', \'sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
+  }
+  member_method {
+    name: "pool2d"
+    argspec: "args=[\'x\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'pool_mode\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\', \'max\'], "
+  }
+  member_method {
+    name: "pool3d"
+    argspec: "args=[\'x\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'pool_mode\'], varargs=None, keywords=None, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'max\'], "
+  }
+  member_method {
+    name: "pow"
+    argspec: "args=[\'x\', \'a\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "print_tensor"
+    argspec: "args=[\'x\', \'message\'], varargs=None, keywords=None, defaults=[\'\'], "
+  }
+  member_method {
+    name: "prod"
+    argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+  }
+  member_method {
+    name: "random_binomial"
+    argspec: "args=[\'shape\', \'p\', \'dtype\', \'seed\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "random_normal"
+    argspec: "args=[\'shape\', \'mean\', \'stddev\', \'dtype\', \'seed\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "random_normal_variable"
+    argspec: "args=[\'shape\', \'mean\', \'scale\', \'dtype\', \'name\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "random_uniform"
+    argspec: "args=[\'shape\', \'minval\', \'maxval\', \'dtype\', \'seed\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "random_uniform_variable"
+    argspec: "args=[\'shape\', \'low\', \'high\', \'dtype\', \'name\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "relu"
+    argspec: "args=[\'x\', \'alpha\', \'max_value\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], "
+  }
+  member_method {
+    name: "repeat"
+    argspec: "args=[\'x\', \'n\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "repeat_elements"
+    argspec: "args=[\'x\', \'rep\', \'axis\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset_uids"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reshape"
+    argspec: "args=[\'x\', \'shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "resize_images"
+    argspec: "args=[\'x\', \'height_factor\', \'width_factor\', \'data_format\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "resize_volumes"
+    argspec: "args=[\'x\', \'depth_factor\', \'height_factor\', \'width_factor\', \'data_format\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reverse"
+    argspec: "args=[\'x\', \'axes\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "rnn"
+    argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\'], "
+  }
+  member_method {
+    name: "round"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "separable_conv2d"
+    argspec: "args=[\'x\', \'depthwise_kernel\', \'pointwise_kernel\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\'], "
+  }
+  member_method {
+    name: "set_epsilon"
+    argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_floatx"
+    argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_image_data_format"
+    argspec: "args=[\'data_format\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_learning_phase"
+    argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_session"
+    argspec: "args=[\'session\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_value"
+    argspec: "args=[\'x\', \'value\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "shape"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "sigmoid"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "sign"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "sin"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "softmax"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "softplus"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "softsign"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "sparse_categorical_crossentropy"
+    argspec: "args=[\'target\', \'output\', \'from_logits\'], varargs=None, keywords=None, defaults=[\'False\'], "
+  }
+  member_method {
+    name: "spatial_2d_padding"
+    argspec: "args=[\'x\', \'padding\', \'data_format\'], varargs=None, keywords=None, defaults=[\'((1, 1), (1, 1))\', \'None\'], "
+  }
+  member_method {
+    name: "spatial_3d_padding"
+    argspec: "args=[\'x\', \'padding\', \'data_format\'], varargs=None, keywords=None, defaults=[\'((1, 1), (1, 1), (1, 1))\', \'None\'], "
+  }
+  member_method {
+    name: "sqrt"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "square"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "squeeze"
+    argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "stack"
+    argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], "
+  }
+  member_method {
+    name: "std"
+    argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+  }
+  member_method {
+    name: "stop_gradient"
+    argspec: "args=[\'variables\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "sum"
+    argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+  }
+  member_method {
+    name: "switch"
+    argspec: "args=[\'condition\', \'then_expression\', \'else_expression\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "tanh"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "temporal_padding"
+    argspec: "args=[\'x\', \'padding\'], varargs=None, keywords=None, defaults=[\'(1, 1)\'], "
+  }
+  member_method {
+    name: "to_dense"
+    argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "transpose"
+    argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "truncated_normal"
+    argspec: "args=[\'shape\', \'mean\', \'stddev\', \'dtype\', \'seed\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "update"
+    argspec: "args=[\'x\', \'new_x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "update_add"
+    argspec: "args=[\'x\', \'increment\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "update_sub"
+    argspec: "args=[\'x\', \'decrement\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "var"
+    argspec: "args=[\'x\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+  }
+  member_method {
+    name: "variable"
+    argspec: "args=[\'value\', \'dtype\', \'name\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "zeros"
+    argspec: "args=[\'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "zeros_like"
+    argspec: "args=[\'x\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..ea4d5143540611f0585b67910cb319454b8560dc
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt
@@ -0,0 +1,42 @@
+path: "tensorflow.keras.callbacks.BaseLogger"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "on_batch_begin"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_batch_end"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_begin"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_end"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_begin"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_end"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_model"
+    argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_params"
+    argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-c-s-v-logger.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-c-s-v-logger.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..86b264c79f63ff78133f0989b5086984a3b16dbd
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-c-s-v-logger.pbtxt
@@ -0,0 +1,42 @@
+path: "tensorflow.keras.callbacks.CSVLogger"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filename\', \'separator\', \'append\'], varargs=None, keywords=None, defaults=[\',\', \'False\'], "
+  }
+  member_method {
+    name: "on_batch_begin"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_batch_end"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_begin"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_end"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_begin"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_end"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_model"
+    argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_params"
+    argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-callback.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-callback.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..1474b392ff38c0c224725867006721096b951567
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-callback.pbtxt
@@ -0,0 +1,41 @@
+path: "tensorflow.keras.callbacks.Callback"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "on_batch_begin"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_batch_end"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_begin"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_end"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_begin"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_end"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_model"
+    argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_params"
+    argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..27d4a208a4108b107ed6a0ffbab733cb1e3d8f46
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt
@@ -0,0 +1,42 @@
+path: "tensorflow.keras.callbacks.EarlyStopping"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'monitor\', \'min_delta\', \'patience\', \'verbose\', \'mode\'], varargs=None, keywords=None, defaults=[\'val_loss\', \'0\', \'0\', \'0\', \'auto\'], "
+  }
+  member_method {
+    name: "on_batch_begin"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_batch_end"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_begin"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_end"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_begin"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_end"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_model"
+    argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_params"
+    argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-history.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-history.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..a7b2deea8286df935db3a85e9569c3097b0b39ce
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-history.pbtxt
@@ -0,0 +1,42 @@
+path: "tensorflow.keras.callbacks.History"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "on_batch_begin"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_batch_end"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_begin"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_end"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_begin"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_end"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_model"
+    argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_params"
+    argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-lambda-callback.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-lambda-callback.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..5ee22948ad52ed082a8790a2127bbe4afc182049
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-lambda-callback.pbtxt
@@ -0,0 +1,42 @@
+path: "tensorflow.keras.callbacks.LambdaCallback"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'on_epoch_begin\', \'on_epoch_end\', \'on_batch_begin\', \'on_batch_end\', \'on_train_begin\', \'on_train_end\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "on_batch_begin"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_batch_end"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_begin"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_end"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_begin"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_end"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_model"
+    argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_params"
+    argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..8719c07ca385d2794e5c7e77f75d6d2bc734b7cb
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-learning-rate-scheduler.pbtxt
@@ -0,0 +1,42 @@
+path: "tensorflow.keras.callbacks.LearningRateScheduler"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'schedule\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "on_batch_begin"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_batch_end"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_begin"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_end"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_begin"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_end"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_model"
+    argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_params"
+    argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-model-checkpoint.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-model-checkpoint.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..79f9c88bbcaba136c544be1cb4b620b4ae55e17a
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-model-checkpoint.pbtxt
@@ -0,0 +1,42 @@
+path: "tensorflow.keras.callbacks.ModelCheckpoint"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'period\'], varargs=None, keywords=None, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'1\'], "
+  }
+  member_method {
+    name: "on_batch_begin"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_batch_end"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_begin"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_end"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_begin"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_end"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_model"
+    argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_params"
+    argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..0e6901f28affdfc73092c2b9f3af07d17db61a9f
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt
@@ -0,0 +1,42 @@
+path: "tensorflow.keras.callbacks.ProgbarLogger"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'count_mode\'], varargs=None, keywords=None, defaults=[\'samples\'], "
+  }
+  member_method {
+    name: "on_batch_begin"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_batch_end"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_begin"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_end"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_begin"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_end"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_model"
+    argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_params"
+    argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..5838d583125e462e1961beabfd80130a794f468d
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-reduce-l-r-on-plateau.pbtxt
@@ -0,0 +1,46 @@
+path: "tensorflow.keras.callbacks.ReduceLROnPlateau"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'monitor\', \'factor\', \'patience\', \'verbose\', \'mode\', \'epsilon\', \'cooldown\', \'min_lr\'], varargs=None, keywords=None, defaults=[\'val_loss\', \'0.1\', \'10\', \'0\', \'auto\', \'0.0001\', \'0\', \'0\'], "
+  }
+  member_method {
+    name: "in_cooldown"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "on_batch_begin"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_batch_end"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_begin"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_end"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_begin"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_end"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_model"
+    argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_params"
+    argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-remote-monitor.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-remote-monitor.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..3d0acfed1d8f5a2c811b442784992400ed958537
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-remote-monitor.pbtxt
@@ -0,0 +1,42 @@
+path: "tensorflow.keras.callbacks.RemoteMonitor"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'root\', \'path\', \'field\', \'headers\'], varargs=None, keywords=None, defaults=[\'http://localhost:9000\', \'/publish/epoch/end/\', \'data\', \'None\'], "
+  }
+  member_method {
+    name: "on_batch_begin"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_batch_end"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_begin"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_end"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_begin"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_end"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_model"
+    argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_params"
+    argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-tensor-board.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-tensor-board.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..6620a9d308f46cd87cedf482929e75bb5afdbaea
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-tensor-board.pbtxt
@@ -0,0 +1,42 @@
+path: "tensorflow.keras.callbacks.TensorBoard"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'batch_size\', \'write_graph\', \'write_grads\', \'write_images\'], varargs=None, keywords=None, defaults=[\'./logs\', \'0\', \'32\', \'True\', \'False\', \'False\'], "
+  }
+  member_method {
+    name: "on_batch_begin"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_batch_end"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_begin"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_end"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_begin"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_end"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_model"
+    argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_params"
+    argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-terminate-on-na-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-terminate-on-na-n.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..bf17e8736c50031c484f5c08bac65ee3566f7da3
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-terminate-on-na-n.pbtxt
@@ -0,0 +1,42 @@
+path: "tensorflow.keras.callbacks.TerminateOnNaN"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "on_batch_begin"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_batch_end"
+    argspec: "args=[\'self\', \'batch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_begin"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_epoch_end"
+    argspec: "args=[\'self\', \'epoch\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_begin"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "on_train_end"
+    argspec: "args=[\'self\', \'logs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_model"
+    argspec: "args=[\'self\', \'model\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_params"
+    argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..1e9085e034ccf22fda7be7565aabb86992a8b0b7
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.pbtxt
@@ -0,0 +1,55 @@
+path: "tensorflow.keras.callbacks"
+tf_module {
+  member {
+    name: "BaseLogger"
+    mtype: ""
+  }
+  member {
+    name: "CSVLogger"
+    mtype: ""
+  }
+  member {
+    name: "Callback"
+    mtype: ""
+  }
+  member {
+    name: "EarlyStopping"
+    mtype: ""
+  }
+  member {
+    name: "History"
+    mtype: ""
+  }
+  member {
+    name: "LambdaCallback"
+    mtype: ""
+  }
+  member {
+    name: "LearningRateScheduler"
+    mtype: ""
+  }
+  member {
+    name: "ModelCheckpoint"
+    mtype: ""
+  }
+  member {
+    name: "ProgbarLogger"
+    mtype: ""
+  }
+  member {
+    name: "ReduceLROnPlateau"
+    mtype: ""
+  }
+  member {
+    name: "RemoteMonitor"
+    mtype: ""
+  }
+  member {
+    name: "TensorBoard"
+    mtype: ""
+  }
+  member {
+    name: "TerminateOnNaN"
+    mtype: ""
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-constraint.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-constraint.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..14977c696fbe70a9d19f37581c926b6c0fdb3d11
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-constraint.pbtxt
@@ -0,0 +1,12 @@
+path: "tensorflow.keras.constraints.Constraint"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-max-norm.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-max-norm.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..a2269f8a18f5b55ffa88031e8ef3d1c39e0bd423
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-max-norm.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.keras.constraints.MaxNorm"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'max_value\', \'axis\'], varargs=None, keywords=None, defaults=[\'2\', \'0\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-min-max-norm.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-min-max-norm.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..afe0d6478dde929aa98556d52ceece03c28c8e5f
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-min-max-norm.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.keras.constraints.MinMaxNorm"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'min_value\', \'max_value\', \'rate\', \'axis\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'1.0\', \'0\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-non-neg.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-non-neg.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..e8c4bb90881ae389cd5215c21e44380b62cb7c9c
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-non-neg.pbtxt
@@ -0,0 +1,13 @@
+path: "tensorflow.keras.constraints.NonNeg"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.-unit-norm.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-unit-norm.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..d457cb6419ef86e83b5440554b2e97706440a734
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.-unit-norm.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.keras.constraints.UnitNorm"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.max_norm.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.max_norm.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..48128096d4638388c99cc62ecc23322a8d368124
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.max_norm.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.keras.constraints.max_norm"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'max_value\', \'axis\'], varargs=None, keywords=None, defaults=[\'2\', \'0\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.min_max_norm.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.min_max_norm.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..02eb3fb00c0ae516bac336066fc8ae5818e455d8
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.min_max_norm.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.keras.constraints.min_max_norm"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'min_value\', \'max_value\', \'rate\', \'axis\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'1.0\', \'0\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.non_neg.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.non_neg.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..cc1101097ce9c4888e4b239f8ae16a58cabf31db
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.non_neg.pbtxt
@@ -0,0 +1,13 @@
+path: "tensorflow.keras.constraints.non_neg"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..655685956f0e42e2d92dca0ac36f4cca075f474a
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.pbtxt
@@ -0,0 +1,51 @@
+path: "tensorflow.keras.constraints"
+tf_module {
+  member {
+    name: "Constraint"
+    mtype: ""
+  }
+  member {
+    name: "MaxNorm"
+    mtype: ""
+  }
+  member {
+    name: "MinMaxNorm"
+    mtype: ""
+  }
+  member {
+    name: "NonNeg"
+    mtype: ""
+  }
+  member {
+    name: "UnitNorm"
+    mtype: ""
+  }
+  member {
+    name: "max_norm"
+    mtype: ""
+  }
+  member {
+    name: "min_max_norm"
+    mtype: ""
+  }
+  member {
+    name: "non_neg"
+    mtype: ""
+  }
+  member {
+    name: "unit_norm"
+    mtype: ""
+  }
+  member_method {
+    name: "deserialize"
+    argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get"
+    argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "serialize"
+    argspec: "args=[\'constraint\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.constraints.unit_norm.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.constraints.unit_norm.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..086f9f2d43c3d340850f02df3e5bcb0cc5a5b8e5
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.constraints.unit_norm.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.keras.constraints.unit_norm"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..ef08f9b20f4c95f3692a03be7f4220f20aae9a58
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.boston_housing.pbtxt
@@ -0,0 +1,7 @@
+path: "tensorflow.keras.datasets.boston_housing"
+tf_module {
+  member_method {
+    name: "load_data"
+    argspec: "args=[\'path\', \'seed\', \'test_split\'], varargs=None, keywords=None, defaults=[\'boston_housing.npz\', \'113\', \'0.2\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.cifar10.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.cifar10.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..8a5142f793d67b3a923f3033c0da14442c4f680f
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.cifar10.pbtxt
@@ -0,0 +1,7 @@
+path: "tensorflow.keras.datasets.cifar10"
+tf_module {
+  member_method {
+    name: "load_data"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.cifar100.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.cifar100.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..16f184eeb5e8ee4f126b943c8988ec28ceab89a4
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.cifar100.pbtxt
@@ -0,0 +1,7 @@
+path: "tensorflow.keras.datasets.cifar100"
+tf_module {
+  member_method {
+    name: "load_data"
+    argspec: "args=[\'label_mode\'], varargs=None, keywords=None, defaults=[\'fine\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..8b1c17e9da13a76dcc2c09f3c01a0375bf0cb9fe
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.imdb.pbtxt
@@ -0,0 +1,11 @@
+path: "tensorflow.keras.datasets.imdb"
+tf_module {
+  member_method {
+    name: "get_word_index"
+    argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=[\'imdb_word_index.json\'], "
+  }
+  member_method {
+    name: "load_data"
+    argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=None, defaults=[\'imdb.npz\', \'None\', \'0\', \'None\', \'113\', \'1\', \'2\', \'3\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.mnist.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.mnist.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..530bb0755060f243281523c68b9c554dcbdbc634
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.mnist.pbtxt
@@ -0,0 +1,7 @@
+path: "tensorflow.keras.datasets.mnist"
+tf_module {
+  member_method {
+    name: "load_data"
+    argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=[\'mnist.npz\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..d4aa436f328487479b81f3bdd26062a339581c0e
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.pbtxt
@@ -0,0 +1,27 @@
+path: "tensorflow.keras.datasets"
+tf_module {
+  member {
+    name: "boston_housing"
+    mtype: ""
+  }
+  member {
+    name: "cifar10"
+    mtype: ""
+  }
+  member {
+    name: "cifar100"
+    mtype: ""
+  }
+  member {
+    name: "imdb"
+    mtype: ""
+  }
+  member {
+    name: "mnist"
+    mtype: ""
+  }
+  member {
+    name: "reuters"
+    mtype: ""
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..6b3ed1e9af0ea7ab4fa83c07c520adf6727a93ac
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.datasets.reuters.pbtxt
@@ -0,0 +1,11 @@
+path: "tensorflow.keras.datasets.reuters"
+tf_module {
+  member_method {
+    name: "get_word_index"
+    argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=[\'reuters_word_index.json\'], "
+  }
+  member_method {
+    name: "load_data"
+    argspec: "args=[\'path\', \'num_words\', \'skip_top\', \'maxlen\', \'test_split\', \'seed\', \'start_char\', \'oov_char\', \'index_from\'], varargs=None, keywords=None, defaults=[\'reuters.npz\', \'None\', \'0\', \'None\', \'0.2\', \'113\', \'1\', \'2\', \'3\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-constant.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-constant.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..cbaba78ed5a851c3d6e29ab67c89fdfd5db01754
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-constant.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.initializers.Constant"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'value\', \'dtype\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'0\', \"\", \'False\'], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-identity.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-identity.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..a5f7f348de9d9899d962e7647d7943ddb6a60604
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-identity.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.initializers.Identity"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'gain\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \"\"], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-initializer.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..8f10d1698e7b7b2afa9c2664c7dca38045eda85b
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-initializer.pbtxt
@@ -0,0 +1,16 @@
+path: "tensorflow.keras.initializers.Initializer"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-ones.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-ones.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..2fbfa774f8ed020164e32bb3cfb69b8a235609ba
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-ones.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.initializers.Ones"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-orthogonal.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-orthogonal.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..874d320d73d1f1cdbd817db587ea9dcfea4d352b
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-orthogonal.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.initializers.Orthogonal"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'gain\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \"\"], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-random-normal.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-random-normal.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..23cd02c0b069d3cb2d7b9e7ebc754db288e4637a
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-random-normal.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.initializers.RandomNormal"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-random-uniform.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-random-uniform.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..d98628f42253603178cdff2624f639afa846a66a
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-random-uniform.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.initializers.RandomUniform"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'minval\', \'maxval\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\', \"\"], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-truncated-normal.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-truncated-normal.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..86d48257c1ffb95fc217de475efba41002f8e7a5
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-truncated-normal.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.initializers.TruncatedNormal"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'mean\', \'stddev\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\', \"\"], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..32a6f6ee88815b3dc70e9cca855f73099554953b
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.initializers.VarianceScaling"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'normal\', \'None\', \"\"], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-zeros.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-zeros.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..b6ab68e5beb47c9bcfbc52f9808255bbb03d2dc0
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-zeros.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.initializers.Zeros"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..093c56595bd54eef4062d4ac9134e4bb3e7f7d98
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt
@@ -0,0 +1,79 @@
+path: "tensorflow.keras.initializers"
+tf_module {
+  member {
+    name: "Constant"
+    mtype: ""
+  }
+  member {
+    name: "Identity"
+    mtype: ""
+  }
+  member {
+    name: "Initializer"
+    mtype: ""
+  }
+  member {
+    name: "Ones"
+    mtype: ""
+  }
+  member {
+    name: "Orthogonal"
+    mtype: ""
+  }
+  member {
+    name: "RandomNormal"
+    mtype: ""
+  }
+  member {
+    name: "RandomUniform"
+    mtype: ""
+  }
+  member {
+    name: "TruncatedNormal"
+    mtype: ""
+  }
+  member {
+    name: "VarianceScaling"
+    mtype: ""
+  }
+  member {
+    name: "Zeros"
+    mtype: ""
+  }
+  member_method {
+    name: "deserialize"
+    argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get"
+    argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "glorot_normal"
+    argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "glorot_uniform"
+    argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "he_normal"
+    argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "he_uniform"
+    argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "lecun_normal"
+    argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "lecun_uniform"
+    argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "serialize"
+    argspec: "args=[\'initializer\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..52b65bb9163b254de52ce837bc92cf925dd1d704
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.Activation"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'activation\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..5ef00eada9fce3245436b770246214a2e15c5df7
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.ActivityRegularization"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'l1\', \'l2\'], varargs=None, keywords=kwargs, defaults=[\'0.0\', \'0.0\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..a75a51a41145a3abefb015f7cc3bca5fc9c7016e
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.Add"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..560295eb3ea100fb935d63812b865d3fac248e13
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.AlphaDropout"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'rate\', \'noise_shape\', \'seed\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..f05a216e95268d9848b0bdc0fd74b033926fac08
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.AveragePooling1D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..2a71a5a2e64d746f1d7906a7e46544909fa9f3cc
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.AveragePooling2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2)\', \'None\', \'valid\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..8756b96297a89f2341e359508fdb12260a1f3bbd
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.AveragePooling3D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2, 2)\', \'None\', \'valid\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..9a2940d29820f8d02d0735197281b09453ef3398
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.Average"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..62a53b8ab64a9db8d71733478aa1c3af4e17c781
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.AvgPool1D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..d4423110877f1903e3a4a9928a33adaaa6ed3033
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.AvgPool2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2)\', \'None\', \'valid\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..812118f34017a3f3a5a64464bcb0e682260d61a4
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.AvgPool3D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2, 2)\', \'None\', \'valid\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..3aa6a990b66682606d1998a59744a4989cfa5e78
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.BatchNormalization"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'axis\', \'momentum\', \'epsilon\', \'center\', \'scale\', \'beta_initializer\', \'gamma_initializer\', \'moving_mean_initializer\', \'moving_variance_initializer\', \'beta_regularizer\', \'gamma_regularizer\', \'beta_constraint\', \'gamma_constraint\'], varargs=None, keywords=kwargs, defaults=[\'-1\', \'0.99\', \'0.001\', \'True\', \'True\', \'zeros\', \'ones\', \'zeros\', \'ones\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..a0f64a8245d9c6b6db36dcd5ed2462f48447c976
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
@@ -0,0 +1,172 @@
+path: "tensorflow.keras.layers.Bidirectional"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "activity_regularizer"
+    mtype: ""
+  }
+  member {
+    name: "constraints"
+    mtype: ""
+  }
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'layer\', \'merge_mode\', \'weights\'], varargs=None, keywords=kwargs, defaults=[\'concat\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset_states"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..fe8fc4fd6df393a3f411900d34f5641997c96ba7
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.Concatenate"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'axis\'], varargs=None, keywords=kwargs, defaults=[\'-1\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..a482dec23f84d461d81fc543f060abd2861989c3
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
@@ -0,0 +1,189 @@
+path: "tensorflow.keras.layers.ConvLSTM2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'return_sequences\', \'go_backwards\', \'stateful\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'False\', \'0.0\', \'0.0\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'initial_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_constants"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_initial_state"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "input_conv"
+    argspec: "args=[\'self\', \'x\', \'w\', \'b\', \'padding\'], varargs=None, keywords=None, defaults=[\'None\', \'valid\'], "
+  }
+  member_method {
+    name: "preprocess_input"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "reccurent_conv"
+    argspec: "args=[\'self\', \'x\', \'w\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset_states"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "step"
+    argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..977a0035bfea8871c7df9fc1c4d9fd50f18f315e
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.Conv1D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..d63c5a23b4adf3c7d9455f3949e930b3352949e3
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
@@ -0,0 +1,162 @@
+path: "tensorflow.keras.layers.Conv2DTranspose"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..3cc9a2267f8721e6b9b88b9407d2d714bb8b2ab2
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.Conv2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..3653eb5b3b15a03ffe6b01549110c85b9a09bcb4
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.Conv3DTranspose"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..e5494449865595d568dcd868546a07ddf189fce0
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.Conv3D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'(1, 1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..a8984deb2ba8f6f6c1dfc6602eeffe824561c700
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.Convolution1D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..bd6114323506a07e8d57277569454ea526583bc6
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
@@ -0,0 +1,162 @@
+path: "tensorflow.keras.layers.Convolution2DTranspose"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..0a87c40e27fdccbec446179d7f370064113036e3
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.Convolution2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..005cec9748f6bc5e745b7fefe4901b22718f139d
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.Convolution3DTranspose"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..caf06b130d7496abacbbc3bd46d301d080e7dba9
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.Convolution3D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'(1, 1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..e3287554a6d2c120c4a5f5e0b052ed4f01ea505a
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.Cropping1D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'cropping\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..7aecf7fe3396993c3a11fe268092180991d4f9b5
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.Cropping2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'cropping\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'((0, 0), (0, 0))\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..a7bd30675b4879e3350f12155d565b472a99661e
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.Cropping3D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'cropping\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'((1, 1), (1, 1), (1, 1))\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..c502083af82b8bfab0ad22e7035e3abb298f4689
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.Dense"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..ebc21b016851a5a109a790a6124123e9d9f19e52
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.Dot"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'axes\', \'normalize\'], varargs=None, keywords=kwargs, defaults=[\'False\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..19a8a3cc03839b2ae722764e9d8b16908f4e71a8
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.Dropout"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'rate\', \'noise_shape\', \'seed\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..2c8f19068b011b50781548e5839f9aebdffb2eac
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.ELU"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'alpha\'], varargs=None, keywords=kwargs, defaults=[\'1.0\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..e5a9273009eb9873c24afbcb74d99b8bb4e4c575
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.Embedding"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'input_dim\', \'output_dim\', \'embeddings_initializer\', \'embeddings_regularizer\', \'activity_regularizer\', \'embeddings_constraint\', \'mask_zero\', \'input_length\'], varargs=None, keywords=kwargs, defaults=[\'uniform\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..0f1898bcfae673055d17fd19baddd5bd08b41679
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.Flatten"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..c8cd8faaac23fdeb58e5e8c7bde93830293567b6
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
@@ -0,0 +1,180 @@
+path: "tensorflow.keras.layers.GRU"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'initial_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_constants"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_initial_state"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "preprocess_input"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "reset_states"
+    argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "step"
+    argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..98c8b96719e7a3d8e09b25155580ff340efad9e4
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.GaussianDropout"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'rate\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..f961291110d193641466cd9c48065748bf77c71b
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.GaussianNoise"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'stddev\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..e120da36495ee3b6b6265bc80c17b4f4d8a88c83
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.GlobalAveragePooling1D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..89eb90efd91e7ded9a3a8dbc421f472fbaee21f8
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.GlobalAveragePooling2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..d6d35c45dff7300fed3c9a7479e986d958476fb3
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.GlobalAveragePooling3D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..3d28cb068edb34a6081ea230e09f039a234fbd31
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.GlobalAvgPool1D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..2bc4297b8315e39f957b1467c37d83c421f2ba75
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.GlobalAvgPool2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..83de1acdcf2a703b88bfd19a42d535d082a25631
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.GlobalAvgPool3D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..58dee9406c038aee705fd2475f014df76a079f9a
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.GlobalMaxPool1D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..6490cd4b59c0ee34c35dda0798a764d270da1e30
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.GlobalMaxPool2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..15e1a609f3d738db9521553203b6b744528c7619
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.GlobalMaxPool3D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..4a795aa66332bd290a2db14ce451dbe2fbd71968
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.GlobalMaxPooling1D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..dab26b5627b769aa1edef6bf204ad261b6bf314c
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.GlobalMaxPooling2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..cbe05ed7a47e999292d6b9d80ff5b21c591bce4c
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.GlobalMaxPooling3D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..b3f81cc4595bc1d4af34c8d97dca9b28ddc5a52a
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.InputLayer"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-spec.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-spec.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..3aeef347ae1f96a3ef40493cc6b722a887e81786
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-spec.pbtxt
@@ -0,0 +1,9 @@
+path: "tensorflow.keras.layers.InputSpec"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..36a7e4a176c4471463fdbdc051bd3882762a7755
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
@@ -0,0 +1,180 @@
+path: "tensorflow.keras.layers.LSTM"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'initial_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_constants"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_initial_state"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "preprocess_input"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "reset_states"
+    argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "step"
+    argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..1d62867eb4bdfd81efd9c85d44b2963219c00eeb
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.Lambda"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'function\', \'mask\', \'arguments\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..7326d87cda9e7d5dad31a04e924265f33651fd08
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt
@@ -0,0 +1,158 @@
+path: "tensorflow.keras.layers.Layer"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..6a0c72ecdf057ba115cc996dcceae120a72f2e47
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.LeakyReLU"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'alpha\'], varargs=None, keywords=kwargs, defaults=[\'0.3\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..a8338314b8297a5904101c5283dbb52bb7297c42
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.LocallyConnected1D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..a74f1a7c2ae3e795c207325fbf6dc2116cd02ec8
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.LocallyConnected2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..8c5d9b0fc994fab9edcc26c961f6fbdea9f08e7c
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.Masking"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'mask_value\'], varargs=None, keywords=kwargs, defaults=[\'0.0\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..0d1998dff60aaeff5b9e1e6c572c98c953345dac
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.MaxPool1D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..4858920ea72778db8d813a6bd0591b5fa89bb21a
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.MaxPool2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2)\', \'None\', \'valid\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..57df6727cf3bd7869f1d0a83f7595ee683b4706f
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.MaxPool3D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2, 2)\', \'None\', \'valid\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..5ddc879399a41dbd5ceab1f20139c5e4a72f38b1
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.MaxPooling1D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..b8186c15f39cf62c684207a8244001d5416dd639
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.MaxPooling2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2)\', \'None\', \'valid\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..16fe3372f775a0091d30888fae4e7fd251fd758c
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.MaxPooling3D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2, 2)\', \'None\', \'valid\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..baeb3d835385d28f60d43c4a75eaf923efc141d1
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.Maximum"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..5c1d511cf72fd1a64eba5a875de405e8c1fc939a
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
@@ -0,0 +1,160 @@
+path: "tensorflow.keras.layers.Multiply"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..a8f938cc6e56436aae6fb6bc77ec7c3b929bb282
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.PReLU"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'alpha_initializer\', \'alpha_regularizer\', \'alpha_constraint\', \'shared_axes\'], varargs=None, keywords=kwargs, defaults=[\'zeros\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..eac826b965fbef7c91a0b76cb79f5f5328c01edb
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.Permute"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'dims\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..dfae244356fe74005669e50e879940ef3516db96
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.RepeatVector"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'n\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..5c8192b22608c654940d763e0152136a74590faf
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.Reshape"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'target_shape\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..3da1d8406009d0c01a86a212e70f8045cea0d39b
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
@@ -0,0 +1,162 @@
+path: "tensorflow.keras.layers.SeparableConv2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..4b593c19c7af848efa236d608c14475e04371313
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
@@ -0,0 +1,162 @@
+path: "tensorflow.keras.layers.SeparableConvolution2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'depth_multiplier\', \'activation\', \'use_bias\', \'depthwise_initializer\', \'pointwise_initializer\', \'bias_initializer\', \'depthwise_regularizer\', \'pointwise_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'depthwise_constraint\', \'pointwise_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..862032223097e81ef5b308fcc86e2c9ff9df296d
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
@@ -0,0 +1,180 @@
+path: "tensorflow.keras.layers.SimpleRNN"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'initial_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_constants"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_initial_state"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "preprocess_input"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "reset_states"
+    argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "step"
+    argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..156943a201a14fa38427c3783b61a8769af51c17
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.SpatialDropout1D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'rate\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..5368b5468abcb0525698fdb90ee205c1bb27fdd8
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.SpatialDropout2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'rate\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..568b5ad66eca4d9392589af1ede938c7331275cb
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
@@ -0,0 +1,161 @@
+path: "tensorflow.keras.layers.SpatialDropout3D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'rate\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..445f2df59d78dd7ffb1341ef66f1db559e3ba844
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.ThresholdedReLU"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'theta\'], varargs=None, keywords=kwargs, defaults=[\'1.0\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..b6ebf02b2a4dc1f8a65e6760bb28dcc2f3cea426
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt
@@ -0,0 +1,168 @@
+path: "tensorflow.keras.layers.TimeDistributed"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "activity_regularizer"
+    mtype: ""
+  }
+  member {
+    name: "constraints"
+    mtype: ""
+  }
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'layer\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..868805a563ba472c2e18d38cf142eb5ff1c5defc
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.UpSampling1D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'size\'], varargs=None, keywords=kwargs, defaults=[\'2\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..caa85afa151fb80941156af0a2a5a4d91e294d4a
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.UpSampling2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'size\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2)\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..d3362faefa30dccccaf44b74d177588d387370f3
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.UpSampling3D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'size\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(2, 2, 2)\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..ede827f4ecbce6ce019a73eec1cf3684af91486d
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt
@@ -0,0 +1,167 @@
+path: "tensorflow.keras.layers.Wrapper"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "activity_regularizer"
+    mtype: ""
+  }
+  member {
+    name: "constraints"
+    mtype: ""
+  }
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'layer\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..3472bb45140f6cd0dbacaefa0298bad2b0a024af
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.ZeroPadding1D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'padding\'], varargs=None, keywords=kwargs, defaults=[\'1\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..5af56bd135c6ad2aad1e6d49dd387510ed51aaea
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.ZeroPadding2D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..1caf07fedcb286779521adc436459bcb4925ac8b
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt
@@ -0,0 +1,159 @@
+path: "tensorflow.keras.layers.ZeroPadding3D"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..8466c3e0390255c74be92900b40a738b5c4eb0dc
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
@@ -0,0 +1,371 @@
+path: "tensorflow.keras.layers"
+tf_module {
+  member {
+    name: "Activation"
+    mtype: ""
+  }
+  member {
+    name: "ActivityRegularization"
+    mtype: ""
+  }
+  member {
+    name: "Add"
+    mtype: ""
+  }
+  member {
+    name: "AlphaDropout"
+    mtype: ""
+  }
+  member {
+    name: "Average"
+    mtype: ""
+  }
+  member {
+    name: "AveragePooling1D"
+    mtype: ""
+  }
+  member {
+    name: "AveragePooling2D"
+    mtype: ""
+  }
+  member {
+    name: "AveragePooling3D"
+    mtype: ""
+  }
+  member {
+    name: "AvgPool1D"
+    mtype: ""
+  }
+  member {
+    name: "AvgPool2D"
+    mtype: ""
+  }
+  member {
+    name: "AvgPool3D"
+    mtype: ""
+  }
+  member {
+    name: "BatchNormalization"
+    mtype: ""
+  }
+  member {
+    name: "Bidirectional"
+    mtype: ""
+  }
+  member {
+    name: "Concatenate"
+    mtype: ""
+  }
+  member {
+    name: "Conv1D"
+    mtype: ""
+  }
+  member {
+    name: "Conv2D"
+    mtype: ""
+  }
+  member {
+    name: "Conv2DTranspose"
+    mtype: ""
+  }
+  member {
+    name: "Conv3D"
+    mtype: ""
+  }
+  member {
+    name: "Conv3DTranspose"
+    mtype: ""
+  }
+  member {
+    name: "ConvLSTM2D"
+    mtype: ""
+  }
+  member {
+    name: "Convolution1D"
+    mtype: ""
+  }
+  member {
+    name: "Convolution2D"
+    mtype: ""
+  }
+  member {
+    name: "Convolution2DTranspose"
+    mtype: ""
+  }
+  member {
+    name: "Convolution3D"
+    mtype: ""
+  }
+  member {
+    name: "Convolution3DTranspose"
+    mtype: ""
+  }
+  member {
+    name: "Cropping1D"
+    mtype: ""
+  }
+  member {
+    name: "Cropping2D"
+    mtype: ""
+  }
+  member {
+    name: "Cropping3D"
+    mtype: ""
+  }
+  member {
+    name: "Dense"
+    mtype: ""
+  }
+  member {
+    name: "Dot"
+    mtype: ""
+  }
+  member {
+    name: "Dropout"
+    mtype: ""
+  }
+  member {
+    name: "ELU"
+    mtype: ""
+  }
+  member {
+    name: "Embedding"
+    mtype: ""
+  }
+  member {
+    name: "Flatten"
+    mtype: ""
+  }
+  member {
+    name: "GRU"
+    mtype: ""
+  }
+  member {
+    name: "GaussianDropout"
+    mtype: ""
+  }
+  member {
+    name: "GaussianNoise"
+    mtype: ""
+  }
+  member {
+    name: "GlobalAveragePooling1D"
+    mtype: ""
+  }
+  member {
+    name: "GlobalAveragePooling2D"
+    mtype: ""
+  }
+  member {
+    name: "GlobalAveragePooling3D"
+    mtype: ""
+  }
+  member {
+    name: "GlobalAvgPool1D"
+    mtype: ""
+  }
+  member {
+    name: "GlobalAvgPool2D"
+    mtype: ""
+  }
+  member {
+    name: "GlobalAvgPool3D"
+    mtype: ""
+  }
+  member {
+    name: "GlobalMaxPool1D"
+    mtype: ""
+  }
+  member {
+    name: "GlobalMaxPool2D"
+    mtype: ""
+  }
+  member {
+    name: "GlobalMaxPool3D"
+    mtype: ""
+  }
+  member {
+    name: "GlobalMaxPooling1D"
+    mtype: ""
+  }
+  member {
+    name: "GlobalMaxPooling2D"
+    mtype: ""
+  }
+  member {
+    name: "GlobalMaxPooling3D"
+    mtype: ""
+  }
+  member {
+    name: "InputLayer"
+    mtype: ""
+  }
+  member {
+    name: "InputSpec"
+    mtype: ""
+  }
+  member {
+    name: "LSTM"
+    mtype: ""
+  }
+  member {
+    name: "Lambda"
+    mtype: ""
+  }
+  member {
+    name: "Layer"
+    mtype: ""
+  }
+  member {
+    name: "LeakyReLU"
+    mtype: ""
+  }
+  member {
+    name: "LocallyConnected1D"
+    mtype: ""
+  }
+  member {
+    name: "LocallyConnected2D"
+    mtype: ""
+  }
+  member {
+    name: "Masking"
+    mtype: ""
+  }
+  member {
+    name: "MaxPool1D"
+    mtype: ""
+  }
+  member {
+    name: "MaxPool2D"
+    mtype: ""
+  }
+  member {
+    name: "MaxPool3D"
+    mtype: ""
+  }
+  member {
+    name: "MaxPooling1D"
+    mtype: ""
+  }
+  member {
+    name: "MaxPooling2D"
+    mtype: ""
+  }
+  member {
+    name: "MaxPooling3D"
+    mtype: ""
+  }
+  member {
+    name: "Maximum"
+    mtype: ""
+  }
+  member {
+    name: "Multiply"
+    mtype: ""
+  }
+  member {
+    name: "PReLU"
+    mtype: ""
+  }
+  member {
+    name: "Permute"
+    mtype: ""
+  }
+  member {
+    name: "RepeatVector"
+    mtype: ""
+  }
+  member {
+    name: "Reshape"
+    mtype: ""
+  }
+  member {
+    name: "SeparableConv2D"
+    mtype: ""
+  }
+  member {
+    name: "SeparableConvolution2D"
+    mtype: ""
+  }
+  member {
+    name: "SimpleRNN"
+    mtype: ""
+  }
+  member {
+    name: "SpatialDropout1D"
+    mtype: ""
+  }
+  member {
+    name: "SpatialDropout2D"
+    mtype: ""
+  }
+  member {
+    name: "SpatialDropout3D"
+    mtype: ""
+  }
+  member {
+    name: "ThresholdedReLU"
+    mtype: ""
+  }
+  member {
+    name: "TimeDistributed"
+    mtype: ""
+  }
+  member {
+    name: "UpSampling1D"
+    mtype: ""
+  }
+  member {
+    name: "UpSampling2D"
+    mtype: ""
+  }
+  member {
+    name: "UpSampling3D"
+    mtype: ""
+  }
+  member {
+    name: "Wrapper"
+    mtype: ""
+  }
+  member {
+    name: "ZeroPadding1D"
+    mtype: ""
+  }
+  member {
+    name: "ZeroPadding2D"
+    mtype: ""
+  }
+  member {
+    name: "ZeroPadding3D"
+    mtype: ""
+  }
+  member_method {
+    name: "Input"
+    argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
+  }
+  member_method {
+    name: "add"
+    argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "average"
+    argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "concatenate"
+    argspec: "args=[\'inputs\', \'axis\'], varargs=None, keywords=kwargs, defaults=[\'-1\'], "
+  }
+  member_method {
+    name: "dot"
+    argspec: "args=[\'inputs\', \'axes\', \'normalize\'], varargs=None, keywords=kwargs, defaults=[\'False\'], "
+  }
+  member_method {
+    name: "maximum"
+    argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "multiply"
+    argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.losses.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.losses.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..ae5f6305b7d1bb85c1c6acd8daf5628d83814b27
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.losses.pbtxt
@@ -0,0 +1,71 @@
+path: "tensorflow.keras.losses"
+tf_module {
+  member_method {
+    name: "binary_crossentropy"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "categorical_crossentropy"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "categorical_hinge"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "cosine_proximity"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "deserialize"
+    argspec: "args=[\'name\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get"
+    argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "hinge"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "kullback_leibler_divergence"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "logcosh"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "mean_absolute_error"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "mean_absolute_percentage_error"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "mean_squared_error"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "mean_squared_logarithmic_error"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "poisson"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "serialize"
+    argspec: "args=[\'loss\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "sparse_categorical_crossentropy"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "squared_hinge"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..de285c1aab197ea5cae9c94048a5131f8463ebde
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt
@@ -0,0 +1,79 @@
+path: "tensorflow.keras.metrics"
+tf_module {
+  member_method {
+    name: "binary_accuracy"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "binary_crossentropy"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "categorical_accuracy"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "categorical_crossentropy"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "cosine_proximity"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "deserialize"
+    argspec: "args=[\'name\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get"
+    argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "hinge"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "kullback_leibler_divergence"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "mean_absolute_error"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "mean_absolute_percentage_error"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "mean_squared_error"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "mean_squared_logarithmic_error"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "poisson"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "serialize"
+    argspec: "args=[\'metric\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "sparse_categorical_crossentropy"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "sparse_top_k_categorical_accuracy"
+    argspec: "args=[\'y_true\', \'y_pred\', \'k\'], varargs=None, keywords=None, defaults=[\'5\'], "
+  }
+  member_method {
+    name: "squared_hinge"
+    argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "top_k_categorical_accuracy"
+    argspec: "args=[\'y_true\', \'y_pred\', \'k\'], varargs=None, keywords=None, defaults=[\'5\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..ade551d02a9ea58741ace58c55cf157bbc56453c
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
@@ -0,0 +1,249 @@
+path: "tensorflow.keras.models.Model"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "input_spec"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "state_updates"
+    mtype: ""
+  }
+  member {
+    name: "stateful"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "uses_learning_phase"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'inputs\', \'outputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compile"
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "evaluate"
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "evaluate_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'10\', \'1\', \'False\'], "
+  }
+  member_method {
+    name: "fit"
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "fit_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_layer"
+    argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "load_weights"
+    argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
+  }
+  member_method {
+    name: "predict"
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+  }
+  member_method {
+    name: "predict_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=kwargs, defaults=[\'10\', \'1\', \'False\', \'0\'], "
+  }
+  member_method {
+    name: "predict_on_batch"
+    argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset_states"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "save"
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
+  }
+  member_method {
+    name: "save_weights"
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\'], varargs=None, keywords=None, defaults=[\'True\'], "
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "summary"
+    argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "test_on_batch"
+    argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "to_json"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "to_yaml"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "train_on_batch"
+    argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..cadd74eb5ff034d1c465c2b535b0232aa636443c
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
@@ -0,0 +1,274 @@
+path: "tensorflow.keras.models.Sequential"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "graph"
+    mtype: ""
+  }
+  member {
+    name: "input"
+    mtype: ""
+  }
+  member {
+    name: "input_mask"
+    mtype: ""
+  }
+  member {
+    name: "input_shape"
+    mtype: ""
+  }
+  member {
+    name: "input_spec"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "output"
+    mtype: ""
+  }
+  member {
+    name: "output_mask"
+    mtype: ""
+  }
+  member {
+    name: "output_shape"
+    mtype: ""
+  }
+  member {
+    name: "regularizers"
+    mtype: ""
+  }
+  member {
+    name: "scope_name"
+    mtype: ""
+  }
+  member {
+    name: "state_updates"
+    mtype: ""
+  }
+  member {
+    name: "stateful"
+    mtype: ""
+  }
+  member {
+    name: "trainable"
+    mtype: ""
+  }
+  member {
+    name: "trainable_variables"
+    mtype: ""
+  }
+  member {
+    name: "trainable_weights"
+    mtype: ""
+  }
+  member {
+    name: "updates"
+    mtype: ""
+  }
+  member {
+    name: "uses_learning_phase"
+    mtype: ""
+  }
+  member {
+    name: "variables"
+    mtype: ""
+  }
+  member {
+    name: "weights"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'layers\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "add"
+    argspec: "args=[\'self\', \'layer\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compile"
+    argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'sample_weight_mode\', \'weighted_metrics\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "evaluate"
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'32\', \'1\', \'None\'], "
+  }
+  member_method {
+    name: "evaluate_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'10\', \'1\', \'False\'], "
+  }
+  member_method {
+    name: "fit"
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'32\', \'10\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\'], "
+  }
+  member_method {
+    name: "fit_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'initial_epoch\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'0\'], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_layer"
+    argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "load_weights"
+    argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
+  }
+  member_method {
+    name: "pop"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "predict"
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], "
+  }
+  member_method {
+    name: "predict_classes"
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'1\'], "
+  }
+  member_method {
+    name: "predict_generator"
+    argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=kwargs, defaults=[\'10\', \'1\', \'False\', \'0\'], "
+  }
+  member_method {
+    name: "predict_on_batch"
+    argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "predict_proba"
+    argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'1\'], "
+  }
+  member_method {
+    name: "reset_states"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "save"
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
+  }
+  member_method {
+    name: "save_weights"
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\'], varargs=None, keywords=None, defaults=[\'True\'], "
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "summary"
+    argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "test_on_batch"
+    argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "to_json"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "to_yaml"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "train_on_batch"
+    argspec: "args=[\'self\', \'x\', \'y\', \'class_weight\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..8ba0e7480bf5100e4bb10ceaf220cfaac0f43f52
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.pbtxt
@@ -0,0 +1,31 @@
+path: "tensorflow.keras.models"
+tf_module {
+  member {
+    name: "Model"
+    mtype: ""
+  }
+  member {
+    name: "Sequential"
+    mtype: ""
+  }
+  member_method {
+    name: "load_model"
+    argspec: "args=[\'filepath\', \'custom_objects\', \'compile\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
+  }
+  member_method {
+    name: "model_from_config"
+    argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "model_from_json"
+    argspec: "args=[\'json_string\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "model_from_yaml"
+    argspec: "args=[\'yaml_string\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "save_model"
+    argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..ed040c15864b4f4c386d2d9e1f664d35d651fa14
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adadelta.pbtxt
@@ -0,0 +1,34 @@
+path: "tensorflow.keras.optimizers.Adadelta"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'1.0\', \'0.95\', \'1e-08\', \'0.0\'], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_gradients"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..a24651429a3db49a96b217259c5c6ef09efed2f2
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adagrad.pbtxt
@@ -0,0 +1,34 @@
+path: "tensorflow.keras.optimizers.Adagrad"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'lr\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.01\', \'1e-08\', \'0.0\'], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_gradients"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..a0d978fded3825bafcd8d60e34677029495b1245
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adam.pbtxt
@@ -0,0 +1,34 @@
+path: "tensorflow.keras.optimizers.Adam"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-08\', \'0.0\'], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_gradients"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..1b70c93ad5f0a8fd52d65fb4b8132a87878c26dd
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-adamax.pbtxt
@@ -0,0 +1,34 @@
+path: "tensorflow.keras.optimizers.Adamax"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'1e-08\', \'0.0\'], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_gradients"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..b49dbe5cf82ea838076134a0feecc120bfb88f84
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-nadam.pbtxt
@@ -0,0 +1,34 @@
+path: "tensorflow.keras.optimizers.Nadam"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'lr\', \'beta_1\', \'beta_2\', \'epsilon\', \'schedule_decay\'], varargs=None, keywords=kwargs, defaults=[\'0.002\', \'0.9\', \'0.999\', \'1e-08\', \'0.004\'], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_gradients"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-optimizer.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..ca47e952282e0c1a9ee85d8912e479a0ed5b4e86
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-optimizer.pbtxt
@@ -0,0 +1,33 @@
+path: "tensorflow.keras.optimizers.Optimizer"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_gradients"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..c8860d80d40353211df65f08fda5deb26af91d66
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-r-m-sprop.pbtxt
@@ -0,0 +1,34 @@
+path: "tensorflow.keras.optimizers.RMSprop"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'lr\', \'rho\', \'epsilon\', \'decay\'], varargs=None, keywords=kwargs, defaults=[\'0.001\', \'0.9\', \'1e-08\', \'0.0\'], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_gradients"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-s-g-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-s-g-d.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..25adfd3f0bc89d9dbd3b2b8068e7b4ff99170909
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.-s-g-d.pbtxt
@@ -0,0 +1,34 @@
+path: "tensorflow.keras.optimizers.SGD"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'lr\', \'momentum\', \'decay\', \'nesterov\'], varargs=None, keywords=kwargs, defaults=[\'0.01\', \'0.0\', \'0.0\', \'False\'], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_gradients"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates"
+    argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.optimizers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..7257b02087e237eaa47ed6a042559aa1332fc87b
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.optimizers.pbtxt
@@ -0,0 +1,47 @@
+path: "tensorflow.keras.optimizers"
+tf_module {
+  member {
+    name: "Adadelta"
+    mtype: ""
+  }
+  member {
+    name: "Adagrad"
+    mtype: ""
+  }
+  member {
+    name: "Adam"
+    mtype: ""
+  }
+  member {
+    name: "Adamax"
+    mtype: ""
+  }
+  member {
+    name: "Nadam"
+    mtype: ""
+  }
+  member {
+    name: "Optimizer"
+    mtype: ""
+  }
+  member {
+    name: "RMSprop"
+    mtype: ""
+  }
+  member {
+    name: "SGD"
+    mtype: ""
+  }
+  member_method {
+    name: "deserialize"
+    argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get"
+    argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "serialize"
+    argspec: "args=[\'optimizer\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..b198bde7afe2faad0d449376f7455e7491588db4
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.pbtxt
@@ -0,0 +1,71 @@
+path: "tensorflow.keras"
+tf_module {
+  member {
+    name: "activations"
+    mtype: ""
+  }
+  member {
+    name: "applications"
+    mtype: ""
+  }
+  member {
+    name: "backend"
+    mtype: ""
+  }
+  member {
+    name: "callbacks"
+    mtype: ""
+  }
+  member {
+    name: "constraints"
+    mtype: ""
+  }
+  member {
+    name: "datasets"
+    mtype: ""
+  }
+  member {
+    name: "initializers"
+    mtype: ""
+  }
+  member {
+    name: "layers"
+    mtype: ""
+  }
+  member {
+    name: "losses"
+    mtype: ""
+  }
+  member {
+    name: "metrics"
+    mtype: ""
+  }
+  member {
+    name: "models"
+    mtype: ""
+  }
+  member {
+    name: "optimizers"
+    mtype: ""
+  }
+  member {
+    name: "preprocessing"
+    mtype: ""
+  }
+  member {
+    name: "regularizers"
+    mtype: ""
+  }
+  member {
+    name: "utils"
+    mtype: ""
+  }
+  member {
+    name: "wrappers"
+    mtype: ""
+  }
+  member_method {
+    name: "Input"
+    argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..8ad1f32551dda913cd98ce544d27af63310a6450
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.preprocessing.image.DirectoryIterator"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'directory\', \'image_data_generator\', \'target_size\', \'color_mode\', \'classes\', \'class_mode\', \'batch_size\', \'shuffle\', \'seed\', \'data_format\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'follow_links\'], varargs=None, keywords=None, defaults=[\'(256, 256)\', \'rgb\', \'None\', \'categorical\', \'32\', \'True\', \'None\', \'None\', \'None\', \'\', \'png\', \'False\'], "
+  }
+  member_method {
+    name: "next"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..7e33285e7abbc10df7f697e10071e429c5183d9e
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt
@@ -0,0 +1,29 @@
+path: "tensorflow.keras.preprocessing.image.ImageDataGenerator"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'featurewise_center\', \'samplewise_center\', \'featurewise_std_normalization\', \'samplewise_std_normalization\', \'zca_whitening\', \'zca_epsilon\', \'rotation_range\', \'width_shift_range\', \'height_shift_range\', \'shear_range\', \'zoom_range\', \'channel_shift_range\', \'fill_mode\', \'cval\', \'horizontal_flip\', \'vertical_flip\', \'rescale\', \'preprocessing_function\', \'data_format\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'1e-06\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'nearest\', \'0.0\', \'False\', \'False\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "fit"
+    argspec: "args=[\'self\', \'x\', \'augment\', \'rounds\', \'seed\'], varargs=None, keywords=None, defaults=[\'False\', \'1\', \'None\'], "
+  }
+  member_method {
+    name: "flow"
+    argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'shuffle\', \'seed\', \'save_to_dir\', \'save_prefix\', \'save_format\'], varargs=None, keywords=None, defaults=[\'None\', \'32\', \'True\', \'None\', \'None\', \'\', \'png\'], "
+  }
+  member_method {
+    name: "flow_from_directory"
+    argspec: "args=[\'self\', \'directory\', \'target_size\', \'color_mode\', \'classes\', \'class_mode\', \'batch_size\', \'shuffle\', \'seed\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'follow_links\'], varargs=None, keywords=None, defaults=[\'(256, 256)\', \'rgb\', \'None\', \'categorical\', \'32\', \'True\', \'None\', \'None\', \'\', \'png\', \'False\'], "
+  }
+  member_method {
+    name: "random_transform"
+    argspec: "args=[\'self\', \'x\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "standardize"
+    argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..d30462a8eb6dfe963ab32a41a5faabcd2b743b74
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-iterator.pbtxt
@@ -0,0 +1,13 @@
+path: "tensorflow.keras.preprocessing.image.Iterator"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'n\', \'batch_size\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..841f1c5585e4d8dffb782ddd989b0ba313dc2caa
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.preprocessing.image.NumpyArrayIterator"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'x\', \'y\', \'image_data_generator\', \'batch_size\', \'shuffle\', \'seed\', \'data_format\', \'save_to_dir\', \'save_prefix\', \'save_format\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'None\', \'None\', \'\', \'png\'], "
+  }
+  member_method {
+    name: "next"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..5652687033559a53235056e35906140dab2d0079
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt
@@ -0,0 +1,59 @@
+path: "tensorflow.keras.preprocessing.image"
+tf_module {
+  member {
+    name: "DirectoryIterator"
+    mtype: ""
+  }
+  member {
+    name: "ImageDataGenerator"
+    mtype: ""
+  }
+  member {
+    name: "Iterator"
+    mtype: ""
+  }
+  member {
+    name: "NumpyArrayIterator"
+    mtype: ""
+  }
+  member_method {
+    name: "apply_transform"
+    argspec: "args=[\'x\', \'transform_matrix\', \'channel_axis\', \'fill_mode\', \'cval\'], varargs=None, keywords=None, defaults=[\'0\', \'nearest\', \'0.0\'], "
+  }
+  member_method {
+    name: "array_to_img"
+    argspec: "args=[\'x\', \'data_format\', \'scale\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
+  }
+  member_method {
+    name: "flip_axis"
+    argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "img_to_array"
+    argspec: "args=[\'img\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "load_img"
+    argspec: "args=[\'path\', \'grayscale\', \'target_size\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+  }
+  member_method {
+    name: "random_channel_shift"
+    argspec: "args=[\'x\', \'intensity\', \'channel_axis\'], varargs=None, keywords=None, defaults=[\'0\'], "
+  }
+  member_method {
+    name: "random_rotation"
+    argspec: "args=[\'x\', \'rg\', \'row_axis\', \'col_axis\', \'channel_axis\', \'fill_mode\', \'cval\'], varargs=None, keywords=None, defaults=[\'1\', \'2\', \'0\', \'nearest\', \'0.0\'], "
+  }
+  member_method {
+    name: "random_shear"
+    argspec: "args=[\'x\', \'intensity\', \'row_axis\', \'col_axis\', \'channel_axis\', \'fill_mode\', \'cval\'], varargs=None, keywords=None, defaults=[\'1\', \'2\', \'0\', \'nearest\', \'0.0\'], "
+  }
+  member_method {
+    name: "random_shift"
+    argspec: "args=[\'x\', \'wrg\', \'hrg\', \'row_axis\', \'col_axis\', \'channel_axis\', \'fill_mode\', \'cval\'], varargs=None, keywords=None, defaults=[\'1\', \'2\', \'0\', \'nearest\', \'0.0\'], "
+  }
+  member_method {
+    name: "random_zoom"
+    argspec: "args=[\'x\', \'zoom_range\', \'row_axis\', \'col_axis\', \'channel_axis\', \'fill_mode\', \'cval\'], varargs=None, keywords=None, defaults=[\'1\', \'2\', \'0\', \'nearest\', \'0.0\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..5a78581fc56ba547ee56560367884c571f18279e
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.pbtxt
@@ -0,0 +1,15 @@
+path: "tensorflow.keras.preprocessing"
+tf_module {
+  member {
+    name: "image"
+    mtype: ""
+  }
+  member {
+    name: "sequence"
+    mtype: ""
+  }
+  member {
+    name: "text"
+    mtype: ""
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..1b01935cc53b450c3e7009f945f86c8e1c10bf8e
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.pbtxt
@@ -0,0 +1,15 @@
+path: "tensorflow.keras.preprocessing.sequence"
+tf_module {
+  member_method {
+    name: "make_sampling_table"
+    argspec: "args=[\'size\', \'sampling_factor\'], varargs=None, keywords=None, defaults=[\'1e-05\'], "
+  }
+  member_method {
+    name: "pad_sequences"
+    argspec: "args=[\'sequences\', \'maxlen\', \'dtype\', \'padding\', \'truncating\', \'value\'], varargs=None, keywords=None, defaults=[\'None\', \'int32\', \'pre\', \'pre\', \'0.0\'], "
+  }
+  member_method {
+    name: "skipgrams"
+    argspec: "args=[\'sequence\', \'vocabulary_size\', \'window_size\', \'negative_samples\', \'shuffle\', \'categorical\', \'sampling_table\', \'seed\'], varargs=None, keywords=None, defaults=[\'4\', \'1.0\', \'True\', \'False\', \'None\', \'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..5bc8c4012049b0414936fb56a853fc32430df3d9
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.-tokenizer.pbtxt
@@ -0,0 +1,33 @@
+path: "tensorflow.keras.preprocessing.text.Tokenizer"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'num_words\', \'filters\', \'lower\', \'split\', \'char_level\'], varargs=None, keywords=None, defaults=[\'None\', \'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \', \'False\'], "
+  }
+  member_method {
+    name: "fit_on_sequences"
+    argspec: "args=[\'self\', \'sequences\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "fit_on_texts"
+    argspec: "args=[\'self\', \'texts\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "sequences_to_matrix"
+    argspec: "args=[\'self\', \'sequences\', \'mode\'], varargs=None, keywords=None, defaults=[\'binary\'], "
+  }
+  member_method {
+    name: "texts_to_matrix"
+    argspec: "args=[\'self\', \'texts\', \'mode\'], varargs=None, keywords=None, defaults=[\'binary\'], "
+  }
+  member_method {
+    name: "texts_to_sequences"
+    argspec: "args=[\'self\', \'texts\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "texts_to_sequences_generator"
+    argspec: "args=[\'self\', \'texts\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..d106429df0273929472aa58909f554bcffde9bca
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.pbtxt
@@ -0,0 +1,15 @@
+path: "tensorflow.keras.preprocessing.text"
+tf_module {
+  member {
+    name: "Tokenizer"
+    mtype: ""
+  }
+  member_method {
+    name: "one_hot"
+    argspec: "args=[\'text\', \'n\', \'filters\', \'lower\', \'split\'], varargs=None, keywords=None, defaults=[\'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \'], "
+  }
+  member_method {
+    name: "text_to_word_sequence"
+    argspec: "args=[\'text\', \'filters\', \'lower\', \'split\'], varargs=None, keywords=None, defaults=[\'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-l1-l2.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-l1-l2.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..04dcda38609c7114bdf6e2784938905fc3ef8af3
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-l1-l2.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.regularizers.L1L2"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'l1\', \'l2\'], varargs=None, keywords=None, defaults=[\'0.0\', \'0.0\'], "
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-regularizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-regularizer.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..b0a125f238e58fb8b1213f52fc1fb85781ca5807
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.regularizers.-regularizer.pbtxt
@@ -0,0 +1,12 @@
+path: "tensorflow.keras.regularizers.Regularizer"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.regularizers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.regularizers.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..bb10d41d704ca456fbf5b8bd19324ee71f17ba8d
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.regularizers.pbtxt
@@ -0,0 +1,35 @@
+path: "tensorflow.keras.regularizers"
+tf_module {
+  member {
+    name: "L1L2"
+    mtype: ""
+  }
+  member {
+    name: "Regularizer"
+    mtype: ""
+  }
+  member_method {
+    name: "deserialize"
+    argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get"
+    argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "l1"
+    argspec: "args=[\'l\'], varargs=None, keywords=None, defaults=[\'0.01\'], "
+  }
+  member_method {
+    name: "l1_l2"
+    argspec: "args=[\'l1\', \'l2\'], varargs=None, keywords=None, defaults=[\'0.01\', \'0.01\'], "
+  }
+  member_method {
+    name: "l2"
+    argspec: "args=[\'l\'], varargs=None, keywords=None, defaults=[\'0.01\'], "
+  }
+  member_method {
+    name: "serialize"
+    argspec: "args=[\'regularizer\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-custom-object-scope.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-custom-object-scope.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..dda39ed221a06827601a9432f887ddc5f5ee9b01
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-custom-object-scope.pbtxt
@@ -0,0 +1,9 @@
+path: "tensorflow.keras.utils.CustomObjectScope"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\'], varargs=args, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..bf27a97cf25ee1ec64efa1aaeb4b10ed200f81fc
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-generator-enqueuer.pbtxt
@@ -0,0 +1,26 @@
+path: "tensorflow.keras.utils.GeneratorEnqueuer"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'generator\', \'use_multiprocessing\', \'wait_time\', \'random_seed\'], varargs=None, keywords=None, defaults=[\'False\', \'0.05\', \'None\'], "
+  }
+  member_method {
+    name: "get"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "is_running"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "start"
+    argspec: "args=[\'self\', \'workers\', \'max_queue_size\'], varargs=None, keywords=None, defaults=[\'1\', \'10\'], "
+  }
+  member_method {
+    name: "stop"
+    argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-h-d-f5-matrix.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-h-d-f5-matrix.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..ce62c8bafcaec1bf2e6ab3989da68588f7c848e9
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-h-d-f5-matrix.pbtxt
@@ -0,0 +1,29 @@
+path: "tensorflow.keras.utils.HDF5Matrix"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  member {
+    name: "dtype"
+    mtype: ""
+  }
+  member {
+    name: "ndim"
+    mtype: ""
+  }
+  member {
+    name: "refs"
+    mtype: ""
+  }
+  member {
+    name: "shape"
+    mtype: ""
+  }
+  member {
+    name: "size"
+    mtype: ""
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'datapath\', \'dataset\', \'start\', \'end\', \'normalizer\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..3adc6b6faa6f62330f9ac3d621f29adfc380a09d
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt
@@ -0,0 +1,17 @@
+path: "tensorflow.keras.utils.Progbar"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'target\', \'width\', \'verbose\', \'interval\'], varargs=None, keywords=None, defaults=[\'30\', \'1\', \'0.05\'], "
+  }
+  member_method {
+    name: "add"
+    argspec: "args=[\'self\', \'n\', \'values\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "update"
+    argspec: "args=[\'self\', \'current\', \'values\', \'force\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence-enqueuer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence-enqueuer.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..5cf2a07b0b265ba88d7942698640520d53a2f407
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence-enqueuer.pbtxt
@@ -0,0 +1,24 @@
+path: "tensorflow.keras.utils.SequenceEnqueuer"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+  }
+  member_method {
+    name: "get"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "is_running"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "start"
+    argspec: "args=[\'self\', \'workers\', \'max_queue_size\'], varargs=None, keywords=None, defaults=[\'1\', \'10\'], "
+  }
+  member_method {
+    name: "stop"
+    argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..5b272253e3767941b10d42ef5fef9c09433e9f59
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-sequence.pbtxt
@@ -0,0 +1,12 @@
+path: "tensorflow.keras.utils.Sequence"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+  }
+  member_method {
+    name: "on_epoch_end"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..e840f331426c52f01db9d6280204ce3ff34a7db2
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.pbtxt
@@ -0,0 +1,63 @@
+path: "tensorflow.keras.utils"
+tf_module {
+  member {
+    name: "CustomObjectScope"
+    mtype: ""
+  }
+  member {
+    name: "GeneratorEnqueuer"
+    mtype: ""
+  }
+  member {
+    name: "HDF5Matrix"
+    mtype: ""
+  }
+  member {
+    name: "Progbar"
+    mtype: ""
+  }
+  member {
+    name: "Sequence"
+    mtype: ""
+  }
+  member {
+    name: "SequenceEnqueuer"
+    mtype: ""
+  }
+  member_method {
+    name: "convert_all_kernels_in_model"
+    argspec: "args=[\'model\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "custom_object_scope"
+    argspec: "args=[], varargs=args, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "deserialize_keras_object"
+    argspec: "args=[\'identifier\', \'module_objects\', \'custom_objects\', \'printable_module_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'object\'], "
+  }
+  member_method {
+    name: "get_custom_objects"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_file"
+    argspec: "args=[\'fname\', \'origin\', \'untar\', \'md5_hash\', \'file_hash\', \'cache_subdir\', \'hash_algorithm\', \'extract\', \'archive_format\', \'cache_dir\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'datasets\', \'auto\', \'False\', \'auto\', \'None\'], "
+  }
+  member_method {
+    name: "normalize"
+    argspec: "args=[\'x\', \'axis\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'2\'], "
+  }
+  member_method {
+    name: "plot_model"
+    argspec: "args=[\'model\', \'to_file\', \'show_shapes\', \'show_layer_names\', \'rankdir\'], varargs=None, keywords=None, defaults=[\'model.png\', \'False\', \'True\', \'TB\'], "
+  }
+  member_method {
+    name: "serialize_keras_object"
+    argspec: "args=[\'instance\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "to_categorical"
+    argspec: "args=[\'y\', \'num_classes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.wrappers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.wrappers.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..0b2fac9b7d998312d1bc080d7464d17b2b5543f5
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.wrappers.pbtxt
@@ -0,0 +1,7 @@
+path: "tensorflow.keras.wrappers"
+tf_module {
+  member {
+    name: "scikit_learn"
+    mtype: ""
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-classifier.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..8d200f99fd14d6a7735e1a74299159d6b198cd68
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-classifier.pbtxt
@@ -0,0 +1,42 @@
+path: "tensorflow.keras.wrappers.scikit_learn.KerasClassifier"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'build_fn\'], varargs=None, keywords=sk_params, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "check_params"
+    argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "filter_sk_params"
+    argspec: "args=[\'self\', \'fn\', \'override\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "fit"
+    argspec: "args=[\'self\', \'x\', \'y\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "get_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=params, defaults=None"
+  }
+  member_method {
+    name: "predict"
+    argspec: "args=[\'self\', \'x\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "predict_proba"
+    argspec: "args=[\'self\', \'x\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "score"
+    argspec: "args=[\'self\', \'x\', \'y\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "set_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=params, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-regressor.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..7a971346d86f4930c7bba872031e049a93445d1d
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.-keras-regressor.pbtxt
@@ -0,0 +1,38 @@
+path: "tensorflow.keras.wrappers.scikit_learn.KerasRegressor"
+tf_class {
+  is_instance: ""
+  is_instance: ""
+  is_instance: ""
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'build_fn\'], varargs=None, keywords=sk_params, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "check_params"
+    argspec: "args=[\'self\', \'params\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "filter_sk_params"
+    argspec: "args=[\'self\', \'fn\', \'override\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "fit"
+    argspec: "args=[\'self\', \'x\', \'y\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "get_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=params, defaults=None"
+  }
+  member_method {
+    name: "predict"
+    argspec: "args=[\'self\', \'x\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "score"
+    argspec: "args=[\'self\', \'x\', \'y\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "set_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=params, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.pbtxt
new file mode 100644
index 0000000000000000000000000000000000000000..fbd4d13387a931c3c947d8d0babcbfa978070de9
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.wrappers.scikit_learn.pbtxt
@@ -0,0 +1,11 @@
+path: "tensorflow.keras.wrappers.scikit_learn"
+tf_module {
+  member {
+    name: "KerasClassifier"
+    mtype: ""
+  }
+  member {
+    name: "KerasRegressor"
+    mtype: ""
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
index 8d726ce8f1d0038fe1db7f5de6f6757991eb46fc..4dc2e8e67db981cc17b68b71daaa7fc8669336db 100644
--- a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
@@ -1,99 +1,7 @@
 path: "tensorflow.linalg"
 tf_module {
   member_method {
-    name: "band_part"
-    argspec: "args=[\'input\', \'num_lower\', \'num_upper\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "cholesky"
-    argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "cholesky_solve"
-    argspec: "args=[\'chol\', \'rhs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "det"
-    argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "diag"
-    argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "diag_part"
-    argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "eigh"
-    argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "eigvalsh"
-    argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "einsum"
-    argspec: "args=[\'equation\'], varargs=inputs, keywords=kwargs, defaults=None"
-  }
-  member_method {
-    name: "eye"
-    argspec: "args=[\'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"\", \'None\'], "
-  }
-  member_method {
-    name: "inv"
-    argspec: "args=[\'input\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
-  }
-  member_method {
-    name: "lstsq"
-    argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], "
-  }
-  member_method {
-    name: "norm"
-    argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keep_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'False\', \'None\'], "
-  }
-  member_method {
-    name: "qr"
-    argspec: "args=[\'input\', \'full_matrices\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
-  }
-  member_method {
-    name: "self_adjoint_eig"
-    argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "self_adjoint_eigvals"
-    argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "set_diag"
-    argspec: "args=[\'input\', \'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "solve"
-    argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
-  }
-  member_method {
-    name: "solve_ls"
-    argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], "
-  }
-  member_method {
-    name: "svd"
-    argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
-  }
-  member_method {
-    name: "tensordot"
-    argspec: "args=[\'a\', \'b\', \'axes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "trace"
-    argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
-  member_method {
-    name: "transpose"
-    argspec: "args=[\'a\', \'name\'], varargs=None, keywords=None, defaults=[\'matrix_transpose\'], "
-  }
-  member_method {
-    name: "triangular_solve"
-    argspec: "args=[\'matrix\', \'rhs\', \'lower\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], "
+    name: "logdet"
+    argspec: "args=[\'matrix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
 }
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 667ae5cf6e59274b73024ae44ec3bdb71b320d78..8935bcda3dcf2589f3c08f3542aa50fcce43395f 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -364,6 +364,10 @@ tf_module {
     name: "int8"
     mtype: ""
   }
+  member {
+    name: "keras"
+    mtype: ""
+  }
   member {
     name: "layers"
     mtype: ""
@@ -1098,7 +1102,7 @@ tf_module {
   }
   member_method {
     name: "gradients"
-    argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\'], "
+    argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], "
   }
   member_method {
     name: "greater"
@@ -1684,6 +1688,10 @@ tf_module {
     name: "serialize_sparse"
     argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "serialize_tensor"
+    argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
   member_method {
     name: "set_random_seed"
     argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=None"
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index 68e826ccd5576100380c727a710482db1d5433f7..3e0eaa26bc0fb5479476f059b21f44fffeb6ac36 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -95,7 +95,8 @@ do_pylint() {
 "^tensorflow/python/platform/default/_googletest\.py.*\[E0102.*function\salready\sdefined "\
 "^tensorflow/python/feature_column/feature_column_test\.py.*\[E0110.*abstract-class-instantiated "\
 "^tensorflow/contrib/layers/python/layers/feature_column\.py.*\[E0110.*abstract-class-instantiated "\
-"^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator"
+"^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\
+"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable"
 
   echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""
 
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
index 817df6a434e8088f3b9e60293f49ae04fd4e32df..08fc82d04c94fd79f40935b506e7b26fde28974e 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_cc_core.sh
@@ -33,7 +33,7 @@ export PYTHON_BIN_PATH=`which python`
 yes "" | $PYTHON_BIN_PATH configure.py
 
 # Run bazel test command. Double test timeouts to avoid flakes.
-bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test --test_lang_filters=cc -k \
+bazel test --test_tag_filters=-no_oss,-gpu,-benchmark-test --test_lang_filters=cc,java -k \
     --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
     --test_output=errors -- \
     //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh
index 9e6cfc017e97aceb721911fb1eddf318c1347132..f53bfb59ff4eaf0e119c9ed9d5de89dfc43b3a5b 100755
--- a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh
+++ b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh
@@ -47,8 +47,30 @@ sed -i 's/ca7beac153d4059c02c8fc59816c82d54ea47fe58365e8aded4082ded0b820c4/a34b2
 sudo sed -i 's/define CURL_SIZEOF_LONG 8/define CURL_SIZEOF_LONG 4/g' /usr/include/curl/curlbuild.h
 sudo sed -i 's/define CURL_SIZEOF_CURL_OFF_T 8/define CURL_SIZEOF_CURL_OFF_T 4/g' /usr/include/curl/curlbuild.h
 
+# Build the OpenBLAS library, which is faster than Eigen on the Pi Zero/One.
+# TODO(petewarden) - It would be nicer to move this into the main Bazel build
+# process if we can maintain a build file for this.
+mkdir toolchain
+cd toolchain
+curl -L https://github.com/raspberrypi/tools/archive/0e906ebc527eab1cdbf7adabff5b474da9562e9f.tar.gz -o toolchain.tar.gz
+tar xzf toolchain.tar.gz
+mv tools-0e906ebc527eab1cdbf7adabff5b474da9562e9f/ tools
+cd ..
+
+CROSSTOOL_CC=$(pwd)/toolchain/tools/arm-bcm2708/arm-rpi-4.9.3-linux-gnueabihf/bin/arm-linux-gnueabihf-gcc
+
+git clone https://github.com/xianyi/OpenBLAS openblas
+cd openblas
+make CC=${CROSSTOOL_CC} FC=${CROSSTOOL_CC} HOSTCC=gcc TARGET=ARMV6
+make PREFIX=$(pwd)/toolchain/openblas/ install
+cd ..
+
 if [[ $1 == "PI_ONE" ]]; then
-  PI_COPTS="--copt=-march=armv6 --copt=-mfpu=vfp"
+  PI_COPTS="--copt=-march=armv6 --copt=-mfpu=vfp
+  --copt=-DUSE_GEMM_FOR_CONV --copt=-DUSE_OPENBLAS
+  --copt=-isystem=$(pwd)/toolchain/openblas/include/
+  --linkopt=-L$(pwd)/toolchain/openblas/lib/
+  --linkopt=-l:libopenblas.a"
   echo "Building for the Pi One/Zero, with no NEON support"
 else
   PI_COPTS='--copt=-march=armv7-a --copt=-mfpu=neon-vfpv4
diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py
index 6f3c3f151032e4baa514ff3ba7d0c14c9a779597..4405678a6b84e97e67cb194bbbd2cb2ceb478303 100755
--- a/tensorflow/tools/ci_build/update_version.py
+++ b/tensorflow/tools/ci_build/update_version.py
@@ -277,8 +277,8 @@ def check_for_lingering_string(lingering_string):
   formatted_string = lingering_string.replace(".", r"\.")
   try:
     linger_str_output = subprocess.check_output(
-        ['grep', '-rnoH', formatted_string, TF_SRC_DIR])
-    linger_strs = linger_str_output.decode('utf8').split("\n")
+        ["grep", "-rnoH", formatted_string, TF_SRC_DIR])
+    linger_strs = linger_str_output.decode("utf8").split("\n")
   except subprocess.CalledProcessError:
     linger_strs = []
 
diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
index 61f5ed084ccb8a6237faa637b54a6f773f471291..f6e3d2e6c716178609b4aeb7e25d4dc12ac12f34 100644
--- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
@@ -60,8 +60,11 @@ reinstall_tensorflow_pip ${PIP_NAME}
 
 # Define no_tensorflow_py_deps=true so that every py_test has no deps anymore,
 # which will result testing system installed tensorflow
+# TODO(pcloudy): Remove TF_SAVER_LENIENT_NAMES after
+# https://github.com/tensorflow/tensorflow/issues/12844 is fixed.
 bazel test -c opt $BUILD_OPTS -k --test_output=errors \
   --define=no_tensorflow_py_deps=true --test_lang_filters=py \
   --test_tag_filters=-no_pip,-no_windows \
   --build_tag_filters=-no_pip,-no_windows --build_tests_only \
+  --test_env=TF_SAVER_LENIENT_NAMES=True \
   //${PY_TEST_DIR}/tensorflow/python/...
diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
index e1972a310046784a74e279560927963a6e725133..25d327c8188666e34477daa0e888a9169c709c66 100644
--- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
@@ -61,8 +61,11 @@ reinstall_tensorflow_pip ${PIP_NAME}
 # Define no_tensorflow_py_deps=true so that every py_test has no deps anymore,
 # which will result testing system installed tensorflow
 # GPU tests are very flaky when running concurrently, so set local_test_jobs=1
+# TODO(pcloudy): Remove TF_SAVER_LENIENT_NAMES after
+# https://github.com/tensorflow/tensorflow/issues/12844 is fixed.
 bazel test -c opt $BUILD_OPTS -k --test_output=errors \
   --define=no_tensorflow_py_deps=true --test_lang_filters=py \
   --test_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu \
   --build_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu \
+  --test_env=TF_SAVER_LENIENT_NAMES=True \
   --local_test_jobs=1 --build_tests_only //${PY_TEST_DIR}/tensorflow/python/...
diff --git a/tensorflow/tools/gcs_test/gcs_smoke.sh b/tensorflow/tools/gcs_test/gcs_smoke.sh
index ec7ee4fbb0180ec9b68de3b51cb0082c672ddc73..69c632f2cf2ea1edc9f9507b31b7dcee71b1d865 100755
--- a/tensorflow/tools/gcs_test/gcs_smoke.sh
+++ b/tensorflow/tools/gcs_test/gcs_smoke.sh
@@ -65,10 +65,6 @@ echo "Building in temporary directory: ${BUILD_DIR}"
 cp -r ${SCRIPT_DIR}/* "${BUILD_DIR}"/ || \
     die "Failed to copy files to ${BUILD_DIR}"
 
-# Download whl file into the build context directory.
-wget -P "${BUILD_DIR}" ${WHL_URL} || \
-    die "Failed to download tensorflow whl file from URL: ${WHL_URL}"
-
 DOCKERFILE="${BUILD_DIR}/Dockerfile"
 if [[ ! -f "${DOCKERFILE}" ]]; then
   die "ERROR: Cannot find Dockerfile at expected path ${DOCKERFILE}"
diff --git a/tensorflow/tools/graph_transforms/remove_attribute.cc b/tensorflow/tools/graph_transforms/remove_attribute.cc
index d76c3ff87d0c0f9b6fa4d083df200c30f5f7317e..b1a04c0f283bf6bc03da702447694558c5b98538 100644
--- a/tensorflow/tools/graph_transforms/remove_attribute.cc
+++ b/tensorflow/tools/graph_transforms/remove_attribute.cc
@@ -34,7 +34,7 @@ Status RemoveAttribute(const GraphDef& input_graph_def,
   if (!context.params.count("attribute_name") ||
       (context.params.at("attribute_name").size() != 1)) {
     return errors::InvalidArgument(
-        "remove_nodes expects exactly one 'attribute_name' "
+        "remove_attribute expects exactly one 'attribute_name' "
         "argument, e.g. remove_attribute(op_name=Mul, attribute_name=foo)");
   }
 
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index ae93cf210be254a2ca477b046f3948fecbd585bf..d62316964f86a21721bddedfb0329b5cb86ce86a 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -84,6 +84,7 @@ py_binary(
         "//tensorflow/python/saved_model",
         "//tensorflow/python:spectral_ops_test_util",
         "//tensorflow/python/tools:tools_pip",
+        "//tensorflow/python/eager:eager_pip",
         # These targets don't build on Windows yet. Exclude them for now.
         # "//tensorflow/contrib/ndlstm",
         # "//tensorflow/contrib/slim",
@@ -157,6 +158,7 @@ sh_binary(
             "//tensorflow/contrib/ndlstm:ndlstm",
             "//tensorflow/contrib/nn:nn_py",
             "//tensorflow/contrib/predictor:predictor_pip",
+            "//tensorflow/contrib/receptive_field:receptive_field_pip",
             "//tensorflow/contrib/session_bundle:session_bundle_pip",
             "//tensorflow/contrib/signal:signal_py",
             "//tensorflow/contrib/slim:slim",
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index b009859e415ed30c2c572a4d4329c2022f91300f..2b0d24d9a8928b8522b989ed0b37ab9a1062c567 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -193,7 +193,7 @@ setup(
     version=_VERSION.replace('-', ''),
     description='TensorFlow helps the tensors flow',
     long_description='',
-    url='http://tensorflow.org/',
+    url='https://www.tensorflow.org/',
     author='Google Inc.',
     author_email='opensource@google.com',
     # Contained modules and scripts.
diff --git a/tensorflow/tools/proto_text/BUILD b/tensorflow/tools/proto_text/BUILD
index 6607f629e71375b86e477dc2fea88c08de462909..3a60c8c95838886979e0a02dc2574af599a1f1a7 100644
--- a/tensorflow/tools/proto_text/BUILD
+++ b/tensorflow/tools/proto_text/BUILD
@@ -34,7 +34,7 @@ cc_binary(
     visibility = ["//tensorflow:internal"],
     deps = [
         ":gen_proto_text_functions_lib",
-        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_proto_parsing",
     ],
 )
 
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 0d171760a53fdfb76177355dccce9cd5828bca8e..e86aebe3b3a7bd1fc5ec458f74c15ab68f125122 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -99,6 +99,9 @@ def _execute_and_check_ret_code(repo_ctx, cmd_and_args):
 # Apply a patch_file to the repository root directory
 # Runs 'patch -p1'
 def _apply_patch(repo_ctx, patch_file):
+  if not repo_ctx.which("patch"):
+    fail("patch command is not found, please install it")
+
   cmd = [
       "patch", "-p1", "-d", repo_ctx.path("."), "-i", repo_ctx.path(patch_file)
   ]
@@ -296,6 +299,14 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
       build_file = str(Label("//third_party:png.BUILD")),
   )
 
+  native.new_http_archive(
+      name = "sqlite_archive",
+      urls = ["http://www.sqlite.org/2017/sqlite-amalgamation-3200000.zip"],
+      sha256 = "208780b3616f9de0aeb50822b7a8f5482f6515193859e91ed61637be6ad74fd4",
+      strip_prefix = "sqlite-amalgamation-3200000",
+      build_file = str(Label("//third_party:sqlite.BUILD"))
+  )
+
   native.new_http_archive(
       name = "gif_archive",
       urls = [
@@ -565,11 +576,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
   patched_http_archive(
       name = "boringssl",
       urls = [
-          "http://mirror.bazel.build/github.com/google/boringssl/archive/bbcaa15b0647816b9a1a9b9e0d209cd6712f0105.tar.gz",
-          "https://github.com/google/boringssl/archive/bbcaa15b0647816b9a1a9b9e0d209cd6712f0105.tar.gz",  # 2016-07-11
+          "http://mirror.bazel.build/github.com/google/boringssl/archive/e3860009a091cd1bd2bc189cdbc3c6d095abde84.tar.gz",
+          "https://github.com/google/boringssl/archive/e3860009a091cd1bd2bc189cdbc3c6d095abde84.tar.gz",  # 2017-07-07
       ],
-      sha256 = "025264d6e9a7ad371f2f66d17a28b6627de0c9592dc2eb54afd062f68f1f9aa3",
-      strip_prefix = "boringssl-bbcaa15b0647816b9a1a9b9e0d209cd6712f0105",
+      sha256 = "02f5950f93c4fd3691771c07c9d04cf2999ab01383ff99da345249e93b0fcfb2",
+      strip_prefix = "boringssl-e3860009a091cd1bd2bc189cdbc3c6d095abde84",
       # Add patch to boringssl code to support s390x
       patch_file = str(Label("//third_party/boringssl:add_boringssl_s390x.patch")),
   )
@@ -675,11 +686,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
   native.new_http_archive(
       name = "cub_archive",
       urls = [
-          "http://mirror.bazel.build/github.com/NVlabs/cub/archive/69ceda618313df8e9cac6659d607b08949455d14.tar.gz",
-          "https://github.com/NVlabs/cub/archive/69ceda618313df8e9cac6659d607b08949455d14.tar.gz",
+          "http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.3.zip",
+          "https://github.com/NVlabs/cub/archive/1.7.3.zip",
       ],
-      sha256 = "87e856522c283b8ea887c3b61d7d5b252d2dd74abac4f1d756d776e721223e82",
-      strip_prefix = "cub-69ceda618313df8e9cac6659d607b08949455d14",
+      sha256 = "b7ead9e291d34ffa8074243541c1380d63be63f88de23de8ee548db573b72ebe",
+      strip_prefix = "cub-1.7.3",
       build_file = str(Label("//third_party:cub.BUILD")),
   )
 
@@ -691,9 +702,9 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
   native.http_archive(
       name = "bazel_toolchains",
       urls = [
-          "http://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/bccee4855c049d34bac481083b4c68e2fab8cc50.tar.gz",
-          "https://github.com/bazelbuild/bazel-toolchains/archive/bccee4855c049d34bac481083b4c68e2fab8cc50.tar.gz",
+          "http://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/9dbd803ad3b9447430a296810197b09b3a710956.tar.gz",
+          "https://github.com/bazelbuild/bazel-toolchains/archive/9dbd803ad3b9447430a296810197b09b3a710956.tar.gz",
       ],
-      sha256 = "3903fd93b96b42067e00b7973a2c16c34e761ad7a0b55e1557d408f352849e41",
-      strip_prefix = "bazel-toolchains-bccee4855c049d34bac481083b4c68e2fab8cc50",
+      sha256 = "0799aa12db5260a499beb40f81744e760c59d055bfc5d271dd2c2ed4d5419faa",
+      strip_prefix = "bazel-toolchains-9dbd803ad3b9447430a296810197b09b3a710956",
   )
diff --git a/third_party/boringssl/add_boringssl_s390x.patch b/third_party/boringssl/add_boringssl_s390x.patch
index 9a34a59a1d12908b76bab216adf2bf1bda48dca8..8b42d10e6871b1e452abffab7b35de095c0be06c 100644
--- a/third_party/boringssl/add_boringssl_s390x.patch
+++ b/third_party/boringssl/add_boringssl_s390x.patch
@@ -3,9 +3,9 @@ index 7a3adfb..88012ad 100644
 --- a/src/include/openssl/base.h
 +++ b/src/include/openssl/base.h
 @@ -94,6 +94,8 @@ extern "C" {
- #elif defined(__pnacl__)
- #define OPENSSL_32_BIT
  #define OPENSSL_PNACL
+ #elif defined(__myriad2__)
+ #define OPENSSL_32_BIT
 +#elif defined(__s390x__)
 +#define OPENSSL_64_BIT
  #else
diff --git a/third_party/gpus/crosstool/BUILD.tpl b/third_party/gpus/crosstool/BUILD.tpl
index 7d8b60051350aac9a6d352424b6480a5c253fe17..98cb326572e75ac3ea15a656d821c1eade53d313 100644
--- a/third_party/gpus/crosstool/BUILD.tpl
+++ b/third_party/gpus/crosstool/BUILD.tpl
@@ -12,12 +12,12 @@ cc_toolchain_suite(
 
 cc_toolchain(
     name = "cc-compiler-local",
-    all_files = ":crosstool_wrapper_driver_is_not_gcc",
+    all_files = "%{linker_files}",
     compiler_files = ":empty",
     cpu = "local",
     dwp_files = ":empty",
     dynamic_runtime_libs = [":empty"],
-    linker_files = ":crosstool_wrapper_driver_is_not_gcc",
+    linker_files = "%{linker_files}",
     objcopy_files = ":empty",
     static_runtime_libs = [":empty"],
     strip_files = ":empty",
@@ -30,12 +30,12 @@ cc_toolchain(
 
 cc_toolchain(
     name = "cc-compiler-darwin",
-    all_files = ":crosstool_wrapper_driver_is_not_gcc",
+    all_files = "%{linker_files}",
     compiler_files = ":empty",
     cpu = "darwin",
     dwp_files = ":empty",
     dynamic_runtime_libs = [":empty"],
-    linker_files = ":crosstool_wrapper_driver_is_not_gcc",
+    linker_files = "%{linker_files}",
     objcopy_files = ":empty",
     static_runtime_libs = [":empty"],
     strip_files = ":empty",
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index b85e565f362633fdd8b78057f49f796710ebd4ca..4a0f47108813f77dfad1a1c9f17d98169332c515 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -971,7 +971,6 @@ def _create_local_cuda_repository(repository_ctx):
                                '        ":cudnn-include",')
        })
   # Set up crosstool/
-  _file(repository_ctx, "crosstool:BUILD")
   cc = find_cc(repository_ctx)
   host_compiler_includes = _host_compiler_includes(repository_ctx, cc)
   cuda_defines = {
@@ -981,11 +980,14 @@ def _create_local_cuda_repository(repository_ctx):
        }
   if _use_cuda_clang(repository_ctx):
     cuda_defines["%{clang_path}"] = cc
+    _tpl(repository_ctx, "crosstool:BUILD", {"%{linker_files}": ":empty"})
     _tpl(repository_ctx, "crosstool:CROSSTOOL_clang", cuda_defines, out="crosstool/CROSSTOOL")
   else:
     nvcc_path = str(repository_ctx.path("%s/bin/nvcc%s" %
         (cuda_config.cuda_toolkit_path,
         ".exe" if cuda_config.cpu_value == "Windows" else "")))
+    _tpl(repository_ctx, "crosstool:BUILD",
+         {"%{linker_files}": ":crosstool_wrapper_driver_is_not_gcc"})
     _tpl(repository_ctx, "crosstool:CROSSTOOL_nvcc", cuda_defines, out="crosstool/CROSSTOOL")
     _tpl(repository_ctx,
          "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc",
diff --git a/third_party/sqlite.BUILD b/third_party/sqlite.BUILD
new file mode 100644
index 0000000000000000000000000000000000000000..9840d7b15147b8d830c40b3156245613dc3ddc12
--- /dev/null
+++ b/third_party/sqlite.BUILD
@@ -0,0 +1,16 @@
+# Description:
+#   Sqlite3 library. Provides utilities for interacting
+#   with sqlite3 databases.
+
+licenses(["unencumbered"])  # Public Domain
+
+# exports_files(["LICENSE"])
+
+cc_library(
+    name = "sqlite",
+    srcs = ["sqlite3.c"],
+    hdrs = ["sqlite3.h"],
+    includes = ["."],
+    linkopts = ["-lm"],
+    visibility = ["//visibility:public"],
+)
diff --git a/third_party/toolchains/cpus/arm/CROSSTOOL.tpl b/third_party/toolchains/cpus/arm/CROSSTOOL.tpl
index 6753476c153b9b624436ed0af436ed0841864632..04e399bed1d2cb7fe62a1a40797da544eada5709 100644
--- a/third_party/toolchains/cpus/arm/CROSSTOOL.tpl
+++ b/third_party/toolchains/cpus/arm/CROSSTOOL.tpl
@@ -77,8 +77,8 @@ toolchain {
   cxx_builtin_include_directory: "%{ARM_COMPILER_PATH}%/lib/gcc/arm-linux-gnueabihf/4.9.3/include-fixed"
   cxx_builtin_include_directory: "%{ARM_COMPILER_PATH}%/local_include"
   cxx_builtin_include_directory: "/usr/include"
-
-  cxx_flag: "-std=c++11"	
+  cxx_builtin_include_directory: "/workspace/toolchain/openblas/include/"
+  cxx_flag: "-std=c++11"
   # The cxx_builtin_include_directory directives don't seem to be adding these, so
   # explicitly set them as flags. There's a query to the Bazel team outstanding about
   # why this is necessary.