From 75259500f80266998f232a94853b0bc08d2925cc Mon Sep 17 00:00:00 2001 From: KB Sriram Date: Wed, 28 Feb 2018 07:16:20 -0800 Subject: [PATCH 001/839] C++ gradients: Fractional*Pool, Soft{Plus,Sign} 1. Adds gradients for four nn ops: FractionalAvgPool FractionalMaxPool SoftPlus SoftSign 2. Update randomization to allow numeric gradient checks on max pooling algorithms with more than one pool. Resolves https://github.com/tensorflow/tensorflow/issues/17330 --- tensorflow/cc/gradients/nn_grad.cc | 47 ++++++++++++++ tensorflow/cc/gradients/nn_grad_test.cc | 84 ++++++++++++++++++++----- 2 files changed, 115 insertions(+), 16 deletions(-) diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 63a67f09f6..4b89dac4c0 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -272,6 +272,53 @@ Status LRNGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("LRN", LRNGradHelper); +Status SoftplusGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftplusGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softplus", SoftplusGradHelper); + +Status SoftsignGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftsignGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softsign", SoftsignGradHelper); + +Status FractionalAvgPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalAvgPoolGrad( + scope, Shape(scope, op.input(0), Shape::OutType(DT_INT64)), + grad_inputs[0], op.output(1), op.output(2), + internal::FractionalAvgPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalAvgPool", FractionalAvgPoolGradHelper); + +Status FractionalMaxPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalMaxPoolGrad( + scope, op.input(0), op.output(0), grad_inputs[0], op.output(1), + op.output(2), internal::FractionalMaxPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalMaxPool", FractionalMaxPoolGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index c4eba7ecb0..b4d457a9d1 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -28,6 +28,8 @@ namespace { using ops::BiasAdd; using ops::Conv2D; using ops::Elu; +using ops::FractionalAvgPool; +using ops::FractionalMaxPool; using ops::L2Loss; using ops::LogSoftmax; using ops::LRN; @@ -41,6 +43,8 @@ using ops::Relu; using ops::Relu6; using ops::Selu; using ops::Softmax; +using ops::Softplus; +using ops::Softsign; class NNGradTest : public ::testing::Test { protected: @@ -71,22 +75,30 @@ class NNGradTest : public ::testing::Test { EXPECT_LT(max_error, 1e-3); } - // Sets tensor with random values, ensuring that the max value is largest by - // a reasonable amount. - // This is an issue for MaxPool, MaxPoolV2 and MaxPool3D, in which - // perturbations by the numeric gradient computation in the gradient checker - // can change the max value if values are too close together. + // Sets tensor with random values, ensuring that every pair of elements are at + // least a reasonable amount apart. + // This is an issue for max pooling operations, in which perturbations by the + // numeric gradient computation in the gradient checker can change the max + // value if a pool has values that are too close together. template - void SetRandomValuesWithBumpedMax(Tensor* tensor) { + void SetRandomValuesForMaxPooling(Tensor* tensor) { auto tensor_flat = tensor->flat(); - tensor_flat.setRandom(); - int32 max_index = 0; - for (size_t i = 1; i < tensor->NumElements(); i++) { - if (tensor_flat(i) > tensor_flat(max_index)) { - max_index = i; - } + // First set the array to an increasing sequence of values spaced + // a reasonable amount apart + T cur = 0; + for (size_t i = 0; i < tensor->NumElements(); i++) { + tensor_flat(i) = cur; + cur += 5e-2; + } + // Fischer-Yates shuffle the array + for (size_t i = tensor->NumElements() - 1; i >= 1; i--) { + // j <- random integer 0 <= j <= i + size_t j = random::New64() % (i + 1); + // swap values at i, j + T tmp = tensor_flat(i); + tensor_flat(i) = tensor_flat(j); + tensor_flat(j) = tmp; } - tensor_flat(max_index) += 1e-2; } Scope scope_; @@ -189,7 +201,7 @@ TEST_F(NNGradTest, MaxPoolGradHelper) { const std::vector strides{1, 2, 2, 1}; auto y = MaxPool(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -202,7 +214,7 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { Tensor strides = test::AsTensor({1, 2, 2, 1}, {4}); auto y = MaxPoolV2(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -215,7 +227,7 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) { const std::vector strides{1, 3, 3, 3, 1}; auto y = MaxPool3D(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -248,5 +260,45 @@ TEST_F(NNGradTest, LRN){ RunTest(x, x_shape, y, x_shape); } +TEST_F(NNGradTest, SoftplusGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softplus(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, SoftsignGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softsign(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, FractionalAvgPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalAvgPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalAvgPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_shape, y.output, y_shape); +} + +TEST_F(NNGradTest, FractionalMaxPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalMaxPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalMaxPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + Tensor x_init_value = Tensor(DT_FLOAT, x_shape); + SetRandomValuesForMaxPooling(&x_init_value); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_init_value, y.output, y_shape); +} + } // namespace } // namespace tensorflow -- GitLab From 1da3a47287aa911287d6667dd837dc2a7ddaa8f1 Mon Sep 17 00:00:00 2001 From: Smit Shilu Date: Thu, 22 Mar 2018 10:58:51 -0400 Subject: [PATCH 002/839] Update BUILD exports_files(["LICENSE"]) gives error while building on Mac and Ubuntu --- tensorflow/contrib/lite/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index dafe6f136e..1c5bc29763 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -6,8 +6,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") -exports_files(["LICENSE"]) - exports_files(glob([ "testdata/*.bin", "testdata/*.pb", -- GitLab From e7f3ed2477c7910e68573880efd2310e149ca785 Mon Sep 17 00:00:00 2001 From: mbhuiyan Date: Wed, 4 Apr 2018 10:52:49 -0700 Subject: [PATCH 003/839] Fixing a unit test failure for INTEL MKL where memeory allocation check failed because of use of INTEL MKL --- .../direct_session_with_tracking_alloc_test.cc | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index 31fb128f93..0ff022a8bc 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -101,11 +101,24 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { EXPECT_EQ(2, shape.dim_size()); EXPECT_EQ(2, shape.dim(0).size()); EXPECT_EQ(1, shape.dim(1).size()); +#ifndef INTEL_MKL + // if MKL is used, it goes through various additional + // graph rewrite pass. In TF, everytime a graph pass + // happens, "constant" nodes are allocated + // and deallocated. Each allocation calls the + // (FindChunkPtr of BFCAllocator) + // , which increments the value of AllocationId. + // Thus AllocationId becomes more than 3 and 4 if + // MKL is used, they can be 10 and 11 or + // other numbers. If MKL is used + // following check will not hold. + // Thus, skipping the check if MKL is used. if (node->name() == y->name()) { EXPECT_EQ(3, cm->AllocationId(node, 0)); } else { EXPECT_EQ(4, cm->AllocationId(node, 0)); } +#endif } EXPECT_LE(0, cm->MaxExecutionTime(node)); EXPECT_GE(run_duration_micros, cm->MaxExecutionTime(node)); -- GitLab From 9d1aa895adda8644ddbb55b5e1dbb0797ea6cbb0 Mon Sep 17 00:00:00 2001 From: Jie Date: Wed, 11 Apr 2018 14:42:15 -0700 Subject: [PATCH 004/839] [tftrt update] Added support for TRT plugin during conversion - converter & shape inference are now aware of plugin factory. - each plugin does serialization of plugin type & input dimensions - wrapper for nvinfer1::IPlugin & nvinfer1::PluginFactory * compatible with TRT 3.0.4 plugin API. * future plugin API changes willl be updated. --- tensorflow/contrib/tensorrt/BUILD | 26 ++++++ .../contrib/tensorrt/convert/convert_graph.cc | 4 +- .../contrib/tensorrt/convert/convert_nodes.cc | 84 ++++++++++++++--- .../contrib/tensorrt/kernels/trt_engine_op.cc | 4 +- .../contrib/tensorrt/plugin/trt_plugin.cc | 89 +++++++++++++++++++ .../contrib/tensorrt/plugin/trt_plugin.h | 81 +++++++++++++++++ .../tensorrt/plugin/trt_plugin_factory.cc | 81 +++++++++++++++++ .../tensorrt/plugin/trt_plugin_factory.h | 83 +++++++++++++++++ .../tensorrt/plugin/trt_plugin_utils.cc | 36 ++++++++ .../tensorrt/plugin/trt_plugin_utils.h | 51 +++++++++++ .../contrib/tensorrt/shape_fn/trt_shfn.cc | 4 +- 11 files changed, 528 insertions(+), 15 deletions(-) create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin.cc create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin.h create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 2f316767b3..98f18835b0 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -67,6 +67,7 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = [ ":trt_logging", + ":trt_plugins", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]) + tf_custom_op_library_additional_deps(), @@ -86,6 +87,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":trt_logging", + ":trt_plugins", ":trt_resources", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", @@ -222,6 +224,7 @@ tf_cuda_library( ], deps = [ ":segment", + ":trt_plugins", ":trt_logging", ":trt_resources", "//tensorflow/core/grappler:grappler_item", @@ -272,3 +275,26 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +# Library for the plugin factory +#cc_library( +tf_cuda_library( + name = "trt_plugins", + srcs = [ + "plugin/trt_plugin.cc", + "plugin/trt_plugin_factory.cc", + "plugin/trt_plugin_utils.cc", + ], + hdrs = [ + "plugin/trt_plugin.h", + "plugin/trt_plugin_factory.h", + "plugin/trt_plugin_utils.h", + ], + linkstatic = 1, + deps = [ + #"@protobuf_archive//:protobuf_headers", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index b412b296e0..899e1721e6 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include @@ -75,7 +76,8 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(ben,jie): ... }; // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h) - return candidate_ops.count(node->type_string()); + return (candidate_ops.count(node->type_string()) || + PluginFactoryTensorRT::GetInstance().IsPlugin(&node->type_string())); } void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 567b4af88d..a03c1e224a 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include @@ -246,6 +247,15 @@ class TFAttrs { return attrs_.count(key) ? this->get(key) : default_value; } + std::vector GetAllAttrKey() { + std::vector attr_list; + for (AttrMap::iterator iter = attrs_.begin(); iter != attrs_.end(); + iter++) { + attr_list.emplace_back(iter->first); + } + return attr_list; + } + private: typedef std::map AttrMap; AttrMap attrs_; @@ -262,6 +272,12 @@ std::vector TFAttrs::get>(string key) const { return std::vector(attr.begin(), attr.end()); } +template <> +std::vector TFAttrs::get>(string key) const { + auto attr = this->at(key)->list().f(); + return std::vector(attr.begin(), attr.end()); +} + template <> std::vector TFAttrs::get>(string key) const { auto attr = this->at(key)->list().s(); @@ -424,6 +440,7 @@ using OpConverter = class Converter { std::unordered_map trt_tensors_; std::unordered_map op_registry_; + OpConverter plugin_converter_; nvinfer1::INetworkDefinition* trt_network_; std::list> temp_bufs_; tensorflow::tensorrt::TRTWeightStore* weight_store_; @@ -444,8 +461,8 @@ class Converter { * remove this and annotate the edge as a control dependency. ************************************************************************/ // skip control nodes - if (input_name[0] == '^' ) continue; - string name = input_name; + if (input_name[0] == '^') continue; + string name = input_name; auto first = name.find_first_of(':'); if (first != string::npos && first + 2 == name.size() && name[first + 1] == '0') @@ -490,13 +507,17 @@ class Converter { std::vector inputs; TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs)); string op = node_def.op(); - if (!op_registry_.count(op)) { - return tensorflow::errors::Unimplemented( - "No converter registered for op: " + op); - } - OpConverter op_converter = op_registry_.at(op); std::vector outputs; - TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs)); + if (PluginFactoryTensorRT::GetInstance().IsPlugin(&op)) { + TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs)); + } else { + if (!op_registry_.count(op)) { + return tensorflow::errors::Unimplemented( + "No converter registered for op: " + op); + } + OpConverter op_converter = op_registry_.at(op); + TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs)); + } for (size_t i = 0; i < outputs.size(); ++i) { TRT_TensorOrWeights output = outputs.at(i); // TODO(jie): tf protobuf seems to be omitting the :0 suffix @@ -1158,9 +1179,9 @@ tensorflow::Status BinaryTensorOpTensor( CHECK_EQ_TYPE(tensor_r->getType(), dtype); auto op_pair = ops.find(node_def.op()); if (op_pair == ops.end()) - return tensorflow::errors::Unimplemented( - "binary op: " + node_def.op() + - " not supported at: " + node_def.name()); + return tensorflow::errors::Unimplemented("binary op: " + node_def.op() + + " not supported at: " + + node_def.name()); nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( *const_cast(tensor_l), @@ -1173,6 +1194,43 @@ tensorflow::Status BinaryTensorOpTensor( return tensorflow::Status::OK(); } +tensorflow::Status ConvertPlugin(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + // prepare input + std::vector all_inputs; + for (auto input : inputs) { + all_inputs.emplace_back(const_cast(input.tensor())); + } + + // plugin is owned by PluginFactory + // TODO(jie): destroy plugins later (resource management) + PluginTensorRT* plugin = + PluginFactoryTensorRT::GetInstance().CreatePlugin(&node_def.op()); + + // passing attributes + // TODO(jie): support more general attribute + TFAttrs attrs(node_def); + auto attr_key_vector = attrs.GetAllAttrKey(); + for (auto attr_key : attr_key_vector) { + std::cout << attr_key << std::endl; + // TODO(jie): support only list of float for toy example here. + auto data = attrs.get>(attr_key); + size_t size_data = data.size() * sizeof(float); + plugin->SetAttribute(attr_key, static_cast(data.data()), size_data); + } + + nvinfer1::IPluginLayer* layer = + ctx.network()->addPlugin(&all_inputs[0], int(inputs.size()), *plugin); + + for (int i = 0; i < layer->getNbOutputs(); i++) { + nvinfer1::ITensor* output_tensor = layer->getOutput(i); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + } + return tensorflow::Status::OK(); +} + tensorflow::Status ConvertPlaceholder( Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector& inputs, @@ -2073,6 +2131,8 @@ void Converter::register_op_converters() { op_registry_["Reshape"] = ConvertReshape; op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; + + plugin_converter_ = ConvertPlugin; } } // namespace @@ -2511,7 +2571,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef( std::vector input_names; std::vector input_dtypes; for (const std::pair& input : s.input_inds) { - VLOG(2) << "parsing input. Node id= " << input.first ; + VLOG(2) << "parsing input. Node id= " << input.first; int node_id = input.first; int output_idx = input.second; tensorflow::Node* node = s.graph.FindNodeId(node_id); diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index b32371b642..8881c48fe6 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/core/platform/logging.h" @@ -58,7 +59,8 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { IRuntime* infer = nvinfer1::createInferRuntime(logger); trt_engine_ptr_.reset(infer->deserializeCudaEngine( - serialized_engine.c_str(), serialized_engine.size(), nullptr)); + serialized_engine.c_str(), serialized_engine.size(), + &PluginFactoryTensorRT::GetInstance())); trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); // Runtime is safe to delete after engine creation infer->destroy(); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc new file mode 100644 index 0000000000..0e4a157d79 --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc @@ -0,0 +1,89 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include +#include +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) { + // sanity check. + assert(EncodeOpName(GetPluginName()) != + *static_cast(serialized_data)); + const char* buffer = static_cast(serialized_data) + + sizeof(input_dim_list_.size()); + + size_t count = *reinterpret_cast(buffer); + buffer += sizeof(size_t); + + for (int i = 0; i < count; i++) { + nvinfer1::Dims dim; + std::memcpy(&(dim.nbDims), buffer, sizeof(dim.nbDims)); + buffer += sizeof(dim.nbDims); + std::memcpy(dim.d, buffer, sizeof(dim.d)); + buffer += sizeof(dim.d); + std::memcpy(dim.type, buffer, sizeof(dim.type)); + buffer += sizeof(dim.type); + input_dim_list_.emplace_back(dim); + } +} + +size_t PluginTensorRT::getSerializationSize() { + nvinfer1::Dims dim; + return sizeof(size_t) + sizeof(input_dim_list_.size()) + sizeof(dim.nbDims) + + sizeof(dim.d) + sizeof(dim.type); +} + +void PluginTensorRT::serialize(void* serialized_data) { + size_t encode_op_name = EncodeOpName(GetPluginName()); + char* buffer = static_cast(serialized_data); + std::memcpy(buffer, &encode_op_name, sizeof(size_t)); + buffer += sizeof(size_t); + + auto list_size = input_dim_list_.size(); + std::memcpy(buffer, &list_size, sizeof(input_dim_list_.size())); + buffer += sizeof(input_dim_list_.size()); + + for (int i = 0; i < input_dim_list_.size(); i++) { + auto dim = input_dim_list_[i]; + std::memcpy(buffer, &(dim.nbDims), sizeof(dim.nbDims)); + buffer += sizeof(dim.nbDims); + std::memcpy(buffer, dim.d, sizeof(dim.d)); + buffer += sizeof(dim.d); + std::memcpy(buffer, dim.type, sizeof(dim.type)); + buffer += sizeof(dim.type); + } +} + +bool PluginTensorRT::StoreAttribute(const string& key, const void* ptr, + const size_t size) { + if (attr_map_.count(key) != 0) return false; + + attr_map_.emplace(key, std::vector(size)); + std::memcpy(attr_map_[key].data(), ptr, size); + return true; +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h new file mode 100644 index 0000000000..1bbfe62a4e --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN +#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN + +#include +#include +#include +#include + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +using std::string; +using std::unordered_map; + +class PluginTensorRT : public nvinfer1::IPlugin { + public: + PluginTensorRT(){}; + PluginTensorRT(const void* serialized_data, size_t length); + // PluginTensorRT(const void* serialized_data, size_t length, size_t + // &incremental); + virtual string GetPluginName() = 0; + virtual bool Finalize() = 0; + + virtual bool SetAttribute(const string& key, const void* ptr, + const size_t size) = 0; + virtual bool GetAttribute(const string& key, const void* ptr, + size_t& size) = 0; + + void configure(const nvinfer1::Dims* inputs, int nbInputs, + const nvinfer1::Dims* outputs, int nbOutputs, + int maxBatchSize) override { + for (int index = 0; index < nbInputs; index++) { + nvinfer1::Dims dim; + dim.nbDims = inputs[index].nbDims; + for (int i = 0; i < dim.nbDims; i++) { + dim.d[i] = inputs[index].d[i]; + dim.type[i] = inputs[index].type[i]; + } + input_dim_list_.emplace_back(dim); + } + return; + } + + virtual bool StoreAttribute(const string& key, const void* ptr, + const size_t size); + + virtual size_t getSerializationSize() override; + virtual void serialize(void* buffer) override; + + protected: + std::unordered_map > attr_map_; + + std::vector input_dim_list_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc new file mode 100644 index 0000000000..799c609a3e --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layerName, + const void* serial_data, + size_t serial_length) { + size_t parsed_byte = 0; + // extract op_name from serial_data + size_t encoded_op_name = + ExtractOpName(serial_data, serial_length, parsed_byte); + + if (!IsPlugin(encoded_op_name)) { + return nullptr; + } + + // should I lock plugins here? + instance_m_.lock(); + auto plugin_ptr = + plugin_registry_[encoded_op_name].first(serial_data, serial_length); + // string op_name = "IncPluginTRT"; + // auto plugin_ptr = plugin_registry_[EncodeLayerName(&op_name)].second(); + // auto plugin_ptr = plugin_registry_.begin()->second.second(); + owned_plugins_.emplace_back(plugin_ptr); + instance_m_.unlock(); + + return plugin_ptr; +} + +PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string* op_name) { + if (!IsPlugin(op_name)) return nullptr; + + instance_m_.lock(); + auto plugin_ptr = plugin_registry_[EncodeLayerName(op_name)].second(); + owned_plugins_.emplace_back(plugin_ptr); + instance_m_.unlock(); + + return plugin_ptr; +} + +bool PluginFactoryTensorRT::RegisterPlugin( + const string* op_name, PluginDeserializeFunc deserialize_func, + PluginConstructFunc construct_func) { + if (IsPlugin(op_name)) return false; + + // get instance_m_ first before write to registry; + instance_m_.lock(); + auto ret = plugin_registry_.emplace( + EncodeLayerName(op_name), + std::make_pair(deserialize_func, construct_func)); + instance_m_.unlock(); + + return ret.second; +} + +void PluginFactoryTensorRT::DestroyPlugins() { return; } + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h new file mode 100644 index 0000000000..e68f4629d0 --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY +#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY + +#include +#include +#include +#include "trt_plugin.h" +#include "trt_plugin_utils.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { + public: + // deserialization method + // virtual nvinfer1::IPlugin* createPlugin(const char* layerName, const void* + // serialData, size_t serialLength) override; + PluginTensorRT* createPlugin(const char* layerName, const void* serialData, + size_t serialLength) override; + + // construction + PluginTensorRT* CreatePlugin(const string* op_name); + + static PluginFactoryTensorRT& GetInstance() { + static PluginFactoryTensorRT factory_instance; + return factory_instance; + } + + bool RegisterPlugin(const string* op_name, + PluginDeserializeFunc deserialize_func, + PluginConstructFunc construct_func); + + bool IsPlugin(const size_t encode_name) { + return plugin_registry_.find(encode_name) != plugin_registry_.end(); + } + + bool IsPlugin(const string* op_name) { + return IsPlugin(EncodeLayerName(op_name)); + } + + size_t EncodeLayerName(const string* op_name) { + return EncodeOpName(*op_name); + } + + void DestroyPlugins(); + + protected: + std::unordered_map > + plugin_registry_; + + // TODO(jie): Owned plugin should be associated with different sessions; + // should really hand ownership of plugins to resource management; + std::vector > owned_plugins_; + std::mutex instance_m_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc new file mode 100644 index 0000000000..b14480cfa6 --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +size_t ExtractOpName(const void* serial_data, size_t serial_length, + size_t& incremental) { + incremental = sizeof(size_t); + if (serial_length < incremental) return 0; + size_t encoded_op_name = *static_cast(serial_data); + return encoded_op_name; +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h new file mode 100644 index 0000000000..e9675d84cd --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS +#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS + +#include +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +typedef std::function + PluginDeserializeFunc; + +typedef std::function PluginConstructFunc; + +inline size_t EncodeOpName(std::string str) { + return std::hash{}(str); +} + +// TODO(jie): work on error handling here +size_t ExtractOpName(const void* serial_data, size_t serial_length, + size_t& incremental); + +// size_t Deserialize(const char* serial_data, size_t serial_length, size_t +// &incremental); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index 8b475177bc..30b5616475 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include @@ -33,7 +34,8 @@ tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine)); nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine( - serialized_engine.c_str(), serialized_engine.size(), nullptr); + serialized_engine.c_str(), serialized_engine.size(), + &tensorrt::PluginFactoryTensorRT::GetInstance()); int num_batch = -1; std::vector<::tensorflow::DataType> input_type; -- GitLab From 0eb443db1a5654168f396702cae39f5dc3fc7e2e Mon Sep 17 00:00:00 2001 From: imsheridan Date: Tue, 17 Apr 2018 20:25:51 +0800 Subject: [PATCH 005/839] Add deprecated_args decoration to array_ops/ sparse_ops --- tensorflow/python/ops/array_ops.py | 2 ++ tensorflow/python/ops/sparse_ops.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index ceeabe090d..06da2485c3 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2690,6 +2690,8 @@ reverse.__doc__ = gen_array_ops.reverse_v2.__doc__ # pylint: disable=redefined-builtin @tf_export("reverse_sequence") +@deprecation.deprecated_args(None, "Use the `seq_axis` argument instead", "seq_dim") +@deprecation.deprecated_args(None, "Use the `batch_axis` argument instead", "batch_dim") def reverse_sequence(input, seq_lengths, seq_axis=None, diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index c580052c32..73ab216f35 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -110,6 +110,7 @@ def _convert_to_sparse_tensors(sp_inputs): # pylint: disable=protected-access @tf_export("sparse_concat") +@deprecation.deprecated_args(None, "concat_dim is deprecated, use axis instead", "concat_dim") def sparse_concat(axis, sp_inputs, name=None, @@ -616,6 +617,7 @@ class KeywordRequired(object): @tf_export("sparse_split") +@deprecation.deprecated_args(None, "split_dim is deprecated, use axis instead", "split_dim") def sparse_split(keyword_required=KeywordRequired(), sp_input=None, num_split=None, -- GitLab From d6838f52c7daea81c57cdeab8e98c4cd617e5f8b Mon Sep 17 00:00:00 2001 From: imsheridan Date: Tue, 17 Apr 2018 20:30:13 +0800 Subject: [PATCH 006/839] fix typo to keep consistence --- tensorflow/python/ops/array_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 06da2485c3..b6a1f5a272 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2690,8 +2690,8 @@ reverse.__doc__ = gen_array_ops.reverse_v2.__doc__ # pylint: disable=redefined-builtin @tf_export("reverse_sequence") -@deprecation.deprecated_args(None, "Use the `seq_axis` argument instead", "seq_dim") -@deprecation.deprecated_args(None, "Use the `batch_axis` argument instead", "batch_dim") +@deprecation.deprecated_args(None, "seq_dim is deprecated, use seq_axis instead", "seq_dim") +@deprecation.deprecated_args(None, "batch_dim is deprecated, use batch_axis instead", "batch_dim") def reverse_sequence(input, seq_lengths, seq_axis=None, -- GitLab From df0ce53aee6f4e14b3f1c9e0e772a1f7bd1bb95a Mon Sep 17 00:00:00 2001 From: Jie Date: Mon, 16 Apr 2018 17:47:00 -0700 Subject: [PATCH 007/839] [PR comment addressed] adding plugin test for registration updating plugin API wrapper addressing comments in the PR addressing coding style issues removing commented code --- tensorflow/contrib/tensorrt/BUILD | 15 ++- .../contrib/tensorrt/convert/convert_graph.cc | 2 +- .../contrib/tensorrt/convert/convert_nodes.cc | 10 +- .../custom_plugin_examples/inc_op_plugin.cc | 89 ++++++++++++++ .../custom_plugin_examples/inc_op_plugin.h | 114 ++++++++++++++++++ .../contrib/tensorrt/kernels/trt_engine_op.cc | 2 +- .../contrib/tensorrt/plugin/trt_plugin.cc | 37 ++++-- .../contrib/tensorrt/plugin/trt_plugin.h | 41 +++---- .../tensorrt/plugin/trt_plugin_factory.cc | 28 +++-- .../tensorrt/plugin/trt_plugin_factory.h | 33 +++-- .../tensorrt/plugin/trt_plugin_utils.cc | 18 ++- .../tensorrt/plugin/trt_plugin_utils.h | 11 +- .../tensorrt/plugin/trt_plugins_test.cc | 112 +++++++++++++++++ .../contrib/tensorrt/shape_fn/trt_shfn.cc | 2 +- 14 files changed, 423 insertions(+), 91 deletions(-) create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 98f18835b0..751f1d3482 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -277,7 +277,6 @@ tf_cc_test( ) # Library for the plugin factory -#cc_library( tf_cuda_library( name = "trt_plugins", srcs = [ @@ -292,9 +291,21 @@ tf_cuda_library( ], linkstatic = 1, deps = [ - #"@protobuf_archive//:protobuf_headers", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]), ) +tf_cuda_cc_test( + name = "trt_plugins_test", + size = "small", + srcs = ["plugin/trt_plugins_test.cc"], + deps = [ + ":trt_plugins", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:nv_infer", + ]), +) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 899e1721e6..91faba7e21 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -77,7 +77,7 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { }; // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h) return (candidate_ops.count(node->type_string()) || - PluginFactoryTensorRT::GetInstance().IsPlugin(&node->type_string())); + PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); } void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index a03c1e224a..d02c1ebf50 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -249,9 +249,8 @@ class TFAttrs { std::vector GetAllAttrKey() { std::vector attr_list; - for (AttrMap::iterator iter = attrs_.begin(); iter != attrs_.end(); - iter++) { - attr_list.emplace_back(iter->first); + for (auto & attr_item : attrs_) { + attr_list.emplace_back(attr_item.first); } return attr_list; } @@ -508,7 +507,7 @@ class Converter { TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs)); string op = node_def.op(); std::vector outputs; - if (PluginFactoryTensorRT::GetInstance().IsPlugin(&op)) { + if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) { TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs)); } else { if (!op_registry_.count(op)) { @@ -1207,14 +1206,13 @@ tensorflow::Status ConvertPlugin(Converter& ctx, // plugin is owned by PluginFactory // TODO(jie): destroy plugins later (resource management) PluginTensorRT* plugin = - PluginFactoryTensorRT::GetInstance().CreatePlugin(&node_def.op()); + PluginFactoryTensorRT::GetInstance()->CreatePlugin(node_def.op()); // passing attributes // TODO(jie): support more general attribute TFAttrs attrs(node_def); auto attr_key_vector = attrs.GetAllAttrKey(); for (auto attr_key : attr_key_vector) { - std::cout << attr_key << std::endl; // TODO(jie): support only list of float for toy example here. auto data = attrs.get>(attr_key); size_t size_data = data.size() * sizeof(float); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc new file mode 100644 index 0000000000..2155079e8b --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -0,0 +1,89 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "inc_op_plugin.h" + +namespace tensorflow { +namespace tensorrt { + +const string IncOpPlugin::plugin_name_ = "IncPluginTRT"; + +IncOpPlugin* CreateIncPlugin() { + return new IncOpPlugin(); +} + + +IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { + return new IncOpPlugin(buffer, length); +} + +bool RegisterIncOpPlugin() { + if (PluginFactoryTensorRT::GetInstance()->IsPlugin(IncOpPlugin::plugin_name_)) + return false; + return PluginFactoryTensorRT::GetInstance()->RegisterPlugin(IncOpPlugin::plugin_name_, CreateIncPluginDeserialize, CreateIncPlugin); +} + + +IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) : + PluginTensorRT(serialized_data, length) +{ + // account for the consumed pointer. + size_t consumed_data = PluginTensorRT::getSerializationSize(); + assert(length-consumed_data >= sizeof(float)); + SetAttribute("inc", serialized_data+consumed_data, sizeof(float)); +} + +bool IncOpPlugin::SetAttribute(const string &key, const void *ptr, const size_t size) { + if (strcmp(key.c_str(), "inc")==0 && size == sizeof(float)) { + StoreAttribute(key, ptr, size); // save the attribute to own the data; + inc_ = *static_cast(ptr); + return true; + } + return false; +} + +bool IncOpPlugin::GetAttribute(const string &key, const void *ptr, size_t &size) { + if (attr_map_.find(key) != attr_map_.end()) { + ptr = attr_map_[key].data(); + size = attr_map_[key].size(); + return true; + } + return false; +} + +int IncOpPlugin::enqueue(int batchSize, const void*const *inputs, void** outputs, void*, cudaStream_t stream) { + int count = 1; + for (int i=0; i(inputs[0]); + float *output = reinterpret_cast(outputs[0]); + IncrementKernel(input, inc_, output, count, stream); + return 0; +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h new file mode 100644 index 0000000000..52b68487e6 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -0,0 +1,114 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN +#define TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include +#include +#include +#include +#include +#include + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +using std::string; +using std::unordered_map; + +class IncOpPlugin : public PluginTensorRT +{ +public: + static const string plugin_name_; + IncOpPlugin() {}; + IncOpPlugin(const void* serialized_data, size_t length); + const string GetPluginName() override {return plugin_name_;}; + bool Finalize() override {return true;}; + bool SetAttribute(const string &key, const void *ptr, const size_t size) override; + bool GetAttribute(const string &key, const void *ptr, size_t &size) override; + + // TRT IPlugin methods + int getNbOutputs() const override {return 1;} + + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override { + assert(index==0); + assert(nbInputDims==1); + return inputs[0]; + } + + // no configure needed + // use configure to setup input dimensions + void configure(const nvinfer1::Dims *inputs, int nbInputs, const nvinfer1::Dims *outputs, int nbOutputs, int maxBatchSize) override { + assert(nbInputs==1); + PluginTensorRT::configure(inputs, nbInputs, outputs, nbOutputs, maxBatchSize); + return; + } + + int initialize() override { + return 0; + } + + void terminate() override { + return; + } + + size_t getWorkspaceSize(int maxBatchSize) const override { + return 0; + } + + int enqueue(int batchSize, const void*const *inputs, void** outputs, void* workspace, cudaStream_t stream) override; + + size_t getSerializationSize() override { + return PluginTensorRT::getSerializationSize() + sizeof(float); + } + + void serialize(void* buffer) override { + // serializa parent stuff + // OpName + PluginTensorRT::serialize(buffer); + + // incremented buffer after parent serialization; + buffer = static_cast(buffer) + PluginTensorRT::getSerializationSize(); + + std::memcpy(buffer, &inc_, sizeof(float)); + buffer = static_cast(buffer) + sizeof(float); + return; + } + +protected: + float inc_; + nvinfer1::Dims dim_; + // std::unordered_map > attr_map_; +}; + +IncOpPlugin* CreateIncPlugin(); +IncOpPlugin* CreateIncPluginDeserialize(const void*, size_t); +bool RegisterIncOpPlugin(); +void IncrementKernel(const float* d_input, float inc, float* d_output, int count, cudaStream_t stream); + + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 8881c48fe6..162301fb52 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -60,7 +60,7 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { IRuntime* infer = nvinfer1::createInferRuntime(logger); trt_engine_ptr_.reset(infer->deserializeCudaEngine( serialized_engine.c_str(), serialized_engine.size(), - &PluginFactoryTensorRT::GetInstance())); + PluginFactoryTensorRT::GetInstance())); trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); // Runtime is safe to delete after engine creation infer->destroy(); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc index 0e4a157d79..7600703775 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc @@ -26,10 +26,10 @@ namespace tensorrt { PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) { // sanity check. - assert(EncodeOpName(GetPluginName()) != - *static_cast(serialized_data)); - const char* buffer = static_cast(serialized_data) + - sizeof(input_dim_list_.size()); + const char* buffer = static_cast(serialized_data); + size_t op_name_char_count = *reinterpret_cast(buffer); + buffer += sizeof(size_t); + buffer += op_name_char_count; size_t count = *reinterpret_cast(buffer); buffer += sizeof(size_t); @@ -46,18 +46,37 @@ PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) { } } +void PluginTensorRT::configure(const nvinfer1::Dims* inputs, int num_inputs, + const nvinfer1::Dims* outputs, int num_outputs, + int max_batch_size) { + for (int index = 0; index < num_inputs; index++) { + nvinfer1::Dims dim; + dim.nbDims = inputs[index].nbDims; + for (int i = 0; i < dim.nbDims; i++) { + dim.d[i] = inputs[index].d[i]; + dim.type[i] = inputs[index].type[i]; + } + input_dim_list_.emplace_back(dim); + } + return; +} + size_t PluginTensorRT::getSerializationSize() { nvinfer1::Dims dim; - return sizeof(size_t) + sizeof(input_dim_list_.size()) + sizeof(dim.nbDims) + - sizeof(dim.d) + sizeof(dim.type); + return sizeof(size_t) + GetPluginName().size() + + sizeof(input_dim_list_.size()) + sizeof(dim.nbDims) + sizeof(dim.d) + + sizeof(dim.type); } void PluginTensorRT::serialize(void* serialized_data) { - size_t encode_op_name = EncodeOpName(GetPluginName()); + size_t op_name_size = GetPluginName().size(); char* buffer = static_cast(serialized_data); - std::memcpy(buffer, &encode_op_name, sizeof(size_t)); + std::memcpy(buffer, &op_name_size, sizeof(size_t)); buffer += sizeof(size_t); + std::memcpy(buffer, GetPluginName().data(), op_name_size); + buffer += op_name_size; + auto list_size = input_dim_list_.size(); std::memcpy(buffer, &list_size, sizeof(input_dim_list_.size())); buffer += sizeof(input_dim_list_.size()); @@ -73,7 +92,7 @@ void PluginTensorRT::serialize(void* serialized_data) { } } -bool PluginTensorRT::StoreAttribute(const string& key, const void* ptr, +bool PluginTensorRT::StoreAttribute(const std::string& key, const void* ptr, const size_t size) { if (attr_map_.count(key) != 0) return false; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h index 1bbfe62a4e..59b92657f6 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -28,46 +28,37 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -using std::string; -using std::unordered_map; - +// A wrapper class for TensorRT plugin +// User application should inherit from this class to write custom kernels. +// Allows user to insert custom op in TensorRT engine +// To register plugin in converter, user should also register custom +// tensorflow::tensorrt::PluginDeserializeFunc & +// tensorflow::tensorrt::PluginConstructFunc through +// tensorflow::tensorrt::PluginFactoryTensorRT class PluginTensorRT : public nvinfer1::IPlugin { public: PluginTensorRT(){}; PluginTensorRT(const void* serialized_data, size_t length); - // PluginTensorRT(const void* serialized_data, size_t length, size_t - // &incremental); - virtual string GetPluginName() = 0; + virtual const std::string& GetPluginName() = 0; virtual bool Finalize() = 0; - virtual bool SetAttribute(const string& key, const void* ptr, + virtual bool SetAttribute(const std::string& key, const void* ptr, const size_t size) = 0; - virtual bool GetAttribute(const string& key, const void* ptr, + virtual bool GetAttribute(const std::string& key, const void* ptr, size_t& size) = 0; - void configure(const nvinfer1::Dims* inputs, int nbInputs, - const nvinfer1::Dims* outputs, int nbOutputs, - int maxBatchSize) override { - for (int index = 0; index < nbInputs; index++) { - nvinfer1::Dims dim; - dim.nbDims = inputs[index].nbDims; - for (int i = 0; i < dim.nbDims; i++) { - dim.d[i] = inputs[index].d[i]; - dim.type[i] = inputs[index].type[i]; - } - input_dim_list_.emplace_back(dim); - } - return; - } - - virtual bool StoreAttribute(const string& key, const void* ptr, + void configure(const nvinfer1::Dims* inputs, int num_inputs, + const nvinfer1::Dims* outputs, int num_outputs, + int max_batch_size) override; + + virtual bool StoreAttribute(const std::string& key, const void* ptr, const size_t size); virtual size_t getSerializationSize() override; virtual void serialize(void* buffer) override; protected: - std::unordered_map > attr_map_; + std::unordered_map > attr_map_; std::vector input_dim_list_; }; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index 799c609a3e..44b10394c8 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -21,13 +21,13 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layerName, +PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, const void* serial_data, size_t serial_length) { size_t parsed_byte = 0; // extract op_name from serial_data - size_t encoded_op_name = - ExtractOpName(serial_data, serial_length, parsed_byte); + std::string encoded_op_name = + ExtractOpName(serial_data, serial_length, &parsed_byte); if (!IsPlugin(encoded_op_name)) { return nullptr; @@ -37,20 +37,18 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layerName, instance_m_.lock(); auto plugin_ptr = plugin_registry_[encoded_op_name].first(serial_data, serial_length); - // string op_name = "IncPluginTRT"; - // auto plugin_ptr = plugin_registry_[EncodeLayerName(&op_name)].second(); - // auto plugin_ptr = plugin_registry_.begin()->second.second(); owned_plugins_.emplace_back(plugin_ptr); instance_m_.unlock(); return plugin_ptr; } -PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string* op_name) { +PluginTensorRT* PluginFactoryTensorRT::CreatePlugin( + const std::string& op_name) { if (!IsPlugin(op_name)) return nullptr; instance_m_.lock(); - auto plugin_ptr = plugin_registry_[EncodeLayerName(op_name)].second(); + auto plugin_ptr = plugin_registry_[op_name].second(); owned_plugins_.emplace_back(plugin_ptr); instance_m_.unlock(); @@ -58,21 +56,27 @@ PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string* op_name) { } bool PluginFactoryTensorRT::RegisterPlugin( - const string* op_name, PluginDeserializeFunc deserialize_func, + const std::string& op_name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func) { if (IsPlugin(op_name)) return false; // get instance_m_ first before write to registry; instance_m_.lock(); auto ret = plugin_registry_.emplace( - EncodeLayerName(op_name), - std::make_pair(deserialize_func, construct_func)); + op_name, std::make_pair(deserialize_func, construct_func)); instance_m_.unlock(); return ret.second; } -void PluginFactoryTensorRT::DestroyPlugins() { return; } +void PluginFactoryTensorRT::DestroyPlugins() { + instance_m_.lock(); + for (auto& owned_plugin_ptr : owned_plugins_) { + owned_plugin_ptr.release(); + } + owned_plugins_.clear(); + instance_m_.unlock(); +} } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index e68f4629d0..824efcff35 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -32,39 +32,34 @@ namespace tensorrt { class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { public: // deserialization method - // virtual nvinfer1::IPlugin* createPlugin(const char* layerName, const void* - // serialData, size_t serialLength) override; - PluginTensorRT* createPlugin(const char* layerName, const void* serialData, - size_t serialLength) override; + PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data, + size_t serial_length) override; - // construction - PluginTensorRT* CreatePlugin(const string* op_name); + // plugin construction, PluginFactoryTensorRT owns the plugin; + PluginTensorRT* CreatePlugin(const std::string& op_name); - static PluginFactoryTensorRT& GetInstance() { - static PluginFactoryTensorRT factory_instance; + static PluginFactoryTensorRT* GetInstance() { + static PluginFactoryTensorRT* factory_instance = nullptr; + if (factory_instance == nullptr) { + factory_instance = new PluginFactoryTensorRT(); + } return factory_instance; } - bool RegisterPlugin(const string* op_name, + bool RegisterPlugin(const std::string& op_name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func); - bool IsPlugin(const size_t encode_name) { - return plugin_registry_.find(encode_name) != plugin_registry_.end(); + bool IsPlugin(const std::string& op_name) { + return plugin_registry_.find(op_name) != plugin_registry_.end(); } - bool IsPlugin(const string* op_name) { - return IsPlugin(EncodeLayerName(op_name)); - } - - size_t EncodeLayerName(const string* op_name) { - return EncodeOpName(*op_name); - } + size_t CountOwnedPlugins() { return owned_plugins_.size(); } void DestroyPlugins(); protected: - std::unordered_map > plugin_registry_; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc index b14480cfa6..8b65e8b41c 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -21,12 +22,17 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -size_t ExtractOpName(const void* serial_data, size_t serial_length, - size_t& incremental) { - incremental = sizeof(size_t); - if (serial_length < incremental) return 0; - size_t encoded_op_name = *static_cast(serial_data); - return encoded_op_name; +std::string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental) { + size_t op_name_char_count = *static_cast(serial_data); + *incremental = sizeof(size_t) + op_name_char_count; + + assert(serial_length >= *incremental); + + const char* buffer = static_cast(serial_data) + sizeof(size_t); + std::string op_name(buffer, op_name_char_count); + + return op_name; } } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h index e9675d84cd..d4da8b261e 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -31,16 +31,9 @@ typedef std::function typedef std::function PluginConstructFunc; -inline size_t EncodeOpName(std::string str) { - return std::hash{}(str); -} - // TODO(jie): work on error handling here -size_t ExtractOpName(const void* serial_data, size_t serial_length, - size_t& incremental); - -// size_t Deserialize(const char* serial_data, size_t serial_length, size_t -// &incremental); +std::string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc new file mode 100644 index 0000000000..2856b0f87d --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +namespace test { + +class StubPlugin : public PluginTensorRT { + public: + static const std::string plugin_name_; + StubPlugin(){}; + StubPlugin(const void* serialized_data, size_t length) + : PluginTensorRT(serialized_data, length){}; + const std::string& GetPluginName() override { return plugin_name_; }; + virtual bool Finalize() { return true; }; + virtual bool SetAttribute(const std::string& key, const void* ptr, + const size_t size) { + return true; + }; + virtual bool GetAttribute(const std::string& key, const void* ptr, + size_t& size) { + return true; + }; + int getNbOutputs() const override { return 1; } + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, + int nbInputDims) override { + return inputs[0]; + } + int initialize() override { return 0; } + void terminate() override { return; } + size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } + int enqueue(int batchSize, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) override { + return 0; + } +}; + +const std::string StubPlugin::plugin_name_ = "StubPlugin"; + +StubPlugin* CreateStubPlugin() { return new StubPlugin(); } + +StubPlugin* CreateStubPluginDeserialize(const void* serialized_data, + size_t length) { + return new StubPlugin(serialized_data, length); +} + +class PluginTest : public ::testing::Test { + public: + bool RegisterStubPlugin() { + if (PluginFactoryTensorRT::GetInstance()->IsPlugin( + StubPlugin::plugin_name_)) + return true; + return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( + StubPlugin::plugin_name_, CreateStubPluginDeserialize, + CreateStubPlugin); + } + + protected: +}; + +TEST_F(PluginTest, Registration) { + EXPECT_FALSE( + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); + EXPECT_TRUE(RegisterStubPlugin()); + + ASSERT_TRUE( + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); +} + +TEST_F(PluginTest, CreationDeletion) { + EXPECT_TRUE(RegisterStubPlugin()); + ASSERT_TRUE( + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); + + PluginFactoryTensorRT::GetInstance()->DestroyPlugins(); + ASSERT_TRUE(PluginFactoryTensorRT::GetInstance()->CreatePlugin( + StubPlugin::plugin_name_)); + ASSERT_EQ(1, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins()); + PluginFactoryTensorRT::GetInstance()->DestroyPlugins(); + ASSERT_EQ(0, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins()); +} + +} // namespace test +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index 30b5616475..f36495f6b6 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -35,7 +35,7 @@ tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine( serialized_engine.c_str(), serialized_engine.size(), - &tensorrt::PluginFactoryTensorRT::GetInstance()); + tensorrt::PluginFactoryTensorRT::GetInstance()); int num_batch = -1; std::vector<::tensorflow::DataType> input_type; -- GitLab From 758f25e8168bf1ff76c63a5b54dfd50ff54e4e27 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Tue, 17 Apr 2018 15:47:33 -0700 Subject: [PATCH 008/839] Fix calculation of the histogram buckets and writing to the tensor and add a unit test --- .../tensorboard/db/summary_db_writer.cc | 21 +++++--- .../tensorboard/db/summary_db_writer_test.cc | 49 +++++++++++++++++++ 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index 6590d6f7df..046a2d3884 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -1182,14 +1182,19 @@ class SummaryDbWriter : public SummaryWriterInterface { // See tensorboard/plugins/histogram/summary.py and data_compat.py Tensor t{DT_DOUBLE, {k, 3}}; auto data = t.flat(); - for (int i = 0; i < k; ++i) { - double left_edge = ((i - 1 >= 0) ? histo.bucket_limit(i - 1) - : std::numeric_limits::min()); - double right_edge = ((i + 1 < k) ? histo.bucket_limit(i + 1) - : std::numeric_limits::max()); - data(i + 0) = left_edge; - data(i + 1) = right_edge; - data(i + 2) = histo.bucket(i); + for (int i = 0, j = 0; i < k; ++i) { + // From summary.proto + // Parallel arrays encoding the bucket boundaries and the bucket values. + // bucket(i) is the count for the bucket i. The range for + // a bucket is: + // i == 0: -DBL_MAX .. bucket_limit(0) + // i != 0: bucket_limit(i-1) .. bucket_limit(i) + double left_edge = (i == 0) ? std::numeric_limits::min() + : histo.bucket_limit(i - 1); + + data(j++) = left_edge; + data(j++) = histo.bucket_limit(i); + data(j++) = histo.bucket(i); } int64 tag_id; PatchPluginName(s->mutable_metadata(), kHistogramPluginName); diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc index 29b8063218..cb51325d15 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -100,6 +100,55 @@ class SummaryDbWriterTest : public ::testing::Test { SummaryWriterInterface* writer_ = nullptr; }; +TEST_F(SummaryDbWriterTest, WriteHistogram_VerifyTensorValues) { + TF_ASSERT_OK(CreateSummaryDbWriter(db_, "histtest", "test1", "user1", &env_, + &writer_)); + int step = 0; + std::unique_ptr e{new Event}; + e->set_step(step); + e->set_wall_time(123); + Summary::Value* s = e->mutable_summary()->add_value(); + s->set_tag("normal/myhisto"); + + double dummy_value = 10.123; + HistogramProto* proto = s->mutable_histo(); + proto->Clear(); + proto->set_min(dummy_value); + proto->set_max(dummy_value); + proto->set_num(dummy_value); + proto->set_sum(dummy_value); + proto->set_sum_squares(dummy_value); + + int size = 3; + double bucket_limits[] = {-30.5, -10.5, -5.5}; + double bucket[] = {-10, 10, 20}; + for (int i = 0; i < size; i++) { + proto->add_bucket_limit(bucket_limits[i]); + proto->add_bucket(bucket[i]); + } + TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); + TF_ASSERT_OK(writer_->Flush()); + writer_->Unref(); + writer_ = nullptr; + + // Verify the data + string result = QueryString("SELECT data FROM Tensors"); + const double* val = reinterpret_cast(result.data()); + double histarray[] = {std::numeric_limits::min(), + -30.5, + -10, + -30.5, + -10.5, + 10, + -10.5, + -5.5, + 20}; + int histarray_size = 9; + for (int i = 0; i < histarray_size; i++) { + EXPECT_EQ(histarray[i], val[i]); + } +} + TEST_F(SummaryDbWriterTest, NothingWritten_NoRowsCreated) { TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_, &writer_)); -- GitLab From 419dbc8f44efe06612845ec291b98bb49e873639 Mon Sep 17 00:00:00 2001 From: Jie Date: Wed, 18 Apr 2018 14:42:42 -0700 Subject: [PATCH 009/839] [PR comment addressed] Added custom plugin example registered tensorflow custom op & plugin kernel python wrapper to import custom op & register plugin clang-format --- tensorflow/contrib/tensorrt/BUILD | 1 + .../contrib/tensorrt/convert/convert_nodes.cc | 2 +- .../tensorrt/custom_plugin_examples/BUILD | 110 ++++++++++++++++++ .../custom_plugin_examples/__init__.py | 24 ++++ .../tensorrt/custom_plugin_examples/inc_op.py | 30 +++++ .../inc_op_kernel.cu.cc | 44 +++++++ .../custom_plugin_examples/inc_op_kernel.h | 34 ++++++ .../custom_plugin_examples/inc_op_plugin.cc | 55 +++++---- .../custom_plugin_examples/inc_op_plugin.h | 85 +++++++------- .../custom_plugin_examples/ops/inc_op.cc | 34 ++++++ .../custom_plugin_examples/plugin_wrap.i | 31 +++++ .../test/plugin_test.py | 93 +++++++++++++++ .../contrib/tensorrt/plugin/trt_plugin.cc | 1 - .../contrib/tensorrt/plugin/trt_plugin.h | 10 +- .../tensorrt/plugin/trt_plugin_factory.cc | 14 +-- .../tensorrt/plugin/trt_plugin_factory.h | 6 +- .../tensorrt/plugin/trt_plugin_utils.cc | 4 +- .../tensorrt/plugin/trt_plugin_utils.h | 5 +- .../tensorrt/plugin/trt_plugins_test.cc | 6 +- 19 files changed, 485 insertions(+), 104 deletions(-) create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 751f1d3482..9c81c12705 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -291,6 +291,7 @@ tf_cuda_library( ], linkstatic = 1, deps = [ + "//tensorflow/core:platform_base", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]), diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index d02c1ebf50..874be96c78 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -249,7 +249,7 @@ class TFAttrs { std::vector GetAllAttrKey() { std::vector attr_list; - for (auto & attr_item : attrs_) { + for (const auto& attr_item : attrs_) { attr_list.emplace_back(attr_item.first); } return attr_list; diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD new file mode 100644 index 0000000000..5603ed0ccf --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -0,0 +1,110 @@ +package(default_visibility = ["//tensorflow:__subpackages__"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_cuda_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_py_wrap_cc", + "tf_copts", +) +load( + "@local_config_tensorrt//:build_defs.bzl", + "if_tensorrt", +) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") + +tf_kernel_library( + name = "_inc_op_plugin_kernel", + srcs = [ + "inc_op_plugin.cc", + ], + hdrs = [ + ], + gpu_srcs = [ + "inc_op_kernel.cu.cc", + "inc_op_kernel.h", + "inc_op_plugin.h", + ], + deps = if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + "//tensorflow/contrib/tensorrt:trt_plugins", + ]), +) + +tf_gen_op_libs( + op_lib_names = [ + "inc_op", + ], + deps = if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + "//tensorflow/contrib/tensorrt:trt_plugins", + ]), +) + +tf_gen_op_wrapper_py( + name = "inc_op", + gen_locally = True, + deps = [ + ":inc_op_op_lib", + ], +) + +tf_py_wrap_cc( + name = "plugin_wrap", + srcs = [ + "plugin_wrap.i", + ], + copts = tf_copts(), + deps = [ + ":_inc_op_plugin_kernel", + "//tensorflow/core:framework_lite", + "//util/python:python_headers", + ], +) + +tf_custom_op_library( + name = "_inc_op.so", + srcs = ["ops/inc_op.cc"], + deps = [ + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "//tensorflow/contrib/tensorrt:trt_plugins", + ]), +) + +tf_custom_op_py_library( + name = "inc_op_loader", + srcs = ["inc_op.py"], + dso = [ + ":_inc_op.so", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:resources", + ], +) + +py_library( + name = "inc_op_py", + srcs_version = "PY2AND3", + deps = [ + ":inc_op", + ":inc_op_loader", + ], +) + +py_library( + name = "init_py", + srcs = [ + "__init__.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":inc_op_py", + ":plugin_wrap", + ], +) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py new file mode 100644 index 0000000000..a61d008941 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Import custom op for plugin and register it in plugin factory registry.""" + +from ops import gen_inc_op +from plugin_wrap import inc_op_register +from inc_op import * + +# pylint: disable=unused-import,wildcard-import,g-import-not-at-top +inc_op = gen_inc_op.inc_plugin_trt +inc_op_register() +# pylint: enable=unused-import,wildcard-import,g-import-not-at-top diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py new file mode 100644 index 0000000000..ef8e26fbde --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py @@ -0,0 +1,30 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import platform +import os + +if platform.system() != "Windows": + from tensorflow.contrib.util import loader + from tensorflow.python.platform import resource_loader + + _inc_op = loader.load_op_library( + os.path.join(os.path.dirname(os.path.realpath(__file__)),"_inc_op.so")) +else: + raise RuntimeError("Windows not supported") diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc new file mode 100644 index 0000000000..5dd6b9bf94 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +__global__ void VecInc(const float* vec, float inc, float* dest, int n) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < n) dest[i] = vec[i] + inc; +} + +void IncrementKernel(const float* d_input, float inc, float* d_output, + int count, cudaStream_t stream) { + int threads_per_block = 256; + int blocks_per_grid = (count + threads_per_block - 1) / threads_per_block; + + VecInc<<>>(d_input, inc, + d_output, count); +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h new file mode 100644 index 0000000000..ec269143e8 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_INC_OP +#define TENSORFLOW_CONTRIB_TENSORRT_INC_OP + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +__global__ void VecInc(float* vec, float inc, float* dest, int n); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_INC_OP diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc index 2155079e8b..21617fa8b5 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -13,24 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "inc_op_plugin.h" namespace tensorflow { namespace tensorrt { -const string IncOpPlugin::plugin_name_ = "IncPluginTRT"; - -IncOpPlugin* CreateIncPlugin() { - return new IncOpPlugin(); -} +const std::string IncOpPlugin::plugin_name_ = "IncPluginTRT"; +IncOpPlugin* CreateIncPlugin() { return new IncOpPlugin(); } IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { return new IncOpPlugin(buffer, length); @@ -39,45 +34,49 @@ IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { bool RegisterIncOpPlugin() { if (PluginFactoryTensorRT::GetInstance()->IsPlugin(IncOpPlugin::plugin_name_)) return false; - return PluginFactoryTensorRT::GetInstance()->RegisterPlugin(IncOpPlugin::plugin_name_, CreateIncPluginDeserialize, CreateIncPlugin); + return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( + IncOpPlugin::plugin_name_, CreateIncPluginDeserialize, CreateIncPlugin); } - -IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) : - PluginTensorRT(serialized_data, length) -{ +IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) + : PluginTensorRT(serialized_data, length) { // account for the consumed pointer. size_t consumed_data = PluginTensorRT::getSerializationSize(); - assert(length-consumed_data >= sizeof(float)); - SetAttribute("inc", serialized_data+consumed_data, sizeof(float)); + assert(length - consumed_data >= sizeof(float)); + const char* buffer = reinterpret_cast(serialized_data); + SetAttribute("inc", buffer + consumed_data, sizeof(float)); } -bool IncOpPlugin::SetAttribute(const string &key, const void *ptr, const size_t size) { - if (strcmp(key.c_str(), "inc")==0 && size == sizeof(float)) { - StoreAttribute(key, ptr, size); // save the attribute to own the data; +bool IncOpPlugin::SetAttribute(const std::string& key, const void* ptr, + const size_t size) { + if (strcmp(key.c_str(), "inc") == 0 && size == sizeof(float)) { + StoreAttribute(key, ptr, size); // save the attribute to own the data; inc_ = *static_cast(ptr); return true; } return false; } -bool IncOpPlugin::GetAttribute(const string &key, const void *ptr, size_t &size) { - if (attr_map_.find(key) != attr_map_.end()) { - ptr = attr_map_[key].data(); - size = attr_map_[key].size(); +bool IncOpPlugin::GetAttribute(const std::string& key, const void** ptr, + size_t* size) const { + const auto& iter = attr_map_.find(key); + if (iter != attr_map_.end()) { + *ptr = iter->second.data(); + *size = iter->second.size(); return true; } return false; } -int IncOpPlugin::enqueue(int batchSize, const void*const *inputs, void** outputs, void*, cudaStream_t stream) { +int IncOpPlugin::enqueue(int batch_size, const void* const* inputs, + void** outputs, void*, cudaStream_t stream) { int count = 1; - for (int i=0; i(inputs[0]); - float *output = reinterpret_cast(outputs[0]); + count *= batch_size; + const float* input = reinterpret_cast(inputs[0]); + float* output = reinterpret_cast(outputs[0]); IncrementKernel(input, inc_, output, count, stream); return 0; } diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index 52b68487e6..a4774d354c 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -16,13 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN #define TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" -#include -#include -#include -#include #include +#include #include +#include +#include +#include +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -31,50 +31,44 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -using std::string; -using std::unordered_map; - -class IncOpPlugin : public PluginTensorRT -{ -public: - static const string plugin_name_; - IncOpPlugin() {}; +class IncOpPlugin : public PluginTensorRT { + public: + static const std::string plugin_name_; + IncOpPlugin(){}; IncOpPlugin(const void* serialized_data, size_t length); - const string GetPluginName() override {return plugin_name_;}; - bool Finalize() override {return true;}; - bool SetAttribute(const string &key, const void *ptr, const size_t size) override; - bool GetAttribute(const string &key, const void *ptr, size_t &size) override; - - // TRT IPlugin methods - int getNbOutputs() const override {return 1;} - - nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override { - assert(index==0); - assert(nbInputDims==1); + const std::string& GetPluginName() const override { return plugin_name_; }; + bool Finalize() override { return true; }; + bool SetAttribute(const std::string& key, const void* ptr, + const size_t size) override; + bool GetAttribute(const std::string& key, const void** ptr, + size_t* size) const override; + + int getNbOutputs() const override { return 1; } + + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, + int num_input_dims) override { + assert(index == 0); + assert(num_input_dims == 1); return inputs[0]; } - // no configure needed // use configure to setup input dimensions - void configure(const nvinfer1::Dims *inputs, int nbInputs, const nvinfer1::Dims *outputs, int nbOutputs, int maxBatchSize) override { - assert(nbInputs==1); - PluginTensorRT::configure(inputs, nbInputs, outputs, nbOutputs, maxBatchSize); - return; + void configure(const nvinfer1::Dims* inputs, int num_inputs, + const nvinfer1::Dims* outputs, int num_outputs, + int max_batch_size) override { + assert(nb_inputs == 1); + PluginTensorRT::configure(inputs, num_inputs, outputs, num_outputs, + max_batch_size); } - int initialize() override { - return 0; - } + int initialize() override { return 0; } - void terminate() override { - return; - } + void terminate() override {} - size_t getWorkspaceSize(int maxBatchSize) const override { - return 0; - } + size_t getWorkspaceSize(int max_batch_size) const override { return 0; } - int enqueue(int batchSize, const void*const *inputs, void** outputs, void* workspace, cudaStream_t stream) override; + int enqueue(int batch_size, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) override; size_t getSerializationSize() override { return PluginTensorRT::getSerializationSize() + sizeof(float); @@ -86,24 +80,23 @@ public: PluginTensorRT::serialize(buffer); // incremented buffer after parent serialization; - buffer = static_cast(buffer) + PluginTensorRT::getSerializationSize(); + buffer = + static_cast(buffer) + PluginTensorRT::getSerializationSize(); std::memcpy(buffer, &inc_, sizeof(float)); buffer = static_cast(buffer) + sizeof(float); - return; } -protected: + protected: float inc_; nvinfer1::Dims dim_; - // std::unordered_map > attr_map_; }; -IncOpPlugin* CreateIncPlugin(); +IncOpPlugin* CreateIncPlugin(); IncOpPlugin* CreateIncPluginDeserialize(const void*, size_t); bool RegisterIncOpPlugin(); -void IncrementKernel(const float* d_input, float inc, float* d_output, int count, cudaStream_t stream); - +void IncrementKernel(const float* d_input, float inc, float* d_output, + int count, cudaStream_t stream); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc new file mode 100644 index 0000000000..0dfead8f57 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +using namespace tensorflow; + +REGISTER_OP("IncPluginTRT") + .Attr("inc: list(float)") + .Input("input: float32") + .Output("output: float32") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i new file mode 100644 index 0000000000..9882daa842 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* Wrap inc_op_plugin */ +%module inc_op_plugin +%{ +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" +extern bool tensorflow::tensorrt::RegisterIncOpPlugin(); +%} + +%{ +bool inc_op_register() { + return tensorflow::tensorrt::RegisterIncOpPlugin(); +} +%} + +extern bool tensorflow::tensorrt::RegisterIncOpPlugin(); + +bool inc_op_register(); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py new file mode 100644 index 0000000000..52f49ae00e --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py @@ -0,0 +1,93 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to show usage of TensorRT custom op & plugin.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# normally we should do import tensorflow as tf and then +# tf.placeholder, tf.constant, tf.nn.conv2d etc but +# it looks like internal builds don't like it so +# importing every module individually + +from tensorflow.contrib import tensorrt as trt +from tensorflow.core.protobuf import config_pb2 as cpb2 +from tensorflow.python.client import session as csess +from tensorflow.python.framework import dtypes as dtypes +from tensorflow.python.framework import importer as importer +from tensorflow.python.framework import ops as ops +from tensorflow.python.ops import array_ops as aops +from tensorflow.python.ops import nn as nn +from tensorflow.python.ops import nn_ops as nn_ops +import numpy as np + +# import custom_op as plugin op +# the python api handles registration to the plugin factory +from tensorflow.contrib.tensorrt import custom_plugin_examples as cpe + +def get_plugin_graph_def(): + """Create a simple graph and return its graph_def.""" + g = ops.Graph() + with g.as_default(): + a = aops.placeholder( + dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") + relu = nn.relu(a, "relu") + v = nn_ops.max_pool( + relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") + + # insert custom_op in the graph + v = cpe.inc_op(v, inc=[16.5], name="plugin_test") + + v = v*2.0 + v = nn.relu(v) + v = nn.relu(v) + aops.squeeze(v, name="output") + return g.as_graph_def() + +def run_graph(gdef, dumm_inp): + """Run given graphdef once.""" + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + ops.reset_default_graph() + g = ops.Graph() + with g.as_default(): + inp, out = importer.import_graph_def( + graph_def=gdef, return_elements=["input", "output"]) + inp = inp.outputs[0] + out = out.outputs[0] + + with csess.Session( + config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: + val = sess.run(out, {inp: dumm_inp}) + return val + +if "__main__" in __name__: + inp_dims = (5, 24, 24, 2) + dummy_input = np.ones(inp_dims).astype(np.float32) + orig_graph = get_plugin_graph_def() # graph with plugin node + + # trigger conversion. + # plugin nodes have been registered during import, converter will be able to + # create corresponding plugin layer during conversion. + trt_graph = trt.create_inference_graph( + input_graph_def=orig_graph, + outputs=["output"], + max_batch_size=inp_dims[0], + max_workspace_size_bytes=1 << 25, + precision_mode="FP32", + minimum_segment_size=2 + ) + o2 = run_graph(trt_graph, dummy_input) + print (o2) diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc index 7600703775..82c549dbf5 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc @@ -58,7 +58,6 @@ void PluginTensorRT::configure(const nvinfer1::Dims* inputs, int num_inputs, } input_dim_list_.emplace_back(dim); } - return; } size_t PluginTensorRT::getSerializationSize() { diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h index 59b92657f6..772974a769 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -32,20 +32,18 @@ namespace tensorrt { // User application should inherit from this class to write custom kernels. // Allows user to insert custom op in TensorRT engine // To register plugin in converter, user should also register custom -// tensorflow::tensorrt::PluginDeserializeFunc & -// tensorflow::tensorrt::PluginConstructFunc through -// tensorflow::tensorrt::PluginFactoryTensorRT +// PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT class PluginTensorRT : public nvinfer1::IPlugin { public: PluginTensorRT(){}; PluginTensorRT(const void* serialized_data, size_t length); - virtual const std::string& GetPluginName() = 0; + virtual const std::string& GetPluginName() const = 0; virtual bool Finalize() = 0; virtual bool SetAttribute(const std::string& key, const void* ptr, const size_t size) = 0; - virtual bool GetAttribute(const std::string& key, const void* ptr, - size_t& size) = 0; + virtual bool GetAttribute(const std::string& key, const void** ptr, + size_t* size) const = 0; void configure(const nvinfer1::Dims* inputs, int num_inputs, const nvinfer1::Dims* outputs, int num_outputs, diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index 44b10394c8..776bce119d 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -33,12 +33,10 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, return nullptr; } - // should I lock plugins here? - instance_m_.lock(); + std::lock_guard lock(instance_m_); auto plugin_ptr = plugin_registry_[encoded_op_name].first(serial_data, serial_length); owned_plugins_.emplace_back(plugin_ptr); - instance_m_.unlock(); return plugin_ptr; } @@ -47,10 +45,9 @@ PluginTensorRT* PluginFactoryTensorRT::CreatePlugin( const std::string& op_name) { if (!IsPlugin(op_name)) return nullptr; - instance_m_.lock(); + std::lock_guard lock(instance_m_); auto plugin_ptr = plugin_registry_[op_name].second(); owned_plugins_.emplace_back(plugin_ptr); - instance_m_.unlock(); return plugin_ptr; } @@ -60,22 +57,19 @@ bool PluginFactoryTensorRT::RegisterPlugin( PluginConstructFunc construct_func) { if (IsPlugin(op_name)) return false; - // get instance_m_ first before write to registry; - instance_m_.lock(); + std::lock_guard lock(instance_m_); auto ret = plugin_registry_.emplace( op_name, std::make_pair(deserialize_func, construct_func)); - instance_m_.unlock(); return ret.second; } void PluginFactoryTensorRT::DestroyPlugins() { - instance_m_.lock(); + std::lock_guard lock(instance_m_); for (auto& owned_plugin_ptr : owned_plugins_) { owned_plugin_ptr.release(); } owned_plugins_.clear(); - instance_m_.unlock(); } } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 824efcff35..08fd376844 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -39,10 +39,8 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { PluginTensorRT* CreatePlugin(const std::string& op_name); static PluginFactoryTensorRT* GetInstance() { - static PluginFactoryTensorRT* factory_instance = nullptr; - if (factory_instance == nullptr) { - factory_instance = new PluginFactoryTensorRT(); - } + static PluginFactoryTensorRT* factory_instance = + new PluginFactoryTensorRT(); return factory_instance; } diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc index 8b65e8b41c..c5d3f38280 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc @@ -22,8 +22,8 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -std::string ExtractOpName(const void* serial_data, size_t serial_length, - size_t* incremental) { +string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental) { size_t op_name_char_count = *static_cast(serial_data); *incremental = sizeof(size_t) + op_name_char_count; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h index d4da8b261e..a94c67bba0 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -32,8 +33,8 @@ typedef std::function typedef std::function PluginConstructFunc; // TODO(jie): work on error handling here -std::string ExtractOpName(const void* serial_data, size_t serial_length, - size_t* incremental); +string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc index 2856b0f87d..9ef0fce972 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc @@ -51,9 +51,9 @@ class StubPlugin : public PluginTensorRT { return inputs[0]; } int initialize() override { return 0; } - void terminate() override { return; } + void terminate() override {} size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } - int enqueue(int batchSize, const void* const* inputs, void** outputs, + int enqueue(int batch_size, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override { return 0; } @@ -78,8 +78,6 @@ class PluginTest : public ::testing::Test { StubPlugin::plugin_name_, CreateStubPluginDeserialize, CreateStubPlugin); } - - protected: }; TEST_F(PluginTest, Registration) { -- GitLab From abfbbb86295c67eb1ac7c92235dbd5fb4b707169 Mon Sep 17 00:00:00 2001 From: Haggai Date: Wed, 18 Apr 2018 22:23:35 -0700 Subject: [PATCH 010/839] Remove reliance on TF core in XLA CPU Fft --- tensorflow/compiler/xla/service/cpu/BUILD | 1 - .../xla/service/cpu/runtime_fft_impl.h | 18 +++--------------- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 246b802861..6428ca528c 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -513,7 +513,6 @@ cc_library( deps = [ "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h index 984cb0616e..4f6b363364 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -21,8 +21,6 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/numeric_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/types.h" // 'tensorflow' namespace is used so that int64 and other types don't require @@ -71,11 +69,9 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = fft_shape[i]; out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -88,8 +84,8 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); // Compute the full FFT using a temporary tensor. - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(in_dims); + const Eigen::DSizes zero_start_indices; full_fft.device(device) = input.template fft(axes); @@ -112,11 +108,9 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; out_dims[i + 1] = fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -129,8 +123,7 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, // region we will slice from input given fft_shape. We slice input to // fft_shape on its inner-most dimensions, except the last (which we // slice to fft_shape[-1] / 2 + 1). - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(out_dims); // Calculate the starting point and range of the source of // negative frequency part. @@ -179,7 +172,6 @@ template void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, int32 fft_type, int64 input_batch, int64 fft_length0, int64 fft_length1, int64 fft_length2) { - CHECK(::xla::FftType_IsValid(fft_type)) << fft_type; switch (fft_type) { case ::xla::FftType::FFT: EigenFftC2C( @@ -203,8 +195,6 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, device, static_cast(out), static_cast(operand), input_batch, fft_length0, fft_length1, fft_length2); break; - default: - LOG(FATAL) << "Unsupported FFT type: " << fft_type; } } @@ -229,8 +219,6 @@ void EigenFftImpl(const EigenDevice& device, void* out, void* operand, input_batch, fft_length0, fft_length1, fft_length2); break; - default: - LOG(FATAL) << "Unsupported FFT rank " << fft_rank; } } -- GitLab From 6343b8dd77ba94c74acc3c04c985a5535b2b8169 Mon Sep 17 00:00:00 2001 From: Haggai Date: Wed, 18 Apr 2018 22:26:42 -0700 Subject: [PATCH 011/839] Add single-threaded support for XLA CPU Fft --- tensorflow/compiler/xla/service/cpu/BUILD | 17 ++++++++++ .../compiler/xla/service/cpu/cpu_runtime.cc | 2 ++ .../compiler/xla/service/cpu/cpu_runtime.h | 1 + .../compiler/xla/service/cpu/ir_emitter.cc | 8 ++++- .../cpu/runtime_single_threaded_fft.cc | 32 +++++++++++++++++++ .../service/cpu/runtime_single_threaded_fft.h | 31 ++++++++++++++++++ .../xla/service/cpu/simple_orc_jit.cc | 2 ++ 7 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 6428ca528c..4862f9e2f9 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -176,6 +176,7 @@ cc_library( ":runtime_matmul", ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", + ":runtime_single_threaded_fft", ":runtime_single_threaded_matmul", "@llvm//:execution_engine", "@llvm//:core", @@ -574,6 +575,22 @@ cc_library( ], ) +cc_library( + name = "runtime_single_threaded_fft", + srcs = [ + "runtime_fft_impl.h", + "runtime_single_threaded_fft.cc", + ], + hdrs = ["runtime_single_threaded_fft.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + cc_library( name = "runtime_single_threaded_matmul", srcs = ["runtime_single_threaded_matmul.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 872b0be1f8..4fcab483d6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -50,6 +50,8 @@ extern const char* const kEigenConvF16SymbolName = extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; +extern const char* const kEigenSingleThreadedFftSymbolName = + "__xla_cpu_runtime_EigenSingleThreadedFft"; extern const char* const kEigenSingleThreadedMatMulF16SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF16"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index e392e231b4..0cc45dac61 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -51,6 +51,7 @@ extern const char* const kMKLSingleThreadedMatMulF64SymbolName; extern const char* const kEigenConvF16SymbolName; extern const char* const kEigenConvF32SymbolName; extern const char* const kEigenFftSymbolName; +extern const char* const kEigenSingleThreadedFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF16SymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 3405277d44..8c2ca7104c 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1171,7 +1171,13 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); - const char* fn_name = runtime::kEigenFftSymbolName; + + bool multi_threaded_eigen = + hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + const char* fn_name = multi_threaded_eigen + ? runtime::kEigenFftSymbolName + : runtime::kEigenSingleThreadedFftSymbolName; + llvm::Function* fft_func = llvm::cast( module_->getOrInsertFunction(fn_name, fft_type)); fft_func->setCallingConv(llvm::CallingConv::C); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc new file mode 100644 index 0000000000..2613ddb127 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" + +#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* run_options_ptr, void* out, void* operand, int32 fft_type, + int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand, fft_type, + fft_rank, input_batch, fft_length0, fft_length1, + fft_length2); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h new file mode 100644 index 0000000000..dcd133d012 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, + void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + tensorflow::int64 input_batch, tensorflow::int64 fft_length0, + tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index b7ce5bbe47..7bd17002e3 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" #include "tensorflow/compiler/xla/types.h" @@ -190,6 +191,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); -- GitLab From 8149077125f6c2701713fef12fe0f0caac729e27 Mon Sep 17 00:00:00 2001 From: Krish Ravindranath Date: Thu, 19 Apr 2018 14:53:10 -0400 Subject: [PATCH 012/839] changes error to ValueError, notes that shuffle must be provided and should be set True for training --- tensorflow/python/estimator/inputs/numpy_io.py | 5 +++-- tensorflow/python/estimator/inputs/pandas_io.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py index a6f4712910..5b5eb41466 100644 --- a/tensorflow/python/estimator/inputs/numpy_io.py +++ b/tensorflow/python/estimator/inputs/numpy_io.py @@ -139,8 +139,9 @@ def numpy_input_fn(x, TypeError: `x` is not a dict or array, or if `shuffle` is not bool. """ if not isinstance(shuffle, bool): - raise TypeError('shuffle must be explicitly set as boolean; ' - 'got {}'.format(shuffle)) + raise ValueError('shuffle must be provided and explicitly set as boolean ' + '(it is recommended to set it as True for training); ' + 'got {}'.format(shuffle)) def input_fn(): """Numpy input function.""" diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py index bd06843021..16825e09de 100644 --- a/tensorflow/python/estimator/inputs/pandas_io.py +++ b/tensorflow/python/estimator/inputs/pandas_io.py @@ -75,8 +75,9 @@ def pandas_input_fn(x, 'pandas_input_fn should not be called without pandas installed') if not isinstance(shuffle, bool): - raise TypeError('shuffle must be explicitly set as boolean; ' - 'got {}'.format(shuffle)) + raise ValueError('shuffle must be provided and explicitly set as boolean ' + '(it is recommended to set it as True for training); ' + 'got {}'.format(shuffle)) x = x.copy() if y is not None: -- GitLab From 459d61cbe8ab9cbb86b2bb7eac602ff565d54fde Mon Sep 17 00:00:00 2001 From: Jie Date: Thu, 19 Apr 2018 13:48:14 -0700 Subject: [PATCH 013/839] [PR comment addressed] switched from std::string to TF string custom_plugin_examples python test added (bazel) style guide violation addressed --- .../contrib/tensorrt/convert/convert_nodes.cc | 22 ++--- .../tensorrt/custom_plugin_examples/BUILD | 42 ++++++--- .../custom_plugin_examples/__init__.py | 12 +-- .../inc_op_kernel.cu.cc | 2 - .../custom_plugin_examples/inc_op_kernel.h | 3 +- .../{inc_op_plugin.cc => inc_op_plugin.cu.cc} | 9 +- .../custom_plugin_examples/inc_op_plugin.h | 18 ++-- .../custom_plugin_examples/ops/inc_op.cc | 4 +- .../{test => }/plugin_test.py | 46 +++++----- tensorflow/contrib/tensorrt/log/trt_logger.h | 2 +- .../contrib/tensorrt/plugin/trt_plugin.cc | 3 +- .../contrib/tensorrt/plugin/trt_plugin.h | 14 +-- .../tensorrt/plugin/trt_plugin_factory.cc | 7 +- .../tensorrt/plugin/trt_plugin_factory.h | 8 +- .../tensorrt/plugin/trt_plugin_utils.cc | 2 +- .../tensorrt/plugin/trt_plugins_test.cc | 19 ++-- tensorflow/contrib/tensorrt/plugin_test.py | 88 +++++++++++++++++++ .../tensorrt/resources/trt_resources.h | 2 +- 18 files changed, 205 insertions(+), 98 deletions(-) rename tensorflow/contrib/tensorrt/custom_plugin_examples/{inc_op_plugin.cc => inc_op_plugin.cu.cc} (91%) rename tensorflow/contrib/tensorrt/custom_plugin_examples/{test => }/plugin_test.py (67%) create mode 100644 tensorflow/contrib/tensorrt/plugin_test.py diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 874be96c78..c8a96e5dba 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -241,9 +241,9 @@ class TFAttrs { return attrs_.at(key); } template - T get(string key) const; + T get(const string& key) const; template - T get(string key, const T& default_value) const { + T get(const string& key, const T& default_value) const { return attrs_.count(key) ? this->get(key) : default_value; } @@ -261,29 +261,29 @@ class TFAttrs { }; template <> -string TFAttrs::get(string key) const { +string TFAttrs::get(const string& key) const { return this->at(key)->s(); } template <> -std::vector TFAttrs::get>(string key) const { +std::vector TFAttrs::get>(const string& key) const { auto attr = this->at(key)->list().i(); return std::vector(attr.begin(), attr.end()); } template <> -std::vector TFAttrs::get>(string key) const { +std::vector TFAttrs::get>(const string& key) const { auto attr = this->at(key)->list().f(); return std::vector(attr.begin(), attr.end()); } template <> -std::vector TFAttrs::get>(string key) const { +std::vector TFAttrs::get>(const string& key) const { auto attr = this->at(key)->list().s(); return std::vector(attr.begin(), attr.end()); } template <> -nvinfer1::Dims TFAttrs::get(string key) const { +nvinfer1::Dims TFAttrs::get(const string& key) const { auto values = this->get>(key); nvinfer1::Dims dims; dims.nbDims = values.size(); @@ -293,24 +293,24 @@ nvinfer1::Dims TFAttrs::get(string key) const { } template <> -nvinfer1::DataType TFAttrs::get(string key) const { +nvinfer1::DataType TFAttrs::get(const string& key) const { nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype)); return trt_dtype; } template <> -tensorflow::DataType TFAttrs::get(string key) const { +tensorflow::DataType TFAttrs::get(const string& key) const { return this->at(key)->type(); } template <> -float TFAttrs::get(string key) const { +float TFAttrs::get(const string& key) const { return this->at(key)->f(); } template <> -bool TFAttrs::get(string key) const { +bool TFAttrs::get(const string& key) const { return this->at(key)->b(); } diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index 5603ed0ccf..3b1a7fb6f3 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -1,3 +1,9 @@ +# Description: +# Example for plugin support in TensorRT(http://developer.nvidia.com/tensorrt) +# through TensorFlow integration. Targeting TensorRT 3.0.4 +# APIs are meant to change while upgrading TRT. +# add init_py into pip package BUILD dependency to install it. + package(default_visibility = ["//tensorflow:__subpackages__"]) load( @@ -8,6 +14,7 @@ load( "tf_gen_op_wrapper_py", "tf_py_wrap_cc", "tf_copts", + "tf_py_test", ) load( "@local_config_tensorrt//:build_defs.bzl", @@ -18,19 +25,16 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library") tf_kernel_library( name = "_inc_op_plugin_kernel", - srcs = [ - "inc_op_plugin.cc", - ], - hdrs = [ - ], gpu_srcs = [ "inc_op_kernel.cu.cc", "inc_op_kernel.h", + "inc_op_plugin.cu.cc", "inc_op_plugin.h", ], - deps = if_tensorrt([ - "@local_config_tensorrt//:nv_infer", + deps = [ "//tensorflow/contrib/tensorrt:trt_plugins", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", ]), ) @@ -38,9 +42,10 @@ tf_gen_op_libs( op_lib_names = [ "inc_op", ], - deps = if_tensorrt([ - "@local_config_tensorrt//:nv_infer", + deps = [ "//tensorflow/contrib/tensorrt:trt_plugins", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", ]), ) @@ -70,9 +75,8 @@ tf_custom_op_library( srcs = ["ops/inc_op.cc"], deps = [ "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ "//tensorflow/contrib/tensorrt:trt_plugins", - ]), + ], ) tf_custom_op_py_library( @@ -97,6 +101,22 @@ py_library( ], ) +tf_py_test( + name = "plugin_test", + size = "small", + srcs = [ + "plugin_test.py", + ], + additional_deps = [ + ":init_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/contrib/tensorrt:init_py", + "//tensorflow/python:platform", + "//tensorflow/python:client_testlib", + "//tensorflow/python:tf_optimizer", + ], +) + py_library( name = "init_py", srcs = [ diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py index a61d008941..e4cd0ae8a0 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py @@ -14,11 +14,13 @@ # ============================================================================= """Import custom op for plugin and register it in plugin factory registry.""" -from ops import gen_inc_op -from plugin_wrap import inc_op_register -from inc_op import * +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tensorrt.custom_plugin_examples.ops import gen_inc_op +from tensorflow.contrib.tensorrt.custom_plugin_examples.plugin_wrap import inc_op_register +from tensorflow.contrib.tensorrt.custom_plugin_examples import inc_op as import_inc_op_so -# pylint: disable=unused-import,wildcard-import,g-import-not-at-top inc_op = gen_inc_op.inc_plugin_trt inc_op_register() -# pylint: enable=unused-import,wildcard-import,g-import-not-at-top diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index 5dd6b9bf94..38e1e01d95 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -14,10 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" -#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" #if GOOGLE_CUDA -#define EIGEN_USE_GPU #if GOOGLE_TENSORRT namespace tensorflow { diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h index ec269143e8..13156dad8f 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h @@ -17,13 +17,14 @@ limitations under the License. #define TENSORFLOW_CONTRIB_TENSORRT_INC_OP #if GOOGLE_CUDA -#define EIGEN_USE_GPU #if GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { __global__ void VecInc(float* vec, float inc, float* dest, int n); +void IncrementKernel(const float* d_input, float inc, float* d_output, + int count, cudaStream_t stream); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc similarity index 91% rename from tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc rename to tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc index 21617fa8b5..508ced587b 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" +#include +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA @@ -23,7 +24,7 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -const std::string IncOpPlugin::plugin_name_ = "IncPluginTRT"; +const string IncOpPlugin::plugin_name_ = "IncPluginTRT"; IncOpPlugin* CreateIncPlugin() { return new IncOpPlugin(); } @@ -47,7 +48,7 @@ IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) SetAttribute("inc", buffer + consumed_data, sizeof(float)); } -bool IncOpPlugin::SetAttribute(const std::string& key, const void* ptr, +bool IncOpPlugin::SetAttribute(const string& key, const void* ptr, const size_t size) { if (strcmp(key.c_str(), "inc") == 0 && size == sizeof(float)) { StoreAttribute(key, ptr, size); // save the attribute to own the data; @@ -57,7 +58,7 @@ bool IncOpPlugin::SetAttribute(const std::string& key, const void* ptr, return false; } -bool IncOpPlugin::GetAttribute(const std::string& key, const void** ptr, +bool IncOpPlugin::GetAttribute(const string& key, const void** ptr, size_t* size) const { const auto& iter = attr_map_.find(key); if (iter != attr_map_.end()) { diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index a4774d354c..87404a755c 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -18,10 +18,6 @@ limitations under the License. #include #include -#include -#include -#include -#include #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #if GOOGLE_CUDA @@ -33,14 +29,14 @@ namespace tensorrt { class IncOpPlugin : public PluginTensorRT { public: - static const std::string plugin_name_; - IncOpPlugin(){}; + static const string plugin_name_; + IncOpPlugin() {}; IncOpPlugin(const void* serialized_data, size_t length); - const std::string& GetPluginName() const override { return plugin_name_; }; + const string& GetPluginName() const override { return plugin_name_; }; bool Finalize() override { return true; }; - bool SetAttribute(const std::string& key, const void* ptr, + bool SetAttribute(const string& key, const void* ptr, const size_t size) override; - bool GetAttribute(const std::string& key, const void** ptr, + bool GetAttribute(const string& key, const void** ptr, size_t* size) const override; int getNbOutputs() const override { return 1; } @@ -56,7 +52,7 @@ class IncOpPlugin : public PluginTensorRT { void configure(const nvinfer1::Dims* inputs, int num_inputs, const nvinfer1::Dims* outputs, int num_outputs, int max_batch_size) override { - assert(nb_inputs == 1); + assert(num_inputs == 1); PluginTensorRT::configure(inputs, num_inputs, outputs, num_outputs, max_batch_size); } @@ -95,8 +91,6 @@ class IncOpPlugin : public PluginTensorRT { IncOpPlugin* CreateIncPlugin(); IncOpPlugin* CreateIncPluginDeserialize(const void*, size_t); bool RegisterIncOpPlugin(); -void IncrementKernel(const float* d_input, float inc, float* d_output, - int count, cudaStream_t stream); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc index 0dfead8f57..7466e59090 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc @@ -19,7 +19,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -using namespace tensorflow; +namespace tensorflow { REGISTER_OP("IncPluginTRT") .Attr("inc: list(float)") @@ -30,5 +30,7 @@ REGISTER_OP("IncPluginTRT") return Status::OK(); }); +} // namespace tensorflow + #endif // GOOGLE_CUDA #endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py similarity index 67% rename from tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py rename to tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py index 52f49ae00e..9f773c66a9 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py @@ -23,43 +23,44 @@ from __future__ import print_function # it looks like internal builds don't like it so # importing every module individually -from tensorflow.contrib import tensorrt as trt -from tensorflow.core.protobuf import config_pb2 as cpb2 -from tensorflow.python.client import session as csess -from tensorflow.python.framework import dtypes as dtypes -from tensorflow.python.framework import importer as importer -from tensorflow.python.framework import ops as ops -from tensorflow.python.ops import array_ops as aops -from tensorflow.python.ops import nn as nn -from tensorflow.python.ops import nn_ops as nn_ops -import numpy as np +from tensorflow.contrib import tensorrt +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.framework import errors +import numpy # import custom_op as plugin op -# the python api handles registration to the plugin factory -from tensorflow.contrib.tensorrt import custom_plugin_examples as cpe +# the python api handles registration to the plugin factory +from tensorflow.contrib.tensorrt import custom_plugin_examples def get_plugin_graph_def(): """Create a simple graph and return its graph_def.""" g = ops.Graph() with g.as_default(): - a = aops.placeholder( + a = array_ops.placeholder( dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") relu = nn.relu(a, "relu") v = nn_ops.max_pool( relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") # insert custom_op in the graph - v = cpe.inc_op(v, inc=[16.5], name="plugin_test") + v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test") v = v*2.0 v = nn.relu(v) v = nn.relu(v) - aops.squeeze(v, name="output") + array_ops.squeeze(v, name="output") return g.as_graph_def() def run_graph(gdef, dumm_inp): """Run given graphdef once.""" - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50) ops.reset_default_graph() g = ops.Graph() with g.as_default(): @@ -68,20 +69,20 @@ def run_graph(gdef, dumm_inp): inp = inp.outputs[0] out = out.outputs[0] - with csess.Session( - config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: + with session.Session( + config=config_pb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: val = sess.run(out, {inp: dumm_inp}) return val if "__main__" in __name__: inp_dims = (5, 24, 24, 2) - dummy_input = np.ones(inp_dims).astype(np.float32) + dummy_input = numpy.ones(inp_dims).astype(numpy.float32) orig_graph = get_plugin_graph_def() # graph with plugin node # trigger conversion. # plugin nodes have been registered during import, converter will be able to # create corresponding plugin layer during conversion. - trt_graph = trt.create_inference_graph( + trt_graph = tensorrt.create_inference_graph( input_graph_def=orig_graph, outputs=["output"], max_batch_size=inp_dims[0], @@ -90,4 +91,7 @@ if "__main__" in __name__: minimum_segment_size=2 ) o2 = run_graph(trt_graph, dummy_input) - print (o2) + if o2.reshape([-1])[0] == 35: + print("pass") + else: + raise RuntimeError("contrib/tensorrt/custom_plugin_examples wrong result") diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h index 7f3544f8cf..3495dc6318 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.h +++ b/tensorflow/contrib/tensorrt/log/trt_logger.h @@ -28,7 +28,7 @@ namespace tensorrt { // Logger for GIE info/warning/errors class Logger : public nvinfer1::ILogger { public: - Logger(string name = "DefaultLogger") : name_(name){}; + Logger(string name = "DefaultLogger") : name_(name) {}; void log(nvinfer1::ILogger::Severity severity, const char* msg) override; private: diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc index 82c549dbf5..062f86e8bb 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc @@ -25,7 +25,6 @@ namespace tensorflow { namespace tensorrt { PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) { - // sanity check. const char* buffer = static_cast(serialized_data); size_t op_name_char_count = *reinterpret_cast(buffer); buffer += sizeof(size_t); @@ -91,7 +90,7 @@ void PluginTensorRT::serialize(void* serialized_data) { } } -bool PluginTensorRT::StoreAttribute(const std::string& key, const void* ptr, +bool PluginTensorRT::StoreAttribute(const string& key, const void* ptr, const size_t size) { if (attr_map_.count(key) != 0) return false; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h index 772974a769..dca377c2d2 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN #include -#include #include #include +#include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -35,28 +35,28 @@ namespace tensorrt { // PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT class PluginTensorRT : public nvinfer1::IPlugin { public: - PluginTensorRT(){}; + PluginTensorRT() {}; PluginTensorRT(const void* serialized_data, size_t length); - virtual const std::string& GetPluginName() const = 0; + virtual const string& GetPluginName() const = 0; virtual bool Finalize() = 0; - virtual bool SetAttribute(const std::string& key, const void* ptr, + virtual bool SetAttribute(const string& key, const void* ptr, const size_t size) = 0; - virtual bool GetAttribute(const std::string& key, const void** ptr, + virtual bool GetAttribute(const string& key, const void** ptr, size_t* size) const = 0; void configure(const nvinfer1::Dims* inputs, int num_inputs, const nvinfer1::Dims* outputs, int num_outputs, int max_batch_size) override; - virtual bool StoreAttribute(const std::string& key, const void* ptr, + virtual bool StoreAttribute(const string& key, const void* ptr, const size_t size); virtual size_t getSerializationSize() override; virtual void serialize(void* buffer) override; protected: - std::unordered_map > attr_map_; + std::unordered_map > attr_map_; std::vector input_dim_list_; }; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index 776bce119d..736a1321fe 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -26,7 +26,7 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, size_t serial_length) { size_t parsed_byte = 0; // extract op_name from serial_data - std::string encoded_op_name = + string encoded_op_name = ExtractOpName(serial_data, serial_length, &parsed_byte); if (!IsPlugin(encoded_op_name)) { @@ -41,8 +41,7 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, return plugin_ptr; } -PluginTensorRT* PluginFactoryTensorRT::CreatePlugin( - const std::string& op_name) { +PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string& op_name) { if (!IsPlugin(op_name)) return nullptr; std::lock_guard lock(instance_m_); @@ -53,7 +52,7 @@ PluginTensorRT* PluginFactoryTensorRT::CreatePlugin( } bool PluginFactoryTensorRT::RegisterPlugin( - const std::string& op_name, PluginDeserializeFunc deserialize_func, + const string& op_name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func) { if (IsPlugin(op_name)) return false; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 08fd376844..4e4a3af4ca 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -36,7 +36,7 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { size_t serial_length) override; // plugin construction, PluginFactoryTensorRT owns the plugin; - PluginTensorRT* CreatePlugin(const std::string& op_name); + PluginTensorRT* CreatePlugin(const string& op_name); static PluginFactoryTensorRT* GetInstance() { static PluginFactoryTensorRT* factory_instance = @@ -44,11 +44,11 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { return factory_instance; } - bool RegisterPlugin(const std::string& op_name, + bool RegisterPlugin(const string& op_name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func); - bool IsPlugin(const std::string& op_name) { + bool IsPlugin(const string& op_name) { return plugin_registry_.find(op_name) != plugin_registry_.end(); } @@ -57,7 +57,7 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { void DestroyPlugins(); protected: - std::unordered_map > plugin_registry_; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc index c5d3f38280..a8f60886c0 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc @@ -30,7 +30,7 @@ string ExtractOpName(const void* serial_data, size_t serial_length, assert(serial_length >= *incremental); const char* buffer = static_cast(serial_data) + sizeof(size_t); - std::string op_name(buffer, op_name_char_count); + string op_name(buffer, op_name_char_count); return op_name; } diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc index 9ef0fce972..b834c5511f 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/test.h" @@ -31,18 +30,17 @@ namespace test { class StubPlugin : public PluginTensorRT { public: - static const std::string plugin_name_; - StubPlugin(){}; + static const string plugin_name_; + StubPlugin() {}; StubPlugin(const void* serialized_data, size_t length) - : PluginTensorRT(serialized_data, length){}; - const std::string& GetPluginName() override { return plugin_name_; }; + : PluginTensorRT(serialized_data, length) {}; + const string& GetPluginName() override { return plugin_name_; }; virtual bool Finalize() { return true; }; - virtual bool SetAttribute(const std::string& key, const void* ptr, + virtual bool SetAttribute(const string& key, const void* ptr, const size_t size) { return true; }; - virtual bool GetAttribute(const std::string& key, const void* ptr, - size_t& size) { + virtual bool GetAttribute(const string& key, const void* ptr, size_t& size) { return true; }; int getNbOutputs() const override { return 1; } @@ -59,7 +57,7 @@ class StubPlugin : public PluginTensorRT { } }; -const std::string StubPlugin::plugin_name_ = "StubPlugin"; +const string StubPlugin::plugin_name_ = "StubPlugin"; StubPlugin* CreateStubPlugin() { return new StubPlugin(); } @@ -72,8 +70,9 @@ class PluginTest : public ::testing::Test { public: bool RegisterStubPlugin() { if (PluginFactoryTensorRT::GetInstance()->IsPlugin( - StubPlugin::plugin_name_)) + StubPlugin::plugin_name_)) { return true; + } return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( StubPlugin::plugin_name_, CreateStubPluginDeserialize, CreateStubPlugin); diff --git a/tensorflow/contrib/tensorrt/plugin_test.py b/tensorflow/contrib/tensorrt/plugin_test.py new file mode 100644 index 0000000000..7c3e765bff --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin_test.py @@ -0,0 +1,88 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to show usage of TensorRT custom op & plugin.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import tensorrt +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +import numpy as np + +# import custom_op as plugin op +# the python api handles registration to the plugin factory +from tensorflow.contrib.tensorrt import custom_plugin_examples + +def get_plugin_graph_def(): + """Create a simple graph and return its graph_def.""" + g = ops.Graph() + with g.as_default(): + a = array_ops.placeholder( + dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") + relu = nn.relu(a, "relu") + v = nn_ops.max_pool( + relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") + + # insert custom_op in the graph + v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test") + + v = v*2.0 + v = nn.relu(v) + v = nn.relu(v) + array_ops.squeeze(v, name="output") + return g.as_graph_def() + +def run_graph(gdef, dumm_inp): + """Run given graphdef once.""" + gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + ops.reset_default_graph() + g = ops.Graph() + with g.as_default(): + inp, out = importer.import_graph_def( + graph_def=gdef, return_elements=["input", "output"]) + inp = inp.outputs[0] + out = out.outputs[0] + + with session.Session( + config=config_pb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: + val = sess.run(out, {inp: dumm_inp}) + return val + +if "__main__" in __name__: + inp_dims = (5, 24, 24, 2) + dummy_input = np.ones(inp_dims).astype(np.float32) + orig_graph = get_plugin_graph_def() # graph with plugin node + + # trigger conversion. + # plugin nodes have been registered during import, converter will be able to + # create corresponding plugin layer during conversion. + trt_graph = tensorrt.create_inference_graph( + input_graph_def=orig_graph, + outputs=["output"], + max_batch_size=inp_dims[0], + max_workspace_size_bytes=1 << 25, + precision_mode="FP32", + minimum_segment_size=2 + ) + o2 = run_graph(trt_graph, dummy_input) + print (o2) diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index 3c85968ae7..5164247f93 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -82,7 +82,7 @@ class TRTWeightStore : public tensorflow::ResourceBase { class TRTEngineResource : public tensorflow::ResourceBase { public: - TRTEngineResource() : runtime_(nullptr), ctx_(nullptr){}; + TRTEngineResource() : runtime_(nullptr), ctx_(nullptr) {}; string DebugString() override { return string(""); } nvinfer1::IRuntime* runtime_; nvinfer1::IExecutionContext* ctx_; -- GitLab From 2ef955b6d354378a7ca19f1f3cafccfc17f79013 Mon Sep 17 00:00:00 2001 From: Haggai Date: Fri, 20 Apr 2018 18:57:12 -0700 Subject: [PATCH 014/839] Abort on invalid fft type or rank --- tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h index 4f6b363364..0bf693edd0 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -195,6 +195,9 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, device, static_cast(out), static_cast(operand), input_batch, fft_length0, fft_length1, fft_length2); break; + default: + // Unsupported FFT type + abort(); } } @@ -219,6 +222,9 @@ void EigenFftImpl(const EigenDevice& device, void* out, void* operand, input_batch, fft_length0, fft_length1, fft_length2); break; + default: + // Unsupported FFT rank + abort(); } } -- GitLab From 364f6eae07fa8f0e2f89a9f665d0af430ea96669 Mon Sep 17 00:00:00 2001 From: Filipe Filardi Date: Sat, 21 Apr 2018 14:45:30 -0300 Subject: [PATCH 015/839] Create pull_request_template.md --- pull_request_template.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 pull_request_template.md diff --git a/pull_request_template.md b/pull_request_template.md new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/pull_request_template.md @@ -0,0 +1 @@ + -- GitLab From ea3d7ab5455f54a67e24428f159e9170be408d71 Mon Sep 17 00:00:00 2001 From: Filipe Filardi Date: Sat, 21 Apr 2018 14:57:38 -0300 Subject: [PATCH 016/839] Create Pull Request Template --- PULL_REQUEST_TEMPLATE.md | 20 ++++++++++++++++++++ pull_request_template.md | 1 - 2 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 PULL_REQUEST_TEMPLATE.md delete mode 100644 pull_request_template.md diff --git a/PULL_REQUEST_TEMPLATE.md b/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000..075bbc9945 --- /dev/null +++ b/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,20 @@ + + +##### Pull Request Checklist + +- [ ] Read [contributing guideline](CONTRIBUTING.md). +- [ ] Read [code of conduct](CODE_OF_CONDUCT.md). +- [ ] Fill [Contributor License Agreement (CLA)](https://cla.developers.google.com/). +- [ ] Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). +- [ ] Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style) +- [ ] Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). + +##### Issue Fix + +- [ ] Yes +- [ ] No + +Fixed issue: + +##### Description + diff --git a/pull_request_template.md b/pull_request_template.md deleted file mode 100644 index 8b13789179..0000000000 --- a/pull_request_template.md +++ /dev/null @@ -1 +0,0 @@ - -- GitLab From 7f78414776718a350b1beb612dd8b1c26ff3f6a4 Mon Sep 17 00:00:00 2001 From: Filipe Filardi Date: Tue, 24 Apr 2018 22:52:29 -0300 Subject: [PATCH 017/839] Merge PR Template to Contributing - Remove pull request template. - Add check list in contributing as a kind of TL;DR for that file. --- CONTRIBUTING.md | 11 +++++++++++ PULL_REQUEST_TEMPLATE.md | 20 -------------------- 2 files changed, 11 insertions(+), 20 deletions(-) delete mode 100644 PULL_REQUEST_TEMPLATE.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3dad41a88c..2e9d8c65e2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,5 +1,16 @@ # Contributing guidelines +## Pull Request Checklist + +Before sending your pull requests, make sure you followed this list. + +- [ ] Read [contributing guidelines](CONTRIBUTING.md). +- [ ] Read [Code of Conduct](CODE_OF_CONDUCT.md). +- [ ] Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). +- [ ] Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). +- [ ] Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style). +- [ ] Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). + ## How to become a contributor and submit your own code ### Contributor License Agreements diff --git a/PULL_REQUEST_TEMPLATE.md b/PULL_REQUEST_TEMPLATE.md deleted file mode 100644 index 075bbc9945..0000000000 --- a/PULL_REQUEST_TEMPLATE.md +++ /dev/null @@ -1,20 +0,0 @@ - - -##### Pull Request Checklist - -- [ ] Read [contributing guideline](CONTRIBUTING.md). -- [ ] Read [code of conduct](CODE_OF_CONDUCT.md). -- [ ] Fill [Contributor License Agreement (CLA)](https://cla.developers.google.com/). -- [ ] Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). -- [ ] Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style) -- [ ] Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). - -##### Issue Fix - -- [ ] Yes -- [ ] No - -Fixed issue: - -##### Description - -- GitLab From 7f70c7a38fc2f4aaa9ceb52240c9112886adda5c Mon Sep 17 00:00:00 2001 From: Filipe Filardi Date: Tue, 24 Apr 2018 23:00:05 -0300 Subject: [PATCH 018/839] Make more like a table of contents --- CONTRIBUTING.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2e9d8c65e2..8669c25c45 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,12 +4,12 @@ Before sending your pull requests, make sure you followed this list. -- [ ] Read [contributing guidelines](CONTRIBUTING.md). -- [ ] Read [Code of Conduct](CODE_OF_CONDUCT.md). -- [ ] Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). -- [ ] Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). -- [ ] Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style). -- [ ] Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). +- Read [contributing guidelines](CONTRIBUTING.md). +- Read [Code of Conduct](CODE_OF_CONDUCT.md). +- Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). +- Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). +- Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style). +- Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). ## How to become a contributor and submit your own code -- GitLab From 87f7d4b894c08031ba5942c1c391199de793eb88 Mon Sep 17 00:00:00 2001 From: ManHyuk Date: Sun, 29 Apr 2018 16:07:33 +0900 Subject: [PATCH 019/839] fix typo --- .../tools/ci_build/windows/cpu/bazel/run_cc_test_windows.sh | 2 +- .../tools/ci_build/windows/gpu/bazel/run_cc_test_windows.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/tools/ci_build/windows/cpu/bazel/run_cc_test_windows.sh b/tensorflow/tools/ci_build/windows/cpu/bazel/run_cc_test_windows.sh index 748a961e44..dc9af221ec 100644 --- a/tensorflow/tools/ci_build/windows/cpu/bazel/run_cc_test_windows.sh +++ b/tensorflow/tools/ci_build/windows/cpu/bazel/run_cc_test_windows.sh @@ -44,7 +44,7 @@ source "tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh" \ run_configure_for_cpu_build -# Compliling the following test is extremely slow with -c opt +# Compiling the following test is extremely slow with -c opt slow_compiling_test="//tensorflow/core/kernels:eigen_backward_spatial_convolutions_test" # Find all the passing cc_tests on Windows and store them in a variable diff --git a/tensorflow/tools/ci_build/windows/gpu/bazel/run_cc_test_windows.sh b/tensorflow/tools/ci_build/windows/gpu/bazel/run_cc_test_windows.sh index f26f8727e5..f1114f4ffa 100644 --- a/tensorflow/tools/ci_build/windows/gpu/bazel/run_cc_test_windows.sh +++ b/tensorflow/tools/ci_build/windows/gpu/bazel/run_cc_test_windows.sh @@ -46,7 +46,7 @@ clean_output_base run_configure_for_gpu_build -# Compliling the following test is extremely slow with -c opt +# Compiling the following test is extremely slow with -c opt slow_compiling_test="//tensorflow/core/kernels:eigen_backward_spatial_convolutions_test" # Find all the passing cc_tests on Windows and store them in a variable -- GitLab From ba1c33faeb6df1ae363888e2e7330e219f0679ea Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 07:51:53 -0700 Subject: [PATCH 020/839] ArraysExtraInfo: Add name_regexp field and regexp name matching. PiperOrigin-RevId: 195091587 --- tensorflow/contrib/lite/toco/BUILD | 1 + .../contrib/lite/toco/model_flags.proto | 3 +- tensorflow/contrib/lite/toco/tooling_util.cc | 79 ++++++++++++------- 3 files changed, 53 insertions(+), 30 deletions(-) diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index f16225fd66..a3eff8ac70 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -396,6 +396,7 @@ cc_library( ":toco_port", ":types_proto_cc", "//tensorflow/core:lib", + "//third_party/re2", "@com_google_absl//absl/strings", "@protobuf_archive//:protobuf_headers", ], diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto index d23e80c464..6c1c53658c 100644 --- a/tensorflow/contrib/lite/toco/model_flags.proto +++ b/tensorflow/contrib/lite/toco/model_flags.proto @@ -96,8 +96,9 @@ message RnnState { // model that does not already contain such MinMax information. message ArraysExtraInfo { message Entry { - // Next ID to use: 7. + // Next ID to use: 8. optional string name = 1; + optional string name_regexp = 7; optional double min = 2; optional double max = 3; optional IODataType data_type = 4; diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index f334c51bbb..36f38ba8b0 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" +#include "third_party/re2/re2.h" #include "tensorflow/contrib/lite/toco/dump_graphviz.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" @@ -1983,38 +1984,58 @@ void FinishBuildingRNNStates(Model* model) { } } +// Returns the array names that match the ArraysExtraInfo's name and +// name_regexp. The regexp match is for a full match. +std::unordered_set ScanArrayNames( + const Model& model, const toco::ArraysExtraInfo_Entry& entry) { + std::unordered_set matches; + if (model.HasArray(entry.name())) { + matches.insert(entry.name()); + } + if (!entry.name_regexp().empty()) { + const auto& arrays = model.GetArrayMap(); + const RE2 name_regexp = {entry.name_regexp()}; + for (auto it = arrays.begin(); it != arrays.end(); ++it) { + if (RE2::FullMatch(it->first, name_regexp)) { + matches.insert(it->first); + } + } + } + return matches; +} + void UseArraysExtraInfo(Model* model, bool quantize_output) { for (const auto& entry : model->flags.arrays_extra_info().entries()) { - if (!model->HasArray(entry.name())) { - continue; - } - auto& array = model->GetArray(entry.name()); - if (entry.has_min() || entry.has_max()) { - CHECK_EQ(entry.has_min(), entry.has_max()); - auto& minmax = array.GetOrCreateMinMax(); - minmax.min = entry.min(); - minmax.max = entry.max(); - } - if (entry.has_data_type() && quantize_output) { - array.final_data_type = - ConvertIODataTypeToArrayDataType(entry.data_type()); - } - if (entry.has_shape()) { - array.clear_shape(); - // Make sure to create the shape even if there are no dims, to - // correctly record 0-D shapes. - array.mutable_shape(); - for (int dim : entry.shape().dims()) { - array.mutable_shape()->mutable_dims()->push_back(dim); + const auto matches = ScanArrayNames(*model, entry); + for (const auto& matched_name : matches) { + auto& array = model->GetArray(matched_name); + if (entry.has_min() || entry.has_max()) { + CHECK_EQ(entry.has_min(), entry.has_max()); + auto& minmax = array.GetOrCreateMinMax(); + minmax.min = entry.min(); + minmax.max = entry.max(); } - } - if (entry.has_constant_float_value()) { - CHECK(array.has_shape()); - if (array.data_type == ArrayDataType::kFloat) { - auto& data = array.GetMutableBuffer().data; - data.resize(RequiredBufferSizeForShape(array.shape())); - for (float& f : data) { - f = entry.constant_float_value(); + if (entry.has_data_type() && quantize_output) { + array.final_data_type = + ConvertIODataTypeToArrayDataType(entry.data_type()); + } + if (entry.has_shape()) { + array.clear_shape(); + // Make sure to create the shape even if there are no dims, to + // correctly record 0-D shapes. + array.mutable_shape(); + for (int dim : entry.shape().dims()) { + array.mutable_shape()->mutable_dims()->push_back(dim); + } + } + if (entry.has_constant_float_value()) { + CHECK(array.has_shape()); + if (array.data_type == ArrayDataType::kFloat) { + auto& data = array.GetMutableBuffer().data; + data.resize(RequiredBufferSizeForShape(array.shape())); + for (float& f : data) { + f = entry.constant_float_value(); + } } } } -- GitLab From 72fd2b8e97f301039ac0eb60435cbbddf36212f6 Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Wed, 2 May 2018 08:04:09 -0700 Subject: [PATCH 021/839] Use experimental auto_sharding in multi worker dataset. PiperOrigin-RevId: 195092992 --- tensorflow/contrib/distribute/python/BUILD | 1 + tensorflow/contrib/distribute/python/values.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index cdb3a8d65e..aaafc184bf 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -21,6 +21,7 @@ py_library( srcs = ["values.py"], visibility = ["//tensorflow:internal"], deps = [ + ":input_ops", ":prefetching_ops_v2", "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/contrib/eager/python:datasets", diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 18afdaa7b0..aaf177d07e 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -27,6 +27,7 @@ import weakref import six from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import input_ops from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device @@ -651,8 +652,8 @@ class MultiWorkerDataset(object): six.iteritems(worker_device_map)): with ops.device(worker): worker_input = dataset_fn() - # TODO(yuefengz, priyag): support efficient sharding. - worker_input = worker_input.shard(len(worker_device_map), i) + worker_input = input_ops.auto_shard_dataset( + worker_input, len(worker_device_map), i) self._datasets[worker] = PerDeviceDataset( worker_input, worker_devices, prefetch_on_device=prefetch_on_device) -- GitLab From 22eed5405906de6df8846bd9ce4ee0a57917aa3c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 08:51:07 -0700 Subject: [PATCH 022/839] Automated g4 rollback of changelist 195091587 PiperOrigin-RevId: 195098224 --- tensorflow/contrib/lite/toco/BUILD | 1 - .../contrib/lite/toco/model_flags.proto | 3 +- tensorflow/contrib/lite/toco/tooling_util.cc | 79 +++++++------------ 3 files changed, 30 insertions(+), 53 deletions(-) diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index a3eff8ac70..f16225fd66 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -396,7 +396,6 @@ cc_library( ":toco_port", ":types_proto_cc", "//tensorflow/core:lib", - "//third_party/re2", "@com_google_absl//absl/strings", "@protobuf_archive//:protobuf_headers", ], diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto index 6c1c53658c..d23e80c464 100644 --- a/tensorflow/contrib/lite/toco/model_flags.proto +++ b/tensorflow/contrib/lite/toco/model_flags.proto @@ -96,9 +96,8 @@ message RnnState { // model that does not already contain such MinMax information. message ArraysExtraInfo { message Entry { - // Next ID to use: 8. + // Next ID to use: 7. optional string name = 1; - optional string name_regexp = 7; optional double min = 2; optional double max = 3; optional IODataType data_type = 4; diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 36f38ba8b0..f334c51bbb 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -26,7 +26,6 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" -#include "third_party/re2/re2.h" #include "tensorflow/contrib/lite/toco/dump_graphviz.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" @@ -1984,58 +1983,38 @@ void FinishBuildingRNNStates(Model* model) { } } -// Returns the array names that match the ArraysExtraInfo's name and -// name_regexp. The regexp match is for a full match. -std::unordered_set ScanArrayNames( - const Model& model, const toco::ArraysExtraInfo_Entry& entry) { - std::unordered_set matches; - if (model.HasArray(entry.name())) { - matches.insert(entry.name()); - } - if (!entry.name_regexp().empty()) { - const auto& arrays = model.GetArrayMap(); - const RE2 name_regexp = {entry.name_regexp()}; - for (auto it = arrays.begin(); it != arrays.end(); ++it) { - if (RE2::FullMatch(it->first, name_regexp)) { - matches.insert(it->first); - } - } - } - return matches; -} - void UseArraysExtraInfo(Model* model, bool quantize_output) { for (const auto& entry : model->flags.arrays_extra_info().entries()) { - const auto matches = ScanArrayNames(*model, entry); - for (const auto& matched_name : matches) { - auto& array = model->GetArray(matched_name); - if (entry.has_min() || entry.has_max()) { - CHECK_EQ(entry.has_min(), entry.has_max()); - auto& minmax = array.GetOrCreateMinMax(); - minmax.min = entry.min(); - minmax.max = entry.max(); - } - if (entry.has_data_type() && quantize_output) { - array.final_data_type = - ConvertIODataTypeToArrayDataType(entry.data_type()); - } - if (entry.has_shape()) { - array.clear_shape(); - // Make sure to create the shape even if there are no dims, to - // correctly record 0-D shapes. - array.mutable_shape(); - for (int dim : entry.shape().dims()) { - array.mutable_shape()->mutable_dims()->push_back(dim); - } + if (!model->HasArray(entry.name())) { + continue; + } + auto& array = model->GetArray(entry.name()); + if (entry.has_min() || entry.has_max()) { + CHECK_EQ(entry.has_min(), entry.has_max()); + auto& minmax = array.GetOrCreateMinMax(); + minmax.min = entry.min(); + minmax.max = entry.max(); + } + if (entry.has_data_type() && quantize_output) { + array.final_data_type = + ConvertIODataTypeToArrayDataType(entry.data_type()); + } + if (entry.has_shape()) { + array.clear_shape(); + // Make sure to create the shape even if there are no dims, to + // correctly record 0-D shapes. + array.mutable_shape(); + for (int dim : entry.shape().dims()) { + array.mutable_shape()->mutable_dims()->push_back(dim); } - if (entry.has_constant_float_value()) { - CHECK(array.has_shape()); - if (array.data_type == ArrayDataType::kFloat) { - auto& data = array.GetMutableBuffer().data; - data.resize(RequiredBufferSizeForShape(array.shape())); - for (float& f : data) { - f = entry.constant_float_value(); - } + } + if (entry.has_constant_float_value()) { + CHECK(array.has_shape()); + if (array.data_type == ArrayDataType::kFloat) { + auto& data = array.GetMutableBuffer().data; + data.resize(RequiredBufferSizeForShape(array.shape())); + for (float& f : data) { + f = entry.constant_float_value(); } } } -- GitLab From d6d4355a39a56a1b0d0abc7ce74d8307a1925459 Mon Sep 17 00:00:00 2001 From: Tony Wang Date: Wed, 2 May 2018 09:52:10 -0700 Subject: [PATCH 023/839] Add Name String to GraphOptimizationPass and Log Registered Passes Added a name string to GraphOptimization class and set to the class name through REGISTER_OPTIMIZATION macro. Modified RunGrouping function to log the name and phase of optimization pass that's running. Added two additional functions to log all registered optimization passes in the order of execution. PiperOrigin-RevId: 195106355 --- .../common_runtime/optimization_registry.cc | 21 +++++++++++++++++++ .../common_runtime/optimization_registry.h | 18 ++++++++++++++-- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/common_runtime/optimization_registry.cc b/tensorflow/core/common_runtime/optimization_registry.cc index 7f270b4d4e..bf49a758b2 100644 --- a/tensorflow/core/common_runtime/optimization_registry.cc +++ b/tensorflow/core/common_runtime/optimization_registry.cc @@ -36,6 +36,8 @@ Status OptimizationPassRegistry::RunGrouping( for (auto& phase : group->second) { VLOG(1) << "Running optimization phase " << phase.first; for (auto& pass : phase.second) { + VLOG(1) << "Running optimization pass: " + << pass->GetOptimizationPassName(); Status s = pass->Run(options); if (!s.ok()) return s; } @@ -44,4 +46,23 @@ Status OptimizationPassRegistry::RunGrouping( return Status::OK(); } +void OptimizationPassRegistry::LogGrouping(Grouping grouping, int vlog_level) { + auto group = groups_.find(grouping); + if (group != groups_.end()) { + for (auto& phase : group->second) { + for (auto& pass : phase.second) { + VLOG(vlog_level) << "Registered optimization pass grouping " << grouping + << " phase " << phase.first << ": " + << pass->GetOptimizationPassName(); + } + } + } +} + +void OptimizationPassRegistry::LogAllGroupings(int vlog_level) { + for (auto group = groups_.begin(); group != groups_.end(); ++group) { + LogGrouping(group->first, vlog_level); + } +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/optimization_registry.h b/tensorflow/core/common_runtime/optimization_registry.h index a469c8aa4e..1b535faf19 100644 --- a/tensorflow/core/common_runtime/optimization_registry.h +++ b/tensorflow/core/common_runtime/optimization_registry.h @@ -65,6 +65,13 @@ class GraphOptimizationPass { public: virtual ~GraphOptimizationPass() {} virtual Status Run(const GraphOptimizationPassOptions& options) = 0; + void SetOptimizationPassName(string name) { _optimization_pass_name = name; } + string GetOptimizationPassName() { return _optimization_pass_name; } + + private: + // The name of the opitmization pass, which is the same as the inherited class + // name. + string _optimization_pass_name; }; // The key is a 'phase' number. Phases are executed in increasing @@ -95,6 +102,10 @@ class OptimizationPassRegistry { // Returns the global registry of optimization passes. static OptimizationPassRegistry* Global(); + // Prints registered optimization passes for debugging. + void LogGrouping(Grouping grouping, int vlog_level); + void LogAllGroupings(int vlog_level); + private: std::map groups_; }; @@ -105,7 +116,9 @@ class OptimizationPassRegistration { public: OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping, int phase, - std::unique_ptr pass) { + std::unique_ptr pass, + string optimization_pass_name) { + pass->SetOptimizationPassName(optimization_pass_name); OptimizationPassRegistry::Global()->Register(grouping, phase, std::move(pass)); } @@ -123,7 +136,8 @@ class OptimizationPassRegistration { static optimization_registration::OptimizationPassRegistration \ register_optimization_##ctr( \ grouping, phase, \ - std::unique_ptr(new optimization)) + std::unique_ptr(new optimization()), \ + #optimization) } // namespace tensorflow -- GitLab From c4394346027fa01f12261e6fea6a1b7f19ac21a9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 10:02:09 -0700 Subject: [PATCH 024/839] Instantiate SwapDimension1And2InTensor3 for Eigen::half PiperOrigin-RevId: 195107839 --- tensorflow/core/kernels/conv_ops_gpu_3.cu.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc index 2503b475dc..180531b8c0 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc @@ -1027,6 +1027,7 @@ template struct functor::SwapDimension1And2InTensor3; template struct functor::SwapDimension1And2InTensor3; +template struct functor::SwapDimension1And2InTensor3; template struct functor::SwapDimension0And2InTensor3; template struct functor::SwapDimension0And2InTensor3; -- GitLab From 1f47bbd1e09a9ed4086d0484024e8989a65274a9 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 2 May 2018 10:07:13 -0700 Subject: [PATCH 025/839] Optimized the analysis of rank and size operations. PiperOrigin-RevId: 195108832 --- .../core/grappler/costs/graph_properties.cc | 32 +++++++++++++++++++ tensorflow/core/grappler/op_types.cc | 4 +++ tensorflow/core/grappler/op_types.h | 2 ++ 3 files changed, 38 insertions(+) diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 2c7b57971a..69b22561b2 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -475,6 +475,38 @@ class SymbolicShapeRefiner { } } } + } else if (IsRank(*input)) { + if (c->inference_context->RankKnown(c->inference_context->input(0))) { + int32 rank = + c->inference_context->Rank(c->inference_context->input(0)); + Tensor t(DT_INT32, {}); + t.flat()(0) = rank; + const_values[dst_input] = t; + input_tensors[dst_input] = &const_values[dst_input]; + } + } else if (IsSize(*input)) { + DimensionHandle size = + c->inference_context->NumElements(c->inference_context->input(0)); + if (c->inference_context->ValueKnown(size)) { + int64 sz = c->inference_context->Value(size); + bool valid = false; + if (input->attr().at("T").type() == DT_INT32) { + if (sz < std::numeric_limits::max()) { + Tensor t(DT_INT32, {}); + t.flat()(0) = sz; + const_values[dst_input] = t; + valid = true; + } + } else { + Tensor t(DT_INT64, {}); + t.flat()(0) = sz; + const_values[dst_input] = t; + valid = true; + } + if (valid) { + input_tensors[dst_input] = &const_values[dst_input]; + } + } } if (c->output_tensors_as_shapes.size() > src_output) { diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index bf6d4c0921..7c936dfca1 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -262,6 +262,8 @@ bool IsRandomShuffle(const NodeDef& node) { return node.op() == "RandomShuffle"; } +bool IsRank(const NodeDef& node) { return node.op() == "Rank"; } + bool IsReal(const NodeDef& node) { return node.op() == "Real"; } bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; } @@ -317,6 +319,8 @@ bool IsShuffle(const NodeDef& node) { return node.op() == "Shuffle"; } bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; } +bool IsSize(const NodeDef& node) { return node.op() == "Size"; } + bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; } bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 3dddf3f1ea..7a1b438768 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -100,6 +100,7 @@ bool IsProd(const NodeDef& node); bool IsPow(const NodeDef& node); bool IsQueue(const NodeDef& node); bool IsRandomShuffle(const NodeDef& node); +bool IsRank(const NodeDef& node); bool IsReal(const NodeDef& node); bool IsRealDiv(const NodeDef& node); bool IsRelu6Grad(const NodeDef& node); @@ -116,6 +117,7 @@ bool IsRsqrtGrad(const NodeDef& node); bool IsSelect(const NodeDef& node); bool IsSeluGrad(const NodeDef& node); bool IsSend(const NodeDef& node); +bool IsSize(const NodeDef& node); bool IsSlice(const NodeDef& node); bool IsShape(const NodeDef& node); bool IsShapeN(const NodeDef& node); -- GitLab From e408e8171540386d3dfed0a7c4fa1d0e1cc812cd Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Wed, 2 May 2018 10:36:26 -0700 Subject: [PATCH 026/839] Internal-only change. PiperOrigin-RevId: 195113702 --- tensorflow/tensorflow.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index e5cc886b32..b2cec7655f 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1507,6 +1507,7 @@ def tf_py_wrap_cc(name, # 2. When --define=no_tensorflow_py_deps=false (by default), it's a normal py_test. def py_test(deps=[], data=[], **kwargs): native.py_test( + # TODO(jlebar): Ideally we'd use tcmalloc here., deps=select({ "//conditions:default": deps, clean_dep("//tensorflow:no_tensorflow_py_deps"): [], -- GitLab From 489640a0d00ea7b5826937781cd1bf3a520b0b5d Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Wed, 2 May 2018 10:43:25 -0700 Subject: [PATCH 027/839] Fix some nits in cpu_literal_caching_test that I noticed after submission PiperOrigin-RevId: 195114829 --- .../service/cpu/tests/cpu_literal_caching_test.cc | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index f0404d07d9..b10eb74635 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -20,9 +20,9 @@ limitations under the License. namespace xla { namespace cpu { namespace { -class CpuExternalConstantsTest : public CpuCodegenTest {}; +class CpuDuplicateConstantsTest : public CpuCodegenTest {}; -TEST_F(CpuExternalConstantsTest, RepeatedArrayConstants) { +TEST_F(CpuDuplicateConstantsTest, RepeatedArrayConstants) { // We use a while loop here to force the two constant HloInstructions to be in // different computations. Otherwise the HLO optimizer itself CSEs them. const string hlo_text = R"( @@ -56,6 +56,10 @@ ENTRY main { } )"; + // TODO(b/78879738): The fake "f32[] constant(1)" root is only needed to work + // around b/78879738. Once b/78879738 is fixed, we can set one of the + // outfeeds as the root. + string filecheck_pattern = R"( CHECK: private constant [2 x [3 x [2 x float]]] CHECK-NOT: private constant [2 x [3 x [2 x float]]] @@ -73,7 +77,7 @@ CHECK-NOT: private constant [2 x [3 x [2 x float]]] /*match_optimized_ir=*/false); } -TEST_F(CpuExternalConstantsTest, RepeatedTupleConstants) { +TEST_F(CpuDuplicateConstantsTest, RepeatedTupleConstants) { // We use a while loop here to force the two constant HloInstructions to be in // different computations. Otherwise the HLO optimizer itself CSEs them. const string hlo_text = R"( @@ -101,6 +105,10 @@ ENTRY main { } )"; + // TODO(b/78879738): The fake "f32[] constant(1)" root is only needed to work + // around b/78879738. Once b/78879738 is fixed, we can set one of the + // outfeeds as the root. + string filecheck_pattern = R"( CHECK: private constant [2 x float] CHECK: private constant [2 x [1 x float]] -- GitLab From 55972370f3243ca061b882120383545d70642cf8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 10:57:13 -0700 Subject: [PATCH 028/839] Initialize all members of CollectiveParams at construction time to avoid warnings of uninit memory access by dynamic analysis tools. PiperOrigin-RevId: 195117321 --- tensorflow/core/framework/collective.h | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index f6fe12e7ef..f8d27d3868 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -52,7 +52,8 @@ struct CollGroupParams { DeviceType device_type; int32 num_tasks; // number of distinct tasks in group string ToString() const; - CollGroupParams() : device_type(DEVICE_CPU) {} + CollGroupParams() + : group_key(0), group_size(0), device_type(DEVICE_CPU), num_tasks(0) {} }; // The best implementation of a collective op depends on many factors @@ -71,10 +72,11 @@ struct CollImplDetails { // Data common to all members of a collective instance. struct CollInstanceParams { - int32 instance_key; // Identifies all participating graph nodes. - CollectiveType type; - DataType data_type; - TensorShape shape; + // Identifies all participating graph nodes. + int32 instance_key = -1; + CollectiveType type = UNDEFINED_COLLECTIVE; + DataType data_type = DT_FLOAT; + TensorShape shape = {0}; // Fully qualified name of device for each member, in default rank order. std::vector device_names; // Task name prefix of corresponding device name. @@ -99,8 +101,8 @@ struct CollectiveParams { CollInstanceParams instance; CollTaskParams task; - string name; // node name used only for log or error messages - int default_rank; // index of this op within device_names + string name = ""; // node name used only for log or error messages + int default_rank = -1; // index of this op within device_names bool is_source = false; // broadcast only // Rank of this device in each subdivision permutation. std::vector subdiv_rank; -- GitLab From c08bf79144b3acc731018147e92fd389bcb60b2d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 10:57:49 -0700 Subject: [PATCH 029/839] Renames _regression_head_with_mean_squared_error_loss to _regression_head. PiperOrigin-RevId: 195117425 --- .../estimator/python/estimator/head.py | 7 +- .../python/estimator/canned/baseline.py | 2 +- .../python/estimator/canned/boosted_trees.py | 2 +- tensorflow/python/estimator/canned/dnn.py | 3 +- .../estimator/canned/dnn_linear_combined.py | 3 +- tensorflow/python/estimator/canned/head.py | 2 +- .../python/estimator/canned/head_test.py | 105 ++++++++---------- tensorflow/python/estimator/canned/linear.py | 2 +- 8 files changed, 53 insertions(+), 73 deletions(-) diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 2a6d17e81b..5d19bf4714 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -235,7 +235,7 @@ def regression_head(weight_column=None, Raises: ValueError: If `label_dimension` or `loss_reduction` is invalid. """ - return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access + return head_lib._regression_head( # pylint:disable=protected-access weight_column=weight_column, label_dimension=label_dimension, loss_reduction=loss_reduction, @@ -297,7 +297,7 @@ def poisson_regression_head( def _poisson_loss(labels, logits): return nn.log_poisson_loss( targets=labels, log_input=logits, compute_full_loss=compute_full_loss) - return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access + return head_lib._regression_head( # pylint:disable=protected-access weight_column=weight_column, label_dimension=label_dimension, loss_reduction=loss_reduction, @@ -360,8 +360,7 @@ def logistic_regression_head( labels, n_classes=2, message='Labels must be in range [0, 1]') return nn.sigmoid_cross_entropy_with_logits( labels=labels, logits=logits) - # TODO(roumposg): Rename to _regression_head, since it supports loss_fn arg. - return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access + return head_lib._regression_head( # pylint:disable=protected-access weight_column=weight_column, label_dimension=1, loss_reduction=loss_reduction, diff --git a/tensorflow/python/estimator/canned/baseline.py b/tensorflow/python/estimator/canned/baseline.py index 3e92a77543..980c057372 100644 --- a/tensorflow/python/estimator/canned/baseline.py +++ b/tensorflow/python/estimator/canned/baseline.py @@ -344,7 +344,7 @@ class BaselineRegressor(estimator.Estimator): A `BaselineRegressor` estimator. """ - head = head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access + head = head_lib._regression_head( # pylint: disable=protected-access label_dimension=label_dimension, weight_column=weight_column, loss_reduction=loss_reduction) diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index d281fd90ea..6d7a3299f7 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -690,7 +690,7 @@ def _create_regression_head(label_dimension, weight_column=None): raise ValueError('For now only 1 dimension regression is supported.' 'label_dimension given as {}'.format(label_dimension)) # pylint: disable=protected-access - return head_lib._regression_head_with_mean_squared_error_loss( + return head_lib._regression_head( label_dimension=label_dimension, weight_column=weight_column, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index 6382622e0b..973a6ec747 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -481,8 +481,7 @@ class DNNRegressor(estimator.Estimator): features=features, labels=labels, mode=mode, - head=head_lib. # pylint: disable=protected-access - _regression_head_with_mean_squared_error_loss( + head=head_lib._regression_head( # pylint: disable=protected-access label_dimension=label_dimension, weight_column=weight_column, loss_reduction=loss_reduction), hidden_units=hidden_units, diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py index f47706db2f..95efc0a028 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py @@ -553,8 +553,7 @@ class DNNLinearCombinedRegressor(estimator.Estimator): features=features, labels=labels, mode=mode, - head=head_lib. # pylint: disable=protected-access - _regression_head_with_mean_squared_error_loss( + head=head_lib._regression_head( # pylint: disable=protected-access label_dimension=label_dimension, weight_column=weight_column, loss_reduction=loss_reduction), linear_feature_columns=linear_feature_columns, diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index efa4bdf598..48f448d7f5 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -1197,7 +1197,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): train_op=train_op) -def _regression_head_with_mean_squared_error_loss( +def _regression_head( weight_column=None, label_dimension=1, loss_reduction=losses.Reduction.SUM, diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index 7da3df01dc..32a6339936 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -2607,26 +2607,24 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): rtol=tol, atol=tol) -class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): +class RegressionHead(test.TestCase): def setUp(self): ops.reset_default_graph() def test_invalid_label_dimension(self): with self.assertRaisesRegexp(ValueError, r'Invalid label_dimension'): - head_lib._regression_head_with_mean_squared_error_loss(label_dimension=-1) + head_lib._regression_head(label_dimension=-1) with self.assertRaisesRegexp(ValueError, r'Invalid label_dimension'): - head_lib._regression_head_with_mean_squared_error_loss(label_dimension=0) + head_lib._regression_head(label_dimension=0) def test_invalid_loss_reduction(self): with self.assertRaisesRegexp( ValueError, r'Invalid loss_reduction: invalid_loss_reduction'): - head_lib._regression_head_with_mean_squared_error_loss( - loss_reduction='invalid_loss_reduction') + head_lib._regression_head(loss_reduction='invalid_loss_reduction') with self.assertRaisesRegexp( ValueError, r'Invalid loss_reduction: none'): - head_lib._regression_head_with_mean_squared_error_loss( - loss_reduction=losses.Reduction.NONE) + head_lib._regression_head(loss_reduction=losses.Reduction.NONE) def test_loss_fn_arg_labels_missing(self): def _loss_fn(logits): @@ -2635,7 +2633,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): ValueError, r'loss_fn must contain argument: labels\. ' r'Given arguments: \(\'logits\',\)'): - head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn) + head_lib._regression_head(loss_fn=_loss_fn) def test_loss_fn_arg_logits_missing(self): def _loss_fn(labels): @@ -2644,12 +2642,12 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): ValueError, r'loss_fn must contain argument: logits\. ' r'Given arguments: \(\'labels\',\)'): - head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn) + head_lib._regression_head(loss_fn=_loss_fn) def test_loss_fn_arg_features_ok(self): def _loss_fn(labels, logits, features): del labels, logits, features # Unused - head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn) + head_lib._regression_head(loss_fn=_loss_fn) def test_loss_fn_arg_invalid(self): def _loss_fn(labels, logits, name=None): @@ -2657,11 +2655,10 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): with self.assertRaisesRegexp( ValueError, r'loss_fn has unexpected args: \[\'name\'\]'): - head_lib._regression_head_with_mean_squared_error_loss(loss_fn=_loss_fn) + head_lib._regression_head(loss_fn=_loss_fn) def test_invalid_logits(self): - head = head_lib._regression_head_with_mean_squared_error_loss( - label_dimension=3) + head = head_lib._regression_head(label_dimension=3) self.assertEqual(3, head.logits_dimension) logits_1d = np.array(((45.,), (41.,),)) @@ -2685,8 +2682,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): }) def test_incompatible_labels_eval(self): - head = head_lib._regression_head_with_mean_squared_error_loss( - label_dimension=3) + head = head_lib._regression_head(label_dimension=3) self.assertEqual(3, head.logits_dimension) values_3d = np.array(((45., 46., 47.), (41., 42., 43.),)) values_1d = np.array(((43.,), (44.,),)) @@ -2732,8 +2728,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): }) def test_incompatible_labels_train(self): - head = head_lib._regression_head_with_mean_squared_error_loss( - label_dimension=3) + head = head_lib._regression_head(label_dimension=3) self.assertEqual(3, head.logits_dimension) values_3d = np.array(((45., 46., 47.), (41., 42., 43.),)) values_1d = np.array(((43.,), (44.,),)) @@ -2784,12 +2779,11 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): }) def test_name(self): - head = head_lib._regression_head_with_mean_squared_error_loss( - name='foo') + head = head_lib._regression_head(name='foo') self.assertEqual('foo', head.name) def test_predict(self): - head = head_lib._regression_head_with_mean_squared_error_loss() + head = head_lib._regression_head() self.assertEqual(1, head.logits_dimension) # Create estimator spec. @@ -2826,8 +2820,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_predict_with_inverse_link_fn(self): def _inverse_link_fn(logits): return logits - 10. - head = head_lib._regression_head_with_mean_squared_error_loss( - inverse_link_fn=_inverse_link_fn) + head = head_lib._regression_head(inverse_link_fn=_inverse_link_fn) # Create estimator spec. logits = np.array(((45,), (41,),), dtype=np.int32) @@ -2866,7 +2859,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): logits, spec.export_outputs['predict'].outputs['logits'].eval()) def test_eval_create_loss(self): - head = head_lib._regression_head_with_mean_squared_error_loss() + head = head_lib._regression_head() logits = np.array(((45,), (41,),), dtype=np.float32) labels = np.array(((43,), (44,),), dtype=np.int32) features = {'x': np.array(((42,),), dtype=np.float32)} @@ -2895,8 +2888,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): data=[logits]) with ops.control_dependencies([check_labels, check_logits]): return constant_op.constant(loss) - head = head_lib._regression_head_with_mean_squared_error_loss( - label_dimension=2, loss_fn=_loss_fn) + head = head_lib._regression_head(label_dimension=2, loss_fn=_loss_fn) actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, @@ -2913,8 +2905,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def _loss_fn(labels, logits): del labels, logits # Unused return constant_op.constant(loss) - head = head_lib._regression_head_with_mean_squared_error_loss( - label_dimension=2, loss_fn=_loss_fn) + head = head_lib._regression_head(label_dimension=2, loss_fn=_loss_fn) logits = np.array([[-1., 1.], [-2., 2.]], dtype=np.float32) labels = np.array([[1., 0.], [2., -1.]], dtype=np.float32) @@ -2933,7 +2924,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_eval_labels_none(self): """Tests that error is raised when labels is None.""" - head = head_lib._regression_head_with_mean_squared_error_loss() + head = head_lib._regression_head() with self.assertRaisesRegexp( ValueError, r'You must provide a labels Tensor\. Given: None\.'): @@ -2944,7 +2935,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): labels=None) def test_eval(self): - head = head_lib._regression_head_with_mean_squared_error_loss() + head = head_lib._regression_head() self.assertEqual(1, head.logits_dimension) logits = np.array(((45,), (41,),), dtype=np.float32) @@ -2986,8 +2977,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): self.assertAllClose(expected_loss_mean, loss_mean_value_op.eval()) def test_eval_metric_ops_with_head_name_for_regression(self): - head = head_lib._regression_head_with_mean_squared_error_loss( - name='some_regression_head') + head = head_lib._regression_head(name='some_regression_head') logits = np.array(((1,), (9,)), dtype=np.float32) labels = np.array(((1,), (1,)), dtype=np.int64) features = {'x': np.array(((42,),), dtype=np.int32)} @@ -3004,7 +2994,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys()) def test_eval_with_regularization_losses(self): - head = head_lib._regression_head_with_mean_squared_error_loss( + head = head_lib._regression_head( loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) self.assertEqual(1, head.logits_dimension) @@ -3049,7 +3039,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): expected_metrics, {k: value_ops[k].eval() for k in value_ops}) def test_train_create_loss(self): - head = head_lib._regression_head_with_mean_squared_error_loss() + head = head_lib._regression_head() logits = np.array(((45,), (41,),), dtype=np.float32) labels = np.array(((43,), (44,),), dtype=np.int32) features = {'x': np.array(((42,),), dtype=np.float32)} @@ -3073,7 +3063,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_train_create_loss_loss_reduction(self): """Tests create_loss with loss_reduction.""" - head = head_lib._regression_head_with_mean_squared_error_loss( + head = head_lib._regression_head( loss_reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS) logits = np.array(((45,), (41,),), dtype=np.float32) labels = np.array(((43,), (44,),), dtype=np.int32) @@ -3098,7 +3088,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_train_labels_none(self): """Tests that error is raised when labels is None.""" - head = head_lib._regression_head_with_mean_squared_error_loss() + head = head_lib._regression_head() def _no_op_train_fn(loss): del loss return control_flow_ops.no_op() @@ -3113,7 +3103,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): train_op_fn=_no_op_train_fn) def test_train(self): - head = head_lib._regression_head_with_mean_squared_error_loss() + head = head_lib._regression_head() self.assertEqual(1, head.logits_dimension) # Create estimator spec. @@ -3163,7 +3153,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): }, summary_str) def test_train_with_optimizer(self): - head = head_lib._regression_head_with_mean_squared_error_loss() + head = head_lib._regression_head() self.assertEqual(1, head.logits_dimension) # Create estimator spec. @@ -3197,8 +3187,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): self.assertEqual(expected_train_result, train_result) def test_train_summaries_with_head_name(self): - head = head_lib._regression_head_with_mean_squared_error_loss( - name='some_regression_head') + head = head_lib._regression_head(name='some_regression_head') self.assertEqual(1, head.logits_dimension) # Create estimator spec. @@ -3237,7 +3226,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): summary_str) def test_train_with_regularization_losses(self): - head = head_lib._regression_head_with_mean_squared_error_loss( + head = head_lib._regression_head( loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) self.assertEqual(1, head.logits_dimension) @@ -3285,8 +3274,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_weighted_multi_example_eval(self): """1d label, 3 examples, 1 batch.""" - head = head_lib._regression_head_with_mean_squared_error_loss( - weight_column='label_weights') + head = head_lib._regression_head(weight_column='label_weights') self.assertEqual(1, head.logits_dimension) # Create estimator spec. @@ -3330,7 +3318,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_weight_with_numeric_column(self): """1d label, 3 examples, 1 batch.""" - head = head_lib._regression_head_with_mean_squared_error_loss( + head = head_lib._regression_head( weight_column=feature_column_lib.numeric_column( 'label_weights', normalizer_fn=lambda x: x + 1.)) @@ -3356,8 +3344,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_weighted_multi_example_train(self): """1d label, 3 examples, 1 batch.""" - head = head_lib._regression_head_with_mean_squared_error_loss( - weight_column='label_weights') + head = head_lib._regression_head(weight_column='label_weights') self.assertEqual(1, head.logits_dimension) # Create estimator spec. @@ -3408,8 +3395,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_train_one_dim_create_loss(self): """Tests create_loss with 1D labels and weights (shape [batch_size]).""" - head = head_lib._regression_head_with_mean_squared_error_loss( - weight_column='label_weights') + head = head_lib._regression_head(weight_column='label_weights') logits = np.array(((45,), (41,), (44,)), dtype=np.float32) x_feature_rank_1 = np.array((42., 43., 44.,), dtype=np.float32) weight_rank_1 = np.array((1., .1, 1.5,), dtype=np.float64) @@ -3435,8 +3421,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_train_one_dim(self): """Tests train with 1D labels and weights (shape [batch_size]).""" - head = head_lib._regression_head_with_mean_squared_error_loss( - weight_column='label_weights') + head = head_lib._regression_head(weight_column='label_weights') self.assertEqual(1, head.logits_dimension) # Create estimator spec. @@ -3493,7 +3478,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_weighted_multi_value_eval_create_loss(self): """3d label, 1 example, 1 batch.""" - head = head_lib._regression_head_with_mean_squared_error_loss( + head = head_lib._regression_head( weight_column='label_weights', label_dimension=3) logits = np.array(((45., 41., 44.),)) labels = np.array(((35., 42., 45.),)) @@ -3515,7 +3500,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_weighted_multi_value_eval(self): """3d label, 1 example, 1 batch.""" - head = head_lib._regression_head_with_mean_squared_error_loss( + head = head_lib._regression_head( weight_column='label_weights', label_dimension=3) self.assertEqual(3, head.logits_dimension) @@ -3562,7 +3547,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_weighted_multi_value_train_create_loss(self): """3d label, 1 example, 1 batch.""" - head = head_lib._regression_head_with_mean_squared_error_loss( + head = head_lib._regression_head( weight_column='label_weights', label_dimension=3) logits = np.array(((45., 41., 44.),)) labels = np.array(((35., 42., 45.),)) @@ -3584,7 +3569,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_weighted_multi_value_train(self): """3d label, 1 example, 1 batch.""" - head = head_lib._regression_head_with_mean_squared_error_loss( + head = head_lib._regression_head( weight_column='label_weights', label_dimension=3) self.assertEqual(3, head.logits_dimension) @@ -3639,8 +3624,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_weighted_multi_batch_eval(self): """1d label, 1 example, 3 batches.""" - head = head_lib._regression_head_with_mean_squared_error_loss( - weight_column='label_weights') + head = head_lib._regression_head(weight_column='label_weights') self.assertEqual(1, head.logits_dimension) # Create estimator spec. @@ -3705,8 +3689,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_weighted_multi_batch_train(self): """1d label, 1 example, 3 batches.""" - head = head_lib._regression_head_with_mean_squared_error_loss( - weight_column='label_weights') + head = head_lib._regression_head(weight_column='label_weights') self.assertEqual(1, head.logits_dimension) # Create estimator spec. @@ -3755,7 +3738,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_multi_dim_weighted_train_create_loss(self): """Logits, labels of shape [2, 2, 3], weight shape [2, 2].""" label_dimension = 3 - head = head_lib._regression_head_with_mean_squared_error_loss( + head = head_lib._regression_head( weight_column='label_weights', label_dimension=label_dimension) logits = np.array([[[00., 01., 02.], [10., 11., 12.]], [[20., 21., 22.], [30., 31., 32.]]]) @@ -3785,7 +3768,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_multi_dim_weighted_train(self): """Logits, labels of shape [2, 2, 3], weight shape [2, 2].""" - head = head_lib._regression_head_with_mean_squared_error_loss( + head = head_lib._regression_head( weight_column='label_weights', label_dimension=3) logits = np.array([[[00., 01., 02.], [10., 11., 12.]], [[20., 21., 22.], [30., 31., 32.]]]) @@ -3816,7 +3799,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_multi_dim_train_weights_wrong_inner_dim(self): """Logits, labels of shape [2, 2, 3], weight shape [2, 1].""" - head = head_lib._regression_head_with_mean_squared_error_loss( + head = head_lib._regression_head( weight_column='label_weights', label_dimension=3) logits = np.array([[[00., 01., 02.], [10., 11., 12.]], [[20., 21., 22.], [30., 31., 32.]]]) @@ -3844,7 +3827,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): def test_multi_dim_train_weights_wrong_outer_dim(self): """Logits, labels of shape [2, 2, 3], weight shape [2, 2, 2].""" - head = head_lib._regression_head_with_mean_squared_error_loss( + head = head_lib._regression_head( weight_column='label_weights', label_dimension=3) logits = np.array([[[00., 01., 02.], [10., 11., 12.]], [[20., 21., 22.], [30., 31., 32.]]]) diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py index e7ec417991..81657f0c01 100644 --- a/tensorflow/python/estimator/canned/linear.py +++ b/tensorflow/python/estimator/canned/linear.py @@ -415,7 +415,7 @@ class LinearRegressor(estimator.Estimator): loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. """ - head = head_lib._regression_head_with_mean_squared_error_loss( # pylint: disable=protected-access + head = head_lib._regression_head( # pylint: disable=protected-access label_dimension=label_dimension, weight_column=weight_column, loss_reduction=loss_reduction) -- GitLab From 1cc225858f9d7fb4d8772a7f0e962b71f780ad54 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 2 May 2018 11:12:58 -0700 Subject: [PATCH 030/839] Automated g4 rollback of changelist 194981511 PiperOrigin-RevId: 195120627 --- tensorflow/core/common_runtime/device.h | 11 ----------- tensorflow/core/common_runtime/device_mgr.cc | 3 --- .../process_function_library_runtime.cc | 3 +-- 3 files changed, 1 insertion(+), 16 deletions(-) diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index b537666492..5918cd9bbf 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -51,8 +51,6 @@ limitations under the License. namespace tensorflow { -class DeviceMgr; - class Device : public DeviceBase { public: Device(Env* env, const DeviceAttributes& device_attributes); @@ -135,10 +133,6 @@ class Device : public DeviceBase { // Returns the resource manager associated w/ this device. virtual ResourceMgr* resource_manager() { return rmgr_; } - // Returns the device manager that owns this device, or nullptr if this Device - // is not owned by a device manager. - DeviceMgr* device_mgr() const { return device_mgr_; } - // Summarizes the status of this Device, for debugging. string DebugString() const { return ProtoDebugString(device_attributes_); } @@ -164,11 +158,6 @@ class Device : public DeviceBase { } private: - friend class DeviceMgr; - - // Pointer to the device manager that owns this device. Not owned. - DeviceMgr* device_mgr_ = nullptr; - const DeviceAttributes device_attributes_; DeviceNameUtils::ParsedName parsed_name_; diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc index 470abc1431..a77601ba79 100644 --- a/tensorflow/core/common_runtime/device_mgr.cc +++ b/tensorflow/core/common_runtime/device_mgr.cc @@ -27,9 +27,6 @@ namespace tensorflow { DeviceMgr::DeviceMgr(const std::vector& devices) : name_backing_store_(128) { for (Device* d : devices) { - CHECK(d->device_mgr_ == nullptr); - d->device_mgr_ = this; - devices_.push_back(d); // Register under the (1) full name and (2) canonical name. diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 668ce87749..e61ed8c479 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -144,8 +144,7 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext( } Device* device = flr->device(); string device_type = device->parsed_name().type; - if (device_type == "CPU" || device_type == "TPU_SYSTEM" || - device_type == "TPU") { + if (device_type == "CPU" || device_type == "TPU_SYSTEM") { // "TPU_SYSTEM" indicates that `device` is a CPU. return Status::OK(); } -- GitLab From 156483f1dc6b2e706482976f09f866d226a4dfee Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 2 May 2018 11:17:47 -0700 Subject: [PATCH 031/839] [XLA:GPU] Unroll unfused elementwise op kernels. So far we only unrolled loop fusions, elementwise ops is a logical extension. We don't spend a lot of time in unfused elementwise ops in benchmarks, so this is only worth a small speedup on V100. PiperOrigin-RevId: 195121530 --- .../xla/service/gpu/ir_emitter_unnested.cc | 46 +++++++++++++------ .../compiler/xla/tests/hlo_test_base.cc | 5 +- 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 26e497762f..9f37235d32 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -257,8 +257,36 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( return kernel; } +namespace { +// Computes the maximum valid unroll factor for a given instruction. +int ComputeMaxUnrollFactor(const HloInstruction* hlo) { + int max_unroll_factor = hlo->GetModule() + ->config() + .debug_options() + .xla_gpu_max_kernel_unroll_factor(); + + // Find the largest possible power of two to unroll by. + // TODO(kramerb): Make this smarter. + int64 num_elements = ShapeUtil::ElementsIn(hlo->shape()); + for (int i = max_unroll_factor; i > 1; i /= 2) { + if (num_elements % i == 0) { + return i; + } + } + + // Cannot unroll. + return 1; +} +} // namespace + Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { - thunk_sequence_->emplace_back(BuildKernelThunk(hlo)); + int unroll_factor = 1; + // Unfused elementwise operations are usually memory bound, unroll them. + if (hlo->IsElementwise()) { + unroll_factor = ComputeMaxUnrollFactor(hlo); + } + + thunk_sequence_->emplace_back(BuildKernelThunk(hlo, unroll_factor)); return IrEmitter::DefaultAction(hlo); } @@ -537,23 +565,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return Status::OK(); } - int max_unroll_factor = fusion->GetModule() - ->config() - .debug_options() - .xla_gpu_max_kernel_unroll_factor(); - - // Find the largest possible power of two to unroll by. - // TODO(kramerb): Make this smarter. int unroll_factor = 1; + // TODO(kramerb): Unrolling multi-output loop fusions too. if (!fusion->IsMultiOutputFusion()) { CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); - int64 num_elements = ShapeUtil::ElementsIn(fusion->shape()); - for (int i = max_unroll_factor; i > 1; i /= 2) { - if (num_elements % i == 0) { - unroll_factor = i; - break; - } - } + unroll_factor = ComputeMaxUnrollFactor(fusion); } thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor)); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 8b64f2e631..12598579c7 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -95,7 +95,10 @@ HloTestBase::HloTestBase(se::Platform* test_platform, /* static */ std::unique_ptr HloTestBase::CreateNewModule(const string& name) { HloModuleConfig config; - config.set_debug_options(GetDebugOptionsForTest()); + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + config.set_debug_options(debug_options); + return MakeUnique(name, VersionedComputationHandle(), config); } -- GitLab From 9b6cba1d739e36ec2da59a593afb09bf17307650 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Wed, 2 May 2018 11:40:09 -0700 Subject: [PATCH 032/839] Internal-only change. PiperOrigin-RevId: 195125476 --- tensorflow/python/kernel_tests/linalg/BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index 6573cb9a1a..052f11f92e 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -62,7 +62,10 @@ cuda_py_test( "//tensorflow/python:platform_test", ], shard_count = 5, - tags = ["noasan"], # times out b/63678675 + tags = [ + "noasan", # times out, b/63678675 + "optonly", # times out + ], ) cuda_py_test( -- GitLab From 1ea4a77c6ccd2c783aedb2ccaf76f46b018c12c5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 11:45:15 -0700 Subject: [PATCH 033/839] Replaced calls to tensorflow::StringPiece::ToString with std::string conversions. That is, instances of sp.ToString() are replaced with std::string(sp). This will allow tensorflow::StringPiece::ToString to be removed, which is necessary before it can be replaced with absl::string_view. PiperOrigin-RevId: 195126422 --- tensorflow/core/kernels/gpu_utils.h | 3 +- .../kernels/merge_v2_checkpoints_op_test.cc | 2 +- .../remote_fused_graph_execute_utils.cc | 32 ++++---- .../core/kernels/save_restore_v2_ops.cc | 4 +- tensorflow/core/kernels/string_strip_op.cc | 2 +- tensorflow/core/kernels/tensor_array_ops.cc | 2 +- .../core/kernels/whole_file_read_ops.cc | 2 +- tensorflow/core/lib/strings/numbers.h | 4 +- tensorflow/core/lib/strings/scanner_test.cc | 82 +++++++++---------- tensorflow/core/lib/strings/str_util.cc | 4 +- tensorflow/core/lib/strings/str_util.h | 2 +- tensorflow/core/platform/env.cc | 4 +- tensorflow/core/platform/env_test.cc | 2 +- tensorflow/core/platform/file_system.cc | 2 +- .../core/platform/file_system_helper.cc | 2 +- tensorflow/core/platform/file_system_test.cc | 2 +- tensorflow/core/util/command_line_flags.cc | 2 +- tensorflow/core/util/env_var.cc | 8 +- .../core/util/example_proto_fast_parsing.cc | 2 +- 19 files changed, 82 insertions(+), 81 deletions(-) diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h index 2f64619afc..c7dbefa0b4 100644 --- a/tensorflow/core/kernels/gpu_utils.h +++ b/tensorflow/core/kernels/gpu_utils.h @@ -123,7 +123,8 @@ class AutoTuneMap { string GetActionSummary(StringPiece action, const Parameters& params, const Config& config) { return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(), - action.ToString().c_str(), params.ToString().c_str(), + std::string(action).c_str(), + params.ToString().c_str(), config.ToString().c_str()); } diff --git a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc index 3b9e9e9b75..10e468ce46 100644 --- a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc +++ b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc @@ -115,7 +115,7 @@ class MergeV2CheckpointsOpTest : public OpsTestBase { for (int i = 0; i < 2; ++i) { int directory_found = Env::Default() - ->IsDirectory(io::Dirname(prefixes[i]).ToString()) + ->IsDirectory(std::string(io::Dirname(prefixes[i]))) .code(); if (delete_old_dirs) { EXPECT_EQ(error::NOT_FOUND, directory_found); diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index cc4d9a49a0..194a711d98 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -47,7 +47,7 @@ std::unordered_set BuildNodeSetFromNodeNamesAndPorts( std::unordered_set retval; for (const string& node_name_and_port : node_names_and_ports) { const TensorId tid = ParseTensorName(node_name_and_port); - retval.emplace(tid.first.ToString()); + retval.emplace(std::string(tid.first)); } return retval; } @@ -64,7 +64,7 @@ Node* FindMutableNodeByName(const string& name, Graph* graph) { const NodeDef* FindNodeDefByName(const string& input, const GraphDef& graph_def) { const TensorId tid = ParseTensorName(input); - const string name = tid.first.ToString(); + const string name = std::string(tid.first); for (const NodeDef& node_def : graph_def.node()) { if (node_def.name() == name) { return &node_def; @@ -77,7 +77,7 @@ bool IsSameNodeName(const NodeDef& node_def, const string& node_name_and_port, TensorId* tid) { CHECK_NOTNULL(tid); *tid = ParseTensorName(node_name_and_port); - if (node_def.name() == tid->first.ToString()) { + if (node_def.name() == tid->first) { return true; } return false; @@ -326,7 +326,7 @@ RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() { const string& node_name) { for (const std::pair& pair : input_tensor_vector) { const TensorId tid = ParseTensorName(pair.first); - if (node_name == tid.first.ToString()) { + if (node_name == tid.first) { return true; } } @@ -423,7 +423,7 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( std::vector data_types; std::vector shapes; const TensorId tid = ParseTensorName(name_and_port); - const string node_name = tid.first.ToString(); + const string node_name = std::string(tid.first); const int port = tid.second; const NodeDef* node_def = FindNodeDefByName(node_name, graph_def); CHECK_NOTNULL(node_def); @@ -522,7 +522,7 @@ RemoteFusedGraphExecuteUtils::GetTensorShapeType( const TensorShapeMap& tensor_shape_map, const string& node_name) { if (node_name.find(':') != string::npos) { const TensorId tid = ParseTensorName(node_name); - return GetTensorShapeType(tensor_shape_map, tid.first.ToString(), + return GetTensorShapeType(tensor_shape_map, std::string(tid.first), tid.second); } else { return GetTensorShapeType(tensor_shape_map, node_name, 0); @@ -570,7 +570,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto( const TensorId tid = ParseTensorName(name); CHECK_EQ(tensor_shape_map->count(name), 0); tensor_shape_map->emplace( - tid.first.ToString(), + std::string(tid.first), std::make_pair(tid.second, std::make_pair(tensor.dtype(), tensor.shape()))); } @@ -692,7 +692,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( std::vector node_out_list; for (const string& input : inputs) { const TensorId tid = ParseTensorName(input); - Node* node = FindMutableNodeByName(tid.first.ToString(), graph); + Node* node = FindMutableNodeByName(std::string(tid.first), graph); CHECK_NOTNULL(node); node_out_list.emplace_back(node, tid.second); } @@ -848,7 +848,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( for (const string& subgraph_input : std::get<1>(cluster)) { const TensorId tid = ParseTensorName(subgraph_input); - const string subgraph_input_name = tid.first.ToString(); + const string subgraph_input_name = std::string(tid.first); const int subgraph_input_port = tid.second; const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def); CHECK_NOTNULL(node_def); @@ -895,7 +895,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( std::deque queue; for (const string& output : border_outputs) { const TensorId tid = ParseTensorName(output); - const string& output_node_name = tid.first.ToString(); + const string& output_node_name = std::string(tid.first); for (const Node* node : graph.nodes()) { if (output_node_name == node->name()) { queue.push_back(node); @@ -916,8 +916,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( bool input_found = false; for (const string& input : border_inputs) { const TensorId tid = ParseTensorName(input); - if (tid.first.ToString() == src_node->name() && - tid.second == src_port) { + if (tid.first == src_node->name() && tid.second == src_port) { input_found = true; border_input_nodes.insert(src_node); } @@ -976,7 +975,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( for (int j = 0; j < border_outputs.size(); ++j) { const string& output = border_outputs.at(j); const TensorId tid = ParseTensorName(output); - const string output_name = tid.first.ToString(); + const string output_name = std::string(tid.first); Node* src_node = edge->src(); if (src_node != nullptr && src_node->name() == output_name && edge->src_output() == tid.second) { @@ -996,11 +995,12 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( // RemoteFusedGraphExecuteOpNode for (const string& output : outputs) { const TensorId output_tid = ParseTensorName(output); - const string output_name = output_tid.first.ToString(); + const string output_name = std::string(output_tid.first); for (size_t i = 0; i < border_outputs.size(); ++i) { const TensorId subgraph_output_tid = ParseTensorName(border_outputs.at(i)); - const string& subgraph_output_name = subgraph_output_tid.first.ToString(); + const string& subgraph_output_name = + std::string(subgraph_output_tid.first); if (output_name == subgraph_output_name) { LOG(INFO) << "As graph output and subgraph output are same, " << "the graph output node is replaced by identity node"; @@ -1435,7 +1435,7 @@ RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions( GraphDef* graph_def) { const TensorId tid = ParseTensorName(input); CHECK_EQ(0, tid.second); - const string node_name = tid.first.ToString(); + const string node_name = std::string(tid.first); for (NodeDef& node : *graph_def->mutable_node()) { if (node.name() != node_name) { continue; diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc index 3acf290ea2..ab4de6c815 100644 --- a/tensorflow/core/kernels/save_restore_v2_ops.cc +++ b/tensorflow/core/kernels/save_restore_v2_ops.cc @@ -220,9 +220,9 @@ class MergeV2Checkpoints : public OpKernel { context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix)); if (delete_old_dirs_) { - const string& merged_dir = io::Dirname(merged_prefix).ToString(); + const string& merged_dir = std::string(io::Dirname(merged_prefix)); for (const string& input_prefix : input_prefixes) { - const string& dirname = io::Dirname(input_prefix).ToString(); + const string& dirname = std::string(io::Dirname(input_prefix)); if (dirname == merged_dir) continue; Status status = env->DeleteDir(dirname); // For sharded save, only the first delete will go through and all diff --git a/tensorflow/core/kernels/string_strip_op.cc b/tensorflow/core/kernels/string_strip_op.cc index ae700f4294..2aeafa28c4 100644 --- a/tensorflow/core/kernels/string_strip_op.cc +++ b/tensorflow/core/kernels/string_strip_op.cc @@ -43,7 +43,7 @@ class StringStripOp : public OpKernel { for (int64 i = 0; i < input.size(); ++i) { StringPiece entry(input(i)); str_util::RemoveWhitespaceContext(&entry); - output(i) = entry.ToString(); + output(i) = std::string(entry); } } }; diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 7ec26d95e6..ef9748b1aa 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -293,7 +293,7 @@ class TensorArrayGradOp : public TensorArrayCreationOp { resource.name()); } tensor_array_name = - StringPiece(resource.name()).substr(container.size()).ToString(); + std::string(StringPiece(resource.name()).substr(container.size())); } auto output_handle = tensor_array_output_handle->flat(); diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc index 17a39ce29b..ed2bf3e8e2 100644 --- a/tensorflow/core/kernels/whole_file_read_ops.cc +++ b/tensorflow/core/kernels/whole_file_read_ops.cc @@ -134,7 +134,7 @@ class WriteFileOp : public OpKernel { "Contents tensor must be scalar, but had shape: ", contents_input->shape().DebugString())); const string& filename = filename_input->scalar()(); - const string dir = io::Dirname(filename).ToString(); + const string dir = std::string(io::Dirname(filename)); if (!context->env()->FileExists(dir).ok()) { OP_REQUIRES_OK(context, context->env()->RecursivelyCreateDir(dir)); } diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h index e9add42849..9cb56415cb 100644 --- a/tensorflow/core/lib/strings/numbers.h +++ b/tensorflow/core/lib/strings/numbers.h @@ -140,11 +140,11 @@ inline bool ProtoParseNumeric(StringPiece s, uint64* value) { } inline bool ProtoParseNumeric(StringPiece s, float* value) { - return safe_strtof(s.ToString().c_str(), value); + return safe_strtof(std::string(s).c_str(), value); } inline bool ProtoParseNumeric(StringPiece s, double* value) { - return safe_strtod(s.ToString().c_str(), value); + return safe_strtod(std::string(s).c_str(), value); } // Convert strings to number of type T. diff --git a/tensorflow/core/lib/strings/scanner_test.cc b/tensorflow/core/lib/strings/scanner_test.cc index 55ff3405c3..b0f568a03e 100644 --- a/tensorflow/core/lib/strings/scanner_test.cc +++ b/tensorflow/core/lib/strings/scanner_test.cc @@ -42,24 +42,24 @@ TEST_F(ScannerTest, Any) { .Any(Scanner::DIGIT) .Any(Scanner::LETTER) .GetResult(&remaining, &match)); - EXPECT_EQ(" horse", match.ToString()); - EXPECT_EQ("0123", remaining.ToString()); + EXPECT_EQ(" horse", match); + EXPECT_EQ("0123", remaining); EXPECT_TRUE(Scanner("") .Any(Scanner::SPACE) .Any(Scanner::DIGIT) .Any(Scanner::LETTER) .GetResult(&remaining, &match)); - EXPECT_EQ("", remaining.ToString()); - EXPECT_EQ("", match.ToString()); + EXPECT_EQ("", remaining); + EXPECT_EQ("", match); EXPECT_TRUE(Scanner("----") .Any(Scanner::SPACE) .Any(Scanner::DIGIT) .Any(Scanner::LETTER) .GetResult(&remaining, &match)); - EXPECT_EQ("----", remaining.ToString()); - EXPECT_EQ("", match.ToString()); + EXPECT_EQ("----", remaining); + EXPECT_EQ("", match); } TEST_F(ScannerTest, AnySpace) { @@ -69,8 +69,8 @@ TEST_F(ScannerTest, AnySpace) { .One(Scanner::LETTER) .AnySpace() .GetResult(&remaining, &match)); - EXPECT_EQ(" a ", match.ToString()); - EXPECT_EQ("b ", remaining.ToString()); + EXPECT_EQ(" a ", match); + EXPECT_EQ("b ", remaining); } TEST_F(ScannerTest, AnyEscapedNewline) { @@ -143,8 +143,8 @@ TEST_F(ScannerTest, ScanUntil) { .ScanUntil('\'') .OneLiteral("'") .GetResult(&remaining, &match)); - EXPECT_EQ(R"( \\'rest)", remaining.ToString()); - EXPECT_EQ(R"(' \1 \2 \3 \')", match.ToString()); + EXPECT_EQ(R"( \\'rest)", remaining); + EXPECT_EQ(R"(' \1 \2 \3 \')", match); // The "scan until" character is not present. remaining = match = "unset"; @@ -152,15 +152,15 @@ TEST_F(ScannerTest, ScanUntil) { .OneLiteral("'") .ScanUntil('\'') .GetResult(&remaining, &match)); - EXPECT_EQ("unset", remaining.ToString()); - EXPECT_EQ("unset", match.ToString()); + EXPECT_EQ("unset", remaining); + EXPECT_EQ("unset", match); // Scan until an escape character. remaining = match = ""; EXPECT_TRUE( Scanner(R"(123\456)").ScanUntil('\\').GetResult(&remaining, &match)); - EXPECT_EQ(R"(\456)", remaining.ToString()); - EXPECT_EQ("123", match.ToString()); + EXPECT_EQ(R"(\456)", remaining); + EXPECT_EQ("123", match); } TEST_F(ScannerTest, ScanEscapedUntil) { @@ -170,8 +170,8 @@ TEST_F(ScannerTest, ScanEscapedUntil) { .ScanEscapedUntil('\'') .OneLiteral("'") .GetResult(&remaining, &match)); - EXPECT_EQ("rest", remaining.ToString()); - EXPECT_EQ(R"(' \1 \2 \3 \' \\')", match.ToString()); + EXPECT_EQ("rest", remaining); + EXPECT_EQ(R"(' \1 \2 \3 \' \\')", match); // The "scan until" character is not present. remaining = match = "unset"; @@ -179,27 +179,27 @@ TEST_F(ScannerTest, ScanEscapedUntil) { .OneLiteral("'") .ScanEscapedUntil('\'') .GetResult(&remaining, &match)); - EXPECT_EQ("unset", remaining.ToString()); - EXPECT_EQ("unset", match.ToString()); + EXPECT_EQ("unset", remaining); + EXPECT_EQ("unset", match); } TEST_F(ScannerTest, ZeroOrOneLiteral) { StringPiece remaining, match; EXPECT_TRUE( Scanner("abc").ZeroOrOneLiteral("abC").GetResult(&remaining, &match)); - EXPECT_EQ("abc", remaining.ToString()); - EXPECT_EQ("", match.ToString()); + EXPECT_EQ("abc", remaining); + EXPECT_EQ("", match); EXPECT_TRUE( Scanner("abcd").ZeroOrOneLiteral("ab").ZeroOrOneLiteral("c").GetResult( &remaining, &match)); - EXPECT_EQ("d", remaining.ToString()); - EXPECT_EQ("abc", match.ToString()); + EXPECT_EQ("d", remaining); + EXPECT_EQ("abc", match); EXPECT_TRUE( Scanner("").ZeroOrOneLiteral("abc").GetResult(&remaining, &match)); - EXPECT_EQ("", remaining.ToString()); - EXPECT_EQ("", match.ToString()); + EXPECT_EQ("", remaining); + EXPECT_EQ("", match); } // Test output of GetResult (including the forms with optional params), @@ -215,24 +215,24 @@ TEST_F(ScannerTest, CaptureAndGetResult) { .StopCapture() .Any(Scanner::SPACE) .GetResult(&remaining, &match)); - EXPECT_EQ("second", remaining.ToString()); - EXPECT_EQ("first", match.ToString()); + EXPECT_EQ("second", remaining); + EXPECT_EQ("first", match); EXPECT_TRUE(scan.GetResult()); remaining = ""; EXPECT_TRUE(scan.GetResult(&remaining)); - EXPECT_EQ("second", remaining.ToString()); + EXPECT_EQ("second", remaining); remaining = ""; match = ""; EXPECT_TRUE(scan.GetResult(&remaining, &match)); - EXPECT_EQ("second", remaining.ToString()); - EXPECT_EQ("first", match.ToString()); + EXPECT_EQ("second", remaining); + EXPECT_EQ("first", match); scan.RestartCapture().One(Scanner::LETTER).One(Scanner::LETTER); remaining = ""; match = ""; EXPECT_TRUE(scan.GetResult(&remaining, &match)); - EXPECT_EQ("cond", remaining.ToString()); - EXPECT_EQ("se", match.ToString()); + EXPECT_EQ("cond", remaining); + EXPECT_EQ("se", match); } // Tests that if StopCapture is not called, then calling GetResult, then @@ -242,14 +242,14 @@ TEST_F(ScannerTest, MultipleGetResultExtendsCapture) { Scanner scan("one2three"); EXPECT_TRUE(scan.Many(Scanner::LETTER).GetResult(&remaining, &match)); - EXPECT_EQ("2three", remaining.ToString()); - EXPECT_EQ("one", match.ToString()); + EXPECT_EQ("2three", remaining); + EXPECT_EQ("one", match); EXPECT_TRUE(scan.Many(Scanner::DIGIT).GetResult(&remaining, &match)); - EXPECT_EQ("three", remaining.ToString()); - EXPECT_EQ("one2", match.ToString()); + EXPECT_EQ("three", remaining); + EXPECT_EQ("one2", match); EXPECT_TRUE(scan.Many(Scanner::LETTER).GetResult(&remaining, &match)); - EXPECT_EQ("", remaining.ToString()); - EXPECT_EQ("one2three", match.ToString()); + EXPECT_EQ("", remaining); + EXPECT_EQ("one2three", match); } TEST_F(ScannerTest, FailedMatchDoesntChangeResult) { @@ -258,8 +258,8 @@ TEST_F(ScannerTest, FailedMatchDoesntChangeResult) { StringPiece remaining = "rem"; StringPiece match = "match"; EXPECT_FALSE(scan.One(Scanner::SPACE).GetResult(&remaining, &match)); - EXPECT_EQ("rem", remaining.ToString()); - EXPECT_EQ("match", match.ToString()); + EXPECT_EQ("rem", remaining); + EXPECT_EQ("match", match); } TEST_F(ScannerTest, DefaultCapturesAll) { @@ -271,8 +271,8 @@ TEST_F(ScannerTest, DefaultCapturesAll) { .AnySpace() .Any(Scanner::LETTER) .GetResult(&remaining, &match)); - EXPECT_EQ("", remaining.ToString()); - EXPECT_EQ("a b", match.ToString()); + EXPECT_EQ("", remaining); + EXPECT_EQ("a b", match); } TEST_F(ScannerTest, AllCharClasses) { diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc index 4598b8ccc7..cab8f81585 100644 --- a/tensorflow/core/lib/strings/str_util.cc +++ b/tensorflow/core/lib/strings/str_util.cc @@ -332,7 +332,7 @@ string StringReplace(StringPiece s, StringPiece oldsub, StringPiece newsub, bool replace_all) { // TODO(jlebar): We could avoid having to shift data around in the string if // we had a StringPiece::find() overload that searched for a StringPiece. - string res = s.ToString(); + string res = std::string(s); size_t pos = 0; while ((pos = res.find(oldsub.data(), pos, oldsub.size())) != string::npos) { res.replace(pos, oldsub.size(), newsub.data(), newsub.size()); @@ -449,7 +449,7 @@ bool SplitAndParseAsFloats(StringPiece text, char delim, return SplitAndParseAsInts(text, delim, [](StringPiece str, float* value) { return strings::safe_strtof( - str.ToString().c_str(), value); + std::string(str).c_str(), value); }, result); } diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h index e97d00b975..c887db7eff 100644 --- a/tensorflow/core/lib/strings/str_util.h +++ b/tensorflow/core/lib/strings/str_util.h @@ -205,7 +205,7 @@ std::vector Split(StringPiece text, StringPiece delims, Predicate p) { if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) { StringPiece token(text.data() + token_start, i - token_start); if (p(token)) { - result.push_back(token.ToString()); + result.push_back(std::string(token)); } token_start = i + 1; } diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc index b9a9ef85eb..fe7d0aa7d1 100644 --- a/tensorflow/core/platform/env.cc +++ b/tensorflow/core/platform/env.cc @@ -92,7 +92,7 @@ Env::Env() : file_system_registry_(new FileSystemRegistryImpl) {} Status Env::GetFileSystemForFile(const string& fname, FileSystem** result) { StringPiece scheme, host, path; io::ParseURI(fname, &scheme, &host, &path); - FileSystem* file_system = file_system_registry_->Lookup(scheme.ToString()); + FileSystem* file_system = file_system_registry_->Lookup(std::string(scheme)); if (!file_system) { if (scheme.empty()) { scheme = "[local]"; @@ -166,7 +166,7 @@ bool Env::FilesExist(const std::vector& files, for (const auto& file : files) { StringPiece scheme, host, path; io::ParseURI(file, &scheme, &host, &path); - files_per_fs[scheme.ToString()].push_back(file); + files_per_fs[std::string(scheme)].push_back(file); } std::unordered_map per_file_status; diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc index a70a417e6a..c461a40086 100644 --- a/tensorflow/core/platform/env_test.cc +++ b/tensorflow/core/platform/env_test.cc @@ -357,7 +357,7 @@ TEST_F(DefaultEnvTest, LocalTempFilename) { CHECK_EQ(error::OUT_OF_RANGE, file_to_read->Read(0 /* offset */, 1024 /* n */, &content, scratch) .code()); - EXPECT_EQ("Null", content.ToString()); + EXPECT_EQ("Null", content); // Delete the temporary file. TF_CHECK_OK(env->DeleteFile(filename)); diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc index b55e94d552..922773684b 100644 --- a/tensorflow/core/platform/file_system.cc +++ b/tensorflow/core/platform/file_system.cc @@ -158,7 +158,7 @@ Status FileSystem::RecursivelyCreateDir(const string& dirname) { std::reverse(sub_dirs.begin(), sub_dirs.end()); // Now create the directories. - string built_path = remaining_dir.ToString(); + string built_path = std::string(remaining_dir); for (const StringPiece sub_dir : sub_dirs) { built_path = io::JoinPath(built_path, sub_dir); Status status = CreateDir(io::CreateURI(scheme, host, built_path)); diff --git a/tensorflow/core/platform/file_system_helper.cc b/tensorflow/core/platform/file_system_helper.cc index 22c5057281..0ba0e6304f 100644 --- a/tensorflow/core/platform/file_system_helper.cc +++ b/tensorflow/core/platform/file_system_helper.cc @@ -59,7 +59,7 @@ Status GetMatchingPaths(FileSystem* fs, Env* env, const string& pattern, string fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\")); string eval_pattern = pattern; std::vector all_files; - string dir = io::Dirname(fixed_prefix).ToString(); + string dir = std::string(io::Dirname(fixed_prefix)); // If dir is empty then we need to fix up fixed_prefix and eval_pattern to // include . as the top level directory. if (dir.empty()) { diff --git a/tensorflow/core/platform/file_system_test.cc b/tensorflow/core/platform/file_system_test.cc index f261b8f576..c0a16c95f9 100644 --- a/tensorflow/core/platform/file_system_test.cc +++ b/tensorflow/core/platform/file_system_test.cc @@ -125,7 +125,7 @@ class InterPlanetaryFileSystem : public NullFileSystem { ASSERT_EQ(scheme, "ipfs"); ASSERT_EQ(host, "solarsystem"); str_util::ConsumePrefix(&path, "/"); - *parsed_path = path.ToString(); + *parsed_path = std::string(path); } std::map> celestial_bodies_ = { diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc index 480ce94fca..8c27d01917 100644 --- a/tensorflow/core/util/command_line_flags.cc +++ b/tensorflow/core/util/command_line_flags.cc @@ -32,7 +32,7 @@ bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, if (str_util::ConsumePrefix(&arg, "--") && str_util::ConsumePrefix(&arg, flag) && str_util::ConsumePrefix(&arg, "=")) { - *value_parsing_ok = hook(arg.ToString()); + *value_parsing_ok = hook(std::string(arg)); return true; } diff --git a/tensorflow/core/util/env_var.cc b/tensorflow/core/util/env_var.cc index c844850179..8d43bcc927 100644 --- a/tensorflow/core/util/env_var.cc +++ b/tensorflow/core/util/env_var.cc @@ -28,7 +28,7 @@ namespace tensorflow { Status ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val, bool* value) { *value = default_val; - const char* tf_env_var_val = getenv(env_var_name.ToString().c_str()); + const char* tf_env_var_val = getenv(std::string(env_var_name).c_str()); if (tf_env_var_val == nullptr) { return Status::OK(); } @@ -48,7 +48,7 @@ Status ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val, Status ReadInt64FromEnvVar(StringPiece env_var_name, int64 default_val, int64* value) { *value = default_val; - const char* tf_env_var_val = getenv(env_var_name.ToString().c_str()); + const char* tf_env_var_val = getenv(std::string(env_var_name).c_str()); if (tf_env_var_val == nullptr) { return Status::OK(); } @@ -62,11 +62,11 @@ Status ReadInt64FromEnvVar(StringPiece env_var_name, int64 default_val, Status ReadStringFromEnvVar(StringPiece env_var_name, StringPiece default_val, string* value) { - const char* tf_env_var_val = getenv(env_var_name.ToString().c_str()); + const char* tf_env_var_val = getenv(std::string(env_var_name).c_str()); if (tf_env_var_val != nullptr) { *value = tf_env_var_val; } else { - *value = default_val.ToString(); + *value = std::string(default_val); } return Status::OK(); } diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index 7946fa1782..3ce7988057 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -353,7 +353,7 @@ bool TestFastParse(const string& serialized, Example* example) { // I.e. last entry in the map overwrites all the previous ones. parsed::FeatureMapEntry& name_and_feature = parsed_example[parsed_example_size - i - 1]; - string name = name_and_feature.first.ToString(); + string name = std::string(name_and_feature.first); if ((*features.mutable_feature()).count(name) > 0) continue; auto& value = (*features.mutable_feature())[name]; -- GitLab From 3c0afb1cf6679097c2316fda8803b3679b37871f Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Wed, 2 May 2018 11:57:24 -0700 Subject: [PATCH 034/839] Turn on two half precision tests for GPU. PiperOrigin-RevId: 195128326 --- tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 7fa61eb33c..6cb470caf8 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -52,12 +52,7 @@ class MatOpsSimpleTest : public ClientLibraryTestBase {}; template class MatOpsSimpleTest_F16F32 : public MatOpsSimpleTest {}; -// TODO(bixia): This test for F16 failed on GPU 02-25-2018. -#ifdef XLA_TEST_BACKEND_GPU -TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, ::testing::Types); -#else TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, TypesF16F32); -#endif XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { using T = TypeParam; @@ -171,11 +166,8 @@ string PrintTestLinspaceMaxParam( } #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 -// TODO(bixia): This test failed on GPU 02-25-2018 -#ifdef XLA_TEST_BACKEND_CPU XLA_TEST_P(TestLinspaceMaxParametric, TestF16) { TestImpl(); } #endif -#endif XLA_TEST_P(TestLinspaceMaxParametric, TestF32) { TestImpl(); } INSTANTIATE_TEST_CASE_P( -- GitLab From ce0ef2275bda40a6edcd738ccede61ccd3dd824b Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Wed, 2 May 2018 12:32:28 -0700 Subject: [PATCH 035/839] docs: Link to the appropriately branched version of the live colab notebooks. And update that link on release changes. PiperOrigin-RevId: 195133689 --- tensorflow/docs_src/get_started/eager.md | 2 +- tensorflow/tools/ci_build/update_version.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tensorflow/docs_src/get_started/eager.md b/tensorflow/docs_src/get_started/eager.md index ad89f0154c..f08ac74425 100644 --- a/tensorflow/docs_src/get_started/eager.md +++ b/tensorflow/docs_src/get_started/eager.md @@ -1,3 +1,3 @@ # Get Started with Eager Execution -[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/get_started/eager.ipynb) +[Colab notebook](https://colab.research.google.com/github/tensorflow/models/blob/r1.8.0/samples/core/get_started/eager.ipynb) diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py index 52a0da9a14..9ddb219048 100755 --- a/tensorflow/tools/ci_build/update_version.py +++ b/tensorflow/tools/ci_build/update_version.py @@ -248,6 +248,16 @@ def update_md_files(old_version, new_version): replace_string_in_line(r"%s<\/version>" % old_version, "%s" % new_version, filepath) + # Update any links to colab notebooks. + def colab_url(version): + version_string = "%d.%d.%d" % (version.major, version.minor, version.patch) + prefix = "https://colab.research.google.com/github/tensorflow/models/blob/r" + return prefix + version_string + "/" + + replace_string_in_line( + colab_url(old_version), colab_url(new_version), + "%s/docs_src/get_started/eager.md" % TF_SRC_DIR) + def major_minor_change(old_version, new_version): """Check if a major or minor change occurred.""" -- GitLab From 262b176e27a3bcd01d518bea6d57683625df42b6 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 2 May 2018 12:58:06 -0700 Subject: [PATCH 036/839] Added support for packing of symbolic shapes PiperOrigin-RevId: 195137239 --- .../core/grappler/costs/graph_properties.cc | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 69b22561b2..23d25cba8d 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -770,6 +770,29 @@ class SymbolicShapeRefiner { c->output_tensors_as_shapes.resize(1); c->output_tensors_as_shapes[0] = result; } + } else if (IsPack(node)) { + // A Pack node concatenating scalars is often used to generate a shape. + std::vector dims; + bool valid = true; + for (int i = 0; i < ic->num_inputs(); ++i) { + const Tensor* t = ic->input_tensor(i); + if (t) { + if (t->dims() != 0 || + (t->dtype() != DT_INT32 && t->dtype() != DT_INT64)) { + valid = false; + break; + } + int64 size = t->dtype() == DT_INT32 ? t->scalar()() + : t->scalar()(); + dims.push_back(size < 0 ? ic->UnknownDim() : ic->MakeDim(size)); + } else { + dims.push_back(ic->UnknownDim()); + } + } + if (valid) { + c->output_tensors_as_shapes.resize(1); + c->output_tensors_as_shapes[0] = ic->MakeShape(dims); + } } else if (IsSlice(node)) { ShapeHandle input = ic->input_tensors_as_shapes()[0]; bool valid = ic->RankKnown(input); -- GitLab From ad491ad2c258fdb71cc0cea5bffe7931622e749f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 13:05:15 -0700 Subject: [PATCH 037/839] [XLA] Redesign: Dump HloSnapshot in local service as well. And support replaying HloSnapshot. PiperOrigin-RevId: 195138472 --- .../xla/service/compile_only_service.cc | 16 +++++++ tensorflow/compiler/xla/tools/BUILD | 1 + .../compiler/xla/tools/replay_computation.cc | 44 +++++++++++++++++-- 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index c9f78a0f9f..d39fd7307a 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -70,6 +70,22 @@ CompileOnlyService::CompileAheadOfTime( TF_RET_CHECK(instance.computation.has_program_shape()); const DebugOptions& debug_options = options.debug_options(); + + // Dump computation proto if flag is set. + const string& directory_path = debug_options.xla_dump_computations_to(); + if (!directory_path.empty()) { + HloSnapshot hlo_snapshot; + *hlo_snapshot.mutable_hlo()->mutable_hlo_module() = instance.computation; + string filename = tensorflow::strings::StrCat( + "computation_", instance.computation.id(), "__", + instance.computation.entry_computation_name()); + const string& per_host_path = tensorflow::io::JoinPath( + directory_path, tensorflow::port::Hostname()); + + TF_RETURN_IF_ERROR( + Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot)); + } + const auto& program_shape = instance.computation.program_shape(); ExecutionOptions execution_options; *execution_options.mutable_debug_options() = debug_options; diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 0bc4045a54..78ab2dccaf 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -88,6 +88,7 @@ cc_library( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:testing", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 62a353ad09..d8cedad65e 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -75,9 +76,14 @@ struct Options { // // Similarly, infeeds fake data of shape fake_infeed_shape if it is provided; // otherwise, no infeed is performed. -StatusOr> ReplayComputation( - const SessionModule& module, Client* client, const Options& opts) { - TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module)); +template +StatusOr> ReplayComputation(const ModuleT& module, + Client* client, + const Options& opts) { + static_assert(std::is_same::value || + std::is_same::value, + "Proto must be in HloSnapshot or SessionModule format"); + TF_ASSIGN_OR_RETURN(auto computation, client->LoadSnapshot(module)); std::vector> arguments; if (opts.use_fake_data) { @@ -153,6 +159,38 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { tensorflow::Env* env = tensorflow::Env::Default(); int exit_status = EXIT_SUCCESS; for (char* arg : args) { + HloSnapshot snapshot; + auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot); + if (status.ok()) { + StatusOr> result_status = + ReplayComputation(snapshot, client, opts); + if (!result_status.ok()) { + fprintf(stderr, "%s: error: %s\n", arg, + result_status.status().ToString().c_str()); + exit_status = EXIT_FAILURE; + continue; + } + + std::unique_ptr result = result_status.ConsumeValueOrDie(); + if (result != nullptr) { + fprintf(stdout, "%s: %s :: %s:%s\n", arg, + snapshot.hlo().hlo_module().name().c_str(), + ShapeUtil::HumanString(result->shape()).c_str(), + result->ToString().c_str()); + if (snapshot.has_result()) { + std::unique_ptr literal = + Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); + fprintf(stdout, "was %s:%s\n", + ShapeUtil::HumanString(snapshot.result().shape()).c_str(), + literal->ToString().c_str()); + } + } + + continue; + } + fprintf(stderr, "%s: is not HloSnapshot: %s. Trying as SessionModule...\n", + arg, status.ToString().c_str()); + SessionModule module; TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module)); StatusOr> result_status = -- GitLab From b182fd88e10b1d36f30a349e312c3a7ae5f3cc95 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Wed, 2 May 2018 13:05:51 -0700 Subject: [PATCH 038/839] Increasing test size to reflect recent additions and prevent test timeouts. PiperOrigin-RevId: 195138565 --- tensorflow/contrib/data/python/kernel_tests/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index d59dd17aea..7643c2a9fc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -32,7 +32,7 @@ py_test( py_test( name = "bucketing_test", - size = "small", + size = "medium", srcs = ["bucketing_test.py"], srcs_version = "PY2AND3", deps = [ -- GitLab From 79f6d50d784cf27c6e1fb5200ca5022a334198fe Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Wed, 2 May 2018 13:23:56 -0700 Subject: [PATCH 039/839] Fix tsan failure in batch_dataset_op_test. The error was being caused because we were trying to save invocation_results_` while the function call was in progress. Now we wait for all invocations to finish before saving both `invocation_results_` and `batch_results_`. Did local A/B testing for tsan. Before: 12/100 failed After: All passed PiperOrigin-RevId: 195141349 --- .../kernels/data/map_and_batch_dataset_op.cc | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 7bc43e2072..c9551fbf16 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -272,6 +272,15 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_batch_index"), current_batch_index_)); TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + // Wait for the map_fn dispatches made in `InvokeFunctionLocked` to + // finish. This may delay saving a checkpoint by a bit but keeps the + // code clean and also saves us from checkpointing the state of the + // `BlockingCounter`. + std::vector num_elements(batch_results_.size()); + for (size_t i = 0; i < batch_results_.size(); i++) { + WaitForBatch(i, &num_elements[i]).IgnoreError(); + } + TF_RETURN_IF_ERROR(writer->WriteScalar( full_name("invocation_results_size"), invocation_results_.size())); for (size_t i = 0; i < invocation_results_.size(); ++i) { @@ -280,7 +289,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("batch_results_size"), batch_results_.size())); for (size_t i = 0; i < batch_results_.size(); ++i) { - TF_RETURN_IF_ERROR(WriteBatchResultLocked(writer, i)); + TF_RETURN_IF_ERROR( + WriteBatchResultLocked(writer, i, num_elements[i])); } return Status::OK(); } @@ -567,15 +577,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - Status WriteBatchResultLocked(IteratorStateWriter* writer, size_t index) + Status WriteBatchResultLocked(IteratorStateWriter* writer, size_t index, + int64 num_elements) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - // Wait for the map_fn dispatches made in `InvokeFunctionLocked` to - // finish. This may delay saving a checkpoint by a bit but keeps the - // code clean and also saves us from checkpointing the state of the - // `BlockingCounter`. - int64 num_elements = 0; - WaitForBatch(index, &num_elements).IgnoreError(); - const BatchResult& result = batch_results_[index]; string prefix = strings::StrCat("batch_results_", index); { -- GitLab From 2706eeb1fbd4a2cf0e1af8efa3c7f3539944079e Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Wed, 2 May 2018 13:29:01 -0700 Subject: [PATCH 040/839] Re-enabling a test. PiperOrigin-RevId: 195142105 --- .../contrib/data/python/kernel_tests/batch_dataset_op_test.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index a4a0ce79b6..6588fd04ac 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -630,9 +630,7 @@ class BatchDatasetSerializationTest( lambda x: array_ops.fill([x], x)).apply( batching.dense_to_sparse_batch(4, [12])) - # TODO(b/70988345): Re-enable when sparse tensors are properly supported by - # the DatasetSerializationTestBase. - def _testDenseToSparseBatchDatasetCore(self): + def testDenseToSparseBatchDatasetCore(self): components = np.random.randint(5, size=(40,)).astype(np.int32) diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) -- GitLab From f9e8a75036154a73f256783eccf53bca6612d606 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Wed, 2 May 2018 13:36:31 -0700 Subject: [PATCH 041/839] [XLA] Add new optimization that sinks constants into while loop bodies Example transformation: state = (..., const, ...) while (pred(state)) { (..., v, ...) = state use(v) state = (..., v, ...) } => state = (..., const, ...) while (pred(state)) { (..., v, ...) = state use(const) state = (..., v, ...) } PiperOrigin-RevId: 195143323 --- tensorflow/compiler/xla/service/BUILD | 27 +++ tensorflow/compiler/xla/service/cpu/BUILD | 1 + .../compiler/xla/service/cpu/cpu_compiler.cc | 2 + tensorflow/compiler/xla/service/hlo_module.cc | 8 + tensorflow/compiler/xla/service/hlo_module.h | 5 + .../service/while_loop_constant_sinking.cc | 128 +++++++++++ .../xla/service/while_loop_constant_sinking.h | 68 ++++++ .../while_loop_constant_sinking_test.cc | 200 ++++++++++++++++++ .../while_loop_invariant_code_motion.cc | 27 +-- tensorflow/compiler/xla/service/while_util.cc | 17 ++ tensorflow/compiler/xla/service/while_util.h | 6 + .../compiler/xla/service/while_util_test.cc | 37 ++++ 12 files changed, 506 insertions(+), 20 deletions(-) create mode 100644 tensorflow/compiler/xla/service/while_loop_constant_sinking.cc create mode 100644 tensorflow/compiler/xla/service/while_loop_constant_sinking.h create mode 100644 tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 6e2510aa10..17964cdd59 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2687,6 +2687,33 @@ tf_cc_test( ], ) +cc_library( + name = "while_loop_constant_sinking", + srcs = ["while_loop_constant_sinking.cc"], + hdrs = ["while_loop_constant_sinking.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":while_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "while_loop_constant_sinking_test", + srcs = ["while_loop_constant_sinking_test.cc"], + deps = [ + ":hlo_matchers", + ":while_loop_constant_sinking", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:test", + ], +) + cc_library( name = "despecializer", srcs = ["despecializer.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 2fc6c6bd55..cb81e413a3 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -131,6 +131,7 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_constant_sinking", "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index e298d67e09..91ed6e427a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -87,6 +87,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" @@ -270,6 +271,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pass.AddPass(); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 987c4b2719..c7a7192867 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -540,6 +540,14 @@ uint64 HloModule::RandomNew64() const { return rng_(); } +HloComputation* HloModule::GetComputationWithName( + tensorflow::StringPiece name) { + auto it = c_find_if(computations(), [&](HloComputation* computation) { + return computation->name() == name; + }); + return it == computations().end() ? nullptr : *it; +} + /* static */ std::atomic HloModule::next_unique_module_id_(0); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 82d790ec3b..f9674df812 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" @@ -138,6 +139,10 @@ class HloModule { MakeUnwrappingIterator(computations_.end())}; } + // Returns the computation in this module that has the name `name`. Returns + // null if there is no such computation. + HloComputation* GetComputationWithName(tensorflow::StringPiece name); + // Gets the number of computations in this module. int64 computation_count() const { return computations_.size(); } diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc new file mode 100644 index 0000000000..10fc4958fa --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -0,0 +1,128 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace xla { + +// Replaces all uses of old_instr with new_instr except the use at +// `while_body_root` (which must be a tuple instruction) at index `tuple_index`. +// This utility helps us replace an instruction in the while body with a +// constant while still keeping it trivially loop invariant. +static Status ReplaceUsesWhileKeepingLoopInvariance( + HloInstruction* old_instr, HloInstruction* new_instr, + HloInstruction* while_body_root, int64 tuple_index) { + CHECK_EQ(while_body_root->opcode(), HloOpcode::kTuple); + + std::vector users; + users.reserve(old_instr->user_count()); + c_copy(old_instr->users(), std::back_inserter(users)); + + for (auto* user : users) { + for (int64 i = 0, e = user->operand_count(); i < e; i++) { + if (user->operand(i) == old_instr && + !(user == while_body_root && i == tuple_index)) { + TF_RETURN_IF_ERROR(user->ReplaceOperandWith(i, new_instr)); + } + } + } + + return Status::OK(); +} + +StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody( + HloInstruction* while_instr) { + HloComputation* while_body = while_instr->while_body(); + + const HloInstruction& init_value = *while_instr->operand(0); + if (init_value.opcode() != HloOpcode::kTuple) { + return false; + } + + bool changed = false; + + for (HloInstruction* invariant_gte : + WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { + int64 index = invariant_gte->tuple_index(); + const HloInstruction& invariant_value = *init_value.operand(index); + if (invariant_value.opcode() == HloOpcode::kConstant) { + auto* constant_instr = + while_body->AddInstruction(invariant_value.Clone(/*suffix=*/".sunk")); + TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance( + invariant_gte, constant_instr, while_body->root_instruction(), + index)); + changed = true; + } + } + + return changed; +} + +StatusOr WhileLoopConstantSinking::Run(HloModule* module) { + VLOG(2) << "HLO module before WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + + bool changed = false; + std::vector while_instrs; + for (auto* comp : module->MakeNonfusionComputations()) { + // Right now we don't particulary care about optimizing while-of-while + // patterns. If/When we do, we'll want to visit the outer while (while_0) + // before we visit the inner while (while_1): + // + // while_1_body(state) { + // val = gte(state, 0) // Loop invariant + // use(val) + // } + // + // while_0_body(state) { + // val = gte(state, 0) // Loop invariant + // while_1 = while(init=tuple(val, ...), body=while_1_body, ...) + // ... + // } + // + // main { + // while_0 = while(init=(constant, ...), body=while_0_body, ...) + // } + // + // This will let us sink the constant into the outer while first and then + // into the inner while in a single run of this pass. + c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); + } + + for (HloInstruction* while_instr : while_instrs) { + // We only sink into while loop bodies, but this can be extended to + // transform conditions as well. + TF_ASSIGN_OR_RETURN(bool result, + TrySinkingConstantsIntoWhileBody(while_instr)); + changed |= result; + } + + if (changed) { + VLOG(2) << "HLO module after WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking"; + } + + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h new file mode 100644 index 0000000000..21fb8568a8 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONSTANT_SINKING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONSTANT_SINKING_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Sinks while loop invariant values that happen to be constants into the while +// loop body. This is probably not a win in isolation but may unlock further +// optimizations like constant folding. +// +// state = (..., const, ...) +// while (pred(state)) { +// (..., v, ...) = state +// use(v) +// state = (..., v, ...) +// } +// +// => +// +// state = (..., const, ...) +// while (pred(state)) { +// (..., v, ...) = state +// use(const) +// state = (..., v, ...) +// } +// +// Note that it leaves the `v` in place to keep that component of the state +// tuple trivially loop invariant. WhileLoopSimplifier will later get rid of +// `v`. +// +// We only sink into while loop bodies, but this can be extended to transform +// conditions as well. +// +// TODO(b/79121449): We should also sink broadcasts of constants. +class WhileLoopConstantSinking : public HloPassInterface { + public: + ~WhileLoopConstantSinking() override = default; + + tensorflow::StringPiece name() const override { + return "while-loop-invariant-code-motion"; + } + + StatusOr Run(HloModule* module) override; + + private: + StatusOr TrySinkingConstantsIntoWhileBody(HloInstruction* while_instr); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONSTANT_SINKING_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc new file mode 100644 index 0000000000..0d2288d8ea --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -0,0 +1,200 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; + +class WhileLoopConstantSinkingTest : public ::testing::Test {}; + +TEST_F(WhileLoopConstantSinkingTest, SinkOneConstant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2],f32[2]) parameter(0) + p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 + p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 + + add.0 = f32[2] add(p_body.0, p_body.1) + ROOT root = (f32[2],f32[2]) tuple(add.0, p_body.1) +} + +condition { + p_cond = (f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + while_init = (f32[2],f32[2]) tuple(const_0, const_1) + ROOT while = (f32[2],f32[2]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(_, op::Constant()), _)); +} + +TEST_F(WhileLoopConstantSinkingTest, KeepConstantsLoopInvariant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2],f32[2],f32[2]) parameter(0) + p_body.0 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=0 + p_body.1 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=1 + p_body.2 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=2 + + add.0 = f32[2] add(p_body.1, p_body.2) + ROOT root = (f32[2],f32[2],f32[2]) tuple(add.0, p_body.1, p_body.2) +} + +condition { + p_cond = (f32[2],f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + const_2 = f32[2] constant({3, 1}) + while_init = (f32[2],f32[2],f32[2]) tuple(const_0, const_1, const_2) + ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(op::Constant(), op::Constant()), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0)))); +} + +TEST_F(WhileLoopConstantSinkingTest, TupleShapedConstants) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_b = (f32[2],(f32[2],f32[2])) parameter(0) + p_b.0 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=0 + p_b.1 = (f32[2],f32[2]) get-tuple-element((f32[2],(f32[2],f32[2])) p_b), index=1 + + p_b.1.1 = f32[2] get-tuple-element(p_b.1), index=0 + + ROOT root = (f32[2],f32[2],f32[2]) tuple(p_b.1.1, p_b.1) +} + +condition { + p_cond = (f32[2],(f32[2],f32[2])) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = (f32[2], f32[2]) constant((f32[2], f32[2]) ({2, 1},{3,1})) + while_init = (f32[2],(f32[2],f32[2])) tuple(const_0, const_1) + ROOT while = (f32[2],(f32[2],f32[2])) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::GetTupleElement(op::Constant(), 0), + op::GetTupleElement(op::Parameter(0)))); +} + +TEST_F(WhileLoopConstantSinkingTest, DuplicateGTEs) { + // This test shows that the pass fails to optimize non-canonical IR. + // + // Even though the input IR has a constant value for p_b.2.dup, + // WhileLoopConstantSinking doesn't try to detect this. Instead, it relies on + // prior runs of HLO CSE to have commoned these identical GTE instructions. + + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_b = (f32[2],f32[2],f32[2]) parameter(0) + + p_b.1 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=1 + p_b.2 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=2 + p_b.2.dup = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=2 + + add.0 = f32[2] add(p_b.1, p_b.2.dup) + ROOT root = (f32[2],f32[2],f32[2]) tuple(add.0, p_b.1, p_b.2) +} + +condition { + p_cond = (f32[2],f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + const_2 = f32[2] constant({3, 1}) + while_init = (f32[2],f32[2],f32[2]) tuple(const_0, const_1, const_2) + ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(op::Constant(), ::testing::Not(op::Constant())), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0)))); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 3ef0cdff67..321fdeb1ea 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -115,25 +115,6 @@ static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { } } -// Populates `gte_set` with the GetTupleElement instructions in `while_body` -// that access elements in the parameter tuple that don't change across -// iterations. Assumes `while_body` is the body computation of the while loop -// in question. -static void GatherInvariantGTEs(HloComputation* while_body, - FlatSet* gte_set) { - const HloInstruction::InstructionVector root_operands = - while_body->root_instruction()->operands(); - for (int i = 0; i < root_operands.size(); i++) { - HloInstruction* instr = root_operands[i]; - if (instr->opcode() == HloOpcode::kGetTupleElement && - instr->tuple_index() == i && - instr->operand(0) == while_body->parameter_instruction(0) && - ShapeUtil::IsArray(instr->shape())) { - InsertOrDie(gte_set, instr); - } - } -} - static StatusOr TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* while_instr) { auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); @@ -172,7 +153,13 @@ static StatusOr TryHoistingInvariantInstructionsFromWhileBody( // unhoisted_invariant_instructions -- they can be legally hoisted, but there // is no benefit to hoisting them unless something that uses it is also // hoisted. - GatherInvariantGTEs(while_body, &unhoisted_invariant_instructions); + for (auto* instr : WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { + if (ShapeUtil::IsArray(instr->shape())) { + // TODO(b/79147885): We should try to generalize this to tuples for + // uniformity's sake, if nothing else. + InsertOrDie(&unhoisted_invariant_instructions, instr); + } + } if (unhoisted_invariant_instructions.empty()) { // There are no obviously loop invariant elements in the state being diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index bd07941843..ed20b36292 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -244,4 +244,21 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { } return result; } + +/*static*/ std::vector WhileUtil::GetInvariantGTEsForWhileBody( + const HloComputation& while_body) { + std::vector result; + const HloInstruction::InstructionVector root_operands = + while_body.root_instruction()->operands(); + for (int i = 0; i < root_operands.size(); i++) { + HloInstruction* instr = root_operands[i]; + if (instr->opcode() == HloOpcode::kGetTupleElement && + instr->tuple_index() == i && + instr->operand(0) == while_body.parameter_instruction(0)) { + result.push_back(instr); + } + } + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index 1688d46742..322d27b88c 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -74,6 +74,12 @@ class WhileUtil { HloComputation* computation, int32 trip_count, const LoopStateTy& init_values, const LoopBodyGeneratorTy& loop_body_generator); + + // Returns the GetTupleElement instructions in `while_body` that access + // elements in the parameter tuple that don't change across iterations. + // Assumes `while_body` is the body computation of the while loop in question. + static std::vector GetInvariantGTEsForWhileBody( + const HloComputation& while_body); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index cf0d0db99b..974bc542a3 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -126,5 +126,42 @@ TEST(WhileUtilTest, MakeTwoInstructionsLive) { op::GetTupleElement(op::Parameter(0), 3))); } +TEST(WhileUtilTest, GetInvariantGTEsForWhileBody) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + param.b = (s32[], s32[]) parameter(0) + gte.0 = s32[] get-tuple-element(param.b), index=0 + gte.1 = s32[] get-tuple-element(param.b), index=1 + add = s32[] add(gte.0, gte.1) + ROOT tuple = (s32[], s32[]) tuple(gte.0, add) +} + +cond { + param.c = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) +} + +ENTRY main { + init = (s32[], s32[]) parameter(0) + ROOT while = (s32[], s32[]) while(init), condition=cond, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + HloComputation* while_body = module->GetComputationWithName("body"); + + ASSERT_NE(while_body, nullptr) + << "Expected exactly one while_body computation"; + + std::vector gte_list = + WhileUtil::GetInvariantGTEsForWhileBody(*while_body); + + ASSERT_EQ(gte_list.size(), 1); + EXPECT_EQ((*gte_list.begin())->name(), "gte.0"); +} } // namespace } // namespace xla -- GitLab From bd6c00aabe9a34715a5b2026eeccac4bc2a8d0de Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 13:42:40 -0700 Subject: [PATCH 042/839] Fix a bug in create_python_api.py I got an error complaining about "RuntimeError: dictionary changed size during iteration", this change fixes it. PiperOrigin-RevId: 195144333 --- tensorflow/tools/api/generator/create_python_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py index c06a39bfbd..65baa6e4b4 100644 --- a/tensorflow/tools/api/generator/create_python_api.py +++ b/tensorflow/tools/api/generator/create_python_api.py @@ -158,7 +158,7 @@ def get_api_init_text(): # Traverse over everything imported above. Specifically, # we want to traverse over TensorFlow Python modules. - for module in sys.modules.values(): + for module in list(sys.modules.values()): # Only look at tensorflow modules. if (not module or not hasattr(module, '__name__') or 'tensorflow.' not in module.__name__): -- GitLab From 8f610384b61f7b1b62302b9a861c1d4b19b36b33 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 13:44:30 -0700 Subject: [PATCH 043/839] Updated ABSL to latest version in workspace.bzl. PiperOrigin-RevId: 195144612 --- tensorflow/workspace.bzl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 16da59c5cf..f4f7bc4615 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -96,11 +96,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "com_google_absl", urls = [ - "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/720c017e30339fd1786ce4aac68bc8559736e53f.tar.gz", - "https://github.com/abseil/abseil-cpp/archive/720c017e30339fd1786ce4aac68bc8559736e53f.tar.gz", + "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/9613678332c976568272c8f4a78631a29159271d.tar.gz", + "https://github.com/abseil/abseil-cpp/archive/9613678332c976568272c8f4a78631a29159271d.tar.gz", ], - sha256 = "5996380e3e8b981f55d1c8d58e709c00dbb4806ba367be75d0925a68cc2f6478", - strip_prefix = "abseil-cpp-720c017e30339fd1786ce4aac68bc8559736e53f", + sha256 = "1273a1434ced93bc3e703a48c5dced058c95e995c8c009e9bdcb24a69e2180e9", + strip_prefix = "abseil-cpp-9613678332c976568272c8f4a78631a29159271d", build_file = clean_dep("//third_party:com_google_absl.BUILD"), ) @@ -299,11 +299,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "absl_py", urls = [ - "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/acec853355ef987eae48a8d87a79351c15dff593.tar.gz", - "https://github.com/abseil/abseil-py/archive/acec853355ef987eae48a8d87a79351c15dff593.tar.gz", + "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/ea8c4d2ddbf3fba610c4d613260561699b776db8.tar.gz", + "https://github.com/abseil/abseil-py/archive/ea8c4d2ddbf3fba610c4d613260561699b776db8.tar.gz", ], - sha256 = "29e4584e778bee13aa4093824133d131d927cc160561892880118d9ff7b95a6a", - strip_prefix = "abseil-py-acec853355ef987eae48a8d87a79351c15dff593", + sha256 = "c30b48e0d2580ef1412e55c5c0e1dab8db2ee4ab56e2075eccff29c90c7c7059", + strip_prefix = "abseil-py-ea8c4d2ddbf3fba610c4d613260561699b776db8", ) tf_http_archive( -- GitLab From 08fec96547a673084589e1be45f4bde0246f6fdf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 14:58:57 -0700 Subject: [PATCH 044/839] Fix support for batch_normalization with mixed precision When the type of the input tensor `x` is not the same as the type of the parameters `mean`, `variance`, `offset`, and `scale`, a cast is required. This mixed precision case occurs when using the BatchNormalization layer with a data type of float16 or bfloat16. PiperOrigin-RevId: 195157279 --- tensorflow/python/ops/nn_batchnorm_test.py | 20 ++++++++++++++------ tensorflow/python/ops/nn_impl.py | 6 ++++-- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/ops/nn_batchnorm_test.py b/tensorflow/python/ops/nn_batchnorm_test.py index 3ac2c8eb17..1508ff44ce 100644 --- a/tensorflow/python/ops/nn_batchnorm_test.py +++ b/tensorflow/python/ops/nn_batchnorm_test.py @@ -292,12 +292,16 @@ class BatchNormalizationTest(test.TestCase): self.assertAllClose( tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001) - def _testBatchNormArbitraryShapes(self, x_shape, param_shape, atol=0.0001): - x_val = np.random.random_sample(x_shape).astype(np.float32) - m_val = np.random.random_sample(param_shape).astype(np.float32) - v_val = np.random.random_sample(param_shape).astype(np.float32) - beta_val = np.random.random_sample(param_shape).astype(np.float32) - gamma_val = np.random.random_sample(param_shape).astype(np.float32) + def _testBatchNormArbitraryShapes(self, x_shape, param_shape, atol=0.0001, + dtype=dtypes.float32, + param_dtype=dtypes.float32): + numpy_dtype = dtype.as_numpy_dtype + numpy_param_dtype = param_dtype.as_numpy_dtype + x_val = np.random.random_sample(x_shape).astype(numpy_dtype) + m_val = np.random.random_sample(param_shape).astype(numpy_param_dtype) + v_val = np.random.random_sample(param_shape).astype(numpy_param_dtype) + beta_val = np.random.random_sample(param_shape).astype(numpy_param_dtype) + gamma_val = np.random.random_sample(param_shape).astype(numpy_param_dtype) for use_gpu in [True, False]: with self.test_session(use_gpu=use_gpu) as sess: x = constant_op.constant(x_val, name="x") @@ -332,6 +336,10 @@ class BatchNormalizationTest(test.TestCase): self._testBatchNormArbitraryShapes( (2, 3, 2, 4, 5), (1, 1, 1, 4, 5), atol=0.005) + def testBatchNormMixedPrecision(self): + self._testBatchNormArbitraryShapes((3, 3), (1, 3), dtype=dtypes.float16, + param_dtype=dtypes.float32, atol=0.001) + @test_util.with_c_api class SufficientStatisticsTest(test.TestCase): diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 576627e78e..783d485892 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -830,8 +830,10 @@ def batch_normalization(x, inv = math_ops.rsqrt(variance + variance_epsilon) if scale is not None: inv *= scale - return x * inv + ( - offset - mean * inv if offset is not None else -mean * inv) + # Note: tensorflow/contrib/quantize/python/fold_batch_norms.py depends on + # the precise order of ops that are generated by the expression below. + return x * math_ops.cast(inv, x.dtype) + math_ops.cast( + offset - mean * inv if offset is not None else -mean * inv, x.dtype) @tf_export("nn.fused_batch_norm") -- GitLab From d030ea951a477e2e141c13fed42681bcc5e97b4a Mon Sep 17 00:00:00 2001 From: Chris Ying Date: Wed, 2 May 2018 15:03:33 -0700 Subject: [PATCH 045/839] Add steps_per_run to LoggingTensorHook and StepCounterHook and other logging bug fixes. PiperOrigin-RevId: 195158238 --- .../contrib/tpu/python/tpu/tpu_estimator.py | 61 ++++++++---- tensorflow/python/estimator/estimator.py | 21 +++-- .../training/basic_session_run_hooks.py | 16 ++-- .../training/basic_session_run_hooks_test.py | 93 +++++++++++++++++-- ...sorflow.train.-checkpoint-saver-hook.pbtxt | 2 +- 5 files changed, 151 insertions(+), 42 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index eb537b7b6a..534042b42c 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -459,11 +459,9 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): session.run(self._init_ops, options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) - logging.info('Start infeed thread controller') self._infeed_controller = self._create_infeed_controller( name='InfeedController', target=self._run_infeed, args=(session,)) - logging.info('Start outfeed thread controller') self._outfeed_controller = _OpQueueContext( name='OutfeedController', target=self._run_outfeed, args=(session,)) @@ -1553,7 +1551,7 @@ class _OutfeedHostCallHook(session_run_hook.SessionRunHook): class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): - """"Calculate and report the number of examples/sec during training.""" + """Calculate and report global_step/sec and examples/sec during runtime.""" def __init__(self, batch_size, @@ -1569,12 +1567,18 @@ class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): summary_writer=summary_writer) def _log_and_record(self, elapsed_steps, elapsed_time, global_step): - examples_per_sec = self._batch_size * elapsed_steps / elapsed_time + global_step_per_sec = elapsed_steps / elapsed_time + examples_per_sec = self._batch_size * global_step_per_sec if self._summary_writer is not None: + global_step_summary = Summary(value=[ + Summary.Value(tag='global_step/sec', simple_value=global_step_per_sec) + ]) example_summary = Summary(value=[ - Summary.Value(tag='examples_sec', simple_value=examples_per_sec) + Summary.Value(tag='examples/sec', simple_value=examples_per_sec) ]) + self._summary_writer.add_summary(global_step_summary, global_step) self._summary_writer.add_summary(example_summary, global_step) + logging.info('global_step/sec: %g', global_step_per_sec) logging.info('examples/sec: %g', examples_per_sec) @@ -1844,6 +1848,12 @@ class TPUEstimator(estimator_lib.Estimator): # config.model_dir. model_function = self._augment_model_fn(model_fn, batch_axis) + # Overwrite log_step_count_steps to disable TensorLoggingHook and + # StepCounterHook from being created in Estimator. TPUEstimator already + # added equivalent hooks in _augment_model_fn above. + self._log_every_n_steps = config.log_step_count_steps + config = config.replace(log_step_count_steps=None) + # Passing non-None params as wrapped model_fn has it. params = params or {} super(TPUEstimator, self).__init__( @@ -2039,39 +2049,50 @@ class TPUEstimator(estimator_lib.Estimator): host_ops = host_call.create_tpu_hostcall() if host_ops is None: host_ops = [] - shutdown_hooks = [] if os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN', '0') != '0': shutdown_hooks.append(session_support.GracefulShutdownHook()) - - hooks = [ + with ops.control_dependencies([loss]): + global_step = array_ops.identity(training.get_global_step()) + hooks = input_hooks + shutdown_hooks + logging_hook_frequency = ( # Divide and round up + (self._log_every_n_steps + + self._config.tpu_config.iterations_per_loop - 1) // + self._config.tpu_config.iterations_per_loop) + hooks.extend([ TPUInfeedOutfeedSessionHook( ctx, enqueue_ops, host_ops, run_infeed_loop_on_coordinator=( run_infeed_loop_on_coordinator)), - ExamplesPerSecondHook( - ctx.global_batch_size, output_dir=self.model_dir), InstallSignalHandlerHook(), training.LoggingTensorHook( { 'loss': array_ops.identity(loss), - 'step': training.get_global_step() + 'step': global_step, }, - every_n_secs=30) - ] + input_hooks + shutdown_hooks + every_n_iter=logging_hook_frequency) + ]) + examples_hook = ExamplesPerSecondHook( + ctx.global_batch_size, + output_dir=self.model_dir, + every_n_steps=self._log_every_n_steps) + examples_hook._set_steps_per_run( # pylint: disable=protected-access + self._config.tpu_config.iterations_per_loop) + hooks.append(examples_hook) chief_hooks = [] if (self._config.save_checkpoints_secs or self._config.save_checkpoints_steps): - chief_hooks.append( - training.CheckpointSaverHook( - self.model_dir, - save_secs=self._config.save_checkpoints_secs, - save_steps=self._config.save_checkpoints_steps, - steps_per_run=self._config.tpu_config.iterations_per_loop, - scaffold=scaffold)) + checkpoint_hook = training.CheckpointSaverHook( + self.model_dir, + save_secs=self._config.save_checkpoints_secs, + save_steps=self._config.save_checkpoints_steps, + scaffold=scaffold) + checkpoint_hook._set_steps_per_run( # pylint: disable=protected-access + self._config.tpu_config.iterations_per_loop) + chief_hooks.append(checkpoint_hook) summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) with ops.control_dependencies([loss]): update_ops = _sync_variables_ops() diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 3691c99dda..946f093ba7 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -994,15 +994,18 @@ class Estimator(object): summary.scalar('loss', estimator_spec.loss) ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss) worker_hooks.extend(hooks) - worker_hooks.extend([ - training.NanTensorHook(estimator_spec.loss), - training.LoggingTensorHook( - { - 'loss': estimator_spec.loss, - 'step': global_step_tensor - }, - every_n_iter=self._config.log_step_count_steps) - ]) + worker_hooks.append( + training.NanTensorHook(estimator_spec.loss) + ) + if self._config.log_step_count_steps is not None: + worker_hooks.append( + training.LoggingTensorHook( + { + 'loss': estimator_spec.loss, + 'step': global_step_tensor + }, + every_n_iter=self._config.log_step_count_steps) + ) worker_hooks.extend(estimator_spec.training_hooks) if not (estimator_spec.scaffold.saver or diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index d1cc7d8ce3..abcf76a220 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -380,8 +380,7 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): saver=None, checkpoint_basename="model.ckpt", scaffold=None, - listeners=None, - steps_per_run=1): + listeners=None): """Initializes a `CheckpointSaverHook`. Args: @@ -394,9 +393,6 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): listeners: List of `CheckpointSaverListener` subclass instances. Used for callbacks that run immediately before or after this hook saves the checkpoint. - steps_per_run: `int`, number of steps that occur between each invocation - of the hook. Primarily used for TPU workloads which run multiple steps - in a while loop in a single Session.run. Raises: ValueError: One of `save_steps` or `save_secs` should be set. @@ -412,6 +408,9 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook): self._timer = SecondOrStepTimer(every_secs=save_secs, every_steps=save_steps) self._listeners = listeners or [] + self._steps_per_run = 1 + + def _set_steps_per_run(self, steps_per_run): self._steps_per_run = steps_per_run def begin(self): @@ -522,6 +521,10 @@ class StepCounterHook(session_run_hook.SessionRunHook): self._output_dir = output_dir self._last_global_step = None self._global_step_check_count = 0 + self._steps_per_run = 1 + + def _set_steps_per_run(self, steps_per_run): + self._steps_per_run = steps_per_run def begin(self): if self._summary_writer is None and self._output_dir: @@ -547,7 +550,8 @@ class StepCounterHook(session_run_hook.SessionRunHook): _ = run_context stale_global_step = run_values.results - if self._timer.should_trigger_for_step(stale_global_step+1): + if self._timer.should_trigger_for_step( + stale_global_step + self._steps_per_run): # get the real value after train op. global_step = run_context.session.run(self._global_step_tensor) if self._timer.should_trigger_for_step(global_step): diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py index 31898562f8..7344ce2758 100644 --- a/tensorflow/python/training/basic_session_run_hooks_test.py +++ b/tensorflow/python/training/basic_session_run_hooks_test.py @@ -764,8 +764,8 @@ class CheckpointSaverHookMultiStepTest(test.TestCase): hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_steps=2*self.steps_per_run, - scaffold=self.scaffold, - steps_per_run=self.steps_per_run) + scaffold=self.scaffold) + hook._set_steps_per_run(self.steps_per_run) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: @@ -781,8 +781,8 @@ class CheckpointSaverHookMultiStepTest(test.TestCase): hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_steps=2*self.steps_per_run, - scaffold=self.scaffold, - steps_per_run=self.steps_per_run) + scaffold=self.scaffold) + hook._set_steps_per_run(self.steps_per_run) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: @@ -823,8 +823,8 @@ class CheckpointSaverHookMultiStepTest(test.TestCase): hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_steps=2*self.steps_per_run, - scaffold=self.scaffold, - steps_per_run=self.steps_per_run) + scaffold=self.scaffold) + hook._set_steps_per_run(self.steps_per_run) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: @@ -997,6 +997,87 @@ class StepCounterHookTest(test.TestCase): 'global step.*has not been increased') hook.end(sess) + def _setup_steps_per_run_test(self, + every_n_steps, + steps_per_run, + graph, + sess): + variables.get_or_create_global_step() + self.train_op = training_util._increment_global_step(steps_per_run) + self.summary_writer = fake_summary_writer.FakeSummaryWriter( + self.log_dir, graph) + self.hook = basic_session_run_hooks.StepCounterHook( + summary_writer=self.summary_writer, every_n_steps=every_n_steps) + self.hook._set_steps_per_run(steps_per_run) + self.hook.begin() + sess.run(variables_lib.global_variables_initializer()) + self.mon_sess = monitored_session._HookedSession(sess, [self.hook]) + + def test_steps_per_run_less_than_every_n_steps(self): + with ops.Graph().as_default() as g, session_lib.Session() as sess: + self._setup_steps_per_run_test(10, 5, g, sess) + + # Logs at 15, 25 + for _ in range(5): + time.sleep(0.01) + self.mon_sess.run(self.train_op) + + self.hook.end(sess) + self.summary_writer.assert_summaries( + test_case=self, + expected_logdir=self.log_dir, + expected_graph=g, + expected_summaries={}) + self.assertItemsEqual([15, 25], self.summary_writer.summaries.keys()) + for step in [15, 25]: + summary_value = self.summary_writer.summaries[step][0].value[0] + self.assertEqual('global_step/sec', summary_value.tag) + self.assertGreater(summary_value.simple_value, 0) + + def test_steps_per_run_equal_every_n_steps(self): + with ops.Graph().as_default() as g, session_lib.Session() as sess: + self._setup_steps_per_run_test(5, 5, g, sess) + + # Logs at 10, 15, 20, 25 + for _ in range(5): + time.sleep(0.01) + self.mon_sess.run(self.train_op) + + self.hook.end(sess) + self.summary_writer.assert_summaries( + test_case=self, + expected_logdir=self.log_dir, + expected_graph=g, + expected_summaries={}) + self.assertItemsEqual([10, 15, 20, 25], + self.summary_writer.summaries.keys()) + for step in [10, 15, 20, 25]: + summary_value = self.summary_writer.summaries[step][0].value[0] + self.assertEqual('global_step/sec', summary_value.tag) + self.assertGreater(summary_value.simple_value, 0) + + def test_steps_per_run_greater_than_every_n_steps(self): + with ops.Graph().as_default() as g, session_lib.Session() as sess: + self._setup_steps_per_run_test(5, 10, g, sess) + + # Logs at 20, 30, 40, 50 + for _ in range(5): + time.sleep(0.01) + self.mon_sess.run(self.train_op) + + self.hook.end(sess) + self.summary_writer.assert_summaries( + test_case=self, + expected_logdir=self.log_dir, + expected_graph=g, + expected_summaries={}) + self.assertItemsEqual([20, 30, 40, 50], + self.summary_writer.summaries.keys()) + for step in [20, 30, 40, 50]: + summary_value = self.summary_writer.summaries[step][0].value[0] + self.assertEqual('global_step/sec', summary_value.tag) + self.assertGreater(summary_value.simple_value, 0) + class SummarySaverHookTest(test.TestCase): diff --git a/tensorflow/tools/api/golden/tensorflow.train.-checkpoint-saver-hook.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-checkpoint-saver-hook.pbtxt index 327799729c..c3037baa8c 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-checkpoint-saver-hook.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-checkpoint-saver-hook.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\', \'steps_per_run\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\', \'1\'], " + argspec: "args=[\'self\', \'checkpoint_dir\', \'save_secs\', \'save_steps\', \'saver\', \'checkpoint_basename\', \'scaffold\', \'listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'model.ckpt\', \'None\', \'None\'], " } member_method { name: "after_create_session" -- GitLab From 1d92d5037e1cec1a5099234a5568b68c7e675576 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Wed, 2 May 2018 15:05:58 -0700 Subject: [PATCH 046/839] [TF:XLA] Bump open source llvm revision to r331338 PiperOrigin-RevId: 195158710 --- tensorflow/workspace.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index f4f7bc4615..94cac4f8fa 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -452,11 +452,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/068c967842b83d22007eee4515b57e8d9aaccb82.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/068c967842b83d22007eee4515b57e8d9aaccb82.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/a5108a08ceab35886a7df07c86f96aedd3d94bb7.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/a5108a08ceab35886a7df07c86f96aedd3d94bb7.tar.gz", ], - sha256 = "4950432fb5cc68e5bf1f87a30b17dfdc69a5b93dac1e89d5274242d3ce7dae7c", - strip_prefix = "llvm-068c967842b83d22007eee4515b57e8d9aaccb82", + sha256 = "79cae03ebbdfd812bb69c460e1325ca069b5c576f7c7071f8216cf2b0975e36f", + strip_prefix = "llvm-a5108a08ceab35886a7df07c86f96aedd3d94bb7", build_file = clean_dep("//third_party/llvm:llvm.BUILD"), ) -- GitLab From 9180cc254dff42368af126aa68eb82823ef67736 Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Wed, 2 May 2018 15:14:08 -0700 Subject: [PATCH 047/839] [XLA] BF16 propagation: do not change if propagation is confined inside a fusion. We now use a set to track all the potential changes, and do the actual changes on the HLOs at the end. This also makes the boolean return value (whether anything is changed) correct. PiperOrigin-RevId: 195160025 --- .../xla/service/bfloat16_propagation.cc | 421 ++++++++++++------ .../xla/service/bfloat16_propagation.h | 93 +++- .../xla/service/bfloat16_propagation_test.cc | 31 ++ 3 files changed, 387 insertions(+), 158 deletions(-) diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 43ebe92c5e..ed0746980f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -33,7 +33,7 @@ BFloat16Propagation::BFloat16Propagation( const BFloat16Support* bfloat16_support) : bfloat16_support_(bfloat16_support) {} -void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision( +void BFloat16Propagation::DetermineFusionComputationPrecision( HloInstruction* fusion) { CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) { @@ -48,15 +48,13 @@ void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision( auto root = fusion->fused_instructions_computation()->root_instruction(); // Adjust root's element types according to the fusion's output shape. - ShapeUtil::ForEachMutableSubshape( - root->mutable_shape(), [&](Shape* subshape, const ShapeIndex& index) { - if (subshape->element_type() != F32) { + ShapeUtil::ForEachSubshape( + root->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.element_type() != F32) { return; } - if (ShapeUtil::GetSubshape(fusion->shape(), index).element_type() == - BF16) { - subshape->set_element_type(BF16); - changed_ = true; + if (OutputTypeAfterChange(fusion, index) == BF16) { + AddToOrRemoveFromBF16ChangeSet(root, index, BF16); VLOG(2) << "Fused root " << root->ToString() << " at shape index " << index << " changed to BF16 precision for fusion " << fusion->ToString(); @@ -67,13 +65,101 @@ void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision( auto insts = fusion->fused_instructions_computation()->MakeInstructionPostOrder(); for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { - DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); + DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false); } - computations_visited_in_mutation_pass_.insert( + computations_visited_in_backward_pass_.insert( fusion->fused_instructions_computation()); + + RevertIfFusionInternalBF16Changes(fusion); +} + +void BFloat16Propagation::RevertIfFusionInternalBF16Changes( + HloInstruction* fusion) { + auto has_changes = [this](HloInstruction* inst) { + auto it = changes_to_bf16_.find(inst); + return it != changes_to_bf16_.end() && !it->second.empty(); + }; + + auto root = fusion->fused_instructions_computation()->root_instruction(); + tensorflow::gtl::FlatSet changed_root_buffers; + + auto root_changes_it = changes_to_bf16_.find(root); + if (root_changes_it != changes_to_bf16_.end()) { + for (const auto& index : root_changes_it->second) { + for (const HloValue* value : + dataflow_->GetValueSet(root, index).values()) { + changed_root_buffers.insert(value); + } + } + } + + auto aliases_changed_root_buffer = + [this, &changed_root_buffers](const HloInstruction* inst) { + bool aliasing = false; + ShapeUtil::ForEachSubshape( + inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (aliasing) { + // Skip if aliasing is already found. + return; + } + // Only F32 buffers are considered for changing to BF16 in this + // pass. + if (subshape.element_type() != F32) { + return; + } + for (const HloValue* value : + dataflow_->GetValueSet(inst, index).values()) { + if (ContainsKey(changed_root_buffers, value)) { + aliasing = true; + break; + } + } + }); + return aliasing; + }; + + for (auto inst : + fusion->fused_instructions_computation()->MakeInstructionPostOrder()) { + if (inst->opcode() == HloOpcode::kParameter) { + continue; + } + if (aliases_changed_root_buffer(inst)) { + continue; + } + if (inst->opcode() == HloOpcode::kFusion) { + bool parameter_reverted = false; + for (int64 i = 0; i < inst->operand_count(); ++i) { + if (has_changes(inst->mutable_operand(i))) { + // Changes on the operand have not been reverted. + continue; + } + auto* fused_parameter = inst->fused_parameter(i); + if (has_changes(fused_parameter)) { + changes_to_bf16_.erase(fused_parameter); + parameter_reverted = true; + } + } + if (parameter_reverted) { + RevertIfFusionInternalBF16Changes(inst); + } + } + if (!has_changes(inst)) { + continue; + } + bool revert_changes = true; + for (auto operand : inst->operands()) { + if (has_changes(operand)) { + revert_changes = false; + break; + } + } + if (revert_changes) { + changes_to_bf16_.erase(inst); + } + } } -void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision( +void BFloat16Propagation::DetermineWhileComputationsPrecision( HloInstruction* while_hlo) { CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile); @@ -86,16 +172,14 @@ void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision( auto body_root = body->root_instruction(); HloComputation* condition = while_hlo->while_condition(); - ShapeUtil::ForEachMutableSubshape( - body_root->mutable_shape(), - [this, while_hlo, body_root](Shape* subshape, const ShapeIndex& index) { - if (subshape->element_type() != F32) { + ShapeUtil::ForEachSubshape( + body_root->shape(), [this, while_hlo, body_root]( + const Shape& subshape, const ShapeIndex& index) { + if (subshape.element_type() != F32) { return; } - if (ShapeUtil::GetSubshape(while_hlo->shape(), index).element_type() == - BF16) { - subshape->set_element_type(BF16); - changed_ = true; + if (OutputTypeAfterChange(while_hlo, index) == BF16) { + AddToOrRemoveFromBF16ChangeSet(body_root, index, BF16); VLOG(2) << "While body root " << body_root->ToString() << " at shape index " << index << " changed to BF16 precision for while " @@ -106,30 +190,30 @@ void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision( auto body_insts = body->MakeInstructionPostOrder(); for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend(); ++inst_it) { - DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); + DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false); } - computations_visited_in_mutation_pass_.insert(body); + computations_visited_in_backward_pass_.insert(body); auto condition_insts = condition->MakeInstructionPostOrder(); for (auto inst_it = condition_insts.rbegin(); inst_it != condition_insts.rend(); ++inst_it) { - DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); + DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false); } - computations_visited_in_mutation_pass_.insert(condition); + computations_visited_in_backward_pass_.insert(condition); } bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, const ShapeIndex& index) const { - auto value_set = dataflow_->GetValueSet(&hlo, index); + auto& value_set = dataflow_->GetValueSet(&hlo, index); for (const HloValue* value : value_set.values()) { if (ContainsKey(values_that_must_be_kept_as_f32_, value)) { return false; } - if (value->shape().element_type() == BF16) { + if (ValueTypeAfterChange(value) == BF16) { continue; } for (const HloUse& use : value->uses()) { - if (!ContainsKey(instructions_visited_in_mutation_pass_, + if (!ContainsKey(instructions_visited_in_backward_pass_, use.instruction)) { // We don't know yet whether use.instruction will consume BF16 since it // hasn't been visited. Although we visit instructions in reverse @@ -145,26 +229,23 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, // precision, or a called computation's parameters have been changed to // BF16 for fusions or whiles. if (use.instruction->opcode() == HloOpcode::kFusion) { - const auto* fused_parameter = + auto* fused_parameter = use.instruction->fused_parameter(use.operand_number); - if (ShapeUtil::GetSubshape(fused_parameter->shape(), use.operand_index) - .element_type() != BF16) { + if (OutputTypeAfterChange(fused_parameter, use.operand_index) != BF16) { return false; } continue; } else if (use.instruction->opcode() == HloOpcode::kWhile) { - const auto* cond_parameter = + auto* cond_parameter = use.instruction->while_condition()->parameter_instruction( use.operand_number); - if (ShapeUtil::GetSubshape(cond_parameter->shape(), use.operand_index) - .element_type() != BF16) { + if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) { return false; } - const auto* body_parameter = + auto* body_parameter = use.instruction->while_body()->parameter_instruction( use.operand_number); - if (ShapeUtil::GetSubshape(body_parameter->shape(), use.operand_index) - .element_type() != BF16) { + if (OutputTypeAfterChange(body_parameter, use.operand_index) != BF16) { return false; } continue; @@ -174,19 +255,20 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, continue; } // If the op propagates precision and it outputs a BF16, then it's OK to - // supply BF16 also as the input. In the backward mutation pass, the users - // shapes should have already been processed. + // supply BF16 also as the input. In the backward pass, the users shapes + // should have already been processed. PrimitiveType user_output_type = PRIMITIVE_TYPE_INVALID; if (use.instruction->opcode() == HloOpcode::kTuple || (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && ShapeUtil::IsTuple(use.instruction->shape()))) { - user_output_type = ShapeUtil::GetSubshape( - ShapeUtil::GetSubshape(use.instruction->shape(), - {use.operand_number}), - use.operand_index) - .element_type(); + ShapeIndex use_output_index{use.operand_number}; + for (int64 i : use.operand_index) { + use_output_index.push_back(i); + } + user_output_type = + OutputTypeAfterChange(use.instruction, use_output_index); } else { - user_output_type = use.instruction->shape().element_type(); + user_output_type = OutputTypeAfterChange(use.instruction, {}); } if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision( *use.instruction, use.operand_number) && @@ -199,8 +281,8 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, return true; } -void BFloat16Propagation::DetermineAndMutateInstructionPrecision( - HloInstruction* hlo, bool skip_parameters) { +void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, + bool skip_parameters) { // We handle any fusion computation or while body/condition after the // instruction is handled, because we need to know the output shape of a // fusion or while before propagating inside its computations. @@ -209,12 +291,12 @@ void BFloat16Propagation::DetermineAndMutateInstructionPrecision( [this, hlo, &postpone_processing_called_computations] { if (!postpone_processing_called_computations) { if (hlo->opcode() == HloOpcode::kFusion) { - DetermineAndMutateFusionComputationPrecision(hlo); + DetermineFusionComputationPrecision(hlo); } else if (hlo->opcode() == HloOpcode::kWhile) { - DetermineAndMutateWhileComputationsPrecision(hlo); + DetermineWhileComputationsPrecision(hlo); } } - instructions_visited_in_mutation_pass_.insert(hlo); + instructions_visited_in_backward_pass_.insert(hlo); }); if (hlo->opcode() == HloOpcode::kWhile && @@ -245,9 +327,9 @@ void BFloat16Propagation::DetermineAndMutateInstructionPrecision( CHECK(hlo->parent() != nullptr); if (hlo == hlo->parent()->root_instruction()) { if (!hlo->parent()->IsFusionComputation()) { - ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& subshape, + ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& /* subshape */, const ShapeIndex& index) { - if (subshape.element_type() != F32) { + if (OutputTypeAfterChange(hlo, index) != F32) { return; } for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { @@ -269,13 +351,12 @@ void BFloat16Propagation::DetermineAndMutateInstructionPrecision( return; } - ShapeUtil::ForEachMutableSubshape( - hlo->mutable_shape(), - [hlo, this](Shape* subshape, const ShapeIndex& index) { - if (subshape->element_type() == F32 && + ShapeUtil::ForEachSubshape( + hlo->shape(), + [hlo, this](const Shape& /* subshape */, const ShapeIndex& index) { + if (OutputTypeAfterChange(hlo, index) == F32 && AllUsersConsumeBF16(*hlo, index)) { - subshape->set_element_type(BF16); - changed_ = true; + AddToOrRemoveFromBF16ChangeSet(hlo, index, BF16); VLOG(2) << "HloInstruction output at shape index " << index << " changed to BF16 precision: " << hlo->ToString(); } @@ -308,26 +389,24 @@ void BFloat16Propagation::AdjustCalledComputationParameters( CHECK_EQ(operands.size(), computation->num_parameters()); for (int64 i = 0; i < operands.size(); ++i) { auto parameter = computation->parameter_instruction(i); - ShapeUtil::ForEachMutableSubshape( - parameter->mutable_shape(), - [this, i, hlo, &operands, parameter](Shape* subshape, + ShapeUtil::ForEachSubshape( + parameter->shape(), + [this, i, hlo, &operands, parameter](const Shape& /* subshape */, const ShapeIndex& index) { if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) { return; } PrimitiveType operand_type = - ShapeUtil::GetSubshape(operands[i]->shape(), index) - .element_type(); - if (subshape->element_type() == operand_type) { + OutputTypeAfterChange(operands[i], index); + if (OutputTypeAfterChange(parameter, index) == operand_type) { return; } - CHECK(operand_type == F32 || operand_type == BF16); - subshape->set_element_type(operand_type); - changed_ = true; + AddToOrRemoveFromBF16ChangeSet(parameter, index, operand_type); VLOG(2) << "Called computation parameter " << parameter->ToString() << " at shape index " << index - << " adjusted to match operand in HLO " - << hlo->ToString(); + << " adjusted to " + << (operand_type == BF16 ? "BF16" : "F32") + << " to match operand in HLO " << hlo->ToString(); }); } }; @@ -348,51 +427,48 @@ void BFloat16Propagation::AdjustCalledComputationParameters( void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) { auto adjust_computation = [this, hlo](HloComputation* computation, - const Shape& output_shape) { + HloInstruction* output) { // Adjust root. HloInstruction* root = computation->root_instruction(); - ShapeUtil::ForEachMutableSubshape( - root->mutable_shape(), [this, hlo, root, &output_shape]( - Shape* subshape, const ShapeIndex& index) { - if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) { - return; - } - const PrimitiveType output_type = - ShapeUtil::GetSubshape(output_shape, index).element_type(); - if (subshape->element_type() == output_type) { - return; - } - CHECK(output_type == F32 || output_type == BF16); - subshape->set_element_type(output_type); - // It's possible that output_type is F32, but the root instruction's - // type is BF16; e.g., a fusion node's output was changed to BF16 - // initially but then adjusted back to F32, and the fusion computation - // is now being adjusted after the fusion node. - if (output_type == F32) { - for (const auto* value : - dataflow_->GetValueSet(root, index).values()) { - // We rely on the fact that this adjustment works in reverse - // topological order so that called computation will be - // processed later. Adding the value to - // values_that_must_be_kept_as_f32_ will ensure the - // correctness of the adjustment for HLOs that will be - // processed later. - values_that_must_be_kept_as_f32_.insert(value); - } - } - changed_ = true; - VLOG(2) << "Called computation root " << root->ToString() - << " at shape index " << index - << " adjusted to match output shape of " << hlo->ToString(); - }); + ShapeUtil::ForEachSubshape(root->shape(), [this, hlo, root, output]( + const Shape& /* subshape */, + const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) { + return; + } + const PrimitiveType output_type = OutputTypeAfterChange(output, index); + if (OutputTypeAfterChange(root, index) == output_type) { + return; + } + AddToOrRemoveFromBF16ChangeSet(root, index, output_type); + // It's possible that output_type is F32, but the root instruction's + // type is BF16; e.g., a fusion node's output was changed to BF16 + // initially but then adjusted back to F32, and the fusion computation + // is now being adjusted after the fusion node. + if (output_type == F32) { + for (const auto* value : dataflow_->GetValueSet(root, index).values()) { + // We rely on the fact that this adjustment works in reverse + // topological order so that called computation will be + // processed later. Adding the value to + // values_that_must_be_kept_as_f32_ will ensure the + // correctness of the adjustment for HLOs that will be + // processed later. + values_that_must_be_kept_as_f32_.insert(value); + } + } + VLOG(2) << "Called computation root " << root->ToString() + << " at shape index " << index << " adjusted to " + << (output_type == BF16 ? "BF16" : "F32") + << " to match output shape of " << hlo->ToString(); + }); }; switch (hlo->opcode()) { case HloOpcode::kFusion: - adjust_computation(hlo->fused_instructions_computation(), hlo->shape()); + adjust_computation(hlo->fused_instructions_computation(), hlo); break; case HloOpcode::kWhile: - adjust_computation(hlo->while_body(), hlo->shape()); + adjust_computation(hlo->while_body(), hlo); break; default: break; @@ -409,16 +485,19 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { auto hlo = *inst_it; auto adjust_hlo_output = [this, hlo, ¶meter_changed]( - Shape* subshape, const ShapeIndex& index) { - if (subshape->element_type() != F32 && subshape->element_type() != BF16) { + const Shape& /* subshape */, + const ShapeIndex& index) { + auto output_type = OutputTypeAfterChange(hlo, index); + if (output_type != F32 && output_type != BF16) { return; } PrimitiveType type = BF16; for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { - if (value->shape().element_type() == BF16) { + auto value_type = ValueTypeAfterChange(value); + if (value_type == BF16) { continue; } - CHECK_EQ(value->shape().element_type(), F32); + CHECK_EQ(value_type, F32); type = F32; break; } @@ -437,16 +516,17 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( values_that_must_be_kept_as_f32_.insert(value); } } - if (type != subshape->element_type()) { - subshape->set_element_type(type); + if (type != output_type) { + AddToOrRemoveFromBF16ChangeSet(hlo, index, type); VLOG(2) << "HloInstruction output at shape index " << index - << " adjusted to " << *subshape << ": " << hlo->ToString(); + << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": " + << hlo->ToString(); if (hlo->opcode() == HloOpcode::kParameter) { parameter_changed = true; } } }; - ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), adjust_hlo_output); + ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output); AdjustCalledComputationRoot(hlo); if (hlo->opcode() == HloOpcode::kWhile) { // We need to run on the while body and condition repeatedly until a fixed @@ -463,8 +543,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(), &visited_in_while)) { visited_in_while.clear(); - ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), - adjust_hlo_output); + ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output); AdjustCalledComputationRoot(hlo); } visited_computations->insert(visited_in_while.begin(), @@ -478,7 +557,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( return parameter_changed; } -Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( +void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( HloModule* module) { std::list computations_topological_order = module->MakeComputationPostOrder(); @@ -490,7 +569,9 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( } ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved); } +} +Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) { // We could have changed a fusion computation's root shape to have a different // precision than the fusion node's output, if the fusion root does not // define a buffer (e.g., a tuple). Now we add conversions after such fusion @@ -517,7 +598,7 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( // (2) after adding conversion // (3) after tuple simplifier and DCE. bool needs_tuple_simplifier = false; - for (auto computation : computations_topological_order) { + for (auto computation : module->MakeComputationPostOrder()) { auto insts = computation->MakeInstructionPostOrder(); for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { auto hlo = *inst_it; @@ -587,7 +668,14 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape()); } } + if (needs_tuple_simplifier) { + TupleSimplifier tuple_simplifier; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + } + return Status::OK(); +} +Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) { // We may have converted some constants from F32 to BF16, so adjust the // constant literals in such cases. We do this here instead of when the // constant node's is changed because 1) the HloInstruction interface does not @@ -598,8 +686,7 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( // can avoid repeated conversions. // // TODO(b/73833576): Consider resetting literal in HloInstruction. - bool needs_dce = needs_tuple_simplifier; - for (auto computation : computations_topological_order) { + for (auto computation : module->MakeComputationPostOrder()) { for (auto hlo : computation->MakeInstructionPostOrder()) { if (hlo->opcode() != HloOpcode::kConstant) { continue; @@ -612,23 +699,13 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( auto new_constant = computation->AddInstruction( HloInstruction::CreateConstant(std::move(converted_literal))); TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); - needs_dce = true; } } } - - if (needs_tuple_simplifier) { - TupleSimplifier tuple_simplifier; - TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); - } - if (needs_dce) { - HloDCE dce; - TF_RETURN_IF_ERROR(dce.Run(module).status()); - } return Status::OK(); } -Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) { +Status BFloat16Propagation::SkipNoopConversions(HloModule* module) { for (auto computation : module->computations()) { for (auto hlo : computation->MakeInstructionPostOrder()) { if (hlo->opcode() != HloOpcode::kConvert) { @@ -643,7 +720,6 @@ Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) { if (is_root) { computation->set_root_instruction(source); } - TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(hlo)); } } return Status::OK(); @@ -652,8 +728,18 @@ Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) { // The algorithm first does a forward pass (parameters to root) to determine a // set of instructions to consider using bfloat16, then does a backward pass to // determine the precisions of those instructions according to the need of -// their users. +// their users. During the backward pass, the potential changes are stored in +// changes_to_bf16_ which are subject to further adjustments then applied to the +// HLOs. StatusOr BFloat16Propagation::Run(HloModule* module) { + consider_using_bfloat16_.clear(); + instructions_visited_in_backward_pass_.clear(); + computations_visited_in_backward_pass_.clear(); + values_that_must_be_kept_as_f32_.clear(); + caller_counts_.clear(); + changes_to_bf16_.clear(); + changed_ = false; + TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module)); std::list computations_topological_order = @@ -686,8 +772,24 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { } auto insts = (*comp_it)->MakeInstructionPostOrder(); for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { - DetermineAndMutateInstructionPrecision(*inst_it, - /*skip_parameters=*/true); + DetermineInstructionPrecision(*inst_it, + /*skip_parameters=*/true); + } + } + + // It's possible that an instruction does not define a buffer, but the + // defining instruction's shape has changed. So we need to adjust the output + // shapes of instructions according to the HLO values they refer to. + ResolveInconsistencyOfAliasingBuffers(module); + + // Apply the changes in changes_to_bf16_. + for (auto& change : changes_to_bf16_) { + auto shape = change.first->mutable_shape(); + for (const auto& index : change.second) { + auto subshape = ShapeUtil::GetMutableSubshape(shape, index); + CHECK_EQ(subshape->element_type(), F32); + subshape->set_element_type(BF16); + changed_ = true; } } @@ -695,15 +797,56 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { return false; } - // It's possible that an instruction does not define a buffer, but the - // defining instruction's shape has changed. So we need to adjust the output - // shapes of instructions according to the HLO values they refer to. - TF_RETURN_IF_ERROR(ResolveInconsistencyOfAliasingBuffers(module)); + TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module)); + TF_RETURN_IF_ERROR(ResolveConvertedConstants(module)); // This pass could have turned an F32 -> BF16 conversion to a no-op (BF16 -> - // BF16), so we remove them now. - TF_RETURN_IF_ERROR(RemoveNoopConversions(module)); + // BF16), so we skip them now. + TF_RETURN_IF_ERROR(SkipNoopConversions(module)); + + { + // We may have dead HLOs after ResolveInconsistentFusions, + // ResolveConvertedConstants and SkipNoopConversions. + HloDCE dce; + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } return true; } +PrimitiveType BFloat16Propagation::OutputTypeAfterChange( + HloInstruction* hlo, const ShapeIndex& index) const { + PrimitiveType type_on_hlo = + ShapeUtil::GetSubshape(hlo->shape(), index).element_type(); + if (type_on_hlo != F32) { + return type_on_hlo; + } + auto it = changes_to_bf16_.find(hlo); + if (it == changes_to_bf16_.end()) { + return type_on_hlo; + } + return ContainsKey(it->second, index) ? BF16 : F32; +} + +PrimitiveType BFloat16Propagation::ValueTypeAfterChange( + const HloValue* value) const { + auto hlo = value->defining_instruction(); + const auto& position = value->defining_position(); + return OutputTypeAfterChange(hlo, position.index); +} + +void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet( + HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) { + if (target_type == BF16) { + auto& entry = changes_to_bf16_[hlo]; + entry.insert(index); + } else { + CHECK_EQ(target_type, F32); + auto it = changes_to_bf16_.find(hlo); + if (it == changes_to_bf16_.end()) { + return; + } + it->second.erase(index); + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 1744e9db90..de0355ddfc 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/hash/hash.h" namespace xla { @@ -85,30 +86,39 @@ class BFloat16Propagation : public HloPassInterface { tensorflow::gtl::FlatSet consider_using_bfloat16_; // *************************** - // Functions called and state produced by the backward mutation pass (from - // root to parameters). + // Functions called and state produced by the backward pass (from root to + // parameters) that finds opportunities to use BF16. - // Determines the precision for the given instruction in the mutation pass. - void DetermineAndMutateInstructionPrecision(HloInstruction* hlo, - bool skip_parameters); + // Determines the precision for the given instruction in the + // opportunity-finding pass. + void DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters); - // Special handling in the mutation pass for fusion computations. + // Special handling in the opportunity-finding pass for fusion computations. // // Precondition: hlo->opcode() == kFusion - void DetermineAndMutateFusionComputationPrecision(HloInstruction* fusion); + void DetermineFusionComputationPrecision(HloInstruction* fusion); - // Special handling in the mutation pass for while computations. + // Reverts changes to BF16 that will not propagate outside a fusion + // computation. This avoids BF16 casts overhead inside a fusion which won't + // save memory bandwidth. + // + // Precondition: hlo->opcode() == kFusion + void RevertIfFusionInternalBF16Changes(HloInstruction* fusion); + + // Special handling in the opportunity-finding pass for while computations. // // Precondition: hlo->opcode() == kWhile - void DetermineAndMutateWhileComputationsPrecision(HloInstruction* while_hlo); + void DetermineWhileComputationsPrecision(HloInstruction* while_hlo); - // The set of HloInstructions that have been visited in the mutation pass. + // The set of HloInstructions that have been visited in the + // opportunity-finding pass. tensorflow::gtl::FlatSet - instructions_visited_in_mutation_pass_; + instructions_visited_in_backward_pass_; - // The set of HloComputations that have been visited in the mutation pass. + // The set of HloComputations that have been visited in the + // opportunity-finding pass. tensorflow::gtl::FlatSet - computations_visited_in_mutation_pass_; + computations_visited_in_backward_pass_; // *************************** // Functions called by the final inconsistency resolving pass. @@ -116,7 +126,7 @@ class BFloat16Propagation : public HloPassInterface { // Adjusts the output shapes of HloInstructions such that if two // HloInstructions have aliasing buffers in their outputs, they must have the // same precision. - Status ResolveInconsistencyOfAliasingBuffers(HloModule* module); + void ResolveInconsistencyOfAliasingBuffers(HloModule* module); // Resolves inconsistency of aliasing buffers for the given computation, and // recursively runs on a while instruction's condition and body until a fixed @@ -134,9 +144,19 @@ class BFloat16Propagation : public HloPassInterface { void AdjustCalledComputationRoot(HloInstruction* hlo); // *************************** - // Removes no-op conversions (same source and target shapes) that can be - // produced this pass. - Status RemoveNoopConversions(HloModule* module); + // Functions called after changes in changes_to_bf16_ are applied. + + // Resolves inconsistencies introduced by this pass for fusions with + // tuple-type output. + Status ResolveInconsistentFusions(HloModule* module); + + // Converts the literals in kConstant HLOs which have their types changed to + // BF16 by this pass. + Status ResolveConvertedConstants(HloModule* module); + + // Skips no-op conversions (same source and target shapes) that can be + // produced this pass, i.e., replaces them in their uses with their operands. + Status SkipNoopConversions(HloModule* module); // *************************** // Functions called and state used by two or more passes. @@ -146,6 +166,23 @@ class BFloat16Propagation : public HloPassInterface { bool AllUsersConsumeBF16(const HloInstruction& hlo, const ShapeIndex& index) const; + // The output element type of the HLO at the given shape index after changes + // in changes_to_bf16_ are applied. + PrimitiveType OutputTypeAfterChange(HloInstruction* hlo, + const ShapeIndex& index) const; + + // The element type of the HLO value after changes in changes_to_bf16_ are + // applied. + PrimitiveType ValueTypeAfterChange(const HloValue* value) const; + + // If target_type == BF16, adds the HLO at the given index to + // changes_to_bf16_; otherwise, target_type must be F32 and this function + // removes the HLO at the given index from changes_to_bf16_ if it was earlier + // added. + void AddToOrRemoveFromBF16ChangeSet(HloInstruction* hlo, + const ShapeIndex& index, + PrimitiveType target_type); + // The set of F32 HLO values that must be kept in F32. tensorflow::gtl::FlatSet values_that_must_be_kept_as_f32_; @@ -153,10 +190,28 @@ class BFloat16Propagation : public HloPassInterface { // module. Populated at the beginning of this pass. tensorflow::gtl::FlatMap caller_counts_; + // We first store the potential F32-to-BF16 changes to changes_to_bf16_, which + // are subject to further adjustment, then finally applied to the HLOs. This + // avoids setting changed_ to true but all changes are reverted during + // adjustment. + struct IndexHasher { + int64 operator()(const ShapeIndex& index) const { + int64 hash = 0; + for (int64 i : index) { + hash = tensorflow::Hash64Combine(hash, std::hash()(i)); + } + return hash; + } + }; + tensorflow::gtl::FlatMap> + changes_to_bf16_; + + // Whether the last processed HLO module has been changed by this pass. + bool changed_ = false; + const BFloat16Support* bfloat16_support_; std::unique_ptr dataflow_; - - bool changed_ = false; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 183db1652e..313910a861 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -323,6 +323,37 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { EXPECT_TRUE(OutputsBF16(b_f1)); } +// Tests that changes to BF16 that cannot be propagated outside a fusion are +// discarded. +TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + + auto builder_f = HloComputation::Builder("fusion"); + HloInstruction* a_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add_f = builder_f.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f)); + HloInstruction* dot_f = builder_f.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add_f, add_f)); + auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_EQ(computation->root_instruction(), fusion); +} + // Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion // outputs are only used by a dot, and 3) one element of the tuple is used by // an add in the fusion computation, then the propagation pass should create a -- GitLab From 4704ae7af1918755d72f159f49d98d35da6eb6fa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 15:21:17 -0700 Subject: [PATCH 048/839] Optimize LogicalOr and LogicalAnd with all true or false inputs: LogicalOr(x, true) = true LogicalOr(x, false) = x LogicalAnd(x, true) = x LogicalAnd(x, false) = false and similar if the first argument is constant. PiperOrigin-RevId: 195161140 --- .../grappler/optimizers/constant_folding.cc | 113 +++++++++++------- .../grappler/optimizers/constant_folding.h | 13 +- .../optimizers/constant_folding_test.cc | 50 ++++++-- .../feature_column/feature_column_test.py | 20 +++- 4 files changed, 132 insertions(+), 64 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 4801f18619..47d8827686 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -866,6 +866,25 @@ Status CreateConstantTensorAttrValue(DataType type, double value, } #undef SET_TENSOR_CAL_CASE + +DataType GetDataTypeFromNodeOrProps(const NodeDef& node, + const GraphProperties& properties) { + DataType dtype = DT_INVALID; + if (node.attr().count("T") == 1) { + dtype = node.attr().at("T").type(); + } else if (node.attr().count("dtype") == 1) { + dtype = node.attr().at("dtype").type(); + } else if (IsLogicalOr(node) || IsLogicalAnd(node)) { + dtype = DT_BOOL; + } else { + auto output_props = properties.GetOutputProperties(node.name()); + if (!output_props.empty()) { + dtype = output_props[0].dtype(); + } + } + return dtype; +} + } // namespace // static @@ -1412,6 +1431,7 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const { } const auto dtype = node.attr().at("dtype").type(); switch (dtype) { + IS_ONES_CASE(DT_BOOL); IS_ONES_CASE(DT_HALF); IS_ONES_CASE(DT_BFLOAT16); IS_ONES_CASE(DT_FLOAT); @@ -1447,6 +1467,7 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const { } const auto dtype = node.attr().at("dtype").type(); switch (dtype) { + IS_ZEROS_CASE(DT_BOOL); IS_ZEROS_CASE(DT_HALF); IS_ZEROS_CASE(DT_BFLOAT16); IS_ZEROS_CASE(DT_FLOAT); @@ -1466,14 +1487,15 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const { return false; } -void ConstantFolding::ReplaceOperationWithIdentity(int input_to_forward, - NodeDef* node, - GraphDef* graph) { +void ConstantFolding::ReplaceOperationWithIdentity( + int input_to_forward, const GraphProperties& properties, NodeDef* node, + GraphDef* graph) { + const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties); + if (dtype == DT_INVALID) return; + node->set_op("Identity"); - DataType dtype = node->attr().at("T").type(); node->clear_attr(); (*node->mutable_attr())["T"].set_type(dtype); - // Propagate the designated input through the identity. node->mutable_input()->SwapElements(0, input_to_forward); // Add all other inputs as control dependencies. @@ -1489,14 +1511,15 @@ void ConstantFolding::ReplaceOperationWithIdentity(int input_to_forward, graph_modified_ = true; } -void ConstantFolding::ReplaceOperationWithSnapshot(int input_to_forward, - NodeDef* node, - GraphDef* graph) { +void ConstantFolding::ReplaceOperationWithSnapshot( + int input_to_forward, const GraphProperties& properties, NodeDef* node, + GraphDef* graph) { + const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties); + if (dtype == DT_INVALID) return; + node->set_op("Snapshot"); - DataType dtype = node->attr().at("T").type(); node->clear_attr(); (*node->mutable_attr())["T"].set_type(dtype); - // Propagate the designated input through the Snapshot. node->mutable_input()->SwapElements(0, input_to_forward); // Add all other inputs as control dependencies. @@ -1535,15 +1558,18 @@ void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node, } Status ConstantFolding::ReplaceOperationWithConstant( - double value, const AttrValue& dtype_attr, const TensorShapeProto& shape, - NodeDef* node, GraphDef* graph) { + double value, const GraphProperties& properties, + const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) { + const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties); + if (dtype == DT_INVALID) return Status::OK(); + AttrValue tensor_attr; - TF_RETURN_IF_ERROR(CreateConstantTensorAttrValue(dtype_attr.type(), value, - shape, &tensor_attr)); + TF_RETURN_IF_ERROR( + CreateConstantTensorAttrValue(dtype, value, shape, &tensor_attr)); + node->set_op("Const"); node->clear_attr(); - node->mutable_attr()->insert({"dtype", dtype_attr}); + (*node->mutable_attr())["dtype"].set_type(dtype); node->mutable_attr()->insert({"value", tensor_attr}); - node->set_op("Const"); // Convert all inputs to control dependencies. for (int i = 0; i < node->input_size(); ++i) { if (IsControlInput(node->input(i))) { @@ -1566,12 +1592,12 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, NodeDef* node = optimized_graph->mutable_node(i); if (IsSplit(*node) && node->attr().at("num_split").i() == 1) { - ReplaceOperationWithIdentity(1, node, optimized_graph); + ReplaceOperationWithIdentity(1, *properties, node, optimized_graph); continue; } if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } @@ -1611,7 +1637,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, replaceable &= shape.dim(j).size() == 1 || j == permutation[j]; } if (replaceable) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } } @@ -1626,7 +1652,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, // unknown_rank == false && (dim_size == 0 || first dim is of size 1) if (!shape.unknown_rank() && (shape.dim_size() == 0 || shape.dim(0).size() == 1)) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } } @@ -1651,11 +1677,11 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, for (int j = 0; j < axis.NumElements(); ++j) { // value of axis can be negative. if (axis.dtype() == DT_INT64) { - target_axes.insert( - (axis.vec()(j) + shape.dim_size()) % shape.dim_size()); + target_axes.insert((axis.vec()(j) + shape.dim_size()) % + shape.dim_size()); } else { - target_axes.insert( - (axis.vec()(j) + shape.dim_size()) % shape.dim_size()); + target_axes.insert((axis.vec()(j) + shape.dim_size()) % + shape.dim_size()); } } @@ -1669,7 +1695,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, target_axes.find(j) == target_axes.end(); } if (replaceable) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } } @@ -1711,7 +1737,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, } } if (replaceable) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } } @@ -1740,7 +1766,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, } } if (replaceable) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } } @@ -1764,7 +1790,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, replaceable &= flatten(j) == 0; } if (replaceable) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } } @@ -1784,7 +1810,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, replaceable &= shape.dim(j).size() > 1; } if (replaceable) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } } @@ -1996,9 +2022,9 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, continue; } - const bool is_mul = IsMul(*node); + const bool is_mul = IsMul(*node) || IsLogicalAnd(*node); const bool is_matmul = IsMatMul(*node); - const bool is_add = IsAdd(*node) || IsBiasAdd(*node); + const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node); const bool is_sub = IsSub(*node); const bool is_any_div = IsAnyDiv(*node); // Simplify arithmetic operations with ones or zeros. @@ -2025,7 +2051,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, if (y_matches_output_shape && ((is_mul && x_is_one) || (is_add && x_is_zero))) { // 1 * y = y or 0 + y = y. - ReplaceOperationWithSnapshot(1, node, optimized_graph); + ReplaceOperationWithSnapshot(1, *properties, node, optimized_graph); continue; } @@ -2052,10 +2078,18 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) || ((is_add || is_sub) && y_is_zero))) { // x * 1 = x or x / 1 = x or x +/- 0 = x - ReplaceOperationWithSnapshot(0, node, optimized_graph); + ReplaceOperationWithSnapshot(0, *properties, node, optimized_graph); continue; } + // x OR true = true OR y = true. + const PartialTensorShape shp(output_shape); + if (shp.IsFullyDefined() && IsLogicalOr(*node) && + (y_is_one || x_is_one)) { + TF_RETURN_IF_ERROR(ReplaceOperationWithConstant( + 1, *properties, output_shape, node, optimized_graph)); + } + // Simplify multiplication and matmul by zeros. // Also optimize zeros divided by a tensor, but only if we are in // aggressive mode, since we might get rid of divisions by zero. @@ -2063,26 +2097,19 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, is_any_div && x_is_zero && is_aggressive; if ((x_is_zero || y_is_zero) && (is_mul || is_matmul || optimize_zeros_divided_by_y)) { - const PartialTensorShape shp(output_shape); if (shp.IsFullyDefined()) { - AttrValue dtype_attr; - if (node->op() == "SparseMatMul") { - dtype_attr.set_type(DT_FLOAT); - } else { - dtype_attr = node->attr().at("T"); - } TF_RETURN_IF_ERROR(ReplaceOperationWithConstant( - 0, dtype_attr, output_shape, node, optimized_graph)); + 0, *properties, output_shape, node, optimized_graph)); continue; } // Even if an input shape is only partially known, we may known that it // matches the output shape and thus forward the corresponding zero // input. if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) { - ReplaceOperationWithIdentity(0, node, optimized_graph); + ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); continue; } else if (is_mul && y_is_zero && y_matches_output_shape) { - ReplaceOperationWithIdentity(1, node, optimized_graph); + ReplaceOperationWithIdentity(1, *properties, node, optimized_graph); continue; } } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index eb06cd081f..a694f1721a 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -78,12 +78,15 @@ class ConstantFolding : public GraphOptimizer { bool IsOnes(const NodeDef& node) const; bool IsZeros(const NodeDef& node) const; - void ReplaceOperationWithIdentity(int input_to_forward, NodeDef* node, - GraphDef* graph); - void ReplaceOperationWithSnapshot(int input_to_forward, NodeDef* node, - GraphDef* graph); + void ReplaceOperationWithIdentity(int input_to_forward, + const GraphProperties& properties, + NodeDef* node, GraphDef* graph); + void ReplaceOperationWithSnapshot(int input_to_forward, + const GraphProperties& properties, + NodeDef* node, GraphDef* graph); void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph); - Status ReplaceOperationWithConstant(double value, const AttrValue& dtype_attr, + Status ReplaceOperationWithConstant(double value, + const GraphProperties& properties, const TensorShapeProto& shape, NodeDef* node, GraphDef* graph); void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 306ddd22d7..f018b217e6 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -47,18 +47,30 @@ class ConstantFoldingTest : public GrapplerTest { } Output zeros = ops::Const(s.WithOpName("zeros"), zeros_t); Output ones = ops::Const(s.WithOpName("ones"), ones_t); - Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros); - Output mul2 = ops::Mul(s.WithOpName("mul2"), x, ones); - + Output mul1; + Output mul2; + Output add1; + Output add2; + if (DTYPE == DT_BOOL) { + mul1 = ops::LogicalAnd(s.WithOpName("mul1"), x, zeros); + mul2 = ops::LogicalAnd(s.WithOpName("mul2"), x, ones); + add1 = ops::LogicalOr(s.WithOpName("add1"), x, zeros); + add2 = ops::LogicalOr(s.WithOpName("add2"), x, ones); + } else { + mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros); + mul2 = ops::Mul(s.WithOpName("mul2"), x, ones); + add1 = ops::Add(s.WithOpName("add1"), x, zeros); + add1 = ops::Add(s.WithOpName("add2"), x, ones); + } GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch = {"mul1", "mul2"}; + item.fetch = {"mul1", "mul2", "add1", "add2"}; ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - LOG(INFO) << output.DebugString(); - EXPECT_EQ(5, output.node_size()); + + EXPECT_EQ(7, output.node_size()); for (int i = 0; i < output.node_size(); ++i) { const NodeDef& node = output.node(i); const string& name = node.name(); @@ -70,14 +82,27 @@ class ConstantFoldingTest : public GrapplerTest { EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^ones", node.input(1)); + } else if (name == "add1") { + EXPECT_EQ("Snapshot", node.op()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^zeros", node.input(1)); + } else if (name == "add2") { + if (DTYPE == DT_BOOL) { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("^x", node.input(0)); + EXPECT_EQ("^ones", node.input(1)); + } else { + EXPECT_EQ("Add", node.op()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("ones", node.input(1)); + } } } - auto tensors_expected = - EvaluateNodes(item.graph, {"mul1", "mul2"}, {{"x", x_t}}); - auto tensors = EvaluateNodes(output, {"mul1", "mul2"}, {{"x", x_t}}); - EXPECT_EQ(2, tensors_expected.size()); - EXPECT_EQ(2, tensors.size()); - for (int i = 0; i < 2; ++i) { + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}}); + auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}}); + EXPECT_EQ(4, tensors_expected.size()); + EXPECT_EQ(4, tensors.size()); + for (int i = 0; i < item.fetch.size(); ++i) { test::ExpectTensorEqual(tensors_expected[i], tensors[i]); } } @@ -393,6 +418,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) { } TEST_F(ConstantFoldingTest, NeutralElement_ShortFloats) { + SimpleNeutralElementTest(); SimpleNeutralElementTest(); SimpleNeutralElementTest(); } diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index d963dd9b55..b06540489f 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -25,6 +25,8 @@ import numpy as np from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -54,8 +56,8 @@ from tensorflow.python.training import coordinator from tensorflow.python.training import queue_runner_impl -def _initialized_session(): - sess = session.Session() +def _initialized_session(config=None): + sess = session.Session(config=config) sess.run(variables_lib.global_variables_initializer()) sess.run(lookup_ops.tables_initializer()) return sess @@ -6191,7 +6193,12 @@ class WeightedCategoricalColumnTest(test.TestCase): 'values': ((.5,), (1.,)) }, (column,), sparse_combiner='mean') - with _initialized_session(): + # Disabling the constant folding optimizer here since it changes the + # error message differently on CPU and GPU. + config = config_pb2.ConfigProto() + config.graph_options.rewrite_options.constant_folding = ( + rewriter_config_pb2.RewriterConfig.OFF) + with _initialized_session(config): with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'): predictions.eval() @@ -6284,7 +6291,12 @@ class WeightedCategoricalColumnTest(test.TestCase): 'values': ((.5,), (1.,)) }, (column,), sparse_combiner='mean') - with _initialized_session(): + # Disabling the constant folding optimizer here since it changes the + # error message differently on CPU and GPU. + config = config_pb2.ConfigProto() + config.graph_options.rewrite_options.constant_folding = ( + rewriter_config_pb2.RewriterConfig.OFF) + with _initialized_session(config): with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'): predictions.eval() -- GitLab From 5e9e6967b47989aa9084fb328f5ef0c40fc146ef Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 2 May 2018 16:22:41 -0700 Subject: [PATCH 049/839] Replaced calls to tensorflow::StringPiece::ToString with std::string conversions. That is, instances of sp.ToString() are replaced with std::string(sp). This will allow tensorflow::StringPiece::ToString to be removed, which is necessary before it can be replaced with absl::string_view. PiperOrigin-RevId: 195162126 --- tensorflow/c/c_api.cc | 2 +- tensorflow/c/c_api_test.cc | 4 ++-- tensorflow/c/checkpoint_reader.cc | 6 +++--- tensorflow/compiler/xla/shape_util.cc | 8 ++++---- tensorflow/compiler/xla/text_literal_reader.cc | 10 +++++----- tensorflow/compiler/xla/text_literal_writer.cc | 2 +- tensorflow/core/common_runtime/bfc_allocator.cc | 2 +- tensorflow/core/common_runtime/graph_runner.cc | 4 ++-- tensorflow/core/common_runtime/session_state.cc | 2 +- .../core/common_runtime/step_stats_collector.cc | 6 +++--- .../core/lib/monitoring/collection_registry.cc | 8 ++++---- .../core/lib/monitoring/collection_registry.h | 4 ++-- tensorflow/core/lib/monitoring/metric_def.h | 4 ++-- .../core/platform/cloud/curl_http_request.cc | 4 ++-- .../core/platform/cloud/gcs_file_system.cc | 14 +++++++------- tensorflow/core/platform/cloud/oauth_client.cc | 4 ++-- .../core/platform/cloud/oauth_client_test.cc | 8 ++++---- tensorflow/stream_executor/lib/env.h | 2 +- tensorflow/stream_executor/lib/path.cc | 2 +- tensorflow/stream_executor/lib/str_util.h | 2 +- .../freeze_requantization_ranges.cc | 5 ++--- .../graph_transforms/sparsify_gather_test.cc | 4 ++-- .../tools/graph_transforms/transform_graph.cc | 16 ++++++++-------- .../tools/graph_transforms/transform_utils.cc | 2 +- 24 files changed, 62 insertions(+), 63 deletions(-) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 18eeb28168..b86b277ac3 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2097,7 +2097,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, for (int i = 0; i < size; ++i) { TensorId id = results.missing_unused_input_map_keys[i]; - tf_results->missing_unused_key_names_data.push_back(id.first.ToString()); + tf_results->missing_unused_key_names_data.push_back(std::string(id.first)); tf_results->missing_unused_key_names[i] = tf_results->missing_unused_key_names_data.back().c_str(); tf_results->missing_unused_key_indexes[i] = id.second; diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 9b86425aa5..577f10c5e6 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -1368,7 +1368,7 @@ TEST(CAPI, SavedModel) { } const tensorflow::string input_op_name = - tensorflow::ParseTensorName(input_name).first.ToString(); + std::string(tensorflow::ParseTensorName(input_name).first); TF_Operation* input_op = TF_GraphOperationByName(graph, input_op_name.c_str()); ASSERT_TRUE(input_op != nullptr); @@ -1376,7 +1376,7 @@ TEST(CAPI, SavedModel) { ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); const tensorflow::string output_op_name = - tensorflow::ParseTensorName(output_name).first.ToString(); + std::string(tensorflow::ParseTensorName(output_name).first); TF_Operation* output_op = TF_GraphOperationByName(graph, output_op_name.c_str()); ASSERT_TRUE(output_op != nullptr); diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index b1f7bdaa54..74bc25a491 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -125,7 +125,7 @@ CheckpointReader::BuildV2VarMaps() { const auto& slice_proto = entry.slices(i); CHECK(filtered_keys .insert(EncodeTensorNameSlice( - v2_reader_->key().ToString() /* full var's name */, + std::string(v2_reader_->key()) /* full var's name */, TensorSlice(slice_proto))) .second); } @@ -138,11 +138,11 @@ CheckpointReader::BuildV2VarMaps() { new TensorSliceReader::VarToDataTypeMap); v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { - if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue; + if (filtered_keys.count(std::string(v2_reader_->key())) > 0) continue; CHECK(entry.ParseFromArray(v2_reader_->value().data(), v2_reader_->value().size())) << entry.InitializationErrorString(); - string key = v2_reader_->key().ToString(); + string key = std::string(v2_reader_->key()); (*var_to_shape_map)[key] = TensorShape(entry.shape()); (*var_to_data_type_map)[key] = DataType(entry.dtype()); } diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index c330473cda..7a897f6f8f 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -511,7 +511,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { break; } else if (must_end) { return InvalidArgument("Expected end of tuple; got: \"%s\"", - s->ToString().c_str()); + std::string(*s).c_str()); } shapes.emplace_back(); TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); @@ -541,7 +541,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) { return InvalidArgument( "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", - input.c_str(), s->ToString().c_str()); + input.c_str(), std::string(*s).c_str()); } return element; }; @@ -594,7 +594,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } return InvalidArgument("Invalid shape string to parse: \"%s\"", - s->ToString().c_str()); + std::string(*s).c_str()); } } // namespace @@ -603,7 +603,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); if (!s.empty()) { return InvalidArgument("Invalid shape string to parse: \"%s\"", - s.ToString().c_str()); + std::string(s).c_str()); } return shape; } diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 44f874cd2a..56702feab9 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -42,7 +42,7 @@ StatusOr> TextLiteralReader::ReadPath( << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr file; Status s = - tensorflow::Env::Default()->NewRandomAccessFile(path.ToString(), &file); + tensorflow::Env::Default()->NewRandomAccessFile(std::string(path), &file); if (!s.ok()) { return s; } @@ -92,7 +92,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { tensorflow::StringPiece sp(shape_string); if (tensorflow::str_util::RemoveWhitespaceContext(&sp) > 0) { - string tmp = sp.ToString(); + string tmp = std::string(sp); shape_string = tmp; } TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string)); @@ -124,10 +124,10 @@ StatusOr> TextLiteralReader::ReadAllLines() { line.c_str()); } float value; - if (!tensorflow::strings::safe_strtof(value_string.ToString().c_str(), + if (!tensorflow::strings::safe_strtof(std::string(value_string).c_str(), &value)) { return InvalidArgument("could not parse value as float: \"%s\"", - value_string.ToString().c_str()); + std::string(value_string).c_str()); } SplitByDelimToStringPieces(coordinates_string, ',', &coordinates); coordinate_values.clear(); @@ -136,7 +136,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { if (!tensorflow::strings::safe_strto64(piece, &coordinate_value)) { return InvalidArgument( "could not parse coordinate member as int64: \"%s\"", - piece.ToString().c_str()); + std::string(piece).c_str()); } coordinate_values.push_back(coordinate_value); } diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 3fee467594..6e3061b78a 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -33,7 +33,7 @@ namespace xla { /* static */ tensorflow::Status TextLiteralWriter::WriteToPath( const Literal& literal, tensorflow::StringPiece path) { std::unique_ptr f; - auto s = tensorflow::Env::Default()->NewWritableFile(path.ToString(), &f); + auto s = tensorflow::Env::Default()->NewWritableFile(std::string(path), &f); if (!s.ok()) { return s; } diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc index e9f839289a..8f2a419756 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.cc +++ b/tensorflow/core/common_runtime/bfc_allocator.cc @@ -616,7 +616,7 @@ string BFCAllocator::RenderOccupancy() { region_offset += region.memory_size(); } - return StringPiece(rendered, resolution).ToString(); + return std::string(rendered, resolution); } void BFCAllocator::DumpMemoryLog(size_t num_bytes) { diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 790f2eaa1e..adf2ef6f44 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -56,7 +56,7 @@ class SimpleRendezvous : public Rendezvous { } mutex_lock l(mu_); - string edge_name = parsed.edge_name.ToString(); + string edge_name = std::string(parsed.edge_name); if (table_.count(edge_name) > 0) { return errors::Internal("Send of an already sent tensor"); } @@ -69,7 +69,7 @@ class SimpleRendezvous : public Rendezvous { Tensor tensor; Status status = Status::OK(); { - string key = parsed.edge_name.ToString(); + string key = std::string(parsed.edge_name); mutex_lock l(mu_); if (table_.count(key) <= 0) { status = errors::Internal("Did not find key ", key); diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc index 6befa53dff..65ff356e73 100644 --- a/tensorflow/core/common_runtime/session_state.cc +++ b/tensorflow/core/common_runtime/session_state.cc @@ -70,7 +70,7 @@ Status TensorStore::SaveTensors(const std::vector& output_names, // Save only the tensors in output_names in the session. for (const string& name : output_names) { TensorId id(ParseTensorName(name)); - const string& op_name = id.first.ToString(); + const string& op_name = std::string(id.first); auto it = tensors_.find(op_name); if (it != tensors_.end()) { // Save the tensor to the session state. diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index f21536d586..af6880c6b3 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -94,7 +94,7 @@ static int ExtractGpuWithStreamAll(string device_name) { } else { // Convert the captured string into an integer. But first we need to put // the digits back in order - string ordered_capture = capture.ToString(); + string ordered_capture = std::string(capture); std::reverse(ordered_capture.begin(), ordered_capture.end()); int gpu_id; CHECK(strings::safe_strto32(ordered_capture, &gpu_id)); @@ -123,7 +123,7 @@ static int ExtractGpuWithoutStream(string device_name) { } else { // Convert the captured string into an integer. But first we need to put // the digits back in order - string ordered_capture = capture.ToString(); + string ordered_capture = std::string(capture); std::reverse(ordered_capture.begin(), ordered_capture.end()); int gpu_id; CHECK(strings::safe_strto32(ordered_capture, &gpu_id)); @@ -170,7 +170,7 @@ void StepStatsCollector::BuildCostModel( for (auto& itr : per_device_stats) { const StringPiece device_name = itr.first; - const int gpu_id = ExtractGpuWithoutStream(device_name.ToString()); + const int gpu_id = ExtractGpuWithoutStream(std::string(device_name)); if (gpu_id >= 0) { // Reference the gpu hardware stats in addition to the regular stats // for this gpu device if they're available. diff --git a/tensorflow/core/lib/monitoring/collection_registry.cc b/tensorflow/core/lib/monitoring/collection_registry.cc index d3fd7132de..8c28620ff9 100644 --- a/tensorflow/core/lib/monitoring/collection_registry.cc +++ b/tensorflow/core/lib/monitoring/collection_registry.cc @@ -38,15 +38,15 @@ void Collector::CollectMetricDescriptor( mutex_lock l(mu_); return collected_metrics_->metric_descriptor_map .insert(std::make_pair( - metric_def->name().ToString(), + std::string(metric_def->name()), std::unique_ptr(new MetricDescriptor()))) .first->second.get(); }(); - metric_descriptor->name = metric_def->name().ToString(); - metric_descriptor->description = metric_def->description().ToString(); + metric_descriptor->name = std::string(metric_def->name()); + metric_descriptor->description = std::string(metric_def->description()); for (const StringPiece label_name : metric_def->label_descriptions()) { - metric_descriptor->label_names.push_back(label_name.ToString()); + metric_descriptor->label_names.push_back(std::string(label_name)); } metric_descriptor->metric_kind = metric_def->kind(); diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h index 63cc0f550d..20f0444f8b 100644 --- a/tensorflow/core/lib/monitoring/collection_registry.h +++ b/tensorflow/core/lib/monitoring/collection_registry.h @@ -72,7 +72,7 @@ class MetricCollector { registration_time_millis_(registration_time_millis), collector_(collector), point_set_(point_set) { - point_set_->metric_name = metric_def->name().ToString(); + point_set_->metric_name = std::string(metric_def->name()); } const MetricDef* const metric_def_; @@ -261,7 +261,7 @@ class Collector { auto* const point_set = [&]() { mutex_lock l(mu_); return collected_metrics_->point_set_map - .insert(std::make_pair(metric_def->name().ToString(), + .insert(std::make_pair(std::string(metric_def->name()), std::unique_ptr(new PointSet()))) .first->second.get(); }(); diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h index 5ecadcc427..6f94685665 100644 --- a/tensorflow/core/lib/monitoring/metric_def.h +++ b/tensorflow/core/lib/monitoring/metric_def.h @@ -98,8 +98,8 @@ class AbstractMetricDef { const std::vector& label_descriptions) : kind_(kind), value_type_(value_type), - name_(name.ToString()), - description_(description.ToString()), + name_(std::string(name)), + description_(std::string(description)), label_descriptions_(std::vector(label_descriptions.begin(), label_descriptions.end())) {} diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc index 1ac6a7531b..081d4cf043 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.cc +++ b/tensorflow/core/platform/cloud/curl_http_request.cc @@ -407,9 +407,9 @@ size_t CurlHttpRequest::HeaderCallback(const void* ptr, size_t size, .StopCapture() .OneLiteral(": ") .GetResult(&value, &name)) { - string str_value = value.ToString(); + string str_value = std::string(value); str_util::StripTrailingWhitespace(&str_value); - that->response_headers_[name.ToString()] = str_value; + that->response_headers_[std::string(name)] = str_value; } return size * nmemb; } diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 2d9c99c124..f1e18403ec 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -167,13 +167,13 @@ Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket, return errors::InvalidArgument("GCS path doesn't start with 'gs://': ", fname); } - *bucket = bucketp.ToString(); + *bucket = std::string(bucketp); if (bucket->empty() || *bucket == ".") { return errors::InvalidArgument("GCS path doesn't contain a bucket name: ", fname); } str_util::ConsumePrefix(&objectp, "/"); - *object = objectp.ToString(); + *object = std::string(objectp); if (!empty_object_ok && object->empty()) { return errors::InvalidArgument("GCS path doesn't contain an object name: ", fname); @@ -212,7 +212,7 @@ std::set AddAllSubpaths(const std::vector& paths) { for (const string& path : paths) { StringPiece subpath = io::Dirname(path); while (!subpath.empty()) { - result.emplace(subpath.ToString()); + result.emplace(std::string(subpath)); subpath = io::Dirname(subpath); } } @@ -704,7 +704,7 @@ GcsFileSystem::GcsFileSystem() if (!header_name.empty() && !header_value.empty()) { additional_header_.reset(new std::pair( - header_name.ToString(), header_value.ToString())); + std::string(header_name), std::string(header_value))); VLOG(1) << "GCS additional header ENABLED. " << "Name: " << additional_header_->first << ", " @@ -1095,7 +1095,7 @@ Status GcsFileSystem::GetMatchingPaths(const string& pattern, // Find the fixed prefix by looking for the first wildcard. const string& fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\")); - const string& dir = io::Dirname(fixed_prefix).ToString(); + const string& dir = std::string(io::Dirname(fixed_prefix)); if (dir.empty()) { return errors::InvalidArgument( "A GCS pattern doesn't have a bucket name: ", pattern); @@ -1192,7 +1192,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, " doesn't match the prefix ", object_prefix)); } if (!relative_path.empty() || include_self_directory_marker) { - result->emplace_back(relative_path.ToString()); + result->emplace_back(std::string(relative_path)); } if (++retrieved_results >= max_results) { return Status::OK(); @@ -1220,7 +1220,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, "Unexpected response: the returned folder name ", prefix_str, " doesn't match the prefix ", object_prefix); } - result->emplace_back(relative_path.ToString()); + result->emplace_back(std::string(relative_path)); if (++retrieved_results >= max_results) { return Status::OK(); } diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc index 06849f9093..59ad3cbcc2 100644 --- a/tensorflow/core/platform/cloud/oauth_client.cc +++ b/tensorflow/core/platform/cloud/oauth_client.cc @@ -216,7 +216,7 @@ Status OAuthClient::GetTokenFromServiceAccountJson( // Send the request to the Google OAuth 2.0 server to get the token. std::unique_ptr request(http_request_factory_->Create()); std::vector response_buffer; - request->SetUri(oauth_server_uri.ToString()); + request->SetUri(std::string(oauth_server_uri)); request->SetPostFromBuffer(request_body.c_str(), request_body.size()); request->SetResultBuffer(&response_buffer); TF_RETURN_IF_ERROR(request->Send()); @@ -248,7 +248,7 @@ Status OAuthClient::GetTokenFromRefreshTokenJson( std::unique_ptr request(http_request_factory_->Create()); std::vector response_buffer; - request->SetUri(oauth_server_uri.ToString()); + request->SetUri(std::string(oauth_server_uri)); request->SetPostFromBuffer(request_body.c_str(), request_body.size()); request->SetResultBuffer(&response_buffer); TF_RETURN_IF_ERROR(request->Send()); diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc index ad569758cc..4ffa72288b 100644 --- a/tensorflow/core/platform/cloud/oauth_client_test.cc +++ b/tensorflow/core/platform/cloud/oauth_client_test.cc @@ -124,11 +124,11 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) { .OneLiteral("&assertion=") .GetResult(&assertion, &grant_type)); EXPECT_EQ("urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer", - grant_type.ToString()); + grant_type); - int last_dot = assertion.ToString().find_last_of("."); - string header_dot_claim = assertion.ToString().substr(0, last_dot); - string signature_encoded = assertion.ToString().substr(last_dot + 1); + int last_dot = std::string(assertion).find_last_of("."); + string header_dot_claim = std::string(assertion.substr(0, last_dot)); + string signature_encoded = std::string(assertion.substr(last_dot + 1)); // Check that 'signature' signs 'header_dot_claim'. diff --git a/tensorflow/stream_executor/lib/env.h b/tensorflow/stream_executor/lib/env.h index 776eba0408..3ef8deb72e 100644 --- a/tensorflow/stream_executor/lib/env.h +++ b/tensorflow/stream_executor/lib/env.h @@ -32,7 +32,7 @@ inline Status FileExists(const string& filename) { } inline Status FileExists(const port::StringPiece& filename) { - return Env::Default()->FileExists(filename.ToString()); + return Env::Default()->FileExists(std::string(filename)); } } // namespace port diff --git a/tensorflow/stream_executor/lib/path.cc b/tensorflow/stream_executor/lib/path.cc index 56e08c316f..58a862206c 100644 --- a/tensorflow/stream_executor/lib/path.cc +++ b/tensorflow/stream_executor/lib/path.cc @@ -33,7 +33,7 @@ string JoinPathImpl(std::initializer_list paths) { if (path.empty()) continue; if (result.empty()) { - result = path.ToString(); + result = std::string(path); continue; } diff --git a/tensorflow/stream_executor/lib/str_util.h b/tensorflow/stream_executor/lib/str_util.h index a81c666818..b02fe4f56f 100644 --- a/tensorflow/stream_executor/lib/str_util.h +++ b/tensorflow/stream_executor/lib/str_util.h @@ -31,7 +31,7 @@ inline string StripSuffixString(port::StringPiece str, port::StringPiece suffix) if (tensorflow::str_util::EndsWith(str, suffix)) { str.remove_suffix(suffix.size()); } - return str.ToString(); + return std::string(str); } using tensorflow::str_util::Lowercase; diff --git a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc index f401723808..c8dc2a7c4d 100644 --- a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc +++ b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc @@ -92,9 +92,8 @@ Status ExtractMinMaxRecords(const string& log_file_name, if (!str_util::EndsWith(name_string, print_suffix)) { continue; } - string name = - name_string.substr(0, name_string.size() - print_suffix.size()) - .ToString(); + string name = std::string( + name_string.substr(0, name_string.size() - print_suffix.size())); records->push_back({name, min, max}); } return Status::OK(); diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc index d41321c9a6..dd95779a1f 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc @@ -42,8 +42,8 @@ class SparsifyGatherTest : public ::testing::Test { const std::vector& inputs, GraphDef* graph_def, bool control_dep = false) { NodeDef* node_def = graph_def->add_node(); - node_def->set_name(name.ToString()); - node_def->set_op(op.ToString()); + node_def->set_name(std::string(name)); + node_def->set_op(std::string(op)); if (!control_dep) { std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) { node_def->add_input(input->name()); diff --git a/tensorflow/tools/graph_transforms/transform_graph.cc b/tensorflow/tools/graph_transforms/transform_graph.cc index 8ce8f5e24b..3b9dd3dd2d 100644 --- a/tensorflow/tools/graph_transforms/transform_graph.cc +++ b/tensorflow/tools/graph_transforms/transform_graph.cc @@ -65,19 +65,19 @@ Status ParseTransformParameters(const string& transforms_string, .GetResult(&remaining, &transform_name); if (!found_transform_name) { return errors::InvalidArgument("Looking for transform name, but found ", - remaining.ToString().c_str()); + std::string(remaining).c_str()); } if (Scanner(remaining).OneLiteral("(").GetResult(&remaining, &match)) { state = TRANSFORM_PARAM_NAME; } else { // Add a transform with no parameters. - params_list->push_back({transform_name.ToString(), func_parameters}); + params_list->push_back({std::string(transform_name), func_parameters}); transform_name = ""; state = TRANSFORM_NAME; } } else if (state == TRANSFORM_PARAM_NAME) { if (Scanner(remaining).OneLiteral(")").GetResult(&remaining, &match)) { - params_list->push_back({transform_name.ToString(), func_parameters}); + params_list->push_back({std::string(transform_name), func_parameters}); transform_name = ""; state = TRANSFORM_NAME; } else { @@ -92,13 +92,13 @@ Status ParseTransformParameters(const string& transforms_string, if (!found_parameter_name) { return errors::InvalidArgument( "Looking for parameter name, but found ", - remaining.ToString().c_str()); + std::string(remaining).c_str()); } if (Scanner(remaining).OneLiteral("=").GetResult(&remaining, &match)) { state = TRANSFORM_PARAM_VALUE; } else { return errors::InvalidArgument("Looking for =, but found ", - remaining.ToString().c_str()); + std::string(remaining).c_str()); } } } else if (state == TRANSFORM_PARAM_VALUE) { @@ -120,10 +120,10 @@ Status ParseTransformParameters(const string& transforms_string, } if (!found_parameter_value) { return errors::InvalidArgument("Looking for parameter name, but found ", - remaining.ToString().c_str()); + std::string(remaining).c_str()); } - func_parameters[parameter_name.ToString()].push_back( - parameter_value.ToString()); + func_parameters[std::string(parameter_name)].push_back( + std::string(parameter_value)); // Eat up any trailing quotes. Scanner(remaining).ZeroOrOneLiteral("\"").GetResult(&remaining, &match); Scanner(remaining).ZeroOrOneLiteral("'").GetResult(&remaining, &match); diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index 367048965d..af17fd75bc 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -93,7 +93,7 @@ void NodeNamePartsFromInput(const string& input_name, string* prefix, } else { *prefix = ""; } - *node_name = node_name_piece.ToString(); + *node_name = std::string(node_name_piece); } string NodeNameFromInput(const string& input_name) { -- GitLab From 4c256cda4f29ce5f634be44628e2c4c639974dc3 Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Wed, 2 May 2018 15:30:30 -0700 Subject: [PATCH 050/839] Add prefetching to one device distribution strategy. PiperOrigin-RevId: 195162570 --- .../contrib/distribute/python/one_device_strategy.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 646d2a5c3b..64aa369201 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -36,9 +36,10 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): # doing something that won't work with other DistributionStrategy # implementations? - def __init__(self, device): + def __init__(self, device, prefetch_on_device=None): super(OneDeviceStrategy, self).__init__() self._device = device + self._prefetch_on_device = prefetch_on_device def _create_variable(self, next_creator, *args, **kwargs): # No need to distinguish tower-local variables when not mirroring, @@ -61,7 +62,9 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): return next_creator(*args, **kwargs) def distribute_dataset(self, dataset_fn): - return self._call_dataset_fn(dataset_fn) + return values.PerDeviceDataset( + self._call_dataset_fn(dataset_fn), [self._device], + self._prefetch_on_device) def _broadcast(self, tensor, destinations): return tensor -- GitLab From 85566b2420833a4ba59241330eeceedea4f98e3c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 15:35:11 -0700 Subject: [PATCH 051/839] Adding a version of rolled triangular solver code for the right-multiply case, which is used in Cholesky decomposition. Replacing the unrolled version with a While loop drastically reduces XLA compilation times which allows much larger models to be run on TPU. PiperOrigin-RevId: 195163298 --- .../compiler/tf2xla/lib/triangular_solve.cc | 179 +++++++++++++++--- .../compiler/tf2xla/lib/triangular_solve.h | 6 + tensorflow/compiler/tf2xla/lib/util.cc | 7 + tensorflow/compiler/tf2xla/lib/util.h | 5 + 4 files changed, 173 insertions(+), 24 deletions(-) diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index d0279d4412..b4503601f9 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -82,13 +82,6 @@ xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, block_size); } - // Applies a complex conjugation operation if `a` is complex and `conjugate_a` - // is true, otherwise returns its argument. - auto maybe_conj = [&](xla::XlaBuilder* builder, xla::XlaOp x) { - auto perform_conj = a_shape.element_type() == xla::C64 && conjugate_a; - return perform_conj ? builder->Conj(x) : x; - }; - std::map base_computations; auto get_base_triangular_solve = [&](int k) -> xla::StatusOr { @@ -117,16 +110,21 @@ xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, PrependMajorDims(sub.get(), batch_dimensions, b_lastd)), "b"); - // We use a left-looking subroutine on the block diagonal in some common - // cases, while falling back to a recursive call in unsupported cases. The - // left-looking subroutine is written with a While loop and so yields much - // faster compile times. Moreover, the left-looking variant can give - // higher performance on smaller (sub)problems. + // We use a left-looking or right-looking subroutine on the block diagonal + // in the lower=true cases, while falling back to a recursive call in + // others. The left-looking and right-looking subroutines are written with + // a While loop and so yields much faster compile times. Moreover, they + // can give higher performance on smaller (sub)problems. if (left_side && lower) { TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param, b_param, transpose_a, conjugate_a) .status()); + } else if (!left_side && lower) { + TF_RETURN_IF_ERROR(TriangularSolveRightLooking(sub.get(), a_param, + b_param, transpose_a, + conjugate_a) + .status()); } else { TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, left_side, lower, transpose_a, @@ -169,7 +167,9 @@ xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {0, i})); @@ -219,7 +219,9 @@ xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); @@ -268,7 +270,9 @@ xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {0, i})); @@ -318,7 +322,9 @@ xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); @@ -371,11 +377,6 @@ xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, batch_dimensions.push_back(a_size); } - auto maybe_conj = [&](xla::XlaBuilder* builder, xla::XlaOp x) { - auto perform_conj = a_shape.element_type() == xla::C64 && conjugate_a; - return perform_conj ? builder->Conj(x) : x; - }; - // The main computation is performed in a While loop. // Allocate the output and set its first or last row, @@ -391,7 +392,9 @@ xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {i, 0}, {i + 1, n})); - auto update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + auto update = builder->Div(b_slice, a_slice_conj); TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); } @@ -493,7 +496,9 @@ xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a, {i, i}, {1, 1})); - auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt)); + TF_ASSIGN_OR_RETURN(auto a_elt_conj, + MaybeConjugate(bodyb.get(), a_elt, conjugate_a)); + auto div_result = bodyb->Div(result_row, a_elt_conj); TF_ASSIGN_OR_RETURN(body_out, DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, div_result, {i, zero})); @@ -513,4 +518,130 @@ xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, return builder->GetTupleElement(triangular_solve_left_looking_while, 1); } +xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); + + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + batch_dimensions.push_back(a_size); + } + + // The main computation is performed in a While loop. + xla::XlaOp output = Zeros(builder, b_shape); + + // Construct the initial loop carry tuple, + // if transpose_a: + // init = (0, output, a, b) + // else: + // init = (n-1, output, a, b) + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of b, with one row updated each iteration. + b_shape, + // The coefficient matrix a is a loop invariant. + a_shape, + // The right-hand-side matrix b is a loop invariant. + b_shape}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + auto init_i = builder->ConstantR0(transpose_a ? 0 : n - 1); + auto init = builder->Tuple({init_i, output, a, b}); + + // Construct the loop condition function, + // def cond_fun(loop_carry): + // i, output, a, b = loop_carry + // return i < n if transpose_a else i >= 0 + std::unique_ptr condb = + builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond"); + { + auto i = condb->GetTupleElement( + condb->Parameter(0, tuple_shape, + "TriangularSolveRightLookingWhileTuple"), + 0); + if (transpose_a) { + condb->Lt(i, condb->ConstantR0(n)); + } else { + condb->Ge(i, condb->ConstantR0(0)); + } + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function, + // def body_fun(loop_carry): + // i, output, a, b = loop_carry + // if transpose_a: + // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2) + // else: + // a_row = a[..., :, i:i+1] + // result_row = b[..., :, i:i+1] - np.matmul(output, a_row) + // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] + // if transpose_a: + // return (i - 1, output, a, b) + // else: + // return (i + 1, output, a, b) + // We have to do some extra FLOPs propagating zeros in the matrix multiply + // because we can't have the size of its arguments depend on the loop counter. + std::unique_ptr bodyb = + builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody"); + { + auto input_tuple = bodyb->Parameter( + 0, tuple_shape, "TriangularSolveRightLookingWhileTuple"); + + // i, output, a, b = loop_carry + auto i = bodyb->GetTupleElement(input_tuple, 0); + auto body_out = bodyb->GetTupleElement(input_tuple, 1); + auto body_a = bodyb->GetTupleElement(input_tuple, 2); + auto body_b = bodyb->GetTupleElement(input_tuple, 3); + auto zero = bodyb->ConstantR0(0); + + // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :, + // i:i+1]) But since we can't have intermediate array sizes depend on the + // loop counter, we instead exploit the fact that we initialized the output + // to all zeros and use that as zero-padding (doing unnecessary FLOPs). + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), body_out, body_a, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a)); + // result = b - np.matmul(output, a) + auto result = bodyb->Sub(body_b, b_update); + // result_row = result[..., :, i:i+1] + TF_ASSIGN_OR_RETURN( + auto result_row, + DynamicSliceInMinorDims(bodyb.get(), result, {zero, i}, {m, 1})); + + // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] + TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(bodyb.get(), body_a, + {i, i}, {1, 1})); + TF_ASSIGN_OR_RETURN(auto a_ii_conj, + MaybeConjugate(bodyb.get(), a_ii, conjugate_a)); + auto div_result = bodyb->Div(result_row, a_ii_conj); + TF_ASSIGN_OR_RETURN(body_out, + DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, + div_result, {zero, i})); + + // if transpose_a: + // return (i + 1, body_out, a, b) + // else: + // return (i - 1, body_out, a, b) + auto next_i = bodyb->Add(i, bodyb->ConstantR0(transpose_a ? 1 : -1)); + bodyb->Tuple({next_i, body_out, body_a, body_b}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto triangular_solve_left_looking_while = builder->While(cond, body, init); + return builder->GetTupleElement(triangular_solve_left_looking_while, 1); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index fd8f2489d1..540c26b247 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -69,6 +69,12 @@ xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, bool transpose_a, bool conjugate_a); +xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index cc7b13571c..d9ff7e6259 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -230,4 +230,11 @@ xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, return builder->Transpose(x, permutation); } +xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, + const xla::XlaOp& x, bool conjugate) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + auto perform_conj = shape.element_type() == xla::C64 && conjugate; + return perform_conj ? builder->Conj(x) : x; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 3df44ef035..3c120a2548 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -85,6 +85,11 @@ xla::StatusOr DynamicUpdateSliceInMinorDims( xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, const xla::XlaOp& x); +// Applies a complex conjugation operation if `a` is complex and `conjugate_a` +// is true, otherwise returns its argument. +xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, + const xla::XlaOp& x, bool conjugate); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ -- GitLab From 0237e86297087ba3e700ac9218f846e6e662c60f Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Wed, 2 May 2018 15:36:54 -0700 Subject: [PATCH 052/839] Adds the EvalListener support for run_local. PiperOrigin-RevId: 195163507 --- tensorflow/python/estimator/training.py | 10 +++ tensorflow/python/estimator/training_test.py | 67 ++++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 534c357067..41ffa371aa 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -656,6 +656,11 @@ class _TrainingExecutor(object): max_steps=self._train_spec.max_steps, hooks=train_hooks) + if not self._continuous_eval_listener.before_eval(): + logging.info('Exiting training and evaluation lopp, as requested by ' + '_ContinuousEvalListener.before_eval.') + break + # Final export signal: For any eval result with global_step >= train # max_steps, the evaluator will send the final export signal. The # _should_stop_local_train will then end the while True as the stopping @@ -669,6 +674,11 @@ class _TrainingExecutor(object): raise RuntimeError('There was no new checkpoint after the training. ' 'Eval status: {}'.format(eval_result.status)) + if not self._continuous_eval_listener.after_eval(eval_result): + logging.info('Exiting evaluation, as requested by ' + '_ContinuousEvalListener.after_eval.') + break + if _should_stop_local_train( eval_result.metrics[ops.GraphKeys.GLOBAL_STEP]): break diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py index c04905ae65..3b6f5e18cb 100644 --- a/tensorflow/python/estimator/training_test.py +++ b/tensorflow/python/estimator/training_test.py @@ -1628,6 +1628,73 @@ class TrainingExecutorRunLocalTest(test.TestCase): self.assertEqual(3, mock_est.times_export_was_called) self.assertEqual(1, mock_est.times_final_export_was_called) + def test_runs_with_eval_listener_before_eval(self): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') + mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn + + train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) + eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=100) + # should be called 2 times without the evallistener + mock_est.evaluate.side_effect = [{ + _GLOBAL_STEP_KEY: train_spec.max_steps - 50 + }, { + _GLOBAL_STEP_KEY: train_spec.max_steps + }] + + class _Listener(training._ContinuousEvalListener): + + def __init__(self): + self.call_count = 0 + + def before_eval(self): + self.call_count += 1 + return False # Will stop the run_local before first eval. + + listener = _Listener() + + executor = training._TrainingExecutor( + mock_est, train_spec, eval_spec, continuous_eval_listener=listener) + executor.run_local() + + self.assertEqual(1, mock_est.train.call_count) + self.assertEqual(0, mock_est.evaluate.call_count) + self.assertEqual(1, listener.call_count) + + def test_runs_with_eval_listener_after_eval(self): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') + mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn + + train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300) + eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=100) + # should be called 2 times without the evallistener + mock_est.evaluate.side_effect = [{ + _GLOBAL_STEP_KEY: train_spec.max_steps - 50 + }, { + _GLOBAL_STEP_KEY: train_spec.max_steps + }] + + class _Listener(training._ContinuousEvalListener): + + def __init__(self, test_case): + self.call_count = 0 + self._test_case = test_case + + def after_eval(self, eval_result): + self.call_count += 1 + self._test_case.assertEqual( + train_spec.max_steps - 50, eval_result.metrics[_GLOBAL_STEP_KEY]) + return False # Will stop the run_local after first eval. + + listener = _Listener(test_case=self) + + executor = training._TrainingExecutor( + mock_est, train_spec, eval_spec, continuous_eval_listener=listener) + executor.run_local() + + self.assertEqual(1, mock_est.train.call_count) + self.assertEqual(1, mock_est.evaluate.call_count) + self.assertEqual(1, listener.call_count) + def test_handles_no_new_checkpoint_found(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') mock_est.latest_checkpoint.return_value = ( -- GitLab From 49f2afe21e3cada8951205d00e877c873a33754c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 15:51:16 -0700 Subject: [PATCH 053/839] Allow evaluation and prediction through warm-starting (no current checkpoint / model_dir). PiperOrigin-RevId: 195165732 --- tensorflow/python/estimator/BUILD | 1 + tensorflow/python/estimator/estimator.py | 20 ++++- tensorflow/python/estimator/estimator_test.py | 89 +++++++++++++++++++ 3 files changed, 106 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index c6bb9b9be7..56dec1eaa1 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -478,6 +478,7 @@ py_library( "//tensorflow/python:util", "//tensorflow/python/data", "//tensorflow/python/saved_model:builder", + "//tensorflow/python/saved_model:constants", "//tensorflow/python/saved_model:tag_constants", "//third_party/py/numpy", "@six_archive//:six", diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 946f093ba7..530a4a24ef 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -50,6 +50,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import builder as saved_model_builder +from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.summary.writer import writer_cache @@ -493,7 +494,6 @@ class Estimator(object): if not checkpoint_path: logging.info('Could not find trained model in model_dir: {}, running ' 'initialization to predict.'.format(self._model_dir)) - with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) self._create_and_assert_global_step(g) @@ -501,6 +501,10 @@ class Estimator(object): input_fn, model_fn_lib.ModeKeys.PREDICT) estimator_spec = self._call_model_fn( features, None, model_fn_lib.ModeKeys.PREDICT, self.config) + + # Call to warm_start has to be after model_fn is called. + self._maybe_warm_start(checkpoint_path) + predictions = self._extract_keys( estimator_spec.predictions, predict_keys) all_hooks = list(input_hooks) @@ -982,9 +986,7 @@ class Estimator(object): if self._warm_start_settings: logging.info('Warm-starting with WarmStartSettings: %s' % (self._warm_start_settings,)) - # pylint: disable=protected-access warm_starting_util.warm_start(*self._warm_start_settings) - # pylint: enable=protected-access # Check if the user created a loss summary, and add one if they didn't. # We assume here that the summary is called 'loss'. If it is not, we will # make another one with the name 'loss' to ensure it shows up in the right @@ -1089,6 +1091,9 @@ class Estimator(object): estimator_spec = self._call_model_fn( features, labels, model_fn_lib.ModeKeys.EVAL, self.config) + # Call to warm_start has to be after model_fn is called. + self._maybe_warm_start(checkpoint_path) + if model_fn_lib.LOSS_METRIC_KEY in estimator_spec.eval_metric_ops: raise ValueError( 'Metric with name "%s" is not allowed, because Estimator ' % ( @@ -1126,6 +1131,12 @@ class Estimator(object): return eval_results + def _maybe_warm_start(self, checkpoint_path): + if not checkpoint_path and self._warm_start_settings: + logging.info('Warm-starting with WarmStartSettings: %s' % + (self._warm_start_settings,)) + warm_starting_util.warm_start(*self._warm_start_settings) + def create_per_tower_ready_op(scaffold): """Create a Scaffold.ready_op inside a tower.""" @@ -1525,7 +1536,8 @@ def _get_default_warm_start_settings(warm_start_from): logging.info('Warm-starting from a SavedModel') return WarmStartSettings(ckpt_to_initialize_from=os.path.join( compat.as_bytes(warm_start_from), - compat.as_bytes('variables/variables'))) + compat.as_bytes('{}/{}'.format(constants.VARIABLES_DIRECTORY, + constants.VARIABLES_FILENAME)))) return WarmStartSettings(ckpt_to_initialize_from=warm_start_from) elif isinstance(warm_start_from, WarmStartSettings): return warm_start_from diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 4d958f8b43..76b45b7f57 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -1116,6 +1116,52 @@ class EstimatorEvaluateTest(test.TestCase): # initialized (since there is no checkpoint). self.assertEqual(3., metrics['metric']) + def test_no_checkpoint_uses_init_with_warm_starting(self): + def _make_model_fn(x): + def _variable_creating_and_export_model_fn(features, labels, mode): + _, _ = features, labels + x_var = variable_scope.get_variable('x', initializer=x) + global_step = training.get_global_step() + return model_fn_lib.EstimatorSpec( + mode, + predictions={'y': constant_op.constant(1.0)}, + loss=constant_op.constant(1.), + eval_metric_ops={'metric': metrics_lib.mean(x_var + 1)}, + train_op=state_ops.assign_add(global_step, 1), + export_outputs={'test': export_output.ClassificationOutput( + constant_op.constant([4.2]), constant_op.constant(['label']))}) + return _variable_creating_and_export_model_fn + + first_est = estimator.Estimator(model_fn=_make_model_fn(42.)) + first_est.train(dummy_input_fn, steps=10) + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + tmpdir = tempfile.mkdtemp() + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + exported_path = first_est.export_savedmodel(export_dir_base, + serving_input_receiver_fn) + + # Test that we can pass either warm_start_from as an external checkpoint + # or an exported SavedModel. + est = estimator.Estimator(model_fn=_make_model_fn(52.), + warm_start_from=exported_path) + metrics = est.evaluate(dummy_input_fn, steps=1) + # Metric value here is set to 1 + the value of the Variable that is + # warm-started from the SavedModel of the first model (42.), as opposed to + # the initialization in the new model_fn (52.). + self.assertEqual(43., metrics['metric']) + + est = estimator.Estimator(model_fn=_make_model_fn(62.), + warm_start_from=first_est.model_dir) + metrics = est.evaluate(dummy_input_fn, steps=1) + # Metric value here is set to 1 + the value of the Variable that is + # warm-started from a checkpoint of the first model (42.), as opposed to + # the initialization in the new model_fn (52.). + self.assertEqual(43., metrics['metric']) + def test_scores(self): est = estimator.Estimator( model_fn=_model_fn_with_eval_metric_ops, @@ -1384,6 +1430,49 @@ class EstimatorPredictTest(test.TestCase): # initialized (since there is no checkpoint). self.assertEqual(4., next(est.predict(dummy_input_fn))) + def test_no_checkpoint_uses_init_with_warm_starting(self): + def _make_model_fn(x): + def _variable_creating_and_export_model_fn(features, labels, mode): + _, _ = features, labels + x_var = variables.Variable([[x]], name='x') + return model_fn_lib.EstimatorSpec( + mode, + predictions=math_ops.add(x_var, 1.), + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + export_outputs={'test': export_output.ClassificationOutput( + constant_op.constant([4.2]), + constant_op.constant(['label']))}) + return _variable_creating_and_export_model_fn + + first_est = estimator.Estimator(model_fn=_make_model_fn(3.)) + first_est.train(dummy_input_fn, steps=10) + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + tmpdir = tempfile.mkdtemp() + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + exported_path = first_est.export_savedmodel(export_dir_base, + serving_input_receiver_fn) + + # Test that we can pass either warm_start_from as an external checkpoint + # or an exported SavedModel. + est = estimator.Estimator(model_fn=_make_model_fn(30.), + warm_start_from=exported_path) + # Prediction here is set to 1 + the value of the Variable that is + # warm-started from the SavedModel of the first model (3.), as opposed to + # the initialization in the new model_fn (30.). + self.assertEqual(4., next(est.predict(dummy_input_fn))) + + est = estimator.Estimator(model_fn=_make_model_fn(40.), + warm_start_from=first_est.model_dir) + # Prediction here is set to 1 + the value of the Variable that is + # warm-started from a checkpoint of the first model (3.), as opposed to + # the initialization in the new model_fn (40.). + self.assertEqual(4., next(est.predict(dummy_input_fn))) + def test_no_trained_model_invalid_checkpoint_path(self): est = estimator.Estimator(model_fn=model_fn_global_step_incrementer) with self.assertRaises(ValueError): -- GitLab From 30927ec6b625121bae1b89b07f9faeaebaed321f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 16:04:09 -0700 Subject: [PATCH 054/839] Mark all nodes processed by AddOpsRewrite/MinBCast stages with a tag. PiperOrigin-RevId: 195167597 --- ...direct_session_with_tracking_alloc_test.cc | 4 +- .../optimizers/arithmetic_optimizer.cc | 77 +++++++++++-------- .../grappler/optimizers/meta_optimizer.cc | 2 +- .../python/grappler/layout_optimizer_test.py | 8 +- 4 files changed, 53 insertions(+), 38 deletions(-) diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index b4dd521bbc..695423b2cb 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -102,9 +102,9 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { EXPECT_EQ(2, shape.dim(0).size()); EXPECT_EQ(1, shape.dim(1).size()); if (node->name() == y->name()) { - EXPECT_EQ(7, cm->AllocationId(node, 0)); + EXPECT_EQ(9, cm->AllocationId(node, 0)); } else { - EXPECT_EQ(8, cm->AllocationId(node, 0)); + EXPECT_EQ(10, cm->AllocationId(node, 0)); } } EXPECT_LE(0, cm->MaxExecutionTime(node)); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index bf59b25449..d6510ba681 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" @@ -49,6 +50,12 @@ namespace tensorflow { namespace grappler { namespace { +// Mark nodes created or optimized by a stage with a tag. +constexpr char kAddOpsRewriteTag[] = + "_grappler:ArithmeticOptimizer:AddOpsRewriteStage"; +constexpr char kMinimizeBroadcastsTag[] = + "_grappler:ArithmeticOptimizer:MinimizeBroadcasts"; + // Extract values from a Const op to `values`. Returns true if succeeds. template bool ValuesFromConstNode(const NodeDef& node, std::vector* values) { @@ -142,18 +149,6 @@ bool MaybeAddControlInput(const string& new_input, NodeDef* node, return !already_exists; } -int CopyControlInputs(const NodeDef& from, NodeDef* to, GraphDef* graph, - NodeMap* node_map) { - int num_copied = 0; - for (const string& input : from.input()) { - if (IsControlInput(input) && - MaybeAddControlInput(input, to, graph, node_map)) { - ++num_copied; - } - } - return num_copied; -} - void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) { (*node->mutable_attr())[attr_name].set_type(dtype); } @@ -326,7 +321,7 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { explicit ArithmeticNodesGroupOptimizerStage( const string& name, const GraphOptimizerContext& ctx, const ArithmeticOptimizerContext ctx_ext) - : ArithmeticOptimizerStage(name, ctx, ctx_ext), optimized_nodes_{} {} + : ArithmeticOptimizerStage(name, ctx, ctx_ext) {} ~ArithmeticNodesGroupOptimizerStage() override = default; // Input name with a statically inferred shape from GraphProperties @@ -465,13 +460,16 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { return signature; } - void AddToOptimizedNodes(const NodeDef* node) { - optimized_nodes_.insert(node->name()); + void MarkWithTag(const StringPiece tag, NodeDef* node) { + AddNodeAttr(tag, true, node); } - void AddAllMembersToOptimizedNodes(const OptimizedNodesGroup& group) { - AddToOptimizedNodes(group.root_node); - for (const NodeDef* opt : group.optimized_nodes) AddToOptimizedNodes(opt); + void MarkAllMembersWithTag(const OptimizedNodesGroup& group, + const StringPiece tag) const { + AddNodeAttr(tag, true, group.root_node); + for (NodeDef* optimized_node : group.optimized_nodes) { + AddNodeAttr(tag, true, optimized_node); + } } bool IsOnTheSameDevice(const OptimizedNodesGroup& group, @@ -479,13 +477,19 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage { return group.root_node->device() == node.device(); } - bool IsAlreadyOptimized(const NodeDef& node) const { - return optimized_nodes_.find(node.name()) != optimized_nodes_.end(); + bool IsInPreserveSet(const NodeDef& node) const { + return ctx().nodes_to_preserve->find(node.name()) != + ctx().nodes_to_preserve->end(); } - private: - // set of nodes already processed by this optimizer stage - std::unordered_set optimized_nodes_; + bool IsMarkedWithTag(const NodeDef& node, const StringPiece tag) const { + return HasNodeAttr(node, tag); + } + + bool IsMarkedWithAnyTag(const NodeDef& node, const StringPiece tag1, + const StringPiece tag2) const { + return IsMarkedWithTag(node, tag1) || IsMarkedWithTag(node, tag2); + } }; // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the @@ -561,7 +565,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage { if (!IsAdd(node) && !IsAddN(node)) { return false; } - if (IsInPreserveSet(node) || IsAlreadyOptimized(node)) { + if (IsInPreserveSet(node) || IsMarkedWithTag(node, kAddOpsRewriteTag)) { return false; } // TODO(ezhulenev): relax this condition for root node @@ -579,7 +583,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage { << " num_inputs=" << group.inputs.size(); // Do not optimize any of the nodes that are part of this group. - AddAllMembersToOptimizedNodes(group); + MarkAllMembersWithTag(group, kAddOpsRewriteTag); // All new nodes will be placed under the scope of a root node. auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name()); @@ -688,7 +692,7 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage { node->add_input(inputAndShape.input); } - AddToOptimizedNodes(node); + MarkWithTag(kAddOpsRewriteTag, node); return InputAndShape(node_name, shape); } @@ -705,14 +709,13 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage { node->set_op("Add"); node->set_device(root_node.device()); (*node->mutable_attr())["T"].set_type(dtype); + node->add_input(left.input); + node->add_input(right.input); ctx().node_map->AddOutput(left.input, node_name); ctx().node_map->AddOutput(right.input, node_name); - node->add_input(left.input); - node->add_input(right.input); - - AddToOptimizedNodes(node); + MarkWithTag(kAddOpsRewriteTag, node); return InputAndShape( node_name, TensorShapeProto()); // shape is not important at this point } @@ -960,7 +963,9 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage { bool IsSupported(const NodeDef* node) const override { if (!IsBinaryAssociative(*node)) return false; - if (IsAlreadyOptimized(*node)) return false; + + if (IsMarkedWithAnyTag(*node, kMinimizeBroadcastsTag, kAddOpsRewriteTag)) + return false; // has a symbolically defined shape with broadcastable inputs OpInfo::TensorProperties properties; @@ -984,7 +989,11 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage { if (!IsSameOp(group, node)) { return false; } - if (IsInPreserveSet(node) || IsAlreadyOptimized(node)) { + if (IsInPreserveSet(node)) { + return false; + } + // Nodes optimized by AddOpsRewrite already have optimal broadcasts. + if (IsMarkedWithAnyTag(node, kMinimizeBroadcastsTag, kAddOpsRewriteTag)) { return false; } if (IsDrivenByControlDependency(node) || DrivesControlDependency(node)) { @@ -1019,7 +1028,7 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage { << " num_optimized_nodes=" << group.optimized_nodes.size(); // Do not optimize any of the nodes that are part of this group. - AddAllMembersToOptimizedNodes(group); + MarkAllMembersWithTag(group, kMinimizeBroadcastsTag); if (CountUniqueShapes(group.inputs) <= 1) { VLOG(3) << "Skip min-bcast group with single unique shape"; @@ -1905,6 +1914,8 @@ void ArithmeticOptimizer::DedupComputations() { FeedsInPlaceOp(graph_view, *node)) { continue; } + VLOG(3) << "Remove duplicated node: node=" << node->name() + << " representative=" << rep->name(); const std::set& fanouts = node_map_->GetOutputs(node->name()); for (NodeDef* fanout : fanouts) { for (int i = 0; i < fanout->input_size(); ++i) { diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 5230177dca..0c8e18d7ab 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -65,7 +65,7 @@ int NumIterations(const RewriterConfig& cfg) { // Check if optimizer is allowed to run only once. bool IsRunOnceOptimizer(const string& name) { return name == "layout" || name == "memory_optimizer" || - name == "arithmetic_optimizer" || name == "loop_optimizer"; + name == "loop_optimizer"; } } // namespace diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index e3dd4b0bdf..2d6925d1a8 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -150,10 +150,14 @@ def _loop_with_vec_and_4d(): def _get_config(layout_optimizer=True): if layout_optimizer: rewrite_options = rewriter_config_pb2.RewriterConfig( - layout_optimizer=rewriter_config_pb2.RewriterConfig.ON) + layout_optimizer=rewriter_config_pb2.RewriterConfig.ON, + # do not remove duplicated nodes + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) else: rewrite_options = rewriter_config_pb2.RewriterConfig( - layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF) + layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, + # do not remove duplicated nodes + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) graph_options = config_pb2.GraphOptions( rewrite_options=rewrite_options, build_cost_model=1) config = config_pb2.ConfigProto(graph_options=graph_options) -- GitLab From 1f4efb78320e1406c0cc9ce4b8753f3d2511048e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 16:05:43 -0700 Subject: [PATCH 055/839] Add RNNEstimator which takes in arbitrary heads. PiperOrigin-RevId: 195167853 --- tensorflow/contrib/estimator/BUILD | 4 + tensorflow/contrib/estimator/__init__.py | 1 + .../contrib/estimator/python/estimator/rnn.py | 164 ++++++++++++++++-- .../estimator/python/estimator/rnn_test.py | 119 ++++++++----- 4 files changed, 237 insertions(+), 51 deletions(-) diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index b473de86ee..41a817673d 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -452,18 +452,22 @@ py_test( "notsan", ], deps = [ + ":head", ":rnn", + "//tensorflow/contrib/data", "//tensorflow/core:protos_all_py", "//tensorflow/python:check_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:lib", "//tensorflow/python:math_ops", "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variables", "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:parsing_utils", "//tensorflow/python/feature_column", "//third_party/py/numpy", "@six_archive//:six", diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index f66d844660..d43b3ea6bf 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -55,6 +55,7 @@ _allowed_symbols = [ 'replicate_model_fn', 'TowerOptimizer', 'RNNClassifier', + 'RNNEstimator', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py index b475c12f5a..7f385fd76e 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn.py @@ -328,6 +328,19 @@ def _rnn_model_fn(features, logits=logits) +def _assert_rnn_cell_fn(rnn_cell_fn, num_units, cell_type): + """Assert arguments are valid and return rnn_cell_fn.""" + if rnn_cell_fn and (num_units or cell_type != USE_DEFAULT): + raise ValueError( + 'num_units and cell_type must not be specified when using rnn_cell_fn' + ) + if not rnn_cell_fn: + if cell_type == USE_DEFAULT: + cell_type = 'basic_rnn' + rnn_cell_fn = _make_rnn_cell_fn(num_units, cell_type) + return rnn_cell_fn + + class RNNClassifier(estimator.Estimator): """A classifier for TensorFlow RNN models. @@ -341,8 +354,8 @@ class RNNClassifier(estimator.Estimator): token_emb = embedding_column(categorical_column=token_sequence, ...) estimator = RNNClassifier( - num_units=[32, 16], cell_type='lstm', - sequence_feature_columns=[token_emb]) + sequence_feature_columns=[token_emb], + num_units=[32, 16], cell_type='lstm') # Input builders def input_fn_train: # returns x, y @@ -438,8 +451,8 @@ class RNNClassifier(estimator.Estimator): encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there will be errors if vocabulary is not provided and labels are string. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to Adagrad optimizer. + optimizer: An instance of `tf.Optimizer` or string specifying optimizer + type. Defaults to Adagrad optimizer. input_layer_partitioner: Optional. Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: `RunConfig` object to configure the runtime settings. @@ -448,14 +461,7 @@ class RNNClassifier(estimator.Estimator): ValueError: If `num_units`, `cell_type`, and `rnn_cell_fn` are not compatible. """ - if rnn_cell_fn and (num_units or cell_type != USE_DEFAULT): - raise ValueError( - 'num_units and cell_type must not be specified when using rnn_cell_fn' - ) - if not rnn_cell_fn: - if cell_type == USE_DEFAULT: - cell_type = 'basic_rnn' - rnn_cell_fn = _make_rnn_cell_fn(num_units, cell_type) + rnn_cell_fn = _assert_rnn_cell_fn(rnn_cell_fn, num_units, cell_type) if n_classes == 2: head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access @@ -479,3 +485,137 @@ class RNNClassifier(estimator.Estimator): config=config) super(RNNClassifier, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) + + +class RNNEstimator(estimator.Estimator): + """An Estimator for TensorFlow RNN models with user-specified head. + + Example: + + ```python + token_sequence = sequence_categorical_column_with_hash_bucket(...) + token_emb = embedding_column(categorical_column=token_sequence, ...) + + estimator = RNNEstimator( + head=tf.contrib.estimator.regression_head(), + sequence_feature_columns=[token_emb], + num_units=[32, 16], cell_type='lstm') + + # Or with custom RNN cell: + def rnn_cell_fn(mode): + cells = [ tf.contrib.rnn.LSTMCell(size) for size in [32, 16] ] + if mode == tf.estimator.ModeKeys.TRAIN: + cells = [ tf.contrib.rnn.DropoutWrapper(cell, input_keep_prob=0.5) + for cell in cells ] + return tf.contrib.rnn.MultiRNNCell(cells) + + estimator = RNNEstimator( + head=tf.contrib.estimator.regression_head(), + sequence_feature_columns=[token_emb], + rnn_cell_fn=rnn_cell_fn) + + # Input builders + def input_fn_train: # returns x, y + pass + estimator.train(input_fn=input_fn_train, steps=100) + + def input_fn_eval: # returns x, y + pass + metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) + def input_fn_predict: # returns x, None + pass + predictions = estimator.predict(input_fn=input_fn_predict) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if the head's `weight_column` is not `None`, a feature with + `key=weight_column` whose value is a `Tensor`. + * for each `column` in `sequence_feature_columns`: + - a feature with `key=column.name` whose `value` is a `SparseTensor`. + * for each `column` in `context_feature_columns`: + - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` + whose `value` is a `SparseTensor`. + - if `column` is a `_WeightedCategoricalColumn`, two features: the first + with `key` the id column name, the second with `key` the weight column + name. Both features' `value` must be a `SparseTensor`. + - if `column` is a `_DenseColumn`, a feature with `key=column.name` + whose `value` is a `Tensor`. + + Loss and predicted output are determined by the specified head. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility + """ + + def __init__(self, + head, + sequence_feature_columns, + context_feature_columns=None, + num_units=None, + cell_type=USE_DEFAULT, + rnn_cell_fn=None, + model_dir=None, + optimizer='Adagrad', + input_layer_partitioner=None, + config=None): + """Initializes a `RNNClassifier` instance. + + Args: + head: A `_Head` instance constructed with a method such as + `tf.contrib.estimator.multi_label_head`. This specifies the model's + output and loss function to be optimized. + sequence_feature_columns: An iterable containing the `FeatureColumn`s + that represent sequential input. All items in the set should either be + sequence columns (e.g. `sequence_numeric_column`) or constructed from + one (e.g. `embedding_column` with `sequence_categorical_column_*` as + input). + context_feature_columns: An iterable containing the `FeatureColumn`s + for contextual input. The data represented by these columns will be + replicated and given to the RNN at each timestep. These columns must be + instances of classes derived from `_DenseColumn` such as + `numeric_column`, not the sequential variants. + num_units: Iterable of integer number of hidden units per RNN layer. If + set, `cell_type` must also be specified and `rnn_cell_fn` must be + `None`. + cell_type: A subclass of `tf.nn.rnn_cell.RNNCell` or a string specifying + the cell type. Supported strings are: `'basic_rnn'`, `'lstm'`, and + `'gru'`. If set, `num_units` must also be specified and `rnn_cell_fn` + must be `None`. + rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and + returns an object of type `tf.nn.rnn_cell.RNNCell` that will be used to + construct the RNN. If set, `num_units` and `cell_type` cannot be set. + This is for advanced users who need additional customization beyond + `num_units` and `cell_type`. Note that `tf.nn.rnn_cell.MultiRNNCell` is + needed for stacked RNNs. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + optimizer: An instance of `tf.Optimizer` or string specifying optimizer + type. Defaults to Adagrad optimizer. + input_layer_partitioner: Optional. Partitioner for input layer. Defaults + to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: `RunConfig` object to configure the runtime settings. + + Raises: + ValueError: If `num_units`, `cell_type`, and `rnn_cell_fn` are not + compatible. + """ + rnn_cell_fn = _assert_rnn_cell_fn(rnn_cell_fn, num_units, cell_type) + + def _model_fn(features, labels, mode, config): + return _rnn_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + rnn_cell_fn=rnn_cell_fn, + sequence_feature_columns=tuple(sequence_feature_columns or []), + context_feature_columns=tuple(context_feature_columns or []), + optimizer=optimizer, + input_layer_partitioner=input_layer_partitioner, + config=config) + super(RNNEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py index 393f94f5c7..959b40371a 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py @@ -25,12 +25,15 @@ import tempfile import numpy as np import six +from tensorflow.contrib.data.python.ops import readers +from tensorflow.contrib.estimator.python.estimator import head as head_lib from tensorflow.contrib.estimator.python.estimator import rnn from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.canned import parsing_utils from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io @@ -38,9 +41,9 @@ from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.lib.io import python_io from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import state_ops @@ -50,7 +53,6 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import checkpoint_utils -from tensorflow.python.training import input as input_lib from tensorflow.python.training import monitored_session from tensorflow.python.training import optimizer from tensorflow.python.training import training_util @@ -984,7 +986,10 @@ class RNNClassifierPredictionTest(test.TestCase): predictions[prediction_keys.PredictionKeys.CLASSES]) -class RNNClassifierIntegrationTest(test.TestCase): +class BaseRNNClassificationIntegrationTest(object): + + def __init__(self, _create_estimator_fn): + self._create_estimator_fn = _create_estimator_fn def setUp(self): self._model_dir = tempfile.mkdtemp() @@ -994,20 +999,11 @@ class RNNClassifierIntegrationTest(test.TestCase): writer_cache.FileWriterCache.clear() shutil.rmtree(self._model_dir) - def _test_complete_flow( - self, train_input_fn, eval_input_fn, predict_input_fn, n_classes, - batch_size): - col = seq_fc.sequence_categorical_column_with_hash_bucket( - 'tokens', hash_bucket_size=10) - embed = fc.embedding_column(col, dimension=2) - feature_columns = [embed] - + def _test_complete_flow(self, feature_columns, train_input_fn, eval_input_fn, + predict_input_fn, n_classes, batch_size): cell_units = [4, 2] - est = rnn.RNNClassifier( - num_units=cell_units, - sequence_feature_columns=feature_columns, - n_classes=n_classes, - model_dir=self._model_dir) + est = self._create_estimator_fn(feature_columns, n_classes, cell_units, + self._model_dir) # TRAIN num_steps = 10 @@ -1026,10 +1022,10 @@ class RNNClassifierIntegrationTest(test.TestCase): self.assertAllEqual((batch_size, n_classes), predicted_proba.shape) # EXPORT - feature_spec = { - 'tokens': parsing_ops.VarLenFeature(dtypes.string), - 'label': parsing_ops.FixedLenFeature([1], dtypes.int64), - } + feature_spec = parsing_utils.classifier_parse_example_spec( + feature_columns, + label_key='label', + label_dtype=dtypes.int64) serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( feature_spec) export_dir = est.export_savedmodel(tempfile.mkdtemp(), @@ -1069,7 +1065,13 @@ class RNNClassifierIntegrationTest(test.TestCase): batch_size=batch_size, shuffle=False) + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + feature_columns = [embed] + self._test_complete_flow( + feature_columns=feature_columns, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, predict_input_fn=predict_input_fn, @@ -1082,7 +1084,8 @@ class RNNClassifierIntegrationTest(test.TestCase): batch_size = 10 words = [b'dog', b'cat', b'bird', b'the', b'a', b'sat', b'flew', b'slept'] - serialized_examples = [] + _, examples_file = tempfile.mkstemp() + writer = python_io.TFRecordWriter(examples_file) for _ in range(batch_size): sequence_length = random.randint(1, len(words)) sentence = random.sample(words, sequence_length) @@ -1096,30 +1099,36 @@ class RNNClassifierIntegrationTest(test.TestCase): feature_pb2.Feature(int64_list=feature_pb2.Int64List( value=[label])), })) - serialized_examples.append(example.SerializeToString()) + writer.write(example.SerializeToString()) + writer.close() + + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + feature_columns = [embed] + feature_spec = parsing_utils.classifier_parse_example_spec( + feature_columns, + label_key='label', + label_dtype=dtypes.int64) - feature_spec = { - 'tokens': parsing_ops.VarLenFeature(dtypes.string), - 'label': parsing_ops.FixedLenFeature([1], dtypes.int64), - } def _train_input_fn(): - features = parsing_ops.parse_example(serialized_examples, feature_spec) - labels = features.pop('label') - return features, labels + dataset = readers.make_batched_features_dataset( + examples_file, batch_size, feature_spec) + return dataset.map(lambda features: (features, features.pop('label'))) def _eval_input_fn(): - features = parsing_ops.parse_example( - input_lib.limit_epochs(serialized_examples, num_epochs=1), - feature_spec) - labels = features.pop('label') - return features, labels + dataset = readers.make_batched_features_dataset( + examples_file, batch_size, feature_spec, num_epochs=1) + return dataset.map(lambda features: (features, features.pop('label'))) def _predict_input_fn(): - features = parsing_ops.parse_example( - input_lib.limit_epochs(serialized_examples, num_epochs=1), - feature_spec) - features.pop('label') - return features, None + dataset = readers.make_batched_features_dataset( + examples_file, batch_size, feature_spec, num_epochs=1) + def features_fn(features): + features.pop('label') + return features + return dataset.map(features_fn) self._test_complete_flow( + feature_columns=feature_columns, train_input_fn=_train_input_fn, eval_input_fn=_eval_input_fn, predict_input_fn=_predict_input_fn, @@ -1127,5 +1136,37 @@ class RNNClassifierIntegrationTest(test.TestCase): batch_size=batch_size) +def _rnn_classifier_fn(feature_columns, n_classes, cell_units, model_dir): + return rnn.RNNClassifier( + num_units=cell_units, + sequence_feature_columns=feature_columns, + n_classes=n_classes, + model_dir=model_dir) + + +class RNNClassifierIntegrationTest(BaseRNNClassificationIntegrationTest, + test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + BaseRNNClassificationIntegrationTest.__init__(self, _rnn_classifier_fn) + + +def _rnn_estimator_fn(feature_columns, n_classes, cell_units, model_dir): + return rnn.RNNEstimator( + head=head_lib.multi_class_head(n_classes=n_classes), + num_units=cell_units, + sequence_feature_columns=feature_columns, + model_dir=model_dir) + + +class RNNEstimatorIntegrationTest(BaseRNNClassificationIntegrationTest, + test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + BaseRNNClassificationIntegrationTest.__init__(self, _rnn_estimator_fn) + + if __name__ == '__main__': test.main() -- GitLab From c7a5787fef8daf3e44313cbd48591464f9643f56 Mon Sep 17 00:00:00 2001 From: Ayush Dubey Date: Wed, 2 May 2018 16:13:06 -0700 Subject: [PATCH 056/839] Enable reshape of _ScopedAllocatorConcat output. The _ScopedAllocatorConcat kernel outputs the backing tensor after performing runtime bounds checks. However, the shape of the backing tensor may not match the desired output shape of the concat operation. This change adds a "reshape" boolean attribute to _ScopedAllocatorConcat kernel. When this attribute is set to true, the kernel outputs a reshaped backing tensor according to the "shape" attribute. PiperOrigin-RevId: 195169105 --- .../core/kernels/scoped_allocator_ops.cc | 39 +++++++---- .../core/kernels/scoped_allocator_ops_test.cc | 64 ++++++++++++++++--- tensorflow/core/ops/scoped_allocator_ops.cc | 11 ++-- 3 files changed, 89 insertions(+), 25 deletions(-) diff --git a/tensorflow/core/kernels/scoped_allocator_ops.cc b/tensorflow/core/kernels/scoped_allocator_ops.cc index d7b25ffad0..1800ee8c1f 100644 --- a/tensorflow/core/kernels/scoped_allocator_ops.cc +++ b/tensorflow/core/kernels/scoped_allocator_ops.cc @@ -94,7 +94,8 @@ class ScopedAllocatorConcatOp : public OpKernel { : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); OP_REQUIRES_OK(context, context->GetAttr("T", &dtype_)); - // This stuff is just for debugging + OP_REQUIRES_OK(context, context->GetAttr("reshape", &reshape_)); + // These attributes are just for debugging. OP_REQUIRES_OK(context, context->GetAttr("sa_name", &name_)); OP_REQUIRES_OK(context, context->GetAttr("id", &id_)); device_ = context->device(); @@ -114,11 +115,14 @@ class ScopedAllocatorConcatOp : public OpKernel { backing_tensor.NumElements(), " is not equal to expected ", shape_.num_elements())); - VLOG(1) << "_ScopedAllocatorConcatOp outputting backing tensor at " - << DMAHelper::base(&backing_tensor); - Tensor backing_copy(backing_tensor); - context->set_output(0, backing_copy); - const TensorBuffer* backing_buf = DMAHelper::buffer(&backing_copy); + Tensor output(dtype_); + if (reshape_) { + CHECK(output.CopyFrom(backing_tensor, shape_)); + } else { + CHECK(output.CopyFrom(backing_tensor, backing_tensor.shape())); + } + context->set_output(0, output); + const TensorBuffer* backing_buf = DMAHelper::buffer(&output); const void* backing_tensor_lb = backing_buf->data(); const void* backing_tensor_ub = static_cast( static_cast(backing_tensor_lb) + backing_buf->size()); @@ -126,17 +130,27 @@ class ScopedAllocatorConcatOp : public OpKernel { for (int i = 1; i < context->num_inputs(); ++i) { const TensorBuffer* input_buf = DMAHelper::buffer(&context->input(i)); const void* input_lb = input_buf->data(); - OP_REQUIRES( - context, input_lb >= backing_tensor_lb, - errors::InvalidArgument("Lower bound check fail for input ", i, - " to node ", context->op_kernel().name())); const void* input_ub = static_cast( static_cast(input_lb) + input_buf->size()); + OP_REQUIRES( + context, input_lb >= backing_tensor_lb, + errors::InvalidArgument( + "Lower bound check fail for input ", i, " from node ", + context->op_kernel().requested_input(i), " to node ", + context->op_kernel().name(), " input bounds = [", input_lb, ", ", + input_ub, "]", " backing_tensor bounds = [", backing_tensor_lb, + ", ", backing_tensor_ub, "]")); OP_REQUIRES( context, input_ub <= backing_tensor_ub, - errors::InvalidArgument("Upper bound check fail for input ", i, - " to node ", context->op_kernel().name())); + errors::InvalidArgument( + "Upper bound check fail for input ", i, " from node ", + context->op_kernel().requested_input(i), " to node ", + context->op_kernel().name(), " input bounds = [", input_lb, ", ", + input_ub, "]", " backing_tensor bounds = [", backing_tensor_lb, + ", ", backing_tensor_ub, "]")); } + VLOG(1) << "_ScopedAllocatorConcatOp outputting backing tensor at " + << backing_buf; } private: @@ -144,6 +158,7 @@ class ScopedAllocatorConcatOp : public OpKernel { DataType dtype_; string name_; int32 id_; + bool reshape_; DeviceBase* device_; }; diff --git a/tensorflow/core/kernels/scoped_allocator_ops_test.cc b/tensorflow/core/kernels/scoped_allocator_ops_test.cc index 3d36c8b7d4..019c6619ee 100644 --- a/tensorflow/core/kernels/scoped_allocator_ops_test.cc +++ b/tensorflow/core/kernels/scoped_allocator_ops_test.cc @@ -120,18 +120,43 @@ void PrepOp(DataType dtype, int32 id, class ScopedAllocatorConcatOpTest : public OpsTestBase { protected: - void MakeOp(const TensorShape& shape, DataType dtype, const string& name, - int32 id, int32 num_tensors) { + void BuildNodeDef(const TensorShape& shape, DataType dtype, + const string& name, int32 id, int32 num_tensors) { + TF_EXPECT_OK( + NodeDefBuilder("scoped_allocator_concat_op", "_ScopedAllocatorConcat") + .Attr("shape", shape) + .Attr("T", dtype) + .Attr("N", num_tensors) + .Attr("sa_name", name) + .Attr("id", id) + .Input(FakeInput(dtype)) // backing tensor + .Input(FakeInput(num_tensors, dtype)) // list of tensors + .Finalize(node_def())); + shape_ = shape; + reshape_ = false; + } + + void BuildNodeDefWithReshape(const TensorShape& shape, DataType dtype, + bool reshape, const string& name, int32 id, + int32 num_tensors) { TF_EXPECT_OK( NodeDefBuilder("scoped_allocator_concat_op", "_ScopedAllocatorConcat") .Attr("shape", shape) .Attr("T", dtype) + .Attr("reshape", reshape) .Attr("N", num_tensors) .Attr("sa_name", name) .Attr("id", id) .Input(FakeInput(dtype)) // backing tensor .Input(FakeInput(num_tensors, dtype)) // list of tensors .Finalize(node_def())); + shape_ = shape; + reshape_ = reshape; + } + + void MakeOp(const TensorShape& shape, DataType dtype, bool reshape, + const string& name, int32 id, int32 num_tensors) { + BuildNodeDefWithReshape(shape, dtype, reshape, name, id, num_tensors); TF_EXPECT_OK(InitOp()); } @@ -141,7 +166,7 @@ class ScopedAllocatorConcatOpTest : public OpsTestBase { std::vector tensors; std::vector fields; PrepOp(dtype, id, fields_shapes, &fields, &backing_tensor, allocator(), - device_->GetScopedAllocatorMgr(), "split", &tensors, &inputs_, + device_->GetScopedAllocatorMgr(), "concat", &tensors, &inputs_, input_types_); TF_ASSERT_OK(RunOpKernel()); @@ -155,34 +180,55 @@ class ScopedAllocatorConcatOpTest : public OpsTestBase { CHECK_EQ(DMAHelper::base(&input), DMAHelper::base(&output)); CHECK_EQ(input.dtype(), output.dtype()); CHECK_EQ(input.NumElements(), output.NumElements()); + if (reshape_) { + CHECK_EQ(shape_, output.shape()); + } else { + TensorShape expected_shape({input.NumElements()}); + CHECK_EQ(expected_shape, output.shape()); + } // Free the backing tensor which was allocated in PrepOp. delete backing_tensor; } + + private: + TensorShape shape_; + bool reshape_; }; TEST_F(ScopedAllocatorConcatOpTest, Success1) { - MakeOp({32}, DT_FLOAT, "test", 120, 2); + MakeOp({32}, DT_FLOAT, false, "test", 120, 2); ExecOp(DT_FLOAT, 120, {{16}, {16}}); } TEST_F(ScopedAllocatorConcatOpTest, Success2) { - MakeOp({2, 2, 2}, DT_DOUBLE, "test", 120, 2); + MakeOp({2, 2, 2}, DT_DOUBLE, false, "test", 120, 2); ExecOp(DT_DOUBLE, 120, {{2, 2}, {2, 2}}); } TEST_F(ScopedAllocatorConcatOpTest, Success3) { - MakeOp({3, 3, 3}, DT_HALF, "test", 120, 3); + MakeOp({3, 3, 3}, DT_HALF, false, "test", 120, 3); ExecOp(DT_HALF, 120, {{3, 3}, {3, 3}, {3, 3}}); } +TEST_F(ScopedAllocatorConcatOpTest, Reshape) { + MakeOp({2, 2, 2}, DT_DOUBLE, true, "test", 120, 2); + ExecOp(DT_DOUBLE, 120, {{2, 2}, {2, 2}}); +} + +TEST_F(ScopedAllocatorConcatOpTest, NoReshapeAttr) { + BuildNodeDef({3, 4, 4}, DT_HALF, "test", 120, 3); + TF_EXPECT_OK(InitOp()); + ExecOp(DT_HALF, 120, {{4, 4}, {4, 4}, {4, 4}}); +} + TEST_F(ScopedAllocatorConcatOpTest, FailDtypeCheck) { - MakeOp({8}, DT_FLOAT, "test", 120, 2); + MakeOp({8}, DT_FLOAT, false, "test", 120, 2); EXPECT_DEATH(ExecOp(DT_DOUBLE, 120, {{4}, {4}}), ""); } TEST_F(ScopedAllocatorConcatOpTest, FailNumElementsCheck) { - MakeOp({32}, DT_FLOAT, "test", 120, 2); + MakeOp({32}, DT_FLOAT, false, "test", 120, 2); AddInputFromArray({8}, {0, 1, 2, 3, 4, 5, 6, 7}); AddInputFromArray({4}, {0, 1, 2, 3}); AddInputFromArray({4}, {4, 5, 6, 7}); @@ -193,7 +239,7 @@ TEST_F(ScopedAllocatorConcatOpTest, FailNumElementsCheck) { // This test should fail because the backing tensor and the input tensors are // unrelated, i.e. the inputs are not slices of the backing tensor. TEST_F(ScopedAllocatorConcatOpTest, FailBounds) { - MakeOp({8}, DT_DOUBLE, "test", 120, 2); + MakeOp({8}, DT_DOUBLE, false, "test", 120, 2); AddInputFromArray({8}, {0, 1, 2, 3, 4, 5, 6, 7}); AddInputFromArray({4}, {0, 1, 2, 3}); AddInputFromArray({4}, {4, 5, 6, 7}); diff --git a/tensorflow/core/ops/scoped_allocator_ops.cc b/tensorflow/core/ops/scoped_allocator_ops.cc index f053a53f4c..1e0dcdac96 100644 --- a/tensorflow/core/ops/scoped_allocator_ops.cc +++ b/tensorflow/core/ops/scoped_allocator_ops.cc @@ -43,6 +43,7 @@ REGISTER_OP("_ScopedAllocatorConcat") .Input("inputs: N * T") .Attr("shape: shape") .Attr("T: type") + .Attr("reshape: bool = false") .Attr("sa_name: string") .Attr("id: int") .Attr("N: int >= 2") @@ -69,10 +70,12 @@ REGISTER_OP("_ScopedAllocatorSplit") .SetIsStateful() .SetShapeFn(shape_inference::ExplicitShape) .Doc(R"doc( -Acts like a Concat Op that merges multple tensors into one, however it must -only be used in conjunction with a ScopedAllocator which is backing the memory -of all of its input tensors so that actually it just outputs a read-only -reference to that ScopedAllocator's backing tensor. +Acts roughly like a SplitV Op that splits one tensor into multiple tensors +but must only be used in conjunction with corresponding ScopedAllocator +and ScopedAllocatorConcat instances. In practice it is provided as inputs +the backing tensor as first input, which contains the concatenated values, +and a list of alias tensors as its other input and it simply outputs that +second list. This is an experimental op for internal use only. It is possible to use this op in unsafe ways. -- GitLab From 8a022d3d0e1bc521ccfee74174e75821bdc1bfa9 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 2 May 2018 16:58:35 -0700 Subject: [PATCH 057/839] Allow `Layer.add_loss` to receive non-tensor; fixes error triggered when using a weight regularizer of factor 0. PiperOrigin-RevId: 195175909 --- tensorflow/python/keras/_impl/keras/engine/base_layer.py | 3 +++ tensorflow/python/keras/_impl/keras/regularizers_test.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py index a3e78c95dc..3af4eaabe9 100644 --- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py +++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py @@ -29,6 +29,7 @@ from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util from tensorflow.python.keras._impl.keras import backend from tensorflow.python.keras._impl.keras import constraints from tensorflow.python.keras._impl.keras import initializers @@ -390,6 +391,8 @@ class Layer(checkpointable.CheckpointableBase): raise RuntimeError('Layer.add_loss not supported in Eager mode.') losses = generic_utils.to_list(losses) + losses = [ops.convert_to_tensor(loss, dtype=backend.floatx()) + if not tensor_util.is_tensor(loss) else loss for loss in losses] self._losses += losses if inputs is None: for loss in losses: diff --git a/tensorflow/python/keras/_impl/keras/regularizers_test.py b/tensorflow/python/keras/_impl/keras/regularizers_test.py index 9a1612b777..c4f04833ba 100644 --- a/tensorflow/python/keras/_impl/keras/regularizers_test.py +++ b/tensorflow/python/keras/_impl/keras/regularizers_test.py @@ -71,6 +71,11 @@ class KerasRegularizersTest(test.TestCase): model.fit(x_train, y_train, batch_size=10, epochs=1, verbose=0) + def test_zero_regularization(self): + inputs = keras.backend.ones(shape=(10, 10)) + layer = keras.layers.Dense(3, kernel_regularizer=keras.regularizers.l2(0)) + layer(inputs) + if __name__ == '__main__': test.main() -- GitLab From dde83d4bee2c524cb5bd0adc4f702c9fc5ac6f3f Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 2 May 2018 17:00:16 -0700 Subject: [PATCH 058/839] Handle negative values when slicing symbolic shapes PiperOrigin-RevId: 195176133 --- tensorflow/core/grappler/costs/graph_properties.cc | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 23d25cba8d..eaf7634daa 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -804,11 +804,16 @@ class SymbolicShapeRefiner { int64 start = slice_offset->dtype() == DT_INT32 ? slice_offset->flat()(0) : slice_offset->flat()(0); - int64 end = start + (slice_size->dtype() == DT_INT32 - ? slice_size->flat()(0) - : slice_size->flat()(0)); + int64 size = + (slice_size->dtype() == DT_INT32 ? slice_size->flat()(0) + : slice_size->flat()(0)); ShapeHandle result; - TF_RETURN_IF_ERROR(ic->Subshape(input, start, end, &result)); + if (size == -1) { + TF_RETURN_IF_ERROR(ic->Subshape(input, start, &result)); + } else { + int64 end = start + size; + TF_RETURN_IF_ERROR(ic->Subshape(input, start, end, &result)); + } c->output_tensors_as_shapes.resize(1); c->output_tensors_as_shapes[0] = result; } -- GitLab From a1ef905926d12b0362c0dcf6d669e1c3d2ffcf70 Mon Sep 17 00:00:00 2001 From: Jeremy Lau Date: Wed, 2 May 2018 17:25:10 -0700 Subject: [PATCH 059/839] BufferValue is a new base class for LogicalBuffer and HloValue. This makes it easier to migrate from TuplePointsToAnalysis/LogicalBuffer to HloDataflowAnalysis/HloValue. No functional changes. PiperOrigin-RevId: 195179676 --- tensorflow/compiler/xla/service/BUILD | 22 +++ .../xla/service/buffer_assignment_test.cc | 3 +- .../compiler/xla/service/buffer_value.cc | 66 +++++++ .../compiler/xla/service/buffer_value.h | 177 ++++++++++++++++++ tensorflow/compiler/xla/service/compiler.h | 5 +- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../compiler/xla/service/gpu/hlo_schedule.cc | 3 +- .../xla/service/heap_simulator_test.cc | 5 +- .../xla/service/hlo_rematerialization.cc | 3 +- .../xla/service/hlo_scheduling_test.cc | 6 +- tensorflow/compiler/xla/service/hlo_value.cc | 8 +- tensorflow/compiler/xla/service/hlo_value.h | 51 ++--- .../compiler/xla/service/logical_buffer.cc | 42 +---- .../compiler/xla/service/logical_buffer.h | 123 +----------- 14 files changed, 318 insertions(+), 197 deletions(-) create mode 100644 tensorflow/compiler/xla/service/buffer_value.cc create mode 100644 tensorflow/compiler/xla/service/buffer_value.h diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 17964cdd59..0b8b22b44c 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -780,6 +780,7 @@ cc_library( srcs = ["compiler.cc"], hdrs = ["compiler.h"], deps = [ + ":buffer_value", ":executable", ":hlo", ":hlo_module_config", @@ -1014,6 +1015,7 @@ tf_cc_test( srcs = ["buffer_assignment_test.cc"], deps = [ ":buffer_assignment", + ":buffer_value", ":call_graph", ":computation_tracker", ":copy_insertion", @@ -1095,6 +1097,7 @@ tf_cc_test( name = "heap_simulator_test", srcs = ["heap_simulator_test.cc"], deps = [ + ":buffer_value", ":heap_simulator", ":hlo", ":hlo_ordering", @@ -1163,6 +1166,7 @@ tf_cc_test( name = "hlo_scheduling_test", srcs = ["hlo_scheduling_test.cc"], deps = [ + ":buffer_value", ":hlo", ":hlo_ordering", ":hlo_scheduling", @@ -1749,11 +1753,27 @@ tf_cc_test( ], ) +cc_library( + name = "buffer_value", + srcs = ["buffer_value.cc"], + hdrs = ["buffer_value.h"], + deps = [ + ":hlo", + ":hlo_proto", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + cc_library( name = "logical_buffer", srcs = ["logical_buffer.cc"], hdrs = ["logical_buffer.h"], deps = [ + ":buffer_value", ":hlo", ":hlo_proto", "//tensorflow/compiler/xla:shape_util", @@ -1769,6 +1789,7 @@ cc_library( srcs = ["hlo_value.cc"], hdrs = ["hlo_value.h"], deps = [ + ":buffer_value", ":hlo", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", @@ -2066,6 +2087,7 @@ cc_library( hdrs = ["hlo_rematerialization.h"], deps = [ ":buffer_liveness", + ":buffer_value", ":call_graph", ":flatten_call_graph", ":hlo", diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index f6d6b5c36a..a4fb0eefac 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" @@ -1684,7 +1685,7 @@ class WhileBufferAssignmentTest : public HloTestBase { .ConsumeValueOrDie(); } - static int64 ByteSizeOf(const LogicalBuffer& buffer) { + static int64 ByteSizeOf(const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*)); } diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc new file mode 100644 index 0000000000..df1a5ca435 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_value.cc @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/buffer_value.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +BufferValue::BufferValue(HloInstruction* instruction, const ShapeIndex& index, + Id id) + : id_(id) { + const Shape& shape = ShapeUtil::GetSubshape(instruction->shape(), index); + is_array_ = ShapeUtil::IsArray(shape); + is_tuple_ = ShapeUtil::IsTuple(shape); +} + +BufferValue::~BufferValue() {} + +std::ostream& operator<<(std::ostream& out, const BufferValue& buffer) { + out << buffer.ToString(); + return out; +} + +/*static*/ LogicalBufferProto::Location BufferValue::ToLocationProto( + const HloInstruction& instruction, const ShapeIndex& index) { + LogicalBufferProto::Location proto; + proto.set_computation_name(instruction.parent()->name()); + proto.set_instruction_name(instruction.name()); + for (const int64 index_entry : index) { + proto.add_shape_index(index_entry); + } + return proto; +} + +LogicalBufferProto BufferValue::ToProto(const SizeFunction& size_fn) const { + LogicalBufferProto proto; + proto.set_id(id()); + proto.set_size(size_fn(*this)); + LogicalBufferProto::Location proto_location = + ToLocationProto(*instruction(), index()); + proto.mutable_defined_at()->Swap(&proto_location); + proto.set_color(color().value()); + return proto; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_value.h b/tensorflow/compiler/xla/service/buffer_value.h new file mode 100644 index 0000000000..f4be16e084 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_value.h @@ -0,0 +1,177 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/int_type.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Abstract class describing a value used by one of the dataflow analyses - +// TuplePointsToAnalysis or HloDataflowAnalysis. +// TODO(b/78906445) Delete this class when TuplePointsToAnalysis is unused. +// +// XLA arrays are trivially a single BufferValue. Tuples are made up of more +// than one BufferValue: an BufferValue for the pointer vector, and an +// BufferValue for each child element. +// +// Every BufferValue is defined by a particular instruction and most +// instructions define only a single BufferValue. Instructions which define a +// single BufferValue include array-shaped instructions such as Add but also +// includes Tuple-shaped instructions such as Tuple. The Tuple instruction +// defines a single BufferValue which is a vector of pointers to the values +// containing the Tuple instruction's operands. Though the result of the Tuple +// instruction includes multiple values only the top-level BufferValue (the +// vector of pointers) is defined by the Tuple instruction. The values +// containing the tuple elements are defined by earlier instructions, usually +// the operands of the Tuple instruction. +// +// Instructions which construct both the tuple *and* the tuple elements define +// more than one BufferValue. This includes (at least) tuple-shaped Constant, +// Parameter, Infeed and While instructions. These tuple-shaped instructions do +// not assemble a tuple from existing BufferValues like the Tuple instruction +// does, but rather define all the BufferValues in the tuple. +// +// Some instructions, such as Bitcast, define no buffers. These instructions +// simply forward buffers from their operands. +// +// The BufferValue object describes which HLO instruction defines a buffer and +// where within that instruction's output shape the buffer is defined. The +// location within the output shape is indicated by BufferValue::index() which +// is defined identically to the index used in ShapeUtil::GetSubshape(). +// Examples: +// +// %add = Add(%foo, %bar) +// %tuple_constant = Constant({1, {42, 43}}) +// +// %add defines a single array-shaped buffer BufferValue(%add, {}) which holds +// the array result of the add operation. The nested-tuple-shaped +// %tuple_constant defines 5 buffers described by the following BufferValue +// objects: +// +// BufferValue(%tuple_constant, {}) // "Top-level" buffer: vector of +// // pointers to BufferValues at +// // indices {0} and {1} +// BufferValue(%tuple_constant, {0}) // Holds value "1" +// BufferValue(%tuple_constant, {1}) // Holds nested tuple: vector of +// // pointers to BufferValues at +// // indices {1, 0} and {1, 1} +// BufferValue(%tuple_constant, {1, 0}) // Holds value "42" +// BufferValue(%tuple_constant, {1, 1}) // Holds value "43" + +class BufferValue { + public: + TF_LIB_GTL_DEFINE_INT_TYPE(Color, int64); + + // Id is a unique identifier for the BufferValue to facilitate efficient + // collections of BufferValues with stable iteration order. + using Id = int64; + + // Functions which return the size and alignment of a logical buffer in bytes. + using SizeFunction = std::function; + using AlignmentFunction = std::function; + + virtual ~BufferValue(); + + Id id() const { return id_; } + + // Return the instruction that defines the buffer. + virtual HloInstruction* instruction() const = 0; + + // Return the index within the output of the instruction where the buffer is + // defined. Index used defined as in ShapeUtil::GetSubshape() + virtual const ShapeIndex& index() const = 0; + + // Return the color of the BufferValue. Differently colored buffers can not be + // parts of the same allocation. + Color color() const { + CHECK_NE(color_, kInvalidColor) + << "Should not query the color of a buffer that was never colored"; + return color_; + } + + void set_color(Color color) { + CHECK_NE(color, kInvalidColor) + << "Should not set the color of a buffer to the invalid color"; + color_ = color; + } + + bool has_color() const { return color_ != kInvalidColor; } + + // Return the shape of the buffer. This reference points into the shape field + // of the instruction defining the buffer. Therefore, the returned shape will + // contain the layout of instruction, if any. + virtual const Shape& shape() const = 0; + + // Returns true if this buffer is the top-level output buffer of the defining + // HLO instruction. This is equivalent to index == {}. + bool IsTopLevel() const { return index().empty(); } + + // Whether this buffer contains a tuple. + bool IsTuple() const { return is_tuple_; } + + // Whether this buffer contains an array. + bool IsArray() const { return is_array_; } + + // operator< is required for std::set. + bool operator<(const BufferValue& other) const { return id_ < other.id_; } + + virtual string ToString() const = 0; + + // TODO(lauj) rename LogicalBufferProto to BufferValueProto. + LogicalBufferProto ToProto(const SizeFunction& size_fn) const; + + // Returns the LogicalBufferProto::Location that serializes the given + // instruction and index. + static LogicalBufferProto::Location ToLocationProto( + const HloInstruction& instruction, const ShapeIndex& index); + + const Color kInvalidColor = Color(-1); + + protected: + BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id); + + private: + // The definining instruction and index are not stored here; they can be found + // in the LogicalBuffer and HloValue subclasses. This class exists only to + // support migrations from TuplePointsToAnalysis to HloDataflowAnalysis, by + // allowing abstract use of LogicalBuffer or HloValue. After those migrations + // are complete, this class should be deleted (b/78906445). Because we plan to + // delete LogicalBuffer and this class, we don't refactor all the shared + // features from LogicalBuffer and HloValue into this class. + Id id_ : 62; + bool is_array_ : 1; + bool is_tuple_ : 1; + Color color_ = kInvalidColor; +}; + +std::ostream& operator<<(std::ostream& out, const BufferValue& buffer); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 5c14591d93..a4b59d1ba9 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" @@ -181,9 +182,9 @@ class Compiler { // Returns a function that computes the size in bytes of a given // logical buffer. - std::function BufferSizeBytesFunction() { + std::function BufferSizeBytesFunction() { HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction(); - return [shape_size](const LogicalBuffer& buffer) { + return [shape_size](const BufferValue& buffer) { return shape_size(buffer.shape()); }; } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index f1707442fe..7cb7f55073 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -620,6 +620,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:buffer_value", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index 42c1539e86..f766f96882 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/types.h" @@ -199,7 +200,7 @@ StatusOr> HloSchedule::Build( TF_ASSIGN_OR_RETURN( schedule->thunk_launch_order_, CreateMemoryMinimizingSequence( - *entry_computation, [pointer_size](const LogicalBuffer& buffer) { + *entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); } else { diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index e983fd11d4..fd56a603bb 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -85,7 +86,7 @@ class HeapSimulatorTracker { // size of the buffers doesn't matter, so we always return 0. We rely on // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by // buffer id, for determinism in the tests. - auto zero_size = [](const LogicalBuffer& buffer) { return 0; }; + auto zero_size = [](const BufferValue& buffer) { return 0; }; auto algorithm = MakeUnique( MakeUnique(&actual_calls_)); result_ = HeapSimulator::Run( @@ -119,7 +120,7 @@ class HeapSimulatorTracker { // the sequence. This lets us ensure the Alloc calls are in the sequence // order. The Free calls are sorted by LogicalBuffer.id, which is at least // deterministic. - auto size_fn = [&reverse_position](const LogicalBuffer& buffer) { + auto size_fn = [&reverse_position](const BufferValue& buffer) { return reverse_position[buffer.instruction()]; }; auto algorithm = MakeUnique( diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index b063244893..b171d41a31 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -1216,7 +1217,7 @@ StatusOr HloRematerialization::Run( // Create initial sequence of HLO instructions. TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence( *module, - [this](const LogicalBuffer& buffer) { + [this](const BufferValue& buffer) { return size_function_(buffer.shape()); }, scheduler_algorithm_)); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 74544c4a67..92df7c1427 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -77,7 +77,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - auto size_fn = [](const LogicalBuffer& buffer) { + auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; @@ -124,7 +124,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) { + CreateMemoryMinimizingSequence(*module, [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); // Verify that all instructions are in the sequence. @@ -160,7 +160,7 @@ ENTRY root { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, tools::Parse(module_str)); - auto size_fn = [](const LogicalBuffer& buffer) { + auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; TF_ASSERT_OK_AND_ASSIGN( diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 05b7dce3d1..7b27dbfec3 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -29,9 +29,11 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" namespace xla { @@ -69,7 +71,7 @@ std::ostream& operator<<(std::ostream& out, const HloUse& use) { HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi) - : id_(id), is_phi_(is_phi) { + : BufferValue(instruction, index, id), is_phi_(is_phi) { // The defining position is always the first element in the positions_ vector. positions_.push_back(HloPosition{instruction, index}); } @@ -90,8 +92,8 @@ string HloValue::ToShortString() const { string index_str = ShapeUtil::IsTuple(defining_instruction()->shape()) ? defining_index().ToString() : ""; - return StrCat(id_, " ", is_phi_ ? "PHI " : "", defining_instruction()->name(), - index_str); + return StrCat(id(), " ", is_phi_ ? "PHI " : "", + defining_instruction()->name(), index_str); } string HloValue::ToString(int indent) const { diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index 2a711e8b42..a1151f65e0 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -16,16 +16,20 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ -#include +#include #include #include +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace xla { @@ -80,30 +84,9 @@ struct HloUse { std::ostream& operator<<(std::ostream& out, const HloUse& use); -// Class describing a value used by the dataflow analysis. XLA arrays are -// trivially a single HloValue. Tuples are made up of more than one HloValue: an -// HloValue for the pointer vector, and an HloValue for each child element. -// -// Every HloValue is defined by a particular instruction and most instructions -// define only a single HloValue. Instructions which define a single HloValue -// include array-shaped instructions such as Add but also includes Tuple-shaped -// instructions such as Tuple. The Tuple instruction defines a single HloValue -// which is a vector of pointers to the values containing the Tuple -// instruction's operands. Though the result of the Tuple instruction includes -// multiple values only the top-level HloValue (the vector of pointers) is -// defined by the Tuple instruction. The values containing the tuple elements -// are defined by earlier instructions, usually the operands of the Tuple -// instruction. -// -// Instructions which construct both the tuple *and* the tuple elements define -// more than one HloValue. This includes (at least) tuple-shaped Constant, -// Parameter, Infeed and While instructions. These tuple-shaped instructions do -// not assemble a tuple from existing HloValues like the Tuple instruction does, -// but rather define all the HloValues in the tuple. -class HloValue { +// HloDataflowAnalysis uses this subclass of BufferValue. +class HloValue : public BufferValue { public: - using Id = int64; - // Predicate comparing HloValues by increasing id, useful for std::sort. static bool IdLessThan(const HloValue* a, const HloValue* b) { return a->id() < b->id(); @@ -120,6 +103,7 @@ class HloValue { // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true). HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi = false); + ~HloValue() override {} // Sets the positions in the module at which the HloValue appears. Updates // uses. Should be called once and only once. The defining position should not @@ -127,10 +111,6 @@ class HloValue { void SetPositionsAndComputeUses( tensorflow::gtl::ArraySlice positions); - // Return a unique identifier for this HloValue. This value is used for stable - // sorting and iteration - Id id() const { return id_; } - // Returns whether this value is a phi value. bool is_phi() const { return is_phi_; } @@ -142,12 +122,18 @@ class HloValue { return defining_position().instruction; } + HloInstruction* instruction() const override { + return defining_instruction(); + } + // Return the shape index at which this HloValue is defined in the output of // its defining instruction. const ShapeIndex& defining_index() const { return defining_position().index; } + const ShapeIndex& index() const override { return defining_index(); } + // Return the shape of this HloValue. - const Shape& shape() const { return defining_position().shape(); } + const Shape& shape() const override { return defining_position().shape(); } // Return all positions of the HloValue in the module. const std::vector& positions() const { return positions_; } @@ -164,12 +150,11 @@ class HloValue { // Return a single-line string representation of the value. string ToShortString() const; - string ToString(int indent = 0) const; + string ToString(int indent) const; - private: - // Unique identifier for this HloValue. Used for stable sorting and iteration. - const Id id_; + string ToString() const override { return ToString(0); } + private: // Whether this instruction is a phi value. const bool is_phi_; diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc index 68553bed12..1b3de8ad17 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.cc +++ b/tensorflow/compiler/xla/service/logical_buffer.cc @@ -15,9 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/logical_buffer.h" -#include -#include - #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" @@ -28,43 +25,16 @@ namespace xla { LogicalBuffer::LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id) - : instruction_(instruction), id_(id), color_(kInvalidColor), index_(index) { - const auto& s = shape(); - is_array_ = ShapeUtil::IsArray(s); - is_tuple_ = ShapeUtil::IsTuple(s); -} + : BufferValue(instruction, index, id), + instruction_(instruction), + index_(index) {} + +LogicalBuffer::~LogicalBuffer() {} string LogicalBuffer::ToString() const { return tensorflow::strings::StrCat(instruction_->name(), "[", tensorflow::str_util::Join(index_, ","), - "](#", id_, " @", color_.value(), ")"); -} - -std::ostream& operator<<(std::ostream& out, const LogicalBuffer& buffer) { - out << buffer.ToString(); - return out; -} - -/*static*/ LogicalBufferProto::Location LogicalBuffer::ToLocationProto( - const HloInstruction& instruction, const ShapeIndex& index) { - LogicalBufferProto::Location proto; - proto.set_computation_name(instruction.parent()->name()); - proto.set_instruction_name(instruction.name()); - for (const int64 index_entry : index) { - proto.add_shape_index(index_entry); - } - return proto; -} - -LogicalBufferProto LogicalBuffer::ToProto(const SizeFunction& size_fn) const { - LogicalBufferProto proto; - proto.set_id(id_); - proto.set_size(size_fn(*this)); - LogicalBufferProto::Location proto_location = - ToLocationProto(*instruction_, index_); - proto.mutable_defined_at()->Swap(&proto_location); - proto.set_color(color_.value()); - return proto; + "](#", id(), " @", color().value(), ")"); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/logical_buffer.h b/tensorflow/compiler/xla/service/logical_buffer.h index 67b205e289..f9ba5a5547 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.h +++ b/tensorflow/compiler/xla/service/logical_buffer.h @@ -16,11 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ -#include -#include #include -#include +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -33,133 +31,30 @@ limitations under the License. namespace xla { -// Class describing a contiguous sequence of elements (ie, C array) which form -// the components of Shaped values in XLA. XLA arrays are trivially a -// single LogicalBuffer. Tuple values are made up of more than one -// LogicalBuffer: a LogicalBuffer for the pointers to elements, and a -// LogicalBuffer for each child element. -// -// Every buffer is defined by a particular instruction and most instructions -// define only a single buffer. Instructions which define a single buffer -// include array-shaped instructions such as Add but also includes Tuple-shaped -// instructions such as Tuple. The Tuple instruction defines a single buffer -// which is a vector of pointers to the buffers containing the Tuple -// instruction's operands. Though the result of the Tuple instruction includes -// multiple buffers only the top-level buffer (the vector of pointers) is -// defined by the Tuple instruction. The buffers containing the tuple elements -// are defined by earlier instructions, usually the operands of the Tuple -// instruction. -// -// Instructions which construct both the tuple *and* the tuple elements define -// more than one buffer. This includes (at least) tuple-shaped Constant, -// Parameter, Infeed and While instructions. The tuple-shaped instructions do -// not assemble a tuple from existing buffers like the Tuple instruction does, -// but rather define the entire tuple. -// -// Some instructions, such as Bitcast, define no buffers. These instructions -// simply forward buffers from their operands. -// -// The LogicalBuffer object describes which HLO instruction defines a buffer and -// where within that instruction's output shape the buffer is defined. The -// location within the output shape is indicated by LogicalBuffer::index() which -// is defined identically to the index used in -// ShapeUtil::GetSubshape(). Examples: -// -// %add = Add(%foo, %bar) -// %tuple_constant = Constant({1, {42, 43}}) -// -// %add defines a single array-shaped buffer LogicalBuffer(%add, {}) which holds -// the array result of the add operation. The nested-tuple-shaped -// %tuple_constant defines 5 buffers described by the following LogicalBuffer -// objects: -// -// LogicalBuffer(%tuple_constant, {}) // "Top-level" buffer: vector of -// // pointers to LogicalBuffers at -// // indices {0} and {1} -// LogicalBuffer(%tuple_constant, {0}) // Holds value "1" -// LogicalBuffer(%tuple_constant, {1}) // Holds nested tuple: vector of -// // pointers to LogicalBuffers at -// // indices {1, 0} and {1, 1} -// LogicalBuffer(%tuple_constant, {1, 0}) // Holds value "42" -// LogicalBuffer(%tuple_constant, {1, 1}) // Holds value "43" -class LogicalBuffer { +// TuplePointsToAnalysis uses this subclass of BufferValue. +class LogicalBuffer : public BufferValue { public: - TF_LIB_GTL_DEFINE_INT_TYPE(Color, int64); - - // Id is a unique identifier for the LogicalBuffer to facilitate efficient - // collections of LogicalBuffers with stable iteration order. - // LogicalBuffers are typically created and accessed through - // TuplePointsToAnalysis, and points-to analysis assigns each LogicalBuffer a - // unique value. - using Id = int64; - - // Functions which return the size and alignment of a logical buffer in bytes. - using SizeFunction = std::function; - using AlignmentFunction = std::function; - LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id); - - Id id() const { return id_; } + ~LogicalBuffer() override; // Return the instruction that defines the buffer. - HloInstruction* instruction() const { return instruction_; } + HloInstruction* instruction() const override { return instruction_; } // Return the index within the output of the instruction where the buffer is // defined. Index used defined as in ShapeUtil::GetSubshape() - const ShapeIndex& index() const { return index_; } - - // Return the color of the logical buffer. Differently colored buffers can - // not be parts of the same allocation. - Color color() const { - CHECK_NE(color_, kInvalidColor) - << "Should not query the color of a buffer that was never colored"; - return color_; - } - - void set_color(Color color) { - CHECK_NE(color, kInvalidColor) - << "Should not set the color of a buffer to the invalid color"; - color_ = color; - } - - bool has_color() const { return color_ != kInvalidColor; } + const ShapeIndex& index() const override { return index_; } // Return the shape of the buffer. This reference points into the shape field // of the instruction defining the buffer. Therefore, the returned shape will // contain the layout of instruction, if any. - const Shape& shape() const { + const Shape& shape() const override { return ShapeUtil::GetSubshape(instruction_->shape(), index_); } - // Returns true if this buffer is the top-level output buffer of the defining - // HLO instruction. This is equivalent to index == {}. - bool IsTopLevel() const { return index_.empty(); } - - // Whether this buffer contains a tuple. - bool IsTuple() const { return is_tuple_; } - - // Whether this buffer contains an array. - bool IsArray() const { return is_array_; } - - // operator< is required for std::set. - bool operator<(const LogicalBuffer& other) const { return id_ < other.id_; } - - string ToString() const; - LogicalBufferProto ToProto(const SizeFunction& size_fn) const; - - // Returns the LogicalBufferProto::Location that serializes the given - // instruction and index. - static LogicalBufferProto::Location ToLocationProto( - const HloInstruction& instruction, const ShapeIndex& index); - - const Color kInvalidColor = Color(-1); + string ToString() const override; private: HloInstruction* instruction_; - Id id_ : 62; - bool is_array_ : 1; - bool is_tuple_ : 1; - Color color_; ShapeIndex index_; // Similar to HLO constructs (HloInstruction, etc), pointers are used for @@ -167,8 +62,6 @@ class LogicalBuffer { TF_DISALLOW_COPY_AND_ASSIGN(LogicalBuffer); }; -std::ostream& operator<<(std::ostream& out, const LogicalBuffer& buffer); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ -- GitLab From 7833890a0da5226e4c409b1020155f1718c0edb2 Mon Sep 17 00:00:00 2001 From: Anna R Date: Wed, 2 May 2018 17:41:26 -0700 Subject: [PATCH 060/839] Add a collect_trace option to run_op_benchmark for cases when callers just want to pass RunOptions.FULL_TRACE but don't want to store trace in extras. PiperOrigin-RevId: 195181533 --- tensorflow/python/kernel_tests/benchmark_test.py | 12 ++++++++---- tensorflow/python/platform/benchmark.py | 14 ++++++++++---- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/kernel_tests/benchmark_test.py b/tensorflow/python/kernel_tests/benchmark_test.py index 623343602d..78b6e38d94 100644 --- a/tensorflow/python/kernel_tests/benchmark_test.py +++ b/tensorflow/python/kernel_tests/benchmark_test.py @@ -67,7 +67,7 @@ class TestReportingBenchmark(test.Benchmark): with session.Session() as sess: a = constant_op.constant(0.0) a_plus_a = a + a - self.run_op_benchmark( + return self.run_op_benchmark( sess, a_plus_a, min_iters=1000, store_trace=True, name="op_benchmark") @@ -148,7 +148,7 @@ class BenchmarkTest(test.TestCase): reporting = TestReportingBenchmark() reporting.benchmarkReport1() # This should write reporting.benchmarkReport2() # This should write - reporting.benchmark_times_an_op() # This should write + benchmark_values3 = reporting.benchmark_times_an_op() # This should write # Check the files were written self.assertTrue(gfile.Exists(expected_output_file)) @@ -186,8 +186,12 @@ class BenchmarkTest(test.TestCase): self.assertEquals(expected_3.name, read_benchmark_3.name) self.assertEquals(expected_3.iters, read_benchmark_3.iters) self.assertGreater(read_benchmark_3.wall_time, 0) - full_trace = read_benchmark_3.extras["full_trace_chrome_format"] - json_trace = json.loads(full_trace.string_value) + + # Trace is not stored in benchmark entry. Instead we get it from + # return value of `run_op_benchmark` call. + full_trace = benchmark_values3["extras"]["full_trace_chrome_format"] + json_trace = json.loads(full_trace) + self.assertTrue(isinstance(json_trace, dict)) self.assertTrue("traceEvents" in json_trace.keys()) allocator_keys = [k for k in read_benchmark_3.extras.keys() diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py index 12dae94a64..eba2baaf6f 100644 --- a/tensorflow/python/platform/benchmark.py +++ b/tensorflow/python/platform/benchmark.py @@ -213,9 +213,10 @@ class TensorFlowBenchmark(Benchmark): burn_iters: Number of burn-in iterations to run. min_iters: Minimum number of iterations to use for timing. store_trace: Boolean, whether to run an extra untimed iteration and - store the trace of iteration in the benchmark report. + store the trace of iteration in returned extras. The trace will be stored as a string in Google Chrome trace format - in the extras field "full_trace_chrome_format". + in the extras field "full_trace_chrome_format". Note that trace + will not be stored in test_log_pb2.TestResults proto. store_memory_usage: Boolean, whether to run an extra untimed iteration, calculate memory usage, and store that in extras fields. name: (optional) Override the BenchmarkEntry name with `name`. @@ -227,7 +228,9 @@ class TensorFlowBenchmark(Benchmark): Returns: A `dict` containing the key-value pairs that were passed to - `report_benchmark`. + `report_benchmark`. If `store_trace` option is used, then + `full_chrome_trace_format` will be included in return dictionary even + though it is not passed to `report_benchmark` with `extras`. """ for _ in range(burn_iters): sess.run(op_or_tensor, feed_dict=feed_dict) @@ -242,6 +245,7 @@ class TensorFlowBenchmark(Benchmark): deltas[i] = delta extras = extras if extras is not None else {} + unreported_extras = {} if store_trace or store_memory_usage: run_options = config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE) @@ -251,7 +255,8 @@ class TensorFlowBenchmark(Benchmark): tl = timeline.Timeline(run_metadata.step_stats) if store_trace: - extras["full_trace_chrome_format"] = tl.generate_chrome_trace_format() + unreported_extras["full_trace_chrome_format"] = ( + tl.generate_chrome_trace_format()) if store_memory_usage: step_stats_analysis = tl.analyze_step_stats(show_memory=True) @@ -277,6 +282,7 @@ class TensorFlowBenchmark(Benchmark): "throughput": mbs / median_delta } self.report_benchmark(**benchmark_values) + benchmark_values["extras"].update(unreported_extras) return benchmark_values -- GitLab From df5ae5ac2a58131737a11e417ac34a663efb3574 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Wed, 2 May 2018 17:52:38 -0700 Subject: [PATCH 061/839] Add some todo's --- tensorflow/contrib/tensorboard/db/summary_db_writer.cc | 1 + tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index 046a2d3884..630c0607ae 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -1183,6 +1183,7 @@ class SummaryDbWriter : public SummaryWriterInterface { Tensor t{DT_DOUBLE, {k, 3}}; auto data = t.flat(); for (int i = 0, j = 0; i < k; ++i) { + // TODO(nickfelt): reconcile with TensorBoard's data_compat.py // From summary.proto // Parallel arrays encoding the bucket boundaries and the bucket values. // bucket(i) is the count for the bucket i. The range for diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc index cb51325d15..2044692b6e 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -131,6 +131,7 @@ TEST_F(SummaryDbWriterTest, WriteHistogram_VerifyTensorValues) { writer_->Unref(); writer_ = nullptr; + // TODO(nickfelt): implement QueryTensor() to encapsulate this // Verify the data string result = QueryString("SELECT data FROM Tensors"); const double* val = reinterpret_cast(result.data()); -- GitLab From 8f0a90b711480c12716d1a3b1094cc8b34939f2d Mon Sep 17 00:00:00 2001 From: RJ Ryan Date: Wed, 2 May 2018 17:57:27 -0700 Subject: [PATCH 062/839] Add complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D. NumPy automatically upcasts to complex128 when computing FFTs, leading to issues like: #10749 This change allows users to choose between 32-bit and 64-bit precision FFTs on CPU and GPU. PiperOrigin-RevId: 195183206 --- tensorflow/compiler/tf2xla/kernels/fft_ops.cc | 17 +- tensorflow/core/kernels/fft_ops.cc | 78 +++++++--- tensorflow/core/ops/spectral_ops.cc | 30 ++-- .../python/kernel_tests/fft_ops_test.py | 145 +++++++++++------- .../linalg/linear_operator_circulant_test.py | 6 +- tensorflow/python/ops/spectral_grad.py | 30 ++-- 6 files changed, 196 insertions(+), 110 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index fcb927dab0..933924cad1 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -81,9 +81,11 @@ class FFTOp : public GenericFftOp { explicit FFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::FFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("FFT"), FFTOp<1>); -REGISTER_XLA_OP(Name("FFT2D"), FFTOp<2>); -REGISTER_XLA_OP(Name("FFT3D"), FFTOp<3>); +REGISTER_XLA_OP(Name("FFT").TypeConstraint("Tcomplex", DT_COMPLEX64), FFTOp<1>); +REGISTER_XLA_OP(Name("FFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64), + FFTOp<2>); +REGISTER_XLA_OP(Name("FFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64), + FFTOp<3>); template class IFFTOp : public GenericFftOp { @@ -91,9 +93,12 @@ class IFFTOp : public GenericFftOp { explicit IFFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::IFFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("IFFT"), IFFTOp<1>); -REGISTER_XLA_OP(Name("IFFT2D"), IFFTOp<2>); -REGISTER_XLA_OP(Name("IFFT3D"), IFFTOp<3>); +REGISTER_XLA_OP(Name("IFFT").TypeConstraint("Tcomplex", DT_COMPLEX64), + IFFTOp<1>); +REGISTER_XLA_OP(Name("IFFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64), + IFFTOp<2>); +REGISTER_XLA_OP(Name("IFFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64), + IFFTOp<3>); template class RFFTOp : public GenericFftOp { diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc index 661bf5fc5f..d7105a71bb 100644 --- a/tensorflow/core/kernels/fft_ops.cc +++ b/tensorflow/core/kernels/fft_ops.cc @@ -129,13 +129,23 @@ class FFTCPU : public FFTBase { auto device = ctx->eigen_device(); if (!IsReal()) { - auto input = Tensor(in).flat_inner_dims(); - // Compute the FFT using eigen. - auto output = out->flat_inner_dims(); + // Compute the FFT using Eigen. constexpr auto direction = Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE; - output.device(device) = - input.template fft(axes); + if (in.dtype() == DT_COMPLEX64) { + DCHECK_EQ(out->dtype(), DT_COMPLEX64); + auto input = Tensor(in).flat_inner_dims(); + auto output = out->flat_inner_dims(); + output.device(device) = + input.template fft(axes); + } else { + DCHECK_EQ(DT_COMPLEX128, in.dtype()); + DCHECK_EQ(DT_COMPLEX128, out->dtype()); + auto input = Tensor(in).flat_inner_dims(); + auto output = out->flat_inner_dims(); + output.device(device) = + input.template fft(axes); + } } else { if (IsForward()) { auto input = Tensor(in).flat_inner_dims(); @@ -392,10 +402,16 @@ class FFTGPUBase : public FFTBase { } constexpr bool kInPlaceFft = false; + const bool is_complex128 = in.dtype() == DT_COMPLEX128; + // complex128 real FFT is not supported yet. + DCHECK(!IsReal() || !is_complex128); + const auto kFftType = IsReal() ? (IsForward() ? se::fft::Type::kR2C : se::fft::Type::kC2R) - : (IsForward() ? se::fft::Type::kC2CForward - : se::fft::Type::kC2CInverse); + : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward + : se::fft::Type::kC2CForward) + : (is_complex128 ? se::fft::Type::kZ2ZInverse + : se::fft::Type::kC2CInverse)); CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx); auto plan = @@ -428,20 +444,42 @@ class FFTGPUBase : public FFTBase { input_shape.DebugString())); } } else { - auto src = AsDeviceMemory(in.flat().data()); - auto dst = AsDeviceMemory(out->flat().data()); - OP_REQUIRES( - ctx, stream->ThenFft(plan.get(), src, &dst).ok(), - errors::Internal("fft failed : type=", static_cast(kFftType), - " in.shape=", input_shape.DebugString())); - if (!IsForward()) { - auto alpha = complex64(1.f / output_distance); + if (!is_complex128) { + DCHECK_EQ(in.dtype(), DT_COMPLEX64); + DCHECK_EQ(out->dtype(), DT_COMPLEX64); + auto src = AsDeviceMemory(in.flat().data()); + auto dst = AsDeviceMemory(out->flat().data()); OP_REQUIRES( - ctx, - stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) - .ok(), - errors::Internal("BlasScal failed : in.shape=", - input_shape.DebugString())); + ctx, stream->ThenFft(plan.get(), src, &dst).ok(), + errors::Internal("fft failed : type=", static_cast(kFftType), + " in.shape=", input_shape.DebugString())); + if (!IsForward()) { + float alpha = 1.f / output_distance; + OP_REQUIRES( + ctx, + stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) + .ok(), + errors::Internal("BlasScal failed : in.shape=", + input_shape.DebugString())); + } + } else { + DCHECK_EQ(in.dtype(), DT_COMPLEX128); + DCHECK_EQ(out->dtype(), DT_COMPLEX128); + auto src = AsDeviceMemory(in.flat().data()); + auto dst = AsDeviceMemory(out->flat().data()); + OP_REQUIRES( + ctx, stream->ThenFft(plan.get(), src, &dst).ok(), + errors::Internal("fft failed : type=", static_cast(kFftType), + " in.shape=", input_shape.DebugString())); + if (!IsForward()) { + double alpha = 1.0 / output_distance; + OP_REQUIRES( + ctx, + stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) + .ok(), + errors::Internal("BlasScal failed : in.shape=", + input_shape.DebugString())); + } } } } diff --git a/tensorflow/core/ops/spectral_ops.cc b/tensorflow/core/ops/spectral_ops.cc index 2790aee37e..b1ae7040f0 100644 --- a/tensorflow/core/ops/spectral_ops.cc +++ b/tensorflow/core/ops/spectral_ops.cc @@ -25,43 +25,49 @@ using shape_inference::InferenceContext; using shape_inference::ShapeHandle; REGISTER_OP("FFT") - .Input("input: complex64") - .Output("output: complex64") + .Input("input: Tcomplex") + .Output("output: Tcomplex") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); }); REGISTER_OP("IFFT") - .Input("input: complex64") - .Output("output: complex64") + .Input("input: Tcomplex") + .Output("output: Tcomplex") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); }); REGISTER_OP("FFT2D") - .Input("input: complex64") - .Output("output: complex64") + .Input("input: Tcomplex") + .Output("output: Tcomplex") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return shape_inference::UnchangedShapeWithRankAtLeast(c, 2); }); REGISTER_OP("IFFT2D") - .Input("input: complex64") - .Output("output: complex64") + .Input("input: Tcomplex") + .Output("output: Tcomplex") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return shape_inference::UnchangedShapeWithRankAtLeast(c, 2); }); REGISTER_OP("FFT3D") - .Input("input: complex64") - .Output("output: complex64") + .Input("input: Tcomplex") + .Output("output: Tcomplex") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); }); REGISTER_OP("IFFT3D") - .Input("input: complex64") - .Output("output: complex64") + .Input("input: Tcomplex") + .Output("output: Tcomplex") + .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64") .SetShapeFn([](InferenceContext* c) { return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); }); diff --git a/tensorflow/python/kernel_tests/fft_ops_test.py b/tensorflow/python/kernel_tests/fft_ops_test.py index b9e2aa1f3a..629acedda5 100644 --- a/tensorflow/python/kernel_tests/fft_ops_test.py +++ b/tensorflow/python/kernel_tests/fft_ops_test.py @@ -38,11 +38,13 @@ VALID_FFT_RANKS = (1, 2, 3) class BaseFFTOpsTest(test.TestCase): - def _compare(self, x, rank, fft_length=None, use_placeholder=False): - self._compareForward(x, rank, fft_length, use_placeholder) - self._compareBackward(x, rank, fft_length, use_placeholder) + def _compare(self, x, rank, fft_length=None, use_placeholder=False, + rtol=1e-4, atol=1e-4): + self._compareForward(x, rank, fft_length, use_placeholder, rtol, atol) + self._compareBackward(x, rank, fft_length, use_placeholder, rtol, atol) - def _compareForward(self, x, rank, fft_length=None, use_placeholder=False): + def _compareForward(self, x, rank, fft_length=None, use_placeholder=False, + rtol=1e-4, atol=1e-4): x_np = self._npFFT(x, rank, fft_length) if use_placeholder: x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype)) @@ -50,9 +52,10 @@ class BaseFFTOpsTest(test.TestCase): else: x_tf = self._tfFFT(x, rank, fft_length) - self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4) + self.assertAllClose(x_np, x_tf, rtol=rtol, atol=atol) - def _compareBackward(self, x, rank, fft_length=None, use_placeholder=False): + def _compareBackward(self, x, rank, fft_length=None, use_placeholder=False, + rtol=1e-4, atol=1e-4): x_np = self._npIFFT(x, rank, fft_length) if use_placeholder: x_ph = array_ops.placeholder(dtype=dtypes.as_dtype(x.dtype)) @@ -60,7 +63,7 @@ class BaseFFTOpsTest(test.TestCase): else: x_tf = self._tfIFFT(x, rank, fft_length) - self.assertAllClose(x_np, x_tf, rtol=1e-4, atol=1e-4) + self.assertAllClose(x_np, x_tf, rtol=rtol, atol=atol) def _checkMemoryFail(self, x, rank): config = config_pb2.ConfigProto() @@ -68,7 +71,8 @@ class BaseFFTOpsTest(test.TestCase): with self.test_session(config=config, force_gpu=True): self._tfFFT(x, rank, fft_length=None) - def _checkGradComplex(self, func, x, y, result_is_complex=True): + def _checkGradComplex(self, func, x, y, result_is_complex=True, + rtol=1e-2, atol=1e-2): with self.test_session(use_gpu=True): inx = ops.convert_to_tensor(x) iny = ops.convert_to_tensor(y) @@ -85,10 +89,10 @@ class BaseFFTOpsTest(test.TestCase): x_init_value=[x, y], delta=1e-2) - self.assertAllClose(x_jacob_t, x_jacob_n, rtol=1e-2, atol=1e-2) - self.assertAllClose(y_jacob_t, y_jacob_n, rtol=1e-2, atol=1e-2) + self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol) + self.assertAllClose(y_jacob_t, y_jacob_n, rtol=rtol, atol=atol) - def _checkGradReal(self, func, x): + def _checkGradReal(self, func, x, rtol=1e-2, atol=1e-2): with self.test_session(use_gpu=True): inx = ops.convert_to_tensor(x) # func is a forward RFFT function (batched or unbatched). @@ -98,7 +102,7 @@ class BaseFFTOpsTest(test.TestCase): x_jacob_t, x_jacob_n = test.compute_gradient( inx, list(x.shape), loss, [1], x_init_value=x, delta=1e-2) - self.assertAllClose(x_jacob_t, x_jacob_n, rtol=1e-2, atol=1e-2) + self.assertAllClose(x_jacob_t, x_jacob_n, rtol=rtol, atol=atol) class FFTOpsTest(BaseFFTOpsTest): @@ -155,27 +159,30 @@ class FFTOpsTest(BaseFFTOpsTest): def testEmpty(self): with spectral_ops_test_util.fft_kernel_label_map(): - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 3): - x = np.zeros((0,) * dims).astype(np.complex64) - self.assertEqual(x.shape, self._tfFFT(x, rank).shape) - self.assertEqual(x.shape, self._tfIFFT(x, rank).shape) + for np_type in (np.complex64, np.complex128): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + x = np.zeros((0,) * dims).astype(np_type) + self.assertEqual(x.shape, self._tfFFT(x, rank).shape) + self.assertEqual(x.shape, self._tfIFFT(x, rank).shape) def testBasic(self): with spectral_ops_test_util.fft_kernel_label_map(): - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 3): - self._compare( - np.mod(np.arange(np.power(4, dims)), 10).reshape( - (4,) * dims).astype(np.complex64), rank) + for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-8)): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + self._compare( + np.mod(np.arange(np.power(4, dims)), 10).reshape( + (4,) * dims).astype(np_type), rank, rtol=tol, atol=tol) def testLargeBatch(self): if test.is_gpu_available(cuda_only=True): rank = 1 for dims in xrange(rank, rank + 3): - self._compare( - np.mod(np.arange(np.power(128, dims)), 10).reshape( - (128,) * dims).astype(np.complex64), rank) + for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-5)): + self._compare( + np.mod(np.arange(np.power(128, dims)), 10).reshape( + (128,) * dims).astype(np_type), rank, rtol=tol, atol=tol) # TODO(yangzihao): Disable before we can figure out a way to # properly test memory fail for large batch fft. @@ -189,27 +196,49 @@ class FFTOpsTest(BaseFFTOpsTest): def testBasicPlaceholder(self): with spectral_ops_test_util.fft_kernel_label_map(): - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 3): - self._compare( - np.mod(np.arange(np.power(4, dims)), 10).reshape( - (4,) * dims).astype(np.complex64), - rank, - use_placeholder=True) + for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-8)): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + self._compare( + np.mod(np.arange(np.power(4, dims)), 10).reshape( + (4,) * dims).astype(np_type), + rank, use_placeholder=True, rtol=tol, atol=tol) def testRandom(self): with spectral_ops_test_util.fft_kernel_label_map(): - np.random.seed(12345) + for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 5e-6)): + def gen(shape): + n = np.prod(shape) + re = np.random.uniform(size=n) + im = np.random.uniform(size=n) + return (re + im * 1j).reshape(shape) - def gen(shape): - n = np.prod(shape) - re = np.random.uniform(size=n) - im = np.random.uniform(size=n) - return (re + im * 1j).reshape(shape) + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 3): + self._compare(gen((4,) * dims).astype(np_type), rank, + rtol=tol, atol=tol) - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 3): - self._compare(gen((4,) * dims), rank) + def testRandom1D(self): + with spectral_ops_test_util.fft_kernel_label_map(): + for np_type in (np.complex64, np.complex128): + has_gpu = test.is_gpu_available(cuda_only=True) + tol = {(np.complex64, True): 1e-4, + (np.complex64, False): 1e-2, + (np.complex128, True): 1e-4, + (np.complex128, False): 1e-2}[(np_type, has_gpu)] + def gen(shape): + n = np.prod(shape) + re = np.random.uniform(size=n) + im = np.random.uniform(size=n) + return (re + im * 1j).reshape(shape) + + # Check a variety of power-of-2 FFT sizes. + for dim in (128, 256, 512, 1024): + self._compare(gen((dim,)).astype(np_type), 1, rtol=tol, atol=tol) + + # Check a variety of non-power-of-2 FFT sizes. + for dim in (127, 255, 511, 1023): + self._compare(gen((dim,)).astype(np_type), 1, rtol=tol, atol=tol) def testError(self): for rank in VALID_FFT_RANKS: @@ -224,22 +253,27 @@ class FFTOpsTest(BaseFFTOpsTest): def testGrad_Simple(self): with spectral_ops_test_util.fft_kernel_label_map(): - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 2): - re = np.ones(shape=(4,) * dims, dtype=np.float32) / 10.0 - im = np.zeros(shape=(4,) * dims, dtype=np.float32) - self._checkGradComplex(self._tfFFTForRank(rank), re, im) - self._checkGradComplex(self._tfIFFTForRank(rank), re, im) + for np_type, tol in ((np.float32, 1e-4), (np.float64, 1e-10)): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 2): + re = np.ones(shape=(4,) * dims, dtype=np_type) / 10.0 + im = np.zeros(shape=(4,) * dims, dtype=np_type) + self._checkGradComplex(self._tfFFTForRank(rank), re, im, + rtol=tol, atol=tol) + self._checkGradComplex(self._tfIFFTForRank(rank), re, im, + rtol=tol, atol=tol) def testGrad_Random(self): with spectral_ops_test_util.fft_kernel_label_map(): - np.random.seed(54321) - for rank in VALID_FFT_RANKS: - for dims in xrange(rank, rank + 2): - re = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1 - im = np.random.rand(*((3,) * dims)).astype(np.float32) * 2 - 1 - self._checkGradComplex(self._tfFFTForRank(rank), re, im) - self._checkGradComplex(self._tfIFFTForRank(rank), re, im) + for np_type, tol in ((np.float32, 1e-2), (np.float64, 1e-10)): + for rank in VALID_FFT_RANKS: + for dims in xrange(rank, rank + 2): + re = np.random.rand(*((3,) * dims)).astype(np_type) * 2 - 1 + im = np.random.rand(*((3,) * dims)).astype(np_type) * 2 - 1 + self._checkGradComplex(self._tfFFTForRank(rank), re, im, + rtol=tol, atol=tol) + self._checkGradComplex(self._tfIFFTForRank(rank), re, im, + rtol=tol, atol=tol) class RFFTOpsTest(BaseFFTOpsTest): @@ -395,8 +429,6 @@ class RFFTOpsTest(BaseFFTOpsTest): def testRandom(self): with spectral_ops_test_util.fft_kernel_label_map(): - np.random.seed(12345) - def gen_real(shape): n = np.prod(shape) re = np.random.uniform(size=n) @@ -491,7 +523,6 @@ class RFFTOpsTest(BaseFFTOpsTest): def testGrad_Random(self): with spectral_ops_test_util.fft_kernel_label_map(): - np.random.seed(54321) for rank in VALID_FFT_RANKS: # rfft3d/irfft3d do not have gradients yet. if rank == 3: diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py index e7f2f1c12b..5713d16969 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py @@ -73,7 +73,7 @@ class LinearOperatorCirculantBaseTest(object): x = np.zeros([domain_dimension]) # x is a basis vector. x[m] = 1.0 - fft_x = math_ops.fft(x) + fft_x = math_ops.fft(x.astype(np.complex64)) h_convolve_x = math_ops.ifft(spectrum * fft_x) matrix_rows.append(h_convolve_x) matrix = array_ops.stack(matrix_rows, axis=-1) @@ -91,7 +91,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator( @property def _dtypes_to_test(self): - # This operator will always be complex because, although the specturm is + # This operator will always be complex because, although the spectrum is # real, the matrix will not be real. return [dtypes.complex64] @@ -408,7 +408,7 @@ class LinearOperatorCirculant2DBaseTest(object): x = np.zeros(block_shape) # x is a basis vector. x[n0, n1] = 1.0 - fft_x = math_ops.fft2d(x) + fft_x = math_ops.fft2d(x.astype(np.complex64)) h_convolve_x = math_ops.ifft2d(spectrum * fft_x) # We want the flat version of the action of the operator on a basis # vector, not the block version. diff --git a/tensorflow/python/ops/spectral_grad.py b/tensorflow/python/ops/spectral_grad.py index deb0a57178..0af24114ac 100644 --- a/tensorflow/python/ops/spectral_grad.py +++ b/tensorflow/python/ops/spectral_grad.py @@ -32,38 +32,44 @@ def _FFTSizeForGrad(grad, rank): @ops.RegisterGradient("FFT") def _FFTGrad(_, grad): - size = math_ops.cast(_FFTSizeForGrad(grad, 1), dtypes.float32) - return spectral_ops.ifft(grad) * math_ops.complex(size, 0.) + size = math_ops.cast(_FFTSizeForGrad(grad, 1), grad.dtype) + return spectral_ops.ifft(grad) * size @ops.RegisterGradient("IFFT") def _IFFTGrad(_, grad): - rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 1), dtypes.float32) - return spectral_ops.fft(grad) * math_ops.complex(rsize, 0.) + rsize = math_ops.cast( + 1. / math_ops.cast(_FFTSizeForGrad(grad, 1), grad.dtype.real_dtype), + grad.dtype) + return spectral_ops.fft(grad) * rsize @ops.RegisterGradient("FFT2D") def _FFT2DGrad(_, grad): - size = math_ops.cast(_FFTSizeForGrad(grad, 2), dtypes.float32) - return spectral_ops.ifft2d(grad) * math_ops.complex(size, 0.) + size = math_ops.cast(_FFTSizeForGrad(grad, 2), grad.dtype) + return spectral_ops.ifft2d(grad) * size @ops.RegisterGradient("IFFT2D") def _IFFT2DGrad(_, grad): - rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 2), dtypes.float32) - return spectral_ops.fft2d(grad) * math_ops.complex(rsize, 0.) + rsize = math_ops.cast( + 1. / math_ops.cast(_FFTSizeForGrad(grad, 2), grad.dtype.real_dtype), + grad.dtype) + return spectral_ops.fft2d(grad) * rsize @ops.RegisterGradient("FFT3D") def _FFT3DGrad(_, grad): - size = math_ops.cast(_FFTSizeForGrad(grad, 3), dtypes.float32) - return spectral_ops.ifft3d(grad) * math_ops.complex(size, 0.) + size = math_ops.cast(_FFTSizeForGrad(grad, 3), grad.dtype) + return spectral_ops.ifft3d(grad) * size @ops.RegisterGradient("IFFT3D") def _IFFT3DGrad(_, grad): - rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 3), dtypes.float32) - return spectral_ops.fft3d(grad) * math_ops.complex(rsize, 0.) + rsize = math_ops.cast( + 1. / math_ops.cast(_FFTSizeForGrad(grad, 3), grad.dtype.real_dtype), + grad.dtype) + return spectral_ops.fft3d(grad) * rsize def _RFFTGradHelper(rank, irfft_fn): -- GitLab From db329cfe2dee382033ad3b3f5e1d906ff489a24d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 18:11:25 -0700 Subject: [PATCH 063/839] Automated g4 rollback of changelist 195091587 PiperOrigin-RevId: 195184798 --- tensorflow/contrib/lite/toco/BUILD | 1 + .../contrib/lite/toco/model_flags.proto | 3 +- tensorflow/contrib/lite/toco/tooling_util.cc | 79 ++++++++++++------- 3 files changed, 53 insertions(+), 30 deletions(-) diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index f16225fd66..ce0a74724a 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -397,6 +397,7 @@ cc_library( ":types_proto_cc", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", "@protobuf_archive//:protobuf_headers", ], ) diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto index d23e80c464..6c1c53658c 100644 --- a/tensorflow/contrib/lite/toco/model_flags.proto +++ b/tensorflow/contrib/lite/toco/model_flags.proto @@ -96,8 +96,9 @@ message RnnState { // model that does not already contain such MinMax information. message ArraysExtraInfo { message Entry { - // Next ID to use: 7. + // Next ID to use: 8. optional string name = 1; + optional string name_regexp = 7; optional double min = 2; optional double max = 3; optional IODataType data_type = 4; diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index f334c51bbb..11293a5fe5 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" +#include "re2/re2.h" #include "tensorflow/contrib/lite/toco/dump_graphviz.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" @@ -1983,38 +1984,58 @@ void FinishBuildingRNNStates(Model* model) { } } +// Returns the array names that match the ArraysExtraInfo's name and +// name_regexp. The regexp match is for a full match. +std::unordered_set ScanArrayNames( + const Model& model, const toco::ArraysExtraInfo_Entry& entry) { + std::unordered_set matches; + if (model.HasArray(entry.name())) { + matches.insert(entry.name()); + } + if (!entry.name_regexp().empty()) { + const auto& arrays = model.GetArrayMap(); + const RE2 name_regexp = {entry.name_regexp()}; + for (auto it = arrays.begin(); it != arrays.end(); ++it) { + if (RE2::FullMatch(it->first, name_regexp)) { + matches.insert(it->first); + } + } + } + return matches; +} + void UseArraysExtraInfo(Model* model, bool quantize_output) { for (const auto& entry : model->flags.arrays_extra_info().entries()) { - if (!model->HasArray(entry.name())) { - continue; - } - auto& array = model->GetArray(entry.name()); - if (entry.has_min() || entry.has_max()) { - CHECK_EQ(entry.has_min(), entry.has_max()); - auto& minmax = array.GetOrCreateMinMax(); - minmax.min = entry.min(); - minmax.max = entry.max(); - } - if (entry.has_data_type() && quantize_output) { - array.final_data_type = - ConvertIODataTypeToArrayDataType(entry.data_type()); - } - if (entry.has_shape()) { - array.clear_shape(); - // Make sure to create the shape even if there are no dims, to - // correctly record 0-D shapes. - array.mutable_shape(); - for (int dim : entry.shape().dims()) { - array.mutable_shape()->mutable_dims()->push_back(dim); + const auto matches = ScanArrayNames(*model, entry); + for (const auto& matched_name : matches) { + auto& array = model->GetArray(matched_name); + if (entry.has_min() || entry.has_max()) { + CHECK_EQ(entry.has_min(), entry.has_max()); + auto& minmax = array.GetOrCreateMinMax(); + minmax.min = entry.min(); + minmax.max = entry.max(); } - } - if (entry.has_constant_float_value()) { - CHECK(array.has_shape()); - if (array.data_type == ArrayDataType::kFloat) { - auto& data = array.GetMutableBuffer().data; - data.resize(RequiredBufferSizeForShape(array.shape())); - for (float& f : data) { - f = entry.constant_float_value(); + if (entry.has_data_type() && quantize_output) { + array.final_data_type = + ConvertIODataTypeToArrayDataType(entry.data_type()); + } + if (entry.has_shape()) { + array.clear_shape(); + // Make sure to create the shape even if there are no dims, to + // correctly record 0-D shapes. + array.mutable_shape(); + for (int dim : entry.shape().dims()) { + array.mutable_shape()->mutable_dims()->push_back(dim); + } + } + if (entry.has_constant_float_value()) { + CHECK(array.has_shape()); + if (array.data_type == ArrayDataType::kFloat) { + auto& data = array.GetMutableBuffer().data; + data.resize(RequiredBufferSizeForShape(array.shape())); + for (float& f : data) { + f = entry.constant_float_value(); + } } } } -- GitLab From 1a4f746ffd82376f6e9ad420d96943ff89e7013a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 18:19:16 -0700 Subject: [PATCH 064/839] Remove duplicated emplace_back floor operator. PiperOrigin-RevId: 195185567 --- tensorflow/contrib/lite/toco/tflite/operator.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index fce3bad326..d2e14ac5e0 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -901,8 +901,6 @@ std::vector> BuildOperatorList() { "MINIMUM", OperatorType::kTensorFlowMinimum)); ops.emplace_back(new SimpleOperator( "LESS", OperatorType::kTensorFlowLess)); - ops.emplace_back( - new SimpleOperator("FLOOR", OperatorType::kFloor)); return ops; } -- GitLab From 2b1a03c2ad502329a1f2b1368a40913ef21e97a0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 18:35:55 -0700 Subject: [PATCH 065/839] Compute shape of segment_ids dynamically in _unsorted_segment_N PiperOrigin-RevId: 195186950 --- tensorflow/python/ops/math_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 7ac3bd8091..ab5997e85c 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -2515,7 +2515,8 @@ def _unsorted_segment_N(data, segment_ids, num_segments): of segment entries with 0-entries set to 1 to allow division by N. """ # bincount doesn't support negative indices so we use unsorted_segment_sum - ones_tensor = array_ops.ones(segment_ids.shape, dtype=data.dtype) + segment_ids_shape = array_ops.shape_internal(segment_ids) + ones_tensor = array_ops.ones(segment_ids_shape, dtype=data.dtype) N = gen_math_ops.unsorted_segment_sum(ones_tensor, segment_ids, num_segments) # add dimensions for all non-reduced axes ndims_output = data.shape.ndims - segment_ids.shape.ndims -- GitLab From 223be4abe74592a781735a6b66e12cb0146f0830 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 18:52:02 -0700 Subject: [PATCH 066/839] Replaced calls to tensorflow::StringPiece::ToString with std::string conversions. That is, instances of sp.ToString() are replaced with std::string(sp). This will allow tensorflow::StringPiece::ToString to be removed, which is necessary before it can be replaced with absl::string_view. PiperOrigin-RevId: 195188185 --- tensorflow/cc/framework/cc_op_gen.cc | 2 +- tensorflow/cc/framework/scope.cc | 2 +- tensorflow/compiler/tf2xla/tf2xla_util.cc | 2 +- tensorflow/compiler/tf2xla/xla_op_registry.cc | 14 +++++++------- .../compiler/xla/service/llvm_ir/llvm_loop.cc | 4 ++-- .../compiler/xla/service/llvm_ir/llvm_loop.h | 2 +- .../compiler/xla/tools/parser/hlo_lexer.cc | 2 +- .../compiler/xla/tools/parser/hlo_parser.cc | 2 +- tensorflow/core/debug/debug_graph_utils.cc | 7 +++---- tensorflow/core/debug/debug_io_utils.cc | 10 +++++----- .../core/distributed_runtime/master_session.cc | 6 +++--- .../core/distributed_runtime/remote_device.cc | 2 +- tensorflow/core/grappler/utils.h | 2 +- .../core/kernels/hexagon/graph_transferer.cc | 2 +- .../kernels/hexagon/hexagon_control_wrapper.cc | 2 +- tensorflow/core/lib/io/path.cc | 6 +++--- tensorflow/core/lib/io/table_test.cc | 6 +++--- .../core/util/tensor_bundle/tensor_bundle.cc | 16 ++++++++-------- .../util/tensor_bundle/tensor_bundle_test.cc | 2 +- tensorflow/stream_executor/kernel.cc | 2 +- tensorflow/stream_executor/kernel_spec.cc | 6 +++--- 21 files changed, 49 insertions(+), 50 deletions(-) diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index d73121c7b7..d6a4f141b6 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -440,7 +440,7 @@ string AvoidCPPKeywords(StringPiece name) { if (IsCPPKeyword(name)) { return strings::StrCat(name, "_"); } - return name.ToString(); + return std::string(name); } void InferArgAttributes(const OpDef::ArgDef& arg, diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index c143b97833..62a889181e 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -220,7 +220,7 @@ std::unordered_set Scope::Impl::GetColocationConstraints( for (const string& entry : node_constraints) { StringPiece s(entry); if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) { - current_constraints.insert(s.ToString()); + current_constraints.insert(std::string(s)); } } } else { diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 7ec85aa3cd..9203e8d9e6 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -232,7 +232,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, // Push input nodes of the currently visited node to name_queue. for (const string& in_edge : map_entry.second->input()) { auto id = ParseTensorName(in_edge); - const string node_name = id.first.ToString(); + const string node_name = std::string(id.first); if (feed_tensors.find(std::make_pair(node_name, id.second)) == feed_tensors.end()) { name_queue.push(node_name); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index bbe808595d..e309cb1e34 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -311,7 +311,7 @@ XlaOpRegistry& XlaOpRegistry::Instance() { XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) { registration_.reset(new XlaOpRegistry::OpRegistration); - registration_->name = name.ToString(); + registration_->name = std::string(name); } XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) { @@ -323,14 +323,14 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( gtl::ArraySlice devices) { registration_->has_device_whitelist = true; for (StringPiece device : devices) { - registration_->device_whitelist.insert(device.ToString()); + registration_->device_whitelist.insert(std::string(device)); } return *this; } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) { registration_->has_device_whitelist = true; - registration_->device_whitelist.insert(device.ToString()); + registration_->device_whitelist.insert(std::string(device)); return *this; } @@ -347,7 +347,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() { XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( StringPiece attr_name, DataType allowed) { std::set& types = - registration_->type_constraints[attr_name.ToString()]; + registration_->type_constraints[std::string(attr_name)]; types.insert(allowed); return *this; } @@ -355,7 +355,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( StringPiece attr_name, gtl::ArraySlice allowed) { std::set& types = - registration_->type_constraints[attr_name.ToString()]; + registration_->type_constraints[std::string(attr_name)]; for (DataType t : allowed) { types.insert(t); } @@ -364,7 +364,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( StringPiece input_name) { - registration_->compile_time_constant_inputs.insert(input_name.ToString()); + registration_->compile_time_constant_inputs.insert(std::string(input_name)); return *this; } @@ -394,7 +394,7 @@ XlaBackendRegistrar::XlaBackendRegistrar( StringPiece name, gtl::ArraySlice types, XlaOpRegistry::BackendOpFilter op_filter) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); - registry.RegisterBackend(name.ToString(), types, op_filter); + registry.RegisterBackend(std::string(name), types, op_filter); } } // namespace tensorflow diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 7b227ce294..497b48ff22 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -36,8 +36,8 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, bool prevent_unrolling, bool prevent_vectorization) - : prefix_(prefix.ToString()), - suffix_(suffix.ToString()), + : prefix_(std::string(prefix)), + suffix_(std::string(suffix)), start_index_(start_index), end_index_(end_index), step_(step), diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index 20069ce5a2..d915f95db1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -174,7 +174,7 @@ class ForLoopNest { : ForLoopNest(/*name=*/"", ir_builder) {} ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) - : name_(name.ToString()), + : name_(std::string(name)), outer_loop_preheader_bb_(nullptr), outer_loop_exit_bb_(nullptr), inner_loop_body_bb_(nullptr), diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc index fc0e444452..350db12653 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -230,7 +230,7 @@ TokKind HloLexer::LexIdentifier() { } } - str_val_ = identifier.ToString(); + str_val_ = std::string(identifier); return TokKind::kIdent; } diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 1bb31ddb7b..3a945fb3b1 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -242,7 +242,7 @@ bool HloParser::Error(LocTy loc, StringPiece msg) { std::vector error_lines; error_lines.push_back( StrCat("was parsing ", line, ":", col, ": error: ", msg)); - error_lines.push_back(lexer_.GetLine(loc).ToString()); + error_lines.push_back(std::string(lexer_.GetLine(loc))); error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); error_.push_back(tensorflow::str_util::Join(error_lines, "\n")); diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc index 4539ea5c0c..7641edea52 100644 --- a/tensorflow/core/debug/debug_graph_utils.cc +++ b/tensorflow/core/debug/debug_graph_utils.cc @@ -356,10 +356,9 @@ Status DebugNodeInserter::ParseDebugOpName( "Malformed attributes in debug op name \"", debug_op_name, "\""); } - const string key = seg.substr(0, eq_index).ToString(); - const string value = - seg.substr(eq_index + 1, attribute_seg.size() - eq_index - 1) - .ToString(); + const string key = std::string(seg.substr(0, eq_index)); + const string value = std::string( + seg.substr(eq_index + 1, attribute_seg.size() - eq_index - 1)); if (key.empty() || value.empty()) { return errors::InvalidArgument( "Malformed attributes in debug op name \"", debug_op_name, "\""); diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc index baa8c08fdf..4998a7acfe 100644 --- a/tensorflow/core/debug/debug_io_utils.cc +++ b/tensorflow/core/debug/debug_io_utils.cc @@ -399,8 +399,8 @@ Status DebugIO::PublishDebugMetadata( strings::Printf("%.14lld", session_run_index))), Env::Default()->NowMicros()); status.Update(DebugFileIO::DumpEventProtoToFile( - event, io::Dirname(core_metadata_path).ToString(), - io::Basename(core_metadata_path).ToString())); + event, std::string(io::Dirname(core_metadata_path)), + std::string(io::Basename(core_metadata_path)))); } } @@ -632,8 +632,8 @@ Status DebugFileIO::DumpTensorToEventFile(const DebugNodeKey& debug_node_key, std::vector events; TF_RETURN_IF_ERROR( WrapTensorAsEvents(debug_node_key, tensor, wall_time_us, 0, &events)); - return DumpEventProtoToFile(events[0], io::Dirname(file_path).ToString(), - io::Basename(file_path).ToString()); + return DumpEventProtoToFile(events[0], std::string(io::Dirname(file_path)), + std::string(io::Basename(file_path))); } Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) { @@ -642,7 +642,7 @@ Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) { return Status::OK(); } - string parent_dir = io::Dirname(dir).ToString(); + string parent_dir = std::string(io::Dirname(dir)); if (!env->FileExists(parent_dir).ok()) { // The parent path does not exist yet, create it first. Status s = RecursiveCreateDir(env, parent_dir); // Recursive call diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 83afc5b1a4..08fbe8b144 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -606,7 +606,7 @@ Status MasterSession::ReffedClientGraph::RunPartitionsHelper( // inadvertently slowing down the normal run path. if (is_partial_) { for (const auto& name_index : feeds) { - const auto iter = part.feed_key.find(name_index.first.ToString()); + const auto iter = part.feed_key.find(std::string(name_index.first)); if (iter == part.feed_key.end()) { // The provided feed must be for a different partition. continue; @@ -950,7 +950,7 @@ Status MasterSession::ReffedClientGraph::CheckFetches( // Skip if already fed. if (input.second) continue; TensorId id(ParseTensorName(input.first)); - const Node* n = execution_state->get_node_by_name(id.first.ToString()); + const Node* n = execution_state->get_node_by_name(std::string(id.first)); if (n == nullptr) { return errors::NotFound("Feed ", input.first, ": not found"); } @@ -966,7 +966,7 @@ Status MasterSession::ReffedClientGraph::CheckFetches( for (size_t i = 0; i < req.num_fetches(); ++i) { const string& fetch = req.fetch_name(i); const TensorId id(ParseTensorName(fetch)); - const Node* n = execution_state->get_node_by_name(id.first.ToString()); + const Node* n = execution_state->get_node_by_name(std::string(id.first)); if (n == nullptr) { return errors::NotFound("Fetch ", fetch, ": not found"); } diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc index ec26ac44b5..15e5919c54 100644 --- a/tensorflow/core/distributed_runtime/remote_device.cc +++ b/tensorflow/core/distributed_runtime/remote_device.cc @@ -37,7 +37,7 @@ string GetLocalDeviceName(StringPiece fullname) { auto pos = fullname.rfind('/'); CHECK_NE(pos, StringPiece::npos); fullname.remove_prefix(pos + 1); - return fullname.ToString(); + return std::string(fullname); } class RemoteDevice : public Device { diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 9776e99f20..b87ae05546 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -139,7 +139,7 @@ inline StringPiece ParseNodeNameAsStringPiece(const string& name, // Returns the node name and position in a single call. inline string ParseNodeName(const string& name, int* position) { - return ParseNodeNameAsStringPiece(name, position).ToString(); + return std::string(ParseNodeNameAsStringPiece(name, position)); } // Add a prefix to a node name with a custom delimiter. diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc index 7960cb4b05..e05de3fe8e 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc @@ -161,7 +161,7 @@ Status GraphTransferer::LoadGraphFromProto( for (const string& output_node_name : output_node_names) { const TensorId tid = ParseTensorName(output_node_name); - const string node_name = tid.first.ToString(); + const string node_name = std::string(tid.first); const int port = tid.second; const int node_id = node_name_to_id_cache_map_.at(node_name); const Node* node = node_name_cache_list_.at(node_id); diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc index 3810cbe5b5..1580b72605 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc @@ -168,7 +168,7 @@ bool HexagonControlWrapper::SetupGraph() { new_output_node_info.set_output_count(0); const TensorId tid = ParseTensorName(graph_output.name()); - const string node_name = tid.first.ToString(); + const string node_name = std::string(tid.first); const int port = tid.second; // Register node input for the new output node const GraphTransferNodeInfo* node_info = diff --git a/tensorflow/core/lib/io/path.cc b/tensorflow/core/lib/io/path.cc index 996fbf62e5..b62206012c 100644 --- a/tensorflow/core/lib/io/path.cc +++ b/tensorflow/core/lib/io/path.cc @@ -42,7 +42,7 @@ string JoinPathImpl(std::initializer_list paths) { if (path.empty()) continue; if (result.empty()) { - result = path.ToString(); + result = std::string(path); continue; } @@ -124,7 +124,7 @@ StringPiece Extension(StringPiece path) { } string CleanPath(StringPiece unclean_path) { - string path = unclean_path.ToString(); + string path = std::string(unclean_path); const char* src = path.c_str(); string::iterator dst = path.begin(); @@ -237,7 +237,7 @@ void ParseURI(StringPiece remaining, StringPiece* scheme, StringPiece* host, string CreateURI(StringPiece scheme, StringPiece host, StringPiece path) { if (scheme.empty()) { - return path.ToString(); + return std::string(path); } return strings::StrCat(scheme, "://", host, path); } diff --git a/tensorflow/core/lib/io/table_test.cc b/tensorflow/core/lib/io/table_test.cc index 78a3fa501c..9e3309f0a7 100644 --- a/tensorflow/core/lib/io/table_test.cc +++ b/tensorflow/core/lib/io/table_test.cc @@ -147,7 +147,7 @@ class Constructor { virtual ~Constructor() {} void Add(const string& key, const StringPiece& value) { - data_[key] = value.ToString(); + data_[key] = std::string(value); } // Finish constructing the data structure with all the keys that have @@ -188,7 +188,7 @@ class BlockConstructor : public Constructor { builder.Add(it->first, it->second); } // Open the block - data_ = builder.Finish().ToString(); + data_ = std::string(builder.Finish()); BlockContents contents; contents.data = data_; contents.cachable = false; @@ -515,7 +515,7 @@ TEST_F(Harness, Randomized) { for (int e = 0; e < num_entries; e++) { string v; Add(test::RandomKey(&rnd, rnd.Skewed(4)), - test::RandomString(&rnd, rnd.Skewed(5), &v).ToString()); + std::string(test::RandomString(&rnd, rnd.Skewed(5), &v))); } Test(&rnd); } diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index 0426fee0e2..7190614706 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -370,14 +370,14 @@ Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) { BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options) : env_(env), options_(options), - prefix_(prefix.ToString()), + prefix_(std::string(prefix)), tmp_metadata_path_(strings::StrCat(MetaFilename(prefix_), ".tempstate", random::New64())), tmp_data_path_(strings::StrCat(DataFilename(prefix_, 0, 1), ".tempstate", random::New64())), out_(nullptr), size_(0) { - status_ = env_->CreateDir(io::Dirname(prefix_).ToString()); + status_ = env_->CreateDir(std::string(io::Dirname(prefix_))); if (!status_.ok() && !errors::IsAlreadyExists(status_)) { return; } @@ -394,7 +394,7 @@ BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options) Status BundleWriter::Add(StringPiece key, const Tensor& val) { if (!status_.ok()) return status_; CHECK_NE(key, kHeaderEntryKey); - const string key_string = key.ToString(); + const string key_string = std::string(key); if (entries_.find(key_string) != entries_.end()) { status_ = errors::InvalidArgument("Adding duplicate key: ", key); return status_; @@ -445,7 +445,7 @@ Status BundleWriter::AddSlice(StringPiece full_tensor_key, // In the case of a sharded save, MergeBundles() is responsible for merging // the "slices" field of multiple metadata entries corresponding to the same // full tensor. - const string full_tensor_key_string = full_tensor_key.ToString(); + const string full_tensor_key_string = std::string(full_tensor_key); BundleEntryProto* full_entry = &entries_[full_tensor_key_string]; if (full_entry->dtype() != DT_INVALID) { CHECK_EQ(full_entry->dtype(), slice_tensor.dtype()); @@ -600,7 +600,7 @@ static Status MergeOneBundle(Env* env, StringPiece prefix, // Loops through the non-header to-merge entries. BundleEntryProto to_merge_entry; for (; iter->Valid(); iter->Next()) { - const string key = iter->key().ToString(); + const string key = std::string(iter->key()); const auto entry_iter = merge_state->entries.find(key); // Illegal: the duplicated entry is a non-slice tensor. @@ -649,7 +649,7 @@ Status MergeBundles(Env* env, gtl::ArraySlice prefixes, // Merges all metadata tables. // TODO(zhifengc): KeyValue sorter if it becomes too big. MergeState merge; - Status status = env->CreateDir(io::Dirname(merged_prefix).ToString()); + Status status = env->CreateDir(std::string(io::Dirname(merged_prefix))); if (!status.ok() && !errors::IsAlreadyExists(status)) return status; for (int i = 0; i < prefixes.size(); ++i) { TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge)); @@ -697,7 +697,7 @@ Status MergeBundles(Env* env, gtl::ArraySlice prefixes, BundleReader::BundleReader(Env* env, StringPiece prefix) : env_(env), - prefix_(prefix.ToString()), + prefix_(std::string(prefix)), metadata_(nullptr), table_(nullptr), iter_(nullptr) { @@ -919,7 +919,7 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key, const TensorShape full_shape(TensorShape(full_tensor_entry.shape())); std::vector> details; - const string full_tensor_key_string = full_tensor_key.ToString(); + const string full_tensor_key_string = std::string(full_tensor_key); const TensorSliceSet* tss = gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string); diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc index 7f166f0ec0..92ce8ae00e 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc @@ -107,7 +107,7 @@ std::vector AllTensorKeys(BundleReader* reader) { reader->Seek(kHeaderEntryKey); reader->Next(); for (; reader->Valid(); reader->Next()) { - ret.push_back(reader->key().ToString()); + ret.push_back(std::string(reader->key())); } return ret; } diff --git a/tensorflow/stream_executor/kernel.cc b/tensorflow/stream_executor/kernel.cc index d1aa596b73..7c1923da51 100644 --- a/tensorflow/stream_executor/kernel.cc +++ b/tensorflow/stream_executor/kernel.cc @@ -94,7 +94,7 @@ KernelCacheConfig KernelBase::GetPreferredCacheConfig() const { static const char *kStubPrefix = "__device_stub_"; void KernelBase::set_name(port::StringPiece name) { - name_ = name.ToString(); + name_ = std::string(name); port::StringPiece stubless_name = name; if (tensorflow::str_util::StartsWith(name, kStubPrefix)) { stubless_name.remove_prefix(strlen(kStubPrefix)); diff --git a/tensorflow/stream_executor/kernel_spec.cc b/tensorflow/stream_executor/kernel_spec.cc index 6a1f0a591f..f0a5785b72 100644 --- a/tensorflow/stream_executor/kernel_spec.cc +++ b/tensorflow/stream_executor/kernel_spec.cc @@ -18,11 +18,11 @@ limitations under the License. namespace stream_executor { KernelLoaderSpec::KernelLoaderSpec(port::StringPiece kernelname) - : kernelname_(kernelname.ToString()) {} + : kernelname_(std::string(kernelname)) {} OnDiskKernelLoaderSpec::OnDiskKernelLoaderSpec(port::StringPiece filename, port::StringPiece kernelname) - : KernelLoaderSpec(kernelname), filename_(filename.ToString()) {} + : KernelLoaderSpec(kernelname), filename_(std::string(filename)) {} CudaPtxOnDisk::CudaPtxOnDisk(port::StringPiece filename, port::StringPiece kernelname) @@ -161,7 +161,7 @@ OpenCLTextOnDisk::OpenCLTextOnDisk(port::StringPiece filename, OpenCLTextInMemory::OpenCLTextInMemory(port::StringPiece text, port::StringPiece kernelname) - : KernelLoaderSpec(kernelname), text_(text.ToString()) {} + : KernelLoaderSpec(kernelname), text_(std::string(text)) {} OpenCLBinaryOnDisk::OpenCLBinaryOnDisk(port::StringPiece filename, port::StringPiece kernelname) -- GitLab From ebad5d624c6c08a3fdb4ffac6051b4888fc36790 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 2 May 2018 19:19:12 -0700 Subject: [PATCH 067/839] Update ops-related pbtxt files. PiperOrigin-RevId: 195190335 --- .../core/ops/compat/ops_history.v1.pbtxt | 144 ++++++++++++++++++ tensorflow/core/ops/ops.pbtxt | 102 +++++++++++-- 2 files changed, 234 insertions(+), 12 deletions(-) diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index cb466ef817..3db00d8180 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -20998,6 +20998,30 @@ op { type: DT_COMPLEX64 } } +op { + name: "FFT" + input_arg { + name: "input" + type_attr: "Tcomplex" + } + output_arg { + name: "output" + type_attr: "Tcomplex" + } + attr { + name: "Tcomplex" + type: "type" + default_value { + type: DT_COMPLEX64 + } + allowed_values { + list { + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} op { name: "FFT2D" input_arg { @@ -21009,6 +21033,30 @@ op { type: DT_COMPLEX64 } } +op { + name: "FFT2D" + input_arg { + name: "input" + type_attr: "Tcomplex" + } + output_arg { + name: "output" + type_attr: "Tcomplex" + } + attr { + name: "Tcomplex" + type: "type" + default_value { + type: DT_COMPLEX64 + } + allowed_values { + list { + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} op { name: "FFT3D" input_arg { @@ -21020,6 +21068,30 @@ op { type: DT_COMPLEX64 } } +op { + name: "FFT3D" + input_arg { + name: "input" + type_attr: "Tcomplex" + } + output_arg { + name: "output" + type_attr: "Tcomplex" + } + attr { + name: "Tcomplex" + type: "type" + default_value { + type: DT_COMPLEX64 + } + allowed_values { + list { + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} op { name: "FIFOQueue" output_arg { @@ -24711,6 +24783,30 @@ op { type: DT_COMPLEX64 } } +op { + name: "IFFT" + input_arg { + name: "input" + type_attr: "Tcomplex" + } + output_arg { + name: "output" + type_attr: "Tcomplex" + } + attr { + name: "Tcomplex" + type: "type" + default_value { + type: DT_COMPLEX64 + } + allowed_values { + list { + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} op { name: "IFFT2D" input_arg { @@ -24722,6 +24818,30 @@ op { type: DT_COMPLEX64 } } +op { + name: "IFFT2D" + input_arg { + name: "input" + type_attr: "Tcomplex" + } + output_arg { + name: "output" + type_attr: "Tcomplex" + } + attr { + name: "Tcomplex" + type: "type" + default_value { + type: DT_COMPLEX64 + } + allowed_values { + list { + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} op { name: "IFFT3D" input_arg { @@ -24733,6 +24853,30 @@ op { type: DT_COMPLEX64 } } +op { + name: "IFFT3D" + input_arg { + name: "input" + type_attr: "Tcomplex" + } + output_arg { + name: "output" + type_attr: "Tcomplex" + } + attr { + name: "Tcomplex" + type: "type" + default_value { + type: DT_COMPLEX64 + } + allowed_values { + list { + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} op { name: "IRFFT" input_arg { diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 207dd1c3d7..7156440b46 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -9709,33 +9709,72 @@ op { name: "FFT" input_arg { name: "input" - type: DT_COMPLEX64 + type_attr: "Tcomplex" } output_arg { name: "output" - type: DT_COMPLEX64 + type_attr: "Tcomplex" + } + attr { + name: "Tcomplex" + type: "type" + default_value { + type: DT_COMPLEX64 + } + allowed_values { + list { + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } } } op { name: "FFT2D" input_arg { name: "input" - type: DT_COMPLEX64 + type_attr: "Tcomplex" } output_arg { name: "output" - type: DT_COMPLEX64 + type_attr: "Tcomplex" + } + attr { + name: "Tcomplex" + type: "type" + default_value { + type: DT_COMPLEX64 + } + allowed_values { + list { + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } } } op { name: "FFT3D" input_arg { name: "input" - type: DT_COMPLEX64 + type_attr: "Tcomplex" } output_arg { name: "output" - type: DT_COMPLEX64 + type_attr: "Tcomplex" + } + attr { + name: "Tcomplex" + type: "type" + default_value { + type: DT_COMPLEX64 + } + allowed_values { + list { + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } } } op { @@ -11877,33 +11916,72 @@ op { name: "IFFT" input_arg { name: "input" - type: DT_COMPLEX64 + type_attr: "Tcomplex" } output_arg { name: "output" - type: DT_COMPLEX64 + type_attr: "Tcomplex" + } + attr { + name: "Tcomplex" + type: "type" + default_value { + type: DT_COMPLEX64 + } + allowed_values { + list { + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } } } op { name: "IFFT2D" input_arg { name: "input" - type: DT_COMPLEX64 + type_attr: "Tcomplex" } output_arg { name: "output" - type: DT_COMPLEX64 + type_attr: "Tcomplex" + } + attr { + name: "Tcomplex" + type: "type" + default_value { + type: DT_COMPLEX64 + } + allowed_values { + list { + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } } } op { name: "IFFT3D" input_arg { name: "input" - type: DT_COMPLEX64 + type_attr: "Tcomplex" } output_arg { name: "output" - type: DT_COMPLEX64 + type_attr: "Tcomplex" + } + attr { + name: "Tcomplex" + type: "type" + default_value { + type: DT_COMPLEX64 + } + allowed_values { + list { + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } } } op { -- GitLab From 71f97c8cd9304a8e1cf8e309e15000d5831b212a Mon Sep 17 00:00:00 2001 From: Mostafa Alaa Date: Wed, 2 May 2018 19:53:18 -0700 Subject: [PATCH 068/839] Fix tf.variable_scope unique name after entering root scope Closes #18702. PiperOrigin-RevId: 195192460 --- tensorflow/python/ops/variable_scope.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index ba213ef884..adb0f59948 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -1175,7 +1175,7 @@ class _VariableScopeStore(threading.local): def close_variable_subscopes(self, scope_name): for k in list(self.variable_scopes_count.keys()): - if not scope_name or k.startswith(scope_name + "/"): + if scope_name is None or k.startswith(scope_name + "/"): self.variable_scopes_count[k] = 0 def variable_scope_count(self, scope_name): -- GitLab From f6000468263c5db7befbf5c320e8b3af7d90b819 Mon Sep 17 00:00:00 2001 From: Andrew Selle Date: Wed, 2 May 2018 21:15:01 -0700 Subject: [PATCH 069/839] Expose Interpreter to tensorflow.contrib.lite PiperOrigin-RevId: 195198645 --- tensorflow/contrib/lite/BUILD | 2 + tensorflow/contrib/lite/python/BUILD | 1 + tensorflow/contrib/lite/python/interpreter.py | 18 ++++- .../interpreter_wrapper.cc | 2 +- tensorflow/contrib/lite/python/lite.py | 2 + .../tools/pip_package/pip_smoke_test.py | 73 +++++++++++++------ 6 files changed, 70 insertions(+), 28 deletions(-) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 1534f97d76..10065e894c 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -92,6 +92,8 @@ cc_library( deps = [":context"], ) +exports_files(["builtin_ops.h"]) + cc_library( name = "string", hdrs = [ diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index e6dcc7aa09..4920e83970 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -44,6 +44,7 @@ py_library( deps = [ ":convert", ":convert_saved_model", + ":interpreter", ":op_hint", ], ) diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py index cb9c0d3121..5fbc551452 100644 --- a/tensorflow/contrib/lite/python/interpreter.py +++ b/tensorflow/contrib/lite/python/interpreter.py @@ -17,7 +17,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.lite.python.interpreter_wrapper import tensorflow_wrap_interpreter_wrapper as interpreter_wrapper +from tensorflow.python.util.lazy_loader import LazyLoader + +# Lazy load since some of the performance benchmark skylark rules +# break dependencies. Must use double quotes to match code internal rewrite +# rule. +# pylint: disable=g-inconsistent-quotes +_interpreter_wrapper = LazyLoader( + "_interpreter_wrapper", globals(), + "tensorflow.contrib.lite.python.interpreter_wrapper." + "tensorflow_wrap_interpreter_wrapper") +# pylint: enable=g-inconsistent-quotes + +del LazyLoader class Interpreter(object): @@ -35,13 +47,13 @@ class Interpreter(object): """ if model_path and not model_content: self._interpreter = ( - interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromFile( + _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromFile( model_path)) if not self._interpreter: raise ValueError('Failed to open {}'.format(model_path)) elif model_content and not model_path: self._interpreter = ( - interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer( + _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer( model_content, len(model_content))) if not self._interpreter: raise ValueError( diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 04fc098129..16f4f30b94 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -116,7 +116,7 @@ PyObject* PyArrayFromIntVector(const int* data, npy_intp size) { PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) { PyObject* result = PyTuple_New(2); PyTuple_SET_ITEM(result, 0, PyFloat_FromDouble(param.scale)); - PyTuple_SET_ITEM(result, 1, PyInt_FromLong(param.zero_point)); + PyTuple_SET_ITEM(result, 1, PyLong_FromLong(param.zero_point)); return result; } diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 4ea40201f7..86b25e68ac 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -19,6 +19,7 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. @@toco_convert @@toco_convert_protos @@tflite_from_saved_model +@@Interpreter @@OpHint @@convert_op_hints_to_stubs @@ -31,6 +32,7 @@ from __future__ import print_function from tensorflow.contrib.lite.python.convert import toco_convert from tensorflow.contrib.lite.python.convert import toco_convert_protos from tensorflow.contrib.lite.python.convert_saved_model import tflite_from_saved_model +from tensorflow.contrib.lite.python.interpreter import Interpreter from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: enable=unused-import diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index b23dde2019..401f833dbd 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -30,15 +30,42 @@ os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))) PIP_PACKAGE_QUERY_EXPRESSION = ( "deps(//tensorflow/tools/pip_package:build_pip_package)") -# pylint: disable=g-backslash-continuation -PY_TEST_QUERY_EXPRESSION = 'deps(\ - filter("^((?!benchmark).)*$",\ - kind(py_test,\ - //tensorflow/python/... \ - + //tensorflow/contrib/... \ - - //tensorflow/contrib/tensorboard/... \ - - attr(tags, "manual|no_pip", //tensorflow/...))), 1)' -# pylint: enable=g-backslash-continuation + +def GetBuild(dir_base): + """Get the list of BUILD file all targets recursively startind at dir_base.""" + items = [] + for root, _, files in os.walk(dir_base): + for name in files: + if (name == "BUILD" and + root.find("tensorflow/contrib/lite/examples/android") == -1): + items.append("//" + root + ":all") + return items + + +def BuildPyTestDependencies(): + python_targets = GetBuild("tensorflow/python") + contrib_targets = GetBuild("tensorflow/contrib") + tensorboard_targets = GetBuild("tensorflow/contrib/tensorboard") + tensorflow_targets = GetBuild("tensorflow") + # Build list of test targets, + # python + contrib - tensorboard - attr(manual|pno_pip) + targets = " + ".join(python_targets) + for t in contrib_targets: + targets += " + " + t + for t in tensorboard_targets: + targets += " - " + t + targets += ' - attr(tags, "manual|no_pip", %s)' % " + ".join( + tensorflow_targets) + query_kind = "kind(py_test, %s)" % targets + # Skip benchmarks etc. + query_filter = 'filter("^((?!benchmark).)*$", %s)' % query_kind + # Get the dependencies + query_deps = "deps(%s, 1)" % query_filter + + return python_targets, query_deps + + +PYTHON_TARGETS, PY_TEST_QUERY_EXPRESSION = BuildPyTestDependencies() # Hard-coded blacklist of files if not included in pip package # TODO(amitpatankar): Clean up blacklist. @@ -79,16 +106,6 @@ BLACKLIST = [ ] -def bazel_query(query_target): - """Run bazel query on target.""" - try: - output = subprocess.check_output( - ["bazel", "query", "--keep_going", query_target]) - except subprocess.CalledProcessError as e: - output = e.output - return output - - def main(): """This script runs the pip smoke test. @@ -103,14 +120,22 @@ def main(): """ # pip_package_dependencies_list is the list of included files in pip packages - pip_package_dependencies = bazel_query(PIP_PACKAGE_QUERY_EXPRESSION) + pip_package_dependencies = subprocess.check_output( + ["bazel", "cquery", PIP_PACKAGE_QUERY_EXPRESSION]) pip_package_dependencies_list = pip_package_dependencies.strip().split("\n") + pip_package_dependencies_list = [ + x.split()[0] for x in pip_package_dependencies_list + ] print("Pip package superset size: %d" % len(pip_package_dependencies_list)) # tf_py_test_dependencies is the list of dependencies for all python # tests in tensorflow - tf_py_test_dependencies = bazel_query(PY_TEST_QUERY_EXPRESSION) + tf_py_test_dependencies = subprocess.check_output( + ["bazel", "cquery", PY_TEST_QUERY_EXPRESSION]) tf_py_test_dependencies_list = tf_py_test_dependencies.strip().split("\n") + tf_py_test_dependencies_list = [ + x.split()[0] for x in tf_py_test_dependencies.strip().split("\n") + ] print("Pytest dependency subset size: %d" % len(tf_py_test_dependencies_list)) missing_dependencies = [] @@ -141,9 +166,9 @@ def main(): for missing_dependency in missing_dependencies: print("\nMissing dependency: %s " % missing_dependency) print("Affected Tests:") - rdep_query = ("rdeps(kind(py_test, //tensorflow/python/...), %s)" % - missing_dependency) - affected_tests = bazel_query(rdep_query) + rdep_query = ("rdeps(kind(py_test, %s), %s)" % + (" + ".join(PYTHON_TARGETS), missing_dependency)) + affected_tests = subprocess.check_output(["bazel", "cquery", rdep_query]) affected_tests_list = affected_tests.split("\n")[:-2] print("\n".join(affected_tests_list)) -- GitLab From 985351dc1ab33cedbfd7790dd9cccc36d2d4b150 Mon Sep 17 00:00:00 2001 From: Tony Wang Date: Wed, 2 May 2018 21:37:04 -0700 Subject: [PATCH 070/839] Simplify getter and setter method for GraphOptimizationPass::name_ PiperOrigin-RevId: 195199912 --- .../core/common_runtime/optimization_registry.cc | 6 ++---- .../core/common_runtime/optimization_registry.h | 12 ++++++------ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/common_runtime/optimization_registry.cc b/tensorflow/core/common_runtime/optimization_registry.cc index bf49a758b2..6ac047295d 100644 --- a/tensorflow/core/common_runtime/optimization_registry.cc +++ b/tensorflow/core/common_runtime/optimization_registry.cc @@ -36,8 +36,7 @@ Status OptimizationPassRegistry::RunGrouping( for (auto& phase : group->second) { VLOG(1) << "Running optimization phase " << phase.first; for (auto& pass : phase.second) { - VLOG(1) << "Running optimization pass: " - << pass->GetOptimizationPassName(); + VLOG(1) << "Running optimization pass: " << pass->name(); Status s = pass->Run(options); if (!s.ok()) return s; } @@ -52,8 +51,7 @@ void OptimizationPassRegistry::LogGrouping(Grouping grouping, int vlog_level) { for (auto& phase : group->second) { for (auto& pass : phase.second) { VLOG(vlog_level) << "Registered optimization pass grouping " << grouping - << " phase " << phase.first << ": " - << pass->GetOptimizationPassName(); + << " phase " << phase.first << ": " << pass->name(); } } } diff --git a/tensorflow/core/common_runtime/optimization_registry.h b/tensorflow/core/common_runtime/optimization_registry.h index 1b535faf19..f5d265aa24 100644 --- a/tensorflow/core/common_runtime/optimization_registry.h +++ b/tensorflow/core/common_runtime/optimization_registry.h @@ -65,13 +65,13 @@ class GraphOptimizationPass { public: virtual ~GraphOptimizationPass() {} virtual Status Run(const GraphOptimizationPassOptions& options) = 0; - void SetOptimizationPassName(string name) { _optimization_pass_name = name; } - string GetOptimizationPassName() { return _optimization_pass_name; } + void set_name(const string& name) { name_ = name; } + string name() const { return name_; } private: - // The name of the opitmization pass, which is the same as the inherited class - // name. - string _optimization_pass_name; + // The name of the opitimization pass, which is the same as the inherited + // class name. + string name_; }; // The key is a 'phase' number. Phases are executed in increasing @@ -118,7 +118,7 @@ class OptimizationPassRegistration { int phase, std::unique_ptr pass, string optimization_pass_name) { - pass->SetOptimizationPassName(optimization_pass_name); + pass->set_name(optimization_pass_name); OptimizationPassRegistry::Global()->Register(grouping, phase, std::move(pass)); } -- GitLab From 283e8fe7e191f8e0e2ca6ea62b8b4553c30a6286 Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Thu, 3 May 2018 00:16:09 -0700 Subject: [PATCH 071/839] Use tensorflow size to determine number of elements instead of the static shape, which can sometimes be missing. PiperOrigin-RevId: 195209826 --- tensorflow/contrib/quantize/python/fold_batch_norms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index 1f286bc39a..76f695dce0 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -414,7 +414,8 @@ def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1, unused_2): x = op.inputs[0] - n = x.get_shape().num_elements() / grad_mean.get_shape().num_elements() + n = math_ops.cast( + array_ops.size(x) / array_ops.size(grad_mean), dtypes.float32) dmean_dx = grad_mean / n dvar_dx = 2 * grad_var * (x - op.outputs[1]) / (n - 1) return (dmean_dx + dvar_dx), None, None, None, None -- GitLab From 090d21c18f16303e740136e8a4e0f62c63df4e10 Mon Sep 17 00:00:00 2001 From: wangsiyu Date: Thu, 3 May 2018 18:31:29 +0800 Subject: [PATCH 072/839] fix bug of declaring regularization loss mutiple times when reusing partitioned variables in tf.layers --- tensorflow/python/layers/base.py | 13 ++++++++++++- tensorflow/python/layers/base_test.py | 15 +++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 64db49c900..c050e6be04 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -233,7 +233,8 @@ class Layer(base_layer.Layer): getter=vs.get_variable) if regularizer: - if context.executing_eagerly() or variable not in existing_variables: + if context.executing_eagerly() or _should_add_regularizer( + variable, existing_variables): self._handle_weight_regularization(name, variable, regularizer) if init_graph is not None: @@ -354,3 +355,13 @@ def _add_elements_to_collection(elements, collection_list): if element not in collection_set: collection.append(element) +def _should_add_regularizer(variable, existing_variable_set): + result = True + if isinstance(variable, tf_variables.PartitionedVariable): + for var in variable._get_variable_list(): + if var in existing_variable_set: + result = False + break + else: + result = variable not in existing_variable_set + return result diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index f08b552840..361e3de7aa 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -30,6 +30,7 @@ from tensorflow.python.layers import core as core_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope @@ -95,6 +96,20 @@ class BaseLayerTest(test.TestCase): regularizer=regularizer) self.assertEqual(len(layer.losses), 1) + def testReusePartitionedVaraiblesAndRegularizers(self): + regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 + partitioner = partitioned_variables.fixed_size_partitioner(3) + for i in xrange(2): + with variable_scope.variable_scope(variable_scope.get_variable_scope(), + partitioner=partitioner, + reuse=False if i == 0 else True): + layer = base_layers.Layer(name='my_layer') + variable = layer.add_variable( + 'reg_part_var', [4, 4], + initializer=init_ops.zeros_initializer(), + regularizer=regularizer) + self.assertEqual(len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 3) + def testNoEagerActivityRegularizer(self): with context.eager_mode(): with self.assertRaisesRegexp(ValueError, 'activity_regularizer'): -- GitLab From a2d35bddc7a1ab58b859ef396501472d7986ff0f Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Thu, 3 May 2018 07:50:13 -0700 Subject: [PATCH 073/839] Fix build dependency; add missing OpKernel; fix some formatting issues --- .../tensorrt/custom_plugin_examples/BUILD | 105 ++++++++---------- .../tensorrt/custom_plugin_examples/inc_op.py | 5 +- .../inc_op_kernel.cu.cc | 42 +++++++ .../custom_plugin_examples/inc_op_kernel.h | 2 +- .../{inc_op_plugin.cu.cc => inc_op_plugin.cc} | 13 ++- .../custom_plugin_examples/inc_op_plugin.h | 18 +-- .../custom_plugin_examples/plugin_test.py | 10 +- .../contrib/tensorrt/kernels/trt_engine_op.cc | 2 +- tensorflow/contrib/tensorrt/log/trt_logger.h | 2 +- .../tensorrt/plugin/trt_plugin_factory.cc | 6 + .../tensorrt/plugin/trt_plugin_factory.h | 12 +- .../tensorrt/plugin/trt_plugins_test.cc | 47 +++++--- 12 files changed, 162 insertions(+), 102 deletions(-) rename tensorflow/contrib/tensorrt/custom_plugin_examples/{inc_op_plugin.cu.cc => inc_op_plugin.cc} (90%) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index 3b1a7fb6f3..a45d4423bb 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -8,74 +8,39 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load( "//tensorflow:tensorflow.bzl", + "tf_copts", "tf_custom_op_library", - "tf_cuda_library", + "tf_custom_op_library_additional_deps", "tf_gen_op_libs", "tf_gen_op_wrapper_py", - "tf_py_wrap_cc", - "tf_copts", - "tf_py_test", + "tf_kernel_library", ) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load( "@local_config_tensorrt//:build_defs.bzl", "if_tensorrt", ) -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") - -tf_kernel_library( - name = "_inc_op_plugin_kernel", - gpu_srcs = [ - "inc_op_kernel.cu.cc", - "inc_op_kernel.h", - "inc_op_plugin.cu.cc", - "inc_op_plugin.h", - ], - deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) tf_gen_op_libs( op_lib_names = [ "inc_op", ], - deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), ) tf_gen_op_wrapper_py( name = "inc_op", - gen_locally = True, deps = [ ":inc_op_op_lib", ], ) -tf_py_wrap_cc( - name = "plugin_wrap", - srcs = [ - "plugin_wrap.i", - ], - copts = tf_copts(), - deps = [ - ":_inc_op_plugin_kernel", - "//tensorflow/core:framework_lite", - "//util/python:python_headers", - ], -) - tf_custom_op_library( name = "_inc_op.so", srcs = ["ops/inc_op.cc"], deps = [ "//tensorflow/core:lib_proto_parsing", - "//tensorflow/contrib/tensorrt:trt_plugins", ], ) @@ -85,6 +50,10 @@ tf_custom_op_py_library( dso = [ ":_inc_op.so", ], + kernels = [ + ":inc_op_op_lib", + ":inc_op_plugin_kernel", + ], srcs_version = "PY2AND3", deps = [ "//tensorflow/python:framework_for_generated_wrappers", @@ -101,30 +70,54 @@ py_library( ], ) -tf_py_test( - name = "plugin_test", - size = "small", - srcs = [ - "plugin_test.py", +tf_kernel_library( + name = "inc_op_plugin_kernel", + srcs = ["inc_op_plugin.cc"], + hdrs = [ + "inc_op_plugin.h", ], - additional_deps = [ - ":init_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/contrib/tensorrt:init_py", - "//tensorflow/python:platform", - "//tensorflow/python:client_testlib", - "//tensorflow/python:tf_optimizer", + gpu_srcs = [ + "inc_op_kernel.h", + "inc_op_kernel.cu.cc" + ], + deps = [ + "//tensorflow/contrib/tensorrt:trt_plugins", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]) + tf_custom_op_library_additional_deps(), +) + +tf_py_wrap_cc( + name = "plugin_wrap", + srcs = ["plugin_wrap.i"], + copts = tf_copts(), + deps = [ + ":inc_op_plugin_kernel", + "//tensorflow/core:framework_lite", + "//util/python:python_headers", ], ) py_library( name = "init_py", - srcs = [ - "__init__.py", - ], + srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ ":inc_op_py", ":plugin_wrap", ], ) + +tf_py_test( + name = "plugin_test", + size = "small", + srcs = ["plugin_test.py"], + additional_deps = [ + ":init_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/contrib/tensorrt:init_py", + "//tensorflow/python:platform", + "//tensorflow/python:client_testlib", + "//tensorflow/python:tf_optimizer", + ], +) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py index ef8e26fbde..47fd55e2f6 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py @@ -18,13 +18,14 @@ from __future__ import division from __future__ import print_function import platform -import os if platform.system() != "Windows": + # pylint: disable=g-import-not-at-top from tensorflow.contrib.util import loader from tensorflow.python.platform import resource_loader + # pylint: enable=g-import-not-at-top _inc_op = loader.load_op_library( - os.path.join(os.path.dirname(os.path.realpath(__file__)),"_inc_op.so")) + resource_loader.get_path_to_datafile("_inc_op.so")) else: raise RuntimeError("Windows not supported") diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index 38e1e01d95..ee9fbe0ea1 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -15,8 +15,14 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/stream_executor.h" + #if GOOGLE_CUDA #if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" namespace tensorflow { namespace tensorrt { @@ -35,6 +41,42 @@ void IncrementKernel(const float* d_input, float inc, float* d_output, d_output, count); } +// Note: this kernel definition is not needed in the plugin_test rule, but it is +// required for correctness of the TF program, i.e. if not using plugin or when +// run with trt optimization pass, the test should work. +class IncPluginTRT : public OpKernel { + public: + explicit IncPluginTRT(OpKernelConstruction* context) : OpKernel(context) { + std::vector inc_list; + OP_REQUIRES_OK(context, context->GetAttr("inc", &inc_list)); + OP_REQUIRES(context, inc_list.size() == 1, + errors::InvalidArgument( + "The increment list should contain single element.")); + inc_ = inc_list[0]; + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_tensor = context->input(0); + const TensorShape& input_shape = input_tensor.shape(); + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input_shape, &output_tensor)); + const cudaStream_t* stream = CHECK_NOTNULL( + reinterpret_cast(context->op_device_context() + ->stream() + ->implementation() + ->CudaStreamMemberHack())); + IncrementKernel(input_tensor.flat().data(), inc_, + output_tensor->flat().data(), + input_shape.num_elements(), *stream); + } + + private: + float inc_; +}; + +REGISTER_KERNEL_BUILDER(Name("IncPluginTRT").Device(DEVICE_GPU), IncPluginTRT); + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h index 13156dad8f..1d0ec0b6b0 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h @@ -18,11 +18,11 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" namespace tensorflow { namespace tensorrt { -__global__ void VecInc(float* vec, float inc, float* dest, int n); void IncrementKernel(const float* d_input, float inc, float* d_output, int count, cudaStream_t stream); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc similarity index 90% rename from tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc rename to tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc index 508ced587b..489bc15def 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" -#include #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA @@ -24,7 +23,7 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -const string IncOpPlugin::plugin_name_ = "IncPluginTRT"; +const char* kPluginName = "IncPluginTRT"; IncOpPlugin* CreateIncPlugin() { return new IncOpPlugin(); } @@ -33,14 +32,16 @@ IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { } bool RegisterIncOpPlugin() { - if (PluginFactoryTensorRT::GetInstance()->IsPlugin(IncOpPlugin::plugin_name_)) + if (PluginFactoryTensorRT::GetInstance()->IsPlugin(kPluginName)) return false; return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( - IncOpPlugin::plugin_name_, CreateIncPluginDeserialize, CreateIncPlugin); + kPluginName, CreateIncPluginDeserialize, CreateIncPlugin); } +IncOpPlugin::IncOpPlugin() : plugin_name_(kPluginName) {} + IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) - : PluginTensorRT(serialized_data, length) { + : PluginTensorRT(serialized_data, length), plugin_name_(kPluginName) { // account for the consumed pointer. size_t consumed_data = PluginTensorRT::getSerializationSize(); assert(length - consumed_data >= sizeof(float)); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index 87404a755c..0676abe768 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -29,13 +29,17 @@ namespace tensorrt { class IncOpPlugin : public PluginTensorRT { public: - static const string plugin_name_; - IncOpPlugin() {}; + IncOpPlugin(); + IncOpPlugin(const void* serialized_data, size_t length); + const string& GetPluginName() const override { return plugin_name_; }; + bool Finalize() override { return true; }; + bool SetAttribute(const string& key, const void* ptr, const size_t size) override; + bool GetAttribute(const string& key, const void** ptr, size_t* size) const override; @@ -71,14 +75,11 @@ class IncOpPlugin : public PluginTensorRT { } void serialize(void* buffer) override { - // serializa parent stuff - // OpName + // Serialize parent data. PluginTensorRT::serialize(buffer); - - // incremented buffer after parent serialization; + // Incremented buffer after parent serialization. buffer = static_cast(buffer) + PluginTensorRT::getSerializationSize(); - std::memcpy(buffer, &inc_, sizeof(float)); buffer = static_cast(buffer) + sizeof(float); } @@ -86,6 +87,9 @@ class IncOpPlugin : public PluginTensorRT { protected: float inc_; nvinfer1::Dims dim_; + + private: + const string plugin_name_; }; IncOpPlugin* CreateIncPlugin(); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py index 9f773c66a9..d1815fdf33 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py @@ -39,6 +39,7 @@ import numpy # the python api handles registration to the plugin factory from tensorflow.contrib.tensorrt import custom_plugin_examples + def get_plugin_graph_def(): """Create a simple graph and return its graph_def.""" g = ops.Graph() @@ -49,15 +50,16 @@ def get_plugin_graph_def(): v = nn_ops.max_pool( relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") - # insert custom_op in the graph + # insert custom_op in the graph v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test") - v = v*2.0 + v = v * 2.0 v = nn.relu(v) v = nn.relu(v) array_ops.squeeze(v, name="output") return g.as_graph_def() + def run_graph(gdef, dumm_inp): """Run given graphdef once.""" gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50) @@ -74,6 +76,7 @@ def run_graph(gdef, dumm_inp): val = sess.run(out, {inp: dumm_inp}) return val + if "__main__" in __name__: inp_dims = (5, 24, 24, 2) dummy_input = numpy.ones(inp_dims).astype(numpy.float32) @@ -88,8 +91,7 @@ if "__main__" in __name__: max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="FP32", - minimum_segment_size=2 - ) + minimum_segment_size=2) o2 = run_graph(trt_graph, dummy_input) if o2.reshape([-1])[0] == 35: print("pass") diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index c39bc12f73..71453631e2 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h index 3495dc6318..96ccacb791 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.h +++ b/tensorflow/contrib/tensorrt/log/trt_logger.h @@ -28,7 +28,7 @@ namespace tensorrt { // Logger for GIE info/warning/errors class Logger : public nvinfer1::ILogger { public: - Logger(string name = "DefaultLogger") : name_(name) {}; + Logger(string name = "DefaultLogger") : name_(name) {} void log(nvinfer1::ILogger::Severity severity, const char* msg) override; private: diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index 736a1321fe..b608e602a7 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -21,6 +21,12 @@ limitations under the License. namespace tensorflow { namespace tensorrt { +PluginFactoryTensorRT* PluginFactoryTensorRT::GetInstance() { + static PluginFactoryTensorRT* factory_instance = + new PluginFactoryTensorRT(); + return factory_instance; +} + PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, const void* serial_data, size_t serial_length) { diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 4e4a3af4ca..a088ffb842 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -31,19 +31,15 @@ namespace tensorrt { class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { public: - // deserialization method + static PluginFactoryTensorRT* GetInstance(); + + // Deserialization method PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data, size_t serial_length) override; - // plugin construction, PluginFactoryTensorRT owns the plugin; + // Plugin construction, PluginFactoryTensorRT owns the plugin. PluginTensorRT* CreatePlugin(const string& op_name); - static PluginFactoryTensorRT* GetInstance() { - static PluginFactoryTensorRT* factory_instance = - new PluginFactoryTensorRT(); - return factory_instance; - } - bool RegisterPlugin(const string& op_name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc index b834c5511f..ae5a3e8742 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/test.h" @@ -20,8 +22,6 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorrt/include/NvInfer.h" namespace tensorflow { @@ -30,34 +30,49 @@ namespace test { class StubPlugin : public PluginTensorRT { public: - static const string plugin_name_; - StubPlugin() {}; + static const char* kPluginName; + + StubPlugin() : plugin_name_(kPluginName) {} + StubPlugin(const void* serialized_data, size_t length) - : PluginTensorRT(serialized_data, length) {}; - const string& GetPluginName() override { return plugin_name_; }; - virtual bool Finalize() { return true; }; + : PluginTensorRT(serialized_data, length) {} + + const string& GetPluginName() override { return plugin_name_; } + + virtual bool Finalize() { return true; } + virtual bool SetAttribute(const string& key, const void* ptr, const size_t size) { return true; - }; + } + virtual bool GetAttribute(const string& key, const void* ptr, size_t& size) { return true; - }; + } + int getNbOutputs() const override { return 1; } + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override { return inputs[0]; } + int initialize() override { return 0; } + void terminate() override {} + size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } + int enqueue(int batch_size, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override { return 0; } + + private: + const string plugin_name_; }; -const string StubPlugin::plugin_name_ = "StubPlugin"; +const char* StubPlugin::kPluginName = "StubPlugin"; StubPlugin* CreateStubPlugin() { return new StubPlugin(); } @@ -70,32 +85,32 @@ class PluginTest : public ::testing::Test { public: bool RegisterStubPlugin() { if (PluginFactoryTensorRT::GetInstance()->IsPlugin( - StubPlugin::plugin_name_)) { + StubPlugin::kPluginName)) { return true; } return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( - StubPlugin::plugin_name_, CreateStubPluginDeserialize, + StubPlugin::kPluginName, CreateStubPluginDeserialize, CreateStubPlugin); } }; TEST_F(PluginTest, Registration) { EXPECT_FALSE( - PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); EXPECT_TRUE(RegisterStubPlugin()); ASSERT_TRUE( - PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); } TEST_F(PluginTest, CreationDeletion) { EXPECT_TRUE(RegisterStubPlugin()); ASSERT_TRUE( - PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); PluginFactoryTensorRT::GetInstance()->DestroyPlugins(); ASSERT_TRUE(PluginFactoryTensorRT::GetInstance()->CreatePlugin( - StubPlugin::plugin_name_)); + StubPlugin::kPluginName)); ASSERT_EQ(1, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins()); PluginFactoryTensorRT::GetInstance()->DestroyPlugins(); ASSERT_EQ(0, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins()); -- GitLab From 03de4a4a6cbfab49c2921d0cac5ccac31c0815f8 Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Thu, 3 May 2018 08:20:51 -0700 Subject: [PATCH 074/839] Move/rename the plugin factory test file; delete duplicate test file; fix minor formatting issues. --- tensorflow/contrib/tensorrt/BUILD | 10 ++- .../tensorrt/custom_plugin_examples/BUILD | 6 +- .../custom_plugin_examples/plugin_test.py | 5 -- .../contrib/tensorrt/plugin/trt_plugin.h | 6 +- .../tensorrt/plugin/trt_plugin_factory.h | 9 +- ...ins_test.cc => trt_plugin_factory_test.cc} | 6 +- .../tensorrt/plugin/trt_plugin_utils.h | 1 + tensorflow/contrib/tensorrt/plugin_test.py | 88 ------------------- 8 files changed, 26 insertions(+), 105 deletions(-) rename tensorflow/contrib/tensorrt/plugin/{trt_plugins_test.cc => trt_plugin_factory_test.cc} (96%) delete mode 100644 tensorflow/contrib/tensorrt/plugin_test.py diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 5fda11eccb..79e525edae 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -282,7 +282,7 @@ tf_cc_test( ], ) -# Library for the plugin factory +# Library for the plugin factory tf_cuda_library( name = "trt_plugins", srcs = [ @@ -304,9 +304,13 @@ tf_cuda_library( ) tf_cuda_cc_test( - name = "trt_plugins_test", + name = "trt_plugin_factory_test", size = "small", - srcs = ["plugin/trt_plugins_test.cc"], + srcs = ["plugin/trt_plugin_factory_test.cc"], + tags = [ + "manual", + "notap", + ], deps = [ ":trt_plugins", "//tensorflow/core:test", diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index a45d4423bb..c68e69457d 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -78,7 +78,7 @@ tf_kernel_library( ], gpu_srcs = [ "inc_op_kernel.h", - "inc_op_kernel.cu.cc" + "inc_op_kernel.cu.cc", ], deps = [ "//tensorflow/contrib/tensorrt:trt_plugins", @@ -120,4 +120,8 @@ tf_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:tf_optimizer", ], + tags = [ + "manual", + "notap", + ], ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py index d1815fdf33..cb40e08493 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py @@ -18,11 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# normally we should do import tensorflow as tf and then -# tf.placeholder, tf.constant, tf.nn.conv2d etc but -# it looks like internal builds don't like it so -# importing every module individually - from tensorflow.contrib import tensorrt from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h index dca377c2d2..d80ec44372 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include + #include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA @@ -35,9 +36,11 @@ namespace tensorrt { // PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT class PluginTensorRT : public nvinfer1::IPlugin { public: - PluginTensorRT() {}; + PluginTensorRT() {} PluginTensorRT(const void* serialized_data, size_t length); + virtual const string& GetPluginName() const = 0; + virtual bool Finalize() = 0; virtual bool SetAttribute(const string& key, const void* ptr, @@ -53,6 +56,7 @@ class PluginTensorRT : public nvinfer1::IPlugin { const size_t size); virtual size_t getSerializationSize() override; + virtual void serialize(void* buffer) override; protected: diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index a088ffb842..6d2992bbbb 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -19,8 +19,9 @@ limitations under the License. #include #include #include -#include "trt_plugin.h" -#include "trt_plugin_utils.h" + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -54,12 +55,12 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { protected: std::unordered_map > + std::pair> plugin_registry_; // TODO(jie): Owned plugin should be associated with different sessions; // should really hand ownership of plugins to resource management; - std::vector > owned_plugins_; + std::vector> owned_plugins_; std::mutex instance_m_; }; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc similarity index 96% rename from tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc rename to tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc index ae5a3e8742..c5b0e75eb1 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc @@ -81,7 +81,7 @@ StubPlugin* CreateStubPluginDeserialize(const void* serialized_data, return new StubPlugin(serialized_data, length); } -class PluginTest : public ::testing::Test { +class TrtPluginFactoryTest : public ::testing::Test { public: bool RegisterStubPlugin() { if (PluginFactoryTensorRT::GetInstance()->IsPlugin( @@ -94,7 +94,7 @@ class PluginTest : public ::testing::Test { } }; -TEST_F(PluginTest, Registration) { +TEST_F(TrtPluginFactoryTest, Registration) { EXPECT_FALSE( PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); EXPECT_TRUE(RegisterStubPlugin()); @@ -103,7 +103,7 @@ TEST_F(PluginTest, Registration) { PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); } -TEST_F(PluginTest, CreationDeletion) { +TEST_F(TrtPluginFactoryTest, CreationDeletion) { EXPECT_TRUE(RegisterStubPlugin()); ASSERT_TRUE( PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h index a94c67bba0..4ff6fbedb4 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS #include + #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/contrib/tensorrt/plugin_test.py b/tensorflow/contrib/tensorrt/plugin_test.py deleted file mode 100644 index 7c3e765bff..0000000000 --- a/tensorflow/contrib/tensorrt/plugin_test.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Script to show usage of TensorRT custom op & plugin.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib import tensorrt -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import importer -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops -import numpy as np - -# import custom_op as plugin op -# the python api handles registration to the plugin factory -from tensorflow.contrib.tensorrt import custom_plugin_examples - -def get_plugin_graph_def(): - """Create a simple graph and return its graph_def.""" - g = ops.Graph() - with g.as_default(): - a = array_ops.placeholder( - dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") - relu = nn.relu(a, "relu") - v = nn_ops.max_pool( - relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") - - # insert custom_op in the graph - v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test") - - v = v*2.0 - v = nn.relu(v) - v = nn.relu(v) - array_ops.squeeze(v, name="output") - return g.as_graph_def() - -def run_graph(gdef, dumm_inp): - """Run given graphdef once.""" - gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - ops.reset_default_graph() - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) - inp = inp.outputs[0] - out = out.outputs[0] - - with session.Session( - config=config_pb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: - val = sess.run(out, {inp: dumm_inp}) - return val - -if "__main__" in __name__: - inp_dims = (5, 24, 24, 2) - dummy_input = np.ones(inp_dims).astype(np.float32) - orig_graph = get_plugin_graph_def() # graph with plugin node - - # trigger conversion. - # plugin nodes have been registered during import, converter will be able to - # create corresponding plugin layer during conversion. - trt_graph = tensorrt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="FP32", - minimum_segment_size=2 - ) - o2 = run_graph(trt_graph, dummy_input) - print (o2) -- GitLab From a88a7e312581ba0c2173188019a420c888df9a10 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 May 2018 09:10:06 -0700 Subject: [PATCH 075/839] Post-transform pass to dedupe large constant arrays. PiperOrigin-RevId: 195260578 --- tensorflow/contrib/lite/toco/args.h | 1 + .../contrib/lite/toco/toco_cmdline_flags.cc | 6 + tensorflow/contrib/lite/toco/toco_flags.proto | 6 +- tensorflow/contrib/lite/toco/toco_tooling.cc | 5 + tensorflow/contrib/lite/toco/tooling_util.cc | 130 ++++++++++++++++++ tensorflow/contrib/lite/toco/tooling_util.h | 16 +++ 6 files changed, 163 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index fe30b88344..6c0311af0a 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -241,6 +241,7 @@ struct ParsedTocoFlags { Arg drop_control_dependency = Arg(false); Arg propagate_fake_quant_num_bits = Arg(false); Arg allow_nudging_weights_to_use_fast_gemm_kernel = Arg(false); + Arg dedupe_array_min_size_bytes = Arg(64); }; } // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index 1611c4d0c0..7786a4ada3 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -148,6 +148,11 @@ bool ParseTocoFlagsFromCommandLineFlags( "Some fast uint8 GEMM kernels require uint8 weights to avoid the " "value 0. This flag allows nudging them to 1 to allow proceeding, " "with moderate inaccuracy."), + Flag("dedupe_array_min_size_bytes", + parsed_flags.dedupe_array_min_size_bytes.bind(), + parsed_flags.dedupe_array_min_size_bytes.default_value(), + "Minimum size of constant arrays to deduplicate; arrays smaller " + "will not be deduplicated."), }; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); @@ -239,6 +244,7 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, READ_TOCO_FLAG(propagate_fake_quant_num_bits, FlagRequirement::kNone); READ_TOCO_FLAG(allow_nudging_weights_to_use_fast_gemm_kernel, FlagRequirement::kNone); + READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone); // Deprecated flag handling. if (parsed_toco_flags.input_type.specified()) { diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index a04017a6bf..253f022e6b 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -37,7 +37,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 18. +// Next ID to use: 19. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -161,4 +161,8 @@ message TocoFlags { // This flag allows nudging them to 1 to allow proceeding, with moderate // inaccuracy. optional bool allow_nudging_weights_to_use_fast_gemm_kernel = 17; + + // Minimum size of constant arrays to deduplicate; arrays smaller will not be + // deduplicated. + optional int64 dedupe_array_min_size_bytes = 18 [default = 64]; } diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 7252ec2ea4..6973b22c5a 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -345,6 +345,11 @@ void Transform(const TocoFlags& toco_flags, Model* model) { EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model); } + // Deduplicate large constant arrays. + if (toco_flags.has_dedupe_array_min_size_bytes()) { + DedupeConstantArrays(model, toco_flags.dedupe_array_min_size_bytes()); + } + LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model); if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) { diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 11293a5fe5..86ee1f3761 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -260,6 +260,23 @@ Operator* GetFirstOpWithInput(const Model& model, const string& array_name) { return it == model.operators.end() ? nullptr : it->get(); } +void ReplaceArrayUsage(Model* model, const string& old_array_name, + const string& new_array_name) { + for (auto& op_it : model->operators) { + Operator* op = op_it.get(); + for (size_t i = 0; i < op->inputs.size(); ++i) { + if (op->inputs[i] == old_array_name) { + op->inputs[i] = new_array_name; + } + } + for (size_t i = 0; i < op->outputs.size(); ++i) { + if (op->outputs[i] == old_array_name) { + op->outputs[i] = new_array_name; + } + } + } +} + string FormatArraysList(const Model& model, const std::vector& list) { if (list.empty()) { return "[]"; @@ -648,6 +665,65 @@ bool IsConstantParameterArray(const Model& model, const string& name) { return !!model.GetArray(name).buffer; } +namespace { +template +bool CompareArrayBuffers(const Array& lhs_array, const Array& rhs_array) { + CHECK(lhs_array.data_type == rhs_array.data_type) << "Data types must match"; + CHECK(lhs_array.buffer) << "LHS must be constant"; + CHECK(rhs_array.buffer) << "RHS must be constant"; + const auto& lhs_data = lhs_array.GetBuffer().data; + const auto& rhs_data = rhs_array.GetBuffer().data; + CHECK_EQ(lhs_data.size(), rhs_data.size()) + << "Buffer sizes must match in element count"; + for (int i = 0; i < lhs_data.size(); ++i) { + if (lhs_data[i] != rhs_data[i]) { + return false; + } + } + return true; +} +} // namespace + +bool CompareConstantArrays(const Array& lhs_array, const Array& rhs_array) { + bool attrs_equal = + lhs_array.shape() == rhs_array.shape() && + lhs_array.data_type == rhs_array.data_type && + lhs_array.final_data_type == rhs_array.final_data_type && + lhs_array.minmax == rhs_array.minmax && + lhs_array.quantization_params == rhs_array.quantization_params; + if (!attrs_equal) { + return false; + } + switch (lhs_array.data_type) { + case ArrayDataType::kBool: + return CompareArrayBuffers(lhs_array, rhs_array); + case ArrayDataType::kFloat: + return CompareArrayBuffers(lhs_array, rhs_array); + case ArrayDataType::kInt8: + return CompareArrayBuffers(lhs_array, rhs_array); + case ArrayDataType::kUint8: + return CompareArrayBuffers(lhs_array, rhs_array); + case ArrayDataType::kInt16: + return CompareArrayBuffers(lhs_array, rhs_array); + case ArrayDataType::kUint16: + return CompareArrayBuffers(lhs_array, rhs_array); + case ArrayDataType::kInt32: + return CompareArrayBuffers(lhs_array, rhs_array); + case ArrayDataType::kUint32: + return CompareArrayBuffers(lhs_array, rhs_array); + case ArrayDataType::kInt64: + return CompareArrayBuffers(lhs_array, rhs_array); + case ArrayDataType::kUint64: + return CompareArrayBuffers(lhs_array, rhs_array); + case ArrayDataType::kString: + return CompareArrayBuffers(lhs_array, rhs_array); + default: + LOG(FATAL) << "Unsupported data type: " + << ArrayDataTypeName(lhs_array.data_type); + return false; + } +} + namespace { // Take an array name, which may be something like "name:3_5" and make it // acceptable as a TF node name, say "name_3_5"; @@ -1072,6 +1148,60 @@ void FixEdgeArrays(Model* model) { } } +void DedupeConstantArrays(Model* model, size_t min_size) { + // Walk all 0..N and compare with the remaining n+1..N. + // This lets us avoid N^2 comparisions and erase duplicate arrays while + // iterating. + const auto& array_map = model->GetArrayMap(); + for (auto lhs_array_it = array_map.begin(); lhs_array_it != array_map.end(); + ++lhs_array_it) { + const auto& lhs_array_name = lhs_array_it->first; + const auto& lhs_array = *lhs_array_it->second; + if (!IsConstantParameterArray(*model, lhs_array_name)) { + // Not a constant array; skip. + continue; + } + ArrayDataType final_data_type = + lhs_array.final_data_type != ArrayDataType::kNone + ? lhs_array.final_data_type + : lhs_array.data_type; + size_t array_byte_size = + lhs_array.buffer->Length() * ElementSize(final_data_type); + if (array_byte_size < min_size) { + // Too small; skip. + continue; + } + + auto next_lhs_array_it = lhs_array_it; + ++next_lhs_array_it; + for (auto rhs_array_it = next_lhs_array_it; + rhs_array_it != array_map.end();) { + const auto& rhs_array_name = rhs_array_it->first; + const auto& rhs_array = *rhs_array_it->second; + ++rhs_array_it; + if (!IsConstantParameterArray(*model, rhs_array_name)) { + // Not a constant array; skip. + continue; + } + if (!IsDiscardableArray(*model, rhs_array_name)) { + // Can't remove the array as it's not discardable (such as an IO edge). + continue; + } + if (!CompareConstantArrays(lhs_array, rhs_array)) { + // Arrays aren't equal; skip. + continue; + } + + // Arrays can be deduped! + VLOG(1) << "Deduplicating arrays; using " << lhs_array_name + << " in place of " << rhs_array_name; + ReplaceArrayUsage(model, rhs_array_name, lhs_array_name); + // Note: rhs_array_it above is already incremented so this is safe. + model->EraseArray(rhs_array_name); + } + } +} + namespace { void CopyArrayAttribs(const Array& source_array, Array* target_array) { target_array->data_type = source_array.data_type; diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index f5b596df0f..1f596ca8e5 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -88,6 +88,10 @@ std::vector>::iterator FindOpWithInput( Operator* GetOpWithInput(const Model& model, const string& array_name); Operator* GetFirstOpWithInput(const Model& model, const string& array_name); +// Replaces all uses of the |old_array_name| with the |new_array_name|. +void ReplaceArrayUsage(Model* model, const string& old_array_name, + const string& new_array_name); + std::vector>::const_iterator FindOp( const Model& model, const Operator* op); std::vector>::iterator FindOp(Model& model, @@ -138,6 +142,9 @@ int RequiredBufferSizeForShape(const Shape& shape); bool IsConstantParameterArray(const Model& model, const string& name); +// Compares two constant parameter arrays for exact equality. +bool CompareConstantArrays(const Array& lhs_array, const Array& rhs_array); + void CheckNoMissingArray(const Model& model); void CheckInvariants(const Model& model); @@ -150,6 +157,15 @@ void FixNoOrphanedArray(Model* model); // Fixes input/output arrays that may have issues during export or inference. void FixEdgeArrays(Model* model); +// Finds and deduplicates large constant arrays in the model. +// After constant propagation runs it's possible to end up with several of the +// same large array (whether they be zeros or otherwise). +// +// |min_size| is used to adjust the minimum size in bytes of an array before +// it's considered for deduping. As deduping can make the graphs more difficult +// to read this helps prevent small arrays from spidering out. +void DedupeConstantArrays(Model* model, size_t min_size); + // Copies the contents of an array into another. // Expects that the shape and data type match. template -- GitLab From eb88fd1ef5505e3f8617cc7105052fbce0e4af9e Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Thu, 3 May 2018 11:17:10 -0700 Subject: [PATCH 076/839] Add a macro for registering the plugin so we don't need to depend on swig; remove the swig file; fix build dependencies; fix tf_custom_op_library by adding GOOGLE_TENSORRT macro when gpu_srcs is not empty. --- .../tensorrt/custom_plugin_examples/BUILD | 69 +++++++++---------- .../custom_plugin_examples/__init__.py | 2 - .../inc_op_kernel.cu.cc | 1 + .../custom_plugin_examples/inc_op_plugin.cc | 7 +- .../custom_plugin_examples/inc_op_plugin.h | 5 +- .../custom_plugin_examples/plugin_wrap.i | 31 --------- .../tensorrt/plugin/trt_plugin_factory.h | 25 +++++++ tensorflow/tensorflow.bzl | 2 +- 8 files changed, 61 insertions(+), 81 deletions(-) delete mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index c68e69457d..e623b54781 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -24,24 +24,48 @@ load( ) tf_gen_op_libs( - op_lib_names = [ - "inc_op", - ], + op_lib_names = ["inc_op"], ) tf_gen_op_wrapper_py( name = "inc_op", - deps = [ - ":inc_op_op_lib", - ], + deps = [":inc_op_op_lib"], ) tf_custom_op_library( name = "_inc_op.so", - srcs = ["ops/inc_op.cc"], + srcs = [ + "inc_op_kernel.h", + "inc_op_plugin.cc", + "inc_op_plugin.h", + "ops/inc_op.cc", + ], + gpu_srcs = [ + "inc_op_kernel.h", + "inc_op_kernel.cu.cc", + ], deps = [ - "//tensorflow/core:lib_proto_parsing", + "//tensorflow/contrib/tensorrt:trt_plugins", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + +tf_kernel_library( + name = "inc_op_plugin_kernel", + srcs = ["inc_op_plugin.cc"], + hdrs = [ + "inc_op_plugin.h", + ], + gpu_srcs = [ + "inc_op_kernel.h", + "inc_op_kernel.cu.cc", ], + deps = [ + "//tensorflow/contrib/tensorrt:trt_plugins", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]) + tf_custom_op_library_additional_deps(), ) tf_custom_op_py_library( @@ -70,41 +94,12 @@ py_library( ], ) -tf_kernel_library( - name = "inc_op_plugin_kernel", - srcs = ["inc_op_plugin.cc"], - hdrs = [ - "inc_op_plugin.h", - ], - gpu_srcs = [ - "inc_op_kernel.h", - "inc_op_kernel.cu.cc", - ], - deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]) + tf_custom_op_library_additional_deps(), -) - -tf_py_wrap_cc( - name = "plugin_wrap", - srcs = ["plugin_wrap.i"], - copts = tf_copts(), - deps = [ - ":inc_op_plugin_kernel", - "//tensorflow/core:framework_lite", - "//util/python:python_headers", - ], -) - py_library( name = "init_py", srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ ":inc_op_py", - ":plugin_wrap", ], ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py index e4cd0ae8a0..e06904ab56 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py @@ -19,8 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.tensorrt.custom_plugin_examples.ops import gen_inc_op -from tensorflow.contrib.tensorrt.custom_plugin_examples.plugin_wrap import inc_op_register from tensorflow.contrib.tensorrt.custom_plugin_examples import inc_op as import_inc_op_so inc_op = gen_inc_op.inc_plugin_trt -inc_op_register() diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index ee9fbe0ea1..abbc0c5680 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -24,6 +24,7 @@ limitations under the License. #if GOOGLE_TENSORRT #include "cuda/include/cuda_runtime_api.h" + namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc index 489bc15def..d56aedc6d4 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -31,12 +31,7 @@ IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { return new IncOpPlugin(buffer, length); } -bool RegisterIncOpPlugin() { - if (PluginFactoryTensorRT::GetInstance()->IsPlugin(kPluginName)) - return false; - return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( - kPluginName, CreateIncPluginDeserialize, CreateIncPlugin); -} +REGISTER_TRT_PLUGIN(kPluginName, CreateIncPluginDeserialize, CreateIncPlugin); IncOpPlugin::IncOpPlugin() : plugin_name_(kPluginName) {} diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index 0676abe768..60153546d2 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -18,6 +18,7 @@ limitations under the License. #include #include + #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #if GOOGLE_CUDA @@ -92,10 +93,6 @@ class IncOpPlugin : public PluginTensorRT { const string plugin_name_; }; -IncOpPlugin* CreateIncPlugin(); -IncOpPlugin* CreateIncPluginDeserialize(const void*, size_t); -bool RegisterIncOpPlugin(); - } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i deleted file mode 100644 index 9882daa842..0000000000 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/* Wrap inc_op_plugin */ -%module inc_op_plugin -%{ -#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" -extern bool tensorflow::tensorrt::RegisterIncOpPlugin(); -%} - -%{ -bool inc_op_register() { - return tensorflow::tensorrt::RegisterIncOpPlugin(); -} -%} - -extern bool tensorflow::tensorrt::RegisterIncOpPlugin(); - -bool inc_op_register(); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 6d2992bbbb..54fbca5930 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -64,6 +66,29 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { std::mutex instance_m_; }; +class TrtPluginRegistrar { + public: + TrtPluginRegistrar(const string& name, + PluginDeserializeFunc deserialize_func, + PluginConstructFunc construct_func) { + auto factory = PluginFactoryTensorRT::GetInstance(); + QCHECK(factory->RegisterPlugin(name, deserialize_func, construct_func)) + << "Failed to register plugin: " << name; + } +}; + +#define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func) \ + REGISTER_TRT_PLUGIN_UNIQ_HELPER( \ + __COUNTER__, name, deserialize_func, construct_func) +#define REGISTER_TRT_PLUGIN_UNIQ_HELPER( \ + ctr, name, deserialize_func, construct_func) \ + REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) +#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \ + static ::tensorflow::tensorrt::TrtPluginRegistrar \ + trt_plugin_registrar##ctr TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::tensorrt::TrtPluginRegistrar( \ + name, deserialize_func, construct_func) + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index e5cc886b32..c27f894365 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1309,7 +1309,7 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]): native.cc_library( name=basename + "_gpu", srcs=gpu_srcs, - copts=_cuda_copts(), + copts=_cuda_copts() + if_tensorrt(["-DGOOGLE_TENSORRT=1"]), deps=deps + if_cuda(cuda_deps)) cuda_deps.extend([":" + basename + "_gpu"]) -- GitLab From 85a47596caf89705aae8ffcb57fcdaecb22fe356 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 May 2018 11:16:06 -0700 Subject: [PATCH 077/839] [XLA] Redesign: add ExecuteGraph to grpc service. PiperOrigin-RevId: 195281004 --- tensorflow/compiler/xla/rpc/BUILD | 2 +- tensorflow/compiler/xla/rpc/grpc_client_test.cc | 4 ++-- tensorflow/compiler/xla/rpc/grpc_service.cc | 7 +++++++ tensorflow/compiler/xla/rpc/grpc_service.h | 4 ++++ third_party/libxsmm.BUILD | 2 +- 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 977f863787..0d56a9a477 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -55,7 +55,7 @@ tf_cc_test( deps = [ ":grpc_stub", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index b559ee4b5a..10997c0719 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "grpc++/security/credentials.h" #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_stub.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/lib/io/path.h" @@ -84,7 +84,7 @@ TEST_F(GRPCClientTestBase, ItsAlive) { } TEST_F(GRPCClientTestBase, AxpyTenValues) { - ComputationBuilder builder(client_.get(), "axpy_10"); + XlaBuilder builder("axpy_10"); auto alpha = builder.ConstantR0(3.1415926535); auto x = builder.ConstantR1( {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index 0b100bd108..ffb72fc73c 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -75,6 +75,13 @@ namespace xla { [this, arg, result]() { return service_->Execute(arg, result); }); } +::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/, + const ExecuteGraphRequest* arg, + ExecuteResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->ExecuteGraph(arg, result); }); +} + ::grpc::Status GRPCService::ExecuteAsync(::grpc::ServerContext* context, const ExecuteAsyncRequest* arg, ExecuteAsyncResponse* result) { diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index fad74375bd..50f02796f2 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -54,6 +54,10 @@ class GRPCService : public grpc::XlaService::Service { const ExecuteRequest* arg, ExecuteResponse* result) override; + ::grpc::Status ExecuteGraph(::grpc::ServerContext* context, + const ExecuteGraphRequest* arg, + ExecuteResponse* result) override; + ::grpc::Status ExecuteAsync(::grpc::ServerContext* context, const ExecuteAsyncRequest* arg, ExecuteAsyncResponse* result) override; diff --git a/third_party/libxsmm.BUILD b/third_party/libxsmm.BUILD index 4124f2db63..78ed1f4e16 100644 --- a/third_party/libxsmm.BUILD +++ b/third_party/libxsmm.BUILD @@ -38,8 +38,8 @@ genrule( ":libxsmm_interface", ], visibility = [ - "//tensorflow/core/kernels:__pkg__", "//third_party/eigen3:__pkg__", + "//tensorflow/core/kernels:__pkg__", ], ) -- GitLab From a16ba4fc0d3faec077c689f3f361264978a2d3cb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 May 2018 12:00:57 -0700 Subject: [PATCH 078/839] Do not delegate temporary tensors to NNAPI. - also added delegation for MUL, and set the default scale to be 0.0f. PiperOrigin-RevId: 195288948 --- tensorflow/contrib/lite/nnapi_delegate.cc | 39 +++++++++++++++++++---- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 6a78f30fd1..e1895dd38e 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -72,11 +72,23 @@ NNAPIDelegate::~NNAPIDelegate() { // Adds the tensors of the interpreter to the NN API model. // Returns the number of operands added. uint32_t addTensorOperands(tflite::Interpreter* interpreter, - ANeuralNetworksModel* nn_model) { + ANeuralNetworksModel* nn_model, + const std::vector& skip_list) { uint32_t next_id = 0; for (size_t i = 0; i < interpreter->tensors_size(); i++) { + // skip temporaries tensors. + bool shouldSkip = false; + for (auto skip_idx : skip_list) { + if (i == skip_idx) { + shouldSkip = true; + break; + } + } + if (shouldSkip) continue; + int32_t nn_type = 0; - float scale = 1.0f; + // NNAPI requires 32-bit float scale to be zero, tflite doesn't care + float scale = 0.0f; int32_t zeroPoint = 0; TfLiteTensor* tensor = interpreter->tensor(i); switch (tensor->type) { @@ -116,11 +128,11 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter, if (const NNAPIAllocation* alloc = dynamic_cast( static_cast(tensor->allocation))) { CHECK_NN(ANeuralNetworksModel_setOperandValueFromMemory( - nn_model, i, alloc->memory(), alloc->offset(tensor->data.raw), + nn_model, next_id, alloc->memory(), alloc->offset(tensor->data.raw), tensor->bytes)); } else { CHECK_NN(ANeuralNetworksModel_setOperandValue( - nn_model, i, tensor->data.raw, tensor->bytes)); + nn_model, next_id, tensor->data.raw, tensor->bytes)); } } ++next_id; @@ -253,6 +265,10 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, nn_op_type = ANEURALNETWORKS_ADD; add_add_params(); break; + case tflite::BuiltinOperator_MUL: + nn_op_type = ANEURALNETWORKS_MUL; + add_add_params(); + break; case tflite::BuiltinOperator_AVERAGE_POOL_2D: add_pooling_params(node.builtin_data); nn_op_type = ANEURALNETWORKS_AVERAGE_POOL_2D; @@ -330,7 +346,6 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: case tflite::BuiltinOperator_L2_NORMALIZATION: case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: - case tflite::BuiltinOperator_MUL: case tflite::BuiltinOperator_PAD: case tflite::BuiltinOperator_RESIZE_BILINEAR: case tflite::BuiltinOperator_CALL: @@ -381,7 +396,19 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { if (!nn_model_) { CHECK_NN(ANeuralNetworksModel_create(&nn_model_)); - uint32_t next_id = addTensorOperands(interpreter, nn_model_); + // Find all the temporary tensors and put them in a skip_list. + std::vector skip_list; + for (size_t i = 0; i < interpreter->nodes_size(); i++) { + const auto* node_and_registration = interpreter->node_and_registration(i); + const TfLiteNode& node = node_and_registration->first; + if (node.temporaries != nullptr) { + for (int j = 0; j < node.temporaries->size; j++) { + skip_list.push_back(static_cast(node.temporaries->data[j])); + } + } + } + + uint32_t next_id = addTensorOperands(interpreter, nn_model_, skip_list); AddOpsAndParams(interpreter, nn_model_, next_id); CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs( nn_model_, static_cast(interpreter->inputs().size()), -- GitLab From e5854637cc3f8099586f18ed144fd6d4f90a6fc7 Mon Sep 17 00:00:00 2001 From: Yao Zhang Date: Thu, 3 May 2018 12:19:01 -0700 Subject: [PATCH 079/839] Simplify file reading and support SavedModel. PiperOrigin-RevId: 195291836 --- .../python/grappler/cost_analyzer_tool.py | 75 ++++++++++--------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py index 0853db2524..e6229e1856 100644 --- a/tensorflow/python/grappler/cost_analyzer_tool.py +++ b/tensorflow/python/grappler/cost_analyzer_tool.py @@ -21,11 +21,13 @@ from __future__ import print_function import argparse import sys +from google.protobuf import message from google.protobuf import text_format from tensorflow.contrib.fused_conv.ops import gen_fused_conv2d_bias_activation_op # pylint: disable=unused-import from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.grappler import cost_analyzer @@ -37,33 +39,42 @@ from tensorflow.python.training import saver def get_metagraph(): """Constructs and returns a MetaGraphDef from the input file.""" - if FLAGS.metagraphdef: - with gfile.GFile(FLAGS.metagraphdef) as meta_file: - metagraph = meta_graph_pb2.MetaGraphDef() - if FLAGS.metagraphdef.endswith(".pbtxt"): - text_format.Merge(meta_file.read(), metagraph) - else: - metagraph.ParseFromString(meta_file.read()) - if FLAGS.fetch is not None: - fetch_collection = meta_graph_pb2.CollectionDef() - for fetch in FLAGS.fetch.split(","): - fetch_collection.node_list.value.append(fetch) - metagraph.collection_def["train_op"].CopyFrom(fetch_collection) - else: - with gfile.GFile(FLAGS.graphdef) as graph_file: - graph_def = graph_pb2.GraphDef() - if FLAGS.graphdef.endswith(".pbtxt"): - text_format.Merge(graph_file.read(), graph_def) - else: - graph_def.ParseFromString(graph_file.read()) - importer.import_graph_def(graph_def, name="") - graph = ops.get_default_graph() - for fetch in FLAGS.fetch.split(","): - fetch_op = graph.get_operation_by_name(fetch) - graph.add_to_collection("train_op", fetch_op) - metagraph = saver.export_meta_graph( - graph_def=graph.as_graph_def(), graph=graph) - return metagraph + with gfile.GFile(FLAGS.input) as input_file: + input_data = input_file.read() + try: + saved_model = saved_model_pb2.SavedModel() + text_format.Merge(input_data, saved_model) + meta_graph = saved_model.meta_graphs[0] + except text_format.ParseError: + try: + saved_model.ParseFromString(input_data) + meta_graph = saved_model.meta_graphs[0] + except message.DecodeError: + try: + meta_graph = meta_graph_pb2.MetaGraphDef() + text_format.Merge(input_data, meta_graph) + except text_format.ParseError: + try: + meta_graph.ParseFromString(input_data) + except message.DecodeError: + try: + graph_def = graph_pb2.GraphDef() + text_format.Merge(input_data, graph_def) + except text_format.ParseError: + try: + graph_def.ParseFromString(input_data) + except message.DecodeError: + raise ValueError("Invalid input file.") + importer.import_graph_def(graph_def, name="") + graph = ops.get_default_graph() + meta_graph = saver.export_meta_graph( + graph_def=graph.as_graph_def(), graph=graph) + if FLAGS.fetch is not None: + fetch_collection = meta_graph_pb2.CollectionDef() + for fetch in FLAGS.fetch.split(","): + fetch_collection.node_list.value.append(fetch) + meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) + return meta_graph def main(_): @@ -85,15 +96,11 @@ def main(_): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "--metagraphdef", + "--input", type=str, default=None, - help="Input .meta MetaGraphDef file path.") - parser.add_argument( - "--graphdef", - type=str, - default=None, - help="Input .pb GraphDef file path.") + help="Input file path. Accept SavedModel, MetaGraphDef, and GraphDef in " + "either binary or text format.") parser.add_argument( "--fetch", type=str, -- GitLab From 4b767a835b61061ef4d167dc1ee935f2f85a3e87 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 3 May 2018 12:53:47 -0700 Subject: [PATCH 080/839] Small fix for an eager colab notebook. PiperOrigin-RevId: 195296384 --- .../contrib/eager/python/examples/notebooks/1_basics.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb index 0279db80fa..9fd2d8d125 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb @@ -478,7 +478,7 @@ "source": [ "# Time GPU-based matrix multiplications.\n", "\n", - "if is_gpu_available:\n", + "if tf.test.is_gpu_available():\n", " # First use of the GPU will be slow:\n", " print(\"Time to conduct first matmul on GPU:\")\n", " %time tf.matmul(gpu_tensor, gpu_tensor)\n", -- GitLab From 775d1c03c1772c0c2e10e5884af8d9363cfdf314 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Thu, 3 May 2018 12:59:33 -0700 Subject: [PATCH 081/839] [TF:XLA] Bump open source llvm revision to r331442 PiperOrigin-RevId: 195297133 --- tensorflow/workspace.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 94cac4f8fa..8b6ad0a138 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -452,11 +452,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/a5108a08ceab35886a7df07c86f96aedd3d94bb7.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/a5108a08ceab35886a7df07c86f96aedd3d94bb7.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/b3f6a6a61625296bb532a65c0bf51b91b05b3361.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/b3f6a6a61625296bb532a65c0bf51b91b05b3361.tar.gz", ], - sha256 = "79cae03ebbdfd812bb69c460e1325ca069b5c576f7c7071f8216cf2b0975e36f", - strip_prefix = "llvm-a5108a08ceab35886a7df07c86f96aedd3d94bb7", + sha256 = "93895b289a78a47a1e75652e12a1b9a6c119f086a509b00e0084cf2bb944b709", + strip_prefix = "llvm-b3f6a6a61625296bb532a65c0bf51b91b05b3361", build_file = clean_dep("//third_party/llvm:llvm.BUILD"), ) -- GitLab From ceda30408f66a7eea86dc359164deb662d5a32d0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 May 2018 13:00:56 -0700 Subject: [PATCH 082/839] Enable unary chain hoisting optimization for concat/split/splitv by default. PiperOrigin-RevId: 195297330 --- tensorflow/core/grappler/op_types.cc | 38 ++++++++++++------- tensorflow/core/grappler/op_types.h | 4 ++ .../optimizers/arithmetic_optimizer.cc | 18 ++++++--- .../optimizers/arithmetic_optimizer.h | 2 +- .../optimizers/arithmetic_optimizer_test.cc | 16 ++++---- 5 files changed, 51 insertions(+), 27 deletions(-) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 7c936dfca1..c48dc00941 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -476,28 +476,40 @@ bool IsInvolution(const NodeDef& node) { return involution_ops->count(node.op()) > 0; } +bool IsValueAndOrderAndShapePreserving(const NodeDef& node) { + if (NumNonControlInputs(node) == 1 && IsAggregate(node)) { + return true; + } + static const std::unordered_set* + value_and_order_and_shape_preserving_ops = + CHECK_NOTNULL((new const std::unordered_set{ + "CheckNumerics", + "DebugGradientIdentity", + "DeepCopy" + "Enter", + "Exit", + "Identity", + "IdentityN", + "PreventGradient", + "Print", + "Snapshot", + "StopGradient", + })); + return value_and_order_and_shape_preserving_ops->count(node.op()) > 0; +} + bool IsValueAndOrderPreserving(const NodeDef& node) { if (NumNonControlInputs(node) == 1 && IsAggregate(node)) { return true; } static const std::unordered_set* value_and_order_preserving_ops = CHECK_NOTNULL((new const std::unordered_set{ - "CheckNumerics", - "DebugGradientIdentity", - "DeepCopy" - "Enter", - "Exit", "ExpandDims", - "Identity", - "IdentityN", - "PreventGradient", - "Print", - "Reshape", "Snapshot", "Squeeze", - "StopGradient", })); - return value_and_order_preserving_ops->count(node.op()) > 0; + return value_and_order_preserving_ops->count(node.op()) > 0 || + IsValueAndOrderAndShapePreserving(node); } bool IsValuePreserving(const NodeDef& node) { @@ -564,7 +576,7 @@ bool IsUnaryElementWise(const NodeDef& node) { "Tanh", })); return element_wise_ops->count(node.op()) > 0 || - (!IsIdentityN(node) && IsValueAndOrderPreserving(node)); + (!IsIdentityN(node) && IsValueAndOrderAndShapePreserving(node)); } bool HasOpDef(const NodeDef& node) { diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 7a1b438768..e33dd21538 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -174,6 +174,10 @@ bool ModifiesInputsInPlace(const NodeDef& node); // own inverse such that f(f(x)) == x. bool IsInvolution(const NodeDef& node); +// Returns true if the op preserves the order and value of elements +// and shape of its first input tensor. +bool IsValueAndOrderAndShapePreserving(const NodeDef& node); + // Returns true if the op preserves the order and value of elements in its // first input tensor and possible changes its shape. bool IsValueAndOrderPreserving(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index d6510ba681..2a5654f752 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1400,6 +1400,11 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { return n > 1; } else if (IsSplit(*node) || IsSplitV(*node)) { const int num_split = node->attr().at("num_split").i(); + if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) { + // TODO(rmlarsen): Remove this constraint when we have optimizations + // in place for merging slices into splits. + return false; + } return num_split > 1 && !IsAlreadyOptimized(*node); } return false; @@ -1458,13 +1463,13 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { if (tails.empty()) { return Status::OK(); } - AddControlInputs(ctrl_inputs, root_node); AddToOptimizationQueue(root_node); optimized_nodes_.insert(root_node->name()); if (node_is_concat_) { + AddControlInputs(ctrl_inputs, root_node); return HoistChainForConcat(prefix_length, tails, root_node); } else { - return HoistChainForSplit(prefix_length, tails, root_node); + return HoistChainForSplit(prefix_length, tails, ctrl_inputs, root_node); } } @@ -1542,9 +1547,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { IsInPreserveSet(*op)) { return false; } - if (node_is_concat_ && - ctx().node_map->GetOutputs(op->name()).size() > 1) { - // TODO(rmlarsen): Allow and hoist outgoing control edges. + if (ctx().node_map->GetOutputs(op->name()).size() > 1) { + // TODO(rmlarsen): Allow outgoing control edges. return false; } } @@ -1612,6 +1616,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { } Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails, + std::set* ctrl_inputs, NodeDef* split_node) { // Create a new chain before the split node to process the input tensor. const string& split_name = split_node->name(); @@ -1646,6 +1651,9 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { cur_copy->add_input(orig_input); ctx().node_map->UpdateOutput(NodeName(orig_input), split_name, cur_copy->name()); + // Make sure all the control inputs are satisfied before running the first + // node in the new chain. + AddControlInputs(ctrl_inputs, cur_copy); // Connect all consumers of the tail nodes directly to the // output port of Split from which the chain started. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 3b297ec0aa..6309dc1a33 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -65,7 +65,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool remove_redundant_bitcast = true; bool remove_redundant_cast = true; bool remove_negation = true; - bool hoist_cwise_unary_chains = false; + bool hoist_cwise_unary_chains = true; bool convert_sqrt_div_to_rsqrt_mul = false; // Choose which arithmetic optimizer stages will be enabled for a given diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index f903f53a35..d32743f3f2 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -2320,16 +2320,16 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) { EXPECT_NE(node.name(), "cos_exp_b2"); if (node.name() == "split1") { - EXPECT_EQ(3, node.input_size()); + EXPECT_EQ(2, node.input_size()); EXPECT_EQ("axis", node.input(0)); EXPECT_EQ("ArithmeticOptimizer/_sin_a_split1", node.input(1)); - EXPECT_EQ("^ctrl1", node.input(2)); found++; } if (node.name() == "ArithmeticOptimizer/_sin_a_split1") { EXPECT_EQ("Sin", node.op()); - EXPECT_EQ(1, node.input_size()); + EXPECT_EQ(2, node.input_size()); EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^ctrl1", node.input(1)); found++; } if (node.name() == "id_a") { @@ -2349,8 +2349,11 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) { } if (node.name() == "ArithmeticOptimizer/_exp_a2_split2") { EXPECT_EQ("Exp", node.op()); - EXPECT_EQ(1, node.input_size()); + EXPECT_EQ(4, node.input_size()); EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^ctrl1", node.input(1)); + EXPECT_EQ("^ctrl2", node.input(2)); + EXPECT_EQ("^ctrl3", node.input(3)); found++; } if (node.name() == "ArithmeticOptimizer/_cos_exp_a2_split2") { @@ -2360,13 +2363,10 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) { found++; } if (node.name() == "split2") { - EXPECT_EQ(6, node.input_size()); + EXPECT_EQ(3, node.input_size()); EXPECT_EQ("ArithmeticOptimizer/_cos_exp_a2_split2", node.input(0)); EXPECT_EQ("size_splits2", node.input(1)); EXPECT_EQ("axis", node.input(2)); - EXPECT_EQ("^ctrl1", node.input(3)); - EXPECT_EQ("^ctrl2", node.input(4)); - EXPECT_EQ("^ctrl3", node.input(5)); found++; } if (node.name() == "id_a2") { -- GitLab From fded0f901c99087b100191273e28692f9b4569ee Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Thu, 3 May 2018 13:03:48 -0700 Subject: [PATCH 083/839] Change all std::bind usages in GCS to lambdas. Fix the wrong #define Guard name in retrying_file_system.h. PiperOrigin-RevId: 195297877 --- .../core/platform/cloud/gcs_dns_cache.cc | 4 +- .../core/platform/cloud/gcs_file_system.cc | 5 +- .../platform/cloud/retrying_file_system.h | 81 ++++++++++--------- .../core/platform/cloud/retrying_utils.cc | 6 +- 4 files changed, 52 insertions(+), 44 deletions(-) diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache.cc b/tensorflow/core/platform/cloud/gcs_dns_cache.cc index 4d9aff4d24..f2e64662a9 100644 --- a/tensorflow/core/platform/cloud/gcs_dns_cache.cc +++ b/tensorflow/core/platform/cloud/gcs_dns_cache.cc @@ -71,8 +71,8 @@ void GcsDnsCache::AnnotateRequest(HttpRequest* request) { addresses_ = ResolveNames(kCachedDomainNames); // Note: we opt to use a thread instead of a delayed closure. - worker_.reset(env_->StartThread( - {}, "gcs_dns_worker", std::bind(&GcsDnsCache::WorkerThread, this))); + worker_.reset(env_->StartThread({}, "gcs_dns_worker", + [this]() { return WorkerThread(); })); started_ = true; } diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index f1e18403ec..488f9cc75d 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -1397,8 +1397,7 @@ Status GcsFileSystem::RenameObject(const string& src, const string& target) { // on the server side, we can't just retry the whole RenameFile operation // because the source object is already gone. return RetryingUtils::DeleteWithRetries( - std::bind(&GcsFileSystem::DeleteFile, this, src), - initial_retry_delay_usec_); + [this, &src]() { return DeleteFile(src); }, initial_retry_delay_usec_); } Status GcsFileSystem::IsDirectory(const string& fname) { @@ -1454,7 +1453,7 @@ Status GcsFileSystem::DeleteRecursively(const string& dirname, // and therefore RetryingFileSystem won't pay attention to the failures, // we need to make sure these failures are properly retried. const auto& delete_file_status = RetryingUtils::DeleteWithRetries( - std::bind(&GcsFileSystem::DeleteFile, this, full_path), + [this, &full_path]() { return DeleteFile(full_path); }, initial_retry_delay_usec_); if (!delete_file_status.ok()) { if (IsDirectory(full_path).ok()) { diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h index 399a21617e..92aa72be89 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system.h +++ b/tensorflow/core/platform/cloud/retrying_file_system.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_PLATFORM_RETRYING_FILE_SYSTEM_H_ -#define TENSORFLOW_CORE_PLATFORM_RETRYING_FILE_SYSTEM_H_ +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_FILE_SYSTEM_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_FILE_SYSTEM_H_ #include #include @@ -54,74 +54,80 @@ class RetryingFileSystem : public FileSystem { Status FileExists(const string& fname) override { return RetryingUtils::CallWithRetries( - std::bind(&FileSystem::FileExists, base_file_system_.get(), fname), + [this, &fname]() { return base_file_system_->FileExists(fname); }, initial_delay_microseconds_); } Status GetChildren(const string& dir, std::vector* result) override { return RetryingUtils::CallWithRetries( - std::bind(&FileSystem::GetChildren, base_file_system_.get(), dir, - result), + [this, &dir, result]() { + return base_file_system_->GetChildren(dir, result); + }, initial_delay_microseconds_); } Status GetMatchingPaths(const string& pattern, std::vector* result) override { return RetryingUtils::CallWithRetries( - std::bind(&FileSystem::GetMatchingPaths, base_file_system_.get(), - pattern, result), + [this, &pattern, result]() { + return base_file_system_->GetMatchingPaths(pattern, result); + }, initial_delay_microseconds_); } Status Stat(const string& fname, FileStatistics* stat) override { return RetryingUtils::CallWithRetries( - std::bind(&FileSystem::Stat, base_file_system_.get(), fname, stat), + [this, &fname, stat]() { return base_file_system_->Stat(fname, stat); }, initial_delay_microseconds_); } Status DeleteFile(const string& fname) override { return RetryingUtils::DeleteWithRetries( - std::bind(&FileSystem::DeleteFile, base_file_system_.get(), fname), + [this, &fname]() { return base_file_system_->DeleteFile(fname); }, initial_delay_microseconds_); } Status CreateDir(const string& dirname) override { return RetryingUtils::CallWithRetries( - std::bind(&FileSystem::CreateDir, base_file_system_.get(), dirname), + [this, &dirname]() { return base_file_system_->CreateDir(dirname); }, initial_delay_microseconds_); } Status DeleteDir(const string& dirname) override { return RetryingUtils::DeleteWithRetries( - std::bind(&FileSystem::DeleteDir, base_file_system_.get(), dirname), + [this, &dirname]() { return base_file_system_->DeleteDir(dirname); }, initial_delay_microseconds_); } Status GetFileSize(const string& fname, uint64* file_size) override { return RetryingUtils::CallWithRetries( - std::bind(&FileSystem::GetFileSize, base_file_system_.get(), fname, - file_size), + [this, &fname, file_size]() { + return base_file_system_->GetFileSize(fname, file_size); + }, initial_delay_microseconds_); } Status RenameFile(const string& src, const string& target) override { return RetryingUtils::CallWithRetries( - std::bind(&FileSystem::RenameFile, base_file_system_.get(), src, - target), + [this, &src, &target]() { + return base_file_system_->RenameFile(src, target); + }, initial_delay_microseconds_); } Status IsDirectory(const string& dirname) override { return RetryingUtils::CallWithRetries( - std::bind(&FileSystem::IsDirectory, base_file_system_.get(), dirname), + [this, &dirname]() { return base_file_system_->IsDirectory(dirname); }, initial_delay_microseconds_); } Status DeleteRecursively(const string& dirname, int64* undeleted_files, int64* undeleted_dirs) override { return RetryingUtils::DeleteWithRetries( - std::bind(&FileSystem::DeleteRecursively, base_file_system_.get(), - dirname, undeleted_files, undeleted_dirs), + [this, &dirname, undeleted_files, undeleted_dirs]() { + return base_file_system_->DeleteRecursively(dirname, undeleted_files, + undeleted_dirs); + }, initial_delay_microseconds_); } @@ -148,8 +154,9 @@ class RetryingRandomAccessFile : public RandomAccessFile { Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { return RetryingUtils::CallWithRetries( - std::bind(&RandomAccessFile::Read, base_file_.get(), offset, n, result, - scratch), + [this, offset, n, result, scratch]() { + return base_file_->Read(offset, n, result, scratch); + }, initial_delay_microseconds_); } @@ -172,23 +179,20 @@ class RetryingWritableFile : public WritableFile { Status Append(const StringPiece& data) override { return RetryingUtils::CallWithRetries( - std::bind(&WritableFile::Append, base_file_.get(), data), + [this, &data]() { return base_file_->Append(data); }, initial_delay_microseconds_); } Status Close() override { return RetryingUtils::CallWithRetries( - std::bind(&WritableFile::Close, base_file_.get()), - initial_delay_microseconds_); + [this]() { return base_file_->Close(); }, initial_delay_microseconds_); } Status Flush() override { return RetryingUtils::CallWithRetries( - std::bind(&WritableFile::Flush, base_file_.get()), - initial_delay_microseconds_); + [this]() { return base_file_->Flush(); }, initial_delay_microseconds_); } Status Sync() override { return RetryingUtils::CallWithRetries( - std::bind(&WritableFile::Sync, base_file_.get()), - initial_delay_microseconds_); + [this]() { return base_file_->Sync(); }, initial_delay_microseconds_); } private: @@ -203,8 +207,9 @@ Status RetryingFileSystem::NewRandomAccessFile( const string& filename, std::unique_ptr* result) { std::unique_ptr base_file; TF_RETURN_IF_ERROR(RetryingUtils::CallWithRetries( - std::bind(&FileSystem::NewRandomAccessFile, base_file_system_.get(), - filename, &base_file), + [this, &filename, &base_file]() { + return base_file_system_->NewRandomAccessFile(filename, &base_file); + }, initial_delay_microseconds_)); result->reset(new retrying_internals::RetryingRandomAccessFile( std::move(base_file), initial_delay_microseconds_)); @@ -216,8 +221,9 @@ Status RetryingFileSystem::NewWritableFile( const string& filename, std::unique_ptr* result) { std::unique_ptr base_file; TF_RETURN_IF_ERROR(RetryingUtils::CallWithRetries( - std::bind(&FileSystem::NewWritableFile, base_file_system_.get(), filename, - &base_file), + [this, &filename, &base_file]() { + return base_file_system_->NewWritableFile(filename, &base_file); + }, initial_delay_microseconds_)); result->reset(new retrying_internals::RetryingWritableFile( std::move(base_file), initial_delay_microseconds_)); @@ -229,8 +235,9 @@ Status RetryingFileSystem::NewAppendableFile( const string& filename, std::unique_ptr* result) { std::unique_ptr base_file; TF_RETURN_IF_ERROR(RetryingUtils::CallWithRetries( - std::bind(&FileSystem::NewAppendableFile, base_file_system_.get(), - filename, &base_file), + [this, &filename, &base_file]() { + return base_file_system_->NewAppendableFile(filename, &base_file); + }, initial_delay_microseconds_)); result->reset(new retrying_internals::RetryingWritableFile( std::move(base_file), initial_delay_microseconds_)); @@ -241,11 +248,13 @@ template Status RetryingFileSystem::NewReadOnlyMemoryRegionFromFile( const string& filename, std::unique_ptr* result) { return RetryingUtils::CallWithRetries( - std::bind(&FileSystem::NewReadOnlyMemoryRegionFromFile, - base_file_system_.get(), filename, result), + [this, &filename, result]() { + return base_file_system_->NewReadOnlyMemoryRegionFromFile(filename, + result); + }, initial_delay_microseconds_); } } // namespace tensorflow -#endif // TENSORFLOW_CORE_PLATFORM_RETRYING_FILE_SYSTEM_H_ +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_RETRYING_FILE_SYSTEM_H_ diff --git a/tensorflow/core/platform/cloud/retrying_utils.cc b/tensorflow/core/platform/cloud/retrying_utils.cc index 99691ecfb9..d2df422024 100644 --- a/tensorflow/core/platform/cloud/retrying_utils.cc +++ b/tensorflow/core/platform/cloud/retrying_utils.cc @@ -44,9 +44,9 @@ bool IsRetriable(error::Code code) { Status RetryingUtils::CallWithRetries(const std::function& f, const int64 initial_delay_microseconds) { - return CallWithRetries(f, initial_delay_microseconds, - std::bind(&Env::SleepForMicroseconds, Env::Default(), - std::placeholders::_1)); + return CallWithRetries(f, initial_delay_microseconds, [](int64 micros) { + return Env::Default()->SleepForMicroseconds(micros); + }); } Status RetryingUtils::CallWithRetries( -- GitLab From 278e68cedbb80c6f3342856bfccf688a808e461a Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 3 May 2018 13:09:28 -0700 Subject: [PATCH 084/839] Simplified the implementation of shape_n since the optimized code path isn't needed anymore and can be incorrect in some rare cases. PiperOrigin-RevId: 195298813 --- tensorflow/python/ops/array_ops.py | 10 +--------- tensorflow/python/profiler/model_analyzer_test.py | 2 +- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index e235047aff..96df15684b 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -263,15 +263,7 @@ def shape_n(input, out_type=dtypes.int32, name=None): type `out_type`. """ - output = gen_array_ops.shape_n(input, out_type=out_type, name=name) - if not context.executing_eagerly(): - for i, input_tensor in enumerate(input): - input_tensor = ops.convert_to_tensor(input_tensor) - input_shape = input_tensor.get_shape() - if input_shape.is_fully_defined(): - output[i] = constant( - input_shape.as_list(), dtype=out_type, name=name) - return output + return gen_array_ops.shape_n(input, out_type=out_type, name=name) @tf_export("size") diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py index 04ba28c219..75580fc630 100644 --- a/tensorflow/python/profiler/model_analyzer_test.py +++ b/tensorflow/python/profiler/model_analyzer_test.py @@ -232,7 +232,7 @@ class PrintModelAnalysisTest(test.TestCase): self.assertLess(0, tfprof_node.total_exec_micros) self.assertEqual(2844, tfprof_node.total_parameters) - self.assertLess(168800, tfprof_node.total_float_ops) + self.assertLess(145660, tfprof_node.total_float_ops) self.assertEqual(8, len(tfprof_node.children)) self.assertEqual('_TFProfRoot', tfprof_node.name) self.assertEqual( -- GitLab From 41dcb67efd272e9ce0e5071433f42a9d540ec6dc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 May 2018 13:09:30 -0700 Subject: [PATCH 085/839] Fix bugs in model pruner. PiperOrigin-RevId: 195298816 --- tensorflow/core/grappler/optimizers/model_pruner.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc index 3311e97010..36eab4999d 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner.cc @@ -70,6 +70,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, } // Try to keep the nodes ordered somewhat topologically since this helps // further optimizations perform better. + runnable_item.graph.mutable_node()->Reserve(keep.size()); for (int i = keep.size() - 1; i >= 0; --i) { *runnable_item.graph.add_node() = *keep[i]; } @@ -113,6 +114,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, } } + pruned_graph->Clear(); *pruned_graph->mutable_library() = item.graph.library(); *pruned_graph->mutable_versions() = item.graph.versions(); @@ -122,6 +124,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, } const bool fetches_are_known = !item.fetch.empty(); + pruned_graph->mutable_node()->Reserve(runnable_item.graph.node_size()); for (auto& node : runnable_item.graph.node()) { if (!fetches_are_known || nodes_to_delete.find(&node) == nodes_to_delete.end()) { @@ -134,6 +137,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, VLOG(1) << "Pruned " << nodes_to_delete.size() << " nodes from the graph. The graph now contains " << pruned_graph->node_size() << " nodes."; + CHECK_LE(pruned_graph->node_size(), item.graph.node_size()); return Status::OK(); } -- GitLab From 5a64e609d0eb94244067f5d7514605863c9f37c3 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 3 May 2018 13:22:33 -0700 Subject: [PATCH 086/839] Checkpointable: Utilities to read object metadata Useful for inspecting checkpoints programatically (e.g. in unit tests). PiperOrigin-RevId: 195300780 --- tensorflow/contrib/checkpoint/__init__.py | 4 ++ .../contrib/checkpoint/python/visualize.py | 16 +------- .../eager/python/examples/spinn/spinn_test.py | 10 ++--- .../python/training/checkpointable_utils.py | 38 +++++++++++++++++++ .../training/checkpointable_utils_test.py | 16 ++++++++ 5 files changed, 64 insertions(+), 20 deletions(-) diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index 1192cc44a1..d2c30f1215 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -16,7 +16,9 @@ For creating and managing dependencies: +@@CheckpointableObjectGraph @@dot_graph_from_checkpoint +@@object_metadata @@split_dependency """ @@ -26,6 +28,8 @@ from __future__ import print_function from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint +from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph +from tensorflow.python.training.checkpointable_utils import object_metadata from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/checkpoint/python/visualize.py b/tensorflow/contrib/checkpoint/python/visualize.py index 86fbdb41d2..9a3b23bb2c 100644 --- a/tensorflow/contrib/checkpoint/python/visualize.py +++ b/tensorflow/contrib/checkpoint/python/visualize.py @@ -17,10 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.core.protobuf import checkpointable_object_graph_pb2 from tensorflow.python import pywrap_tensorflow -from tensorflow.python.framework import errors_impl from tensorflow.python.training import checkpointable +from tensorflow.python.training import checkpointable_utils def dot_graph_from_checkpoint(save_path): @@ -52,20 +51,9 @@ def dot_graph_from_checkpoint(save_path): A graph in DOT format as a string. """ reader = pywrap_tensorflow.NewCheckpointReader(save_path) - try: - object_graph_string = reader.get_tensor( - checkpointable.OBJECT_GRAPH_PROTO_KEY) - except errors_impl.NotFoundError: - raise ValueError( - ('The specified checkpoint "%s" does not appear to be object-based (it ' - 'is missing the key "%s"). Likely it was created with a name-based ' - 'saver and does not contain an object dependency graph.') % ( - save_path, checkpointable.OBJECT_GRAPH_PROTO_KEY)) + object_graph = checkpointable_utils.object_metadata(save_path) shape_map = reader.get_variable_to_shape_map() dtype_map = reader.get_variable_to_dtype_map() - object_graph = ( - checkpointable_object_graph_pb2.CheckpointableObjectGraph()) - object_graph.ParseFromString(object_graph_string) graph = 'digraph {\n' def _escape(name): return name.replace('"', '\\"') diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index f825a2a736..1e4746d01c 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -34,10 +34,10 @@ import tensorflow.contrib.eager as tfe from tensorflow.contrib.eager.python.examples.spinn import data from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util -from tensorflow.core.protobuf import checkpointable_object_graph_pb2 from tensorflow.python.eager import test from tensorflow.python.framework import test_util -from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import checkpointable_utils +from tensorflow.python.training import saver # pylint: enable=g-bad-import-order @@ -421,10 +421,8 @@ class SpinnTest(test_util.TensorFlowTestCase): # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) - object_graph_string = checkpoint_utils.load_variable( - config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH") - object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph() - object_graph.ParseFromString(object_graph_string) + object_graph = checkpointable_utils.object_metadata( + saver.latest_checkpoint(config.logdir)) ckpt_variable_names = set() for node in object_graph.nodes: for attribute in node.attributes: diff --git a/tensorflow/python/training/checkpointable_utils.py b/tensorflow/python/training/checkpointable_utils.py index 9cdd53cbf9..cf4112ff99 100644 --- a/tensorflow/python/training/checkpointable_utils.py +++ b/tensorflow/python/training/checkpointable_utils.py @@ -159,6 +159,44 @@ def add_variable(checkpointable, name, shape=None, dtype=dtypes.float32, initializer=initializer, getter=_default_getter) +def object_metadata(save_path): + """Retrieves information about the objects in a checkpoint. + + Example usage: + + ```python + object_graph = tf.contrib.checkpoint.object_metadata( + tf.train.latest_checkpoint(checkpoint_directory)) + ckpt_variable_names = set() + for node in object_graph.nodes: + for attribute in node.attributes: + ckpt_variable_names.add(attribute.full_name) + ``` + + Args: + save_path: The path to the checkpoint, as returned by `save` or + `tf.train.latest_checkpoint`. + Returns: + A parsed `tf.contrib.checkpoint.CheckpointableObjectGraph` protocol buffer. + Raises: + ValueError: If an object graph was not found in the checkpoint. + """ + reader = pywrap_tensorflow.NewCheckpointReader(save_path) + try: + object_graph_string = reader.get_tensor( + checkpointable_lib.OBJECT_GRAPH_PROTO_KEY) + except errors_impl.NotFoundError: + raise ValueError( + ('The specified checkpoint "%s" does not appear to be object-based (it ' + 'is missing the key "%s"). Likely it was created with a name-based ' + 'saver and does not contain an object dependency graph.') % ( + save_path, checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)) + object_graph_proto = ( + checkpointable_object_graph_pb2.CheckpointableObjectGraph()) + object_graph_proto.ParseFromString(object_graph_string) + return object_graph_proto + + def _breadth_first_checkpointable_traversal(root_checkpointable): """Find shortest paths to all variables owned by dependencies of root.""" bfs_sorted = [] diff --git a/tensorflow/python/training/checkpointable_utils_test.py b/tensorflow/python/training/checkpointable_utils_test.py index 40dfeb28d5..3b8166bf37 100644 --- a/tensorflow/python/training/checkpointable_utils_test.py +++ b/tensorflow/python/training/checkpointable_utils_test.py @@ -155,6 +155,22 @@ class InterfaceTests(test.TestCase): self.assertEqual(dtypes.float64, v2.dtype) self.assertAllEqual([1., 1., 1.], self.evaluate(v2)) + def testObjectMetadata(self): + with context.eager_mode(): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + dense = core.Dense(1) + checkpoint = checkpointable_utils.Checkpoint(dense=dense) + dense(constant_op.constant([[1.]])) + save_path = checkpoint.save(checkpoint_prefix) + + objects = checkpointable_utils.object_metadata(save_path) + all_variable_names = [] + for obj in objects.nodes: + for attribute in obj.attributes: + all_variable_names.append(attribute.full_name) + self.assertIn("dense/kernel", all_variable_names) + class _MirroringSaveable(saver_lib.BaseSaverBuilder.SaveableObject): -- GitLab From 7529268d692c1c888f93924e6ca5e10fd3183b80 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Thu, 3 May 2018 13:30:12 -0700 Subject: [PATCH 087/839] tfdbg + tflearn: replace deprecated classes and methods in example & docs * `tf.contrib.learn.Experiment` is deprecated. Remove it from debug_tflearn_iris.py. * Use `tf.estimator.DNNClassifier`, instead of the older one from `tf.contrib.learn`. * Use `train()`, instead of `fit()` of Estimators. * `Estimator.predict()` supports hooks. Add example lines for that. PiperOrigin-RevId: 195301913 --- .../docs_src/programmers_guide/debugger.md | 89 ++++++------------- tensorflow/python/debug/BUILD | 1 - .../debug/examples/debug_tflearn_iris.py | 83 ++++++++--------- 3 files changed, 65 insertions(+), 108 deletions(-) diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/programmers_guide/debugger.md index f7817b06d4..6bd941886d 100644 --- a/tensorflow/docs_src/programmers_guide/debugger.md +++ b/tensorflow/docs_src/programmers_guide/debugger.md @@ -34,7 +34,7 @@ type of bug in TensorFlow model development. The following example is for users who use the low-level [`Session`](https://www.tensorflow.org/api_docs/python/tf/Session) API of TensorFlow. A later section of this document describes how to use **tfdbg** -with a higher-level API, namely tf-learn `Estimator`s and `Experiment`s. +with a higher-level API, namely `Estimator`s. To *observe* such an issue, run the following command without the debugger (the source code can be found [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/debug/examples/debug_mnist.py)): @@ -418,21 +418,20 @@ run -f has_inf_or_nan` Confirm that no tensors are flagged as containing `nan` or `inf` values, and accuracy now continues to rise rather than getting stuck. Success! -## Debugging tf-learn Estimators and Experiments +## Debugging TensorFlow Estimators This section explains how to debug TensorFlow programs that use the `Estimator` -and `Experiment` APIs. Part of the convenience provided by these APIs is that +APIs. Part of the convenience provided by these APIs is that they manage `Session`s internally. This makes the `LocalCLIDebugWrapperSession` described in the preceding sections inapplicable. Fortunately, you can still debug them by using special `hook`s provided by `tfdbg`. -### Debugging tf.contrib.learn Estimators - -Currently, `tfdbg` can debug the -@{tf.contrib.learn.BaseEstimator.fit$`fit()`} -@{tf.contrib.learn.BaseEstimator.evaluate$`evaluate()`} -methods of tf-learn `Estimator`s. To debug `Estimator.fit()`, -create a `LocalCLIDebugHook` and supply it in the `monitors` argument. For example: +`tfdbg` can debug the +@{tf.estimator.Estimator.train$`train()`}, +@{tf.estimator.Estimator.evaluate$`evaluate()`} and +@{tf.estimator.Estimator.predict$`predict()`} +methods of tf-learn `Estimator`s. To debug `Estimator.train()`, +create a `LocalCLIDebugHook` and supply it in the `hooks` argument. For example: ```python # First, let your BUILD target depend on "//tensorflow/python/debug:debug_py" @@ -443,67 +442,33 @@ from tensorflow.python import debug as tf_debug # Create a LocalCLIDebugHook and use it as a monitor when calling fit(). hooks = [tf_debug.LocalCLIDebugHook()] -classifier.fit(x=training_set.data, - y=training_set.target, - steps=1000, - monitors=hooks) +# To debug `train`: +classifier.train(input_fn, + steps=1000, + hooks=hooks) ``` -To debug `Estimator.evaluate()`, assign hooks to the `hooks` parameter, as in -the following example: +Similarly, to debug `Estimator.evaluate()` and `Estimator.predict()`, assign +hooks to the `hooks` parameter, as in the following example: ```python -accuracy_score = classifier.evaluate(x=test_set.data, - y=test_set.target, +# To debug `evaluate`: +accuracy_score = classifier.evaluate(eval_input_fn, hooks=hooks)["accuracy"] -``` +# To debug `predict`: +predict_results = classifier.predict(predict_input_fn, hooks=hooks) +``` [debug_tflearn_iris.py](https://www.tensorflow.org/code/tensorflow/python/debug/examples/debug_tflearn_iris.py), -based on [tf-learn's iris tutorial](https://www.tensorflow.org/versions/r1.2/get_started/tflearn), contains a full example of how to -use the tfdbg with `Estimator`s. To run this example, do: +based on [tf-learn's iris tutorial](https://www.tensorflow.org/versions/r1.8/get_started/tflearn), +contains a full example of how to use the tfdbg with `Estimator`s. +To run this example, do: ```none python -m tensorflow.python.debug.examples.debug_tflearn_iris --debug ``` -### Debugging tf.contrib.learn Experiments - -`Experiment` is a construct in `tf.contrib.learn` at a higher level than -`Estimator`. -It provides a single interface for training and evaluating a model. To debug -the `train()` and `evaluate()` calls to an `Experiment` object, you can -use the keyword arguments `train_monitors` and `eval_hooks`, respectively, when -calling its constructor. For example: - -```python -# First, let your BUILD target depend on "//tensorflow/python/debug:debug_py" -# (You don't need to worry about the BUILD dependency if you are using a pip -# install of open-source TensorFlow.) -from tensorflow.python import debug as tf_debug - -hooks = [tf_debug.LocalCLIDebugHook()] - -ex = experiment.Experiment(classifier, - train_input_fn=iris_input_fn, - eval_input_fn=iris_input_fn, - train_steps=FLAGS.train_steps, - eval_delay_secs=0, - eval_steps=1, - train_monitors=hooks, - eval_hooks=hooks) - -ex.train() -accuracy_score = ex.evaluate()["accuracy"] -``` - -To build and run the `debug_tflearn_iris` example in the `Experiment` mode, do: - -```none -python -m tensorflow.python.debug.examples.debug_tflearn_iris \ - --use_experiment --debug -``` - The `LocalCLIDebugHook` also allows you to configure a `watch_fn` that can be used to flexibly specify what `Tensor`s to watch on different `Session.run()` calls, as a function of the `fetches` and `feed_dict` and other states. See @@ -573,7 +538,7 @@ Often, your model is running on a remote machine or a process that you don't have terminal access to. To perform model debugging in such cases, you can use the `offline_analyzer` binary of `tfdbg` (described below). It operates on dumped data directories. This can be done to both the lower-level `Session` API -and the higher-level `Estimator` and `Experiment` APIs. +and the higher-level `Estimator` API. ### Debugging Remote tf.Sessions @@ -636,7 +601,7 @@ can be inspected offline. See [the proto definition](https://www.tensorflow.org/code/tensorflow/core/protobuf/debug.proto) for more details. -### Debugging Remotely-Running tf-learn Estimators and Experiments +### Debugging Remotely-Running Estimators If your remote TensorFlow server runs `Estimator`s, you can use the non-interactive `DumpingDebugHook`. For example: @@ -652,8 +617,8 @@ hooks = [tf_debug.DumpingDebugHook("/shared/storage/location/tfdbg_dumps_1")] Then this `hook` can be used in the same way as the `LocalCLIDebugHook` examples described earlier in this document. -As the training and/or evalution of `Estimator` or `Experiment` -happens, tfdbg creates directories having the following name pattern: +As the training, evalution or prediction happens with `Estimator`, +tfdbg creates directories having the following name pattern: `/shared/storage/location/tfdbg_dumps_1/run__`. Each directory corresponds to a `Session.run()` call that underlies the `fit()` or `evaluate()` call. You can load these directories and inspect diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index b5760df1ed..183994ddaa 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -449,7 +449,6 @@ py_binary( deps = [ ":debug_py", "//tensorflow:tensorflow_py", - "//third_party/py/numpy", "@six_archive//:six", ], ) diff --git a/tensorflow/python/debug/examples/debug_tflearn_iris.py b/tensorflow/python/debug/examples/debug_tflearn_iris.py index 4f4666ee4f..00090b21fe 100644 --- a/tensorflow/python/debug/examples/debug_tflearn_iris.py +++ b/tensorflow/python/debug/examples/debug_tflearn_iris.py @@ -22,11 +22,9 @@ import os import sys import tempfile -import numpy as np from six.moves import urllib import tensorflow as tf -from tensorflow.contrib.learn.python.learn import experiment from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.python import debug as tf_debug @@ -82,28 +80,34 @@ def iris_input_fn(): def main(_): # Load datasets. if FLAGS.fake_data: - training_set = tf.contrib.learn.datasets.base.Dataset( - np.random.random([120, 4]), - np.random.random_integers(3, size=[120]) - 1) - test_set = tf.contrib.learn.datasets.base.Dataset( - np.random.random([30, 4]), - np.random.random_integers(3, size=[30]) - 1) + def training_input_fn(): + return ({"features": tf.random_normal([128, 4])}, + tf.random_uniform([128], minval=0, maxval=3, dtype=tf.int32)) + def test_input_fn(): + return ({"features": tf.random_normal([32, 4])}, + tf.random_uniform([32], minval=0, maxval=3, dtype=tf.int32)) + feature_columns = [ + tf.feature_column.numeric_column("features", shape=(4,))] else: training_data_path, test_data_path = maybe_download_data(FLAGS.data_dir) - training_set = tf.contrib.learn.datasets.base.load_csv_with_header( - filename=training_data_path, - target_dtype=np.int, - features_dtype=np.float32) - test_set = tf.contrib.learn.datasets.base.load_csv_with_header( - filename=test_data_path, target_dtype=np.int, features_dtype=np.float32) - - # Specify that all features have real-value data - feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] + column_names = [ + "sepal_length", "sepal_width", "petal_length", "petal_width", "label"] + batch_size = 32 + def training_input_fn(): + return tf.contrib.data.make_csv_dataset( + [training_data_path], batch_size, + column_names=column_names, label_name="label") + def test_input_fn(): + return tf.contrib.data.make_csv_dataset( + [test_data_path], batch_size, + column_names=column_names, label_name="label") + feature_columns = [tf.feature_column.numeric_column(feature) + for feature in column_names[:-1]] # Build 3 layer DNN with 10, 20, 10 units respectively. model_dir = FLAGS.model_dir or tempfile.mkdtemp(prefix="debug_tflearn_iris_") - classifier = tf.contrib.learn.DNNClassifier( + classifier = tf.estimator.DNNClassifier( feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3, @@ -121,32 +125,23 @@ def main(_): debug_hook = tf_debug.TensorBoardDebugHook(FLAGS.tensorboard_debug_address) hooks = [debug_hook] - if not FLAGS.use_experiment: - # Fit model. - classifier.fit(x=training_set.data, - y=training_set.target, + # Train model, using tfdbg hook. + classifier.train(training_input_fn, steps=FLAGS.train_steps, - monitors=hooks) + hooks=hooks) - # Evaluate accuracy. - accuracy_score = classifier.evaluate(x=test_set.data, - y=test_set.target, - hooks=hooks)["accuracy"] - else: - ex = experiment.Experiment(classifier, - train_input_fn=iris_input_fn, - eval_input_fn=iris_input_fn, - train_steps=FLAGS.train_steps, - eval_delay_secs=0, - eval_steps=1, - train_monitors=hooks, - eval_hooks=hooks) - ex.train() - accuracy_score = ex.evaluate()["accuracy"] + # Evaluate accuracy, using tfdbg hook. + accuracy_score = classifier.evaluate(test_input_fn, + steps=FLAGS.eval_steps, + hooks=hooks)["accuracy"] print("After training %d steps, Accuracy = %f" % (FLAGS.train_steps, accuracy_score)) + # Make predictions, using tfdbg hook. + predict_results = classifier.predict(test_input_fn, hooks=hooks) + print("A prediction result: %s" % predict_results.next()) + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -165,14 +160,12 @@ if __name__ == "__main__": "--train_steps", type=int, default=10, - help="Number of steps to run trainer.") + help="Number of steps to run training for.") parser.add_argument( - "--use_experiment", - type="bool", - nargs="?", - const=True, - default=False, - help="Use tf.contrib.learn Experiment to run training and evaluation") + "--eval_steps", + type=int, + default=1, + help="Number of steps to run evaluation foir.") parser.add_argument( "--ui_type", type=str, -- GitLab From e629595e8f629f2de7db225463136b0e331bd71c Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Thu, 3 May 2018 15:00:57 -0700 Subject: [PATCH 088/839] Simplify build dependencies; fix python import order; fix multiple singleton issues by inlining the singleton method. --- tensorflow/contrib/tensorrt/BUILD | 2 -- .../contrib/tensorrt/custom_plugin_examples/BUILD | 12 ++---------- .../tensorrt/custom_plugin_examples/plugin_test.py | 9 +++------ .../contrib/tensorrt/plugin/trt_plugin_factory.cc | 6 ------ .../contrib/tensorrt/plugin/trt_plugin_factory.h | 8 +++++++- 5 files changed, 12 insertions(+), 25 deletions(-) diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 79e525edae..5b56feed0f 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -259,7 +259,6 @@ cc_library( "segment/segment.h", "segment/union_find.h", ], - linkstatic = 1, deps = [ "//tensorflow/core:graph", "//tensorflow/core:lib_proto_parsing", @@ -295,7 +294,6 @@ tf_cuda_library( "plugin/trt_plugin_factory.h", "plugin/trt_plugin_utils.h", ], - linkstatic = 1, deps = [ "//tensorflow/core:platform_base", ] + if_tensorrt([ diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index e623b54781..6f81ac2b44 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -85,21 +85,13 @@ tf_custom_op_py_library( ], ) -py_library( - name = "inc_op_py", - srcs_version = "PY2AND3", - deps = [ - ":inc_op", - ":inc_op_loader", - ], -) - py_library( name = "init_py", srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ - ":inc_op_py", + ":inc_op", + ":inc_op_loader", ], ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py index cb40e08493..aedfb16211 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py @@ -18,7 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy + from tensorflow.contrib import tensorrt +from tensorflow.contrib.tensorrt import custom_plugin_examples from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import dtypes @@ -27,12 +30,6 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops -from tensorflow.python.framework import errors -import numpy - -# import custom_op as plugin op -# the python api handles registration to the plugin factory -from tensorflow.contrib.tensorrt import custom_plugin_examples def get_plugin_graph_def(): diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index b608e602a7..736a1321fe 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -21,12 +21,6 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -PluginFactoryTensorRT* PluginFactoryTensorRT::GetInstance() { - static PluginFactoryTensorRT* factory_instance = - new PluginFactoryTensorRT(); - return factory_instance; -} - PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, const void* serial_data, size_t serial_length) { diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 54fbca5930..0eee705fb9 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -34,7 +34,13 @@ namespace tensorrt { class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { public: - static PluginFactoryTensorRT* GetInstance(); + // TODO(aaroey): this static method has to be inlined to make the singleton a + // unique global symbol. Find a way to fix it. + static PluginFactoryTensorRT* GetInstance() { + static PluginFactoryTensorRT* factory_instance = + new PluginFactoryTensorRT(); + return factory_instance; + } // Deserialization method PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data, -- GitLab From fe9b2637cfe39cf11eb3d0494948a733b7fc1d7d Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Thu, 29 Mar 2018 05:28:16 +0800 Subject: [PATCH 089/839] Parse op definition and generate a Java Op class. --- tensorflow/java/BUILD | 4 + tensorflow/java/src/gen/cc/java_defs.h | 76 ++-- tensorflow/java/src/gen/cc/op_gen_main.cc | 22 +- tensorflow/java/src/gen/cc/op_generator.cc | 406 +++++++++++++++-- tensorflow/java/src/gen/cc/op_generator.h | 42 +- tensorflow/java/src/gen/cc/op_parser.cc | 417 ++++++++++++++++++ tensorflow/java/src/gen/cc/op_parser.h | 137 ++++++ tensorflow/java/src/gen/cc/source_writer.cc | 127 +++--- tensorflow/java/src/gen/cc/source_writer.h | 55 ++- .../java/src/gen/cc/source_writer_test.cc | 82 ++-- tensorflow/java/src/gen/gen_ops.bzl | 29 +- .../src/gen/resources/license.snippet.java | 14 + 12 files changed, 1201 insertions(+), 210 deletions(-) create mode 100644 tensorflow/java/src/gen/cc/op_parser.cc create mode 100644 tensorflow/java/src/gen/cc/op_parser.h create mode 100644 tensorflow/java/src/gen/resources/license.snippet.java diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index ab7d698a45..635a4e807d 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -70,6 +70,7 @@ filegroup( tf_java_op_gen_srcjar( name = "java_op_gen_sources", + api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], gen_base_package = "org.tensorflow.op", gen_tool = "java_op_gen_tool", ops_libs = [ @@ -111,11 +112,13 @@ cc_library( name = "java_op_gen_lib", srcs = [ "src/gen/cc/op_generator.cc", + "src/gen/cc/op_parser.cc", "src/gen/cc/source_writer.cc", ], hdrs = [ "src/gen/cc/java_defs.h", "src/gen/cc/op_generator.h", + "src/gen/cc/op_parser.h", "src/gen/cc/source_writer.h", ], copts = tf_copts(), @@ -124,6 +127,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:op_gen_lib", ], ) diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index 59f8beaee7..2065477f58 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -18,12 +18,15 @@ limitations under the License. #include #include +#include +#include namespace tensorflow { namespace java { // An enumeration of different modifiers commonly used in Java enum Modifier { + PACKAGE = 0, PUBLIC = (1 << 0), PROTECTED = (1 << 1), PRIVATE = (1 << 2), @@ -72,6 +75,12 @@ class Type { // Reflection API does return Type(Type::PRIMITIVE, "void"); } + static Type Generic(const string& name) { + return Type(Type::GENERIC, name); + } + static Type Wildcard() { + return Type(Type::GENERIC, ""); + } static Type Class(const string& name, const string& package = "") { return Type(Type::CLASS, name, package); } @@ -81,9 +90,6 @@ class Type { static Type Enum(const string& name, const string& package = "") { return Type(Type::ENUM, name, package); } - static Type Generic(const string& name = "") { - return Type(Type::GENERIC, name); - } static Type ClassOf(const Type& type) { return Class("Class").add_parameter(type); } @@ -96,11 +102,10 @@ class Type { const Kind& kind() const { return kind_; } const string& name() const { return name_; } const string& package() const { return package_; } - const string& description() const { return description_; } - Type& description(const string& description) { - description_ = description; - return *this; + const string full_name() const { + return package_.empty() ? name_ : package_ + "." + name_; } + bool unknown() const { return name_.empty(); } // only wildcards has no name const std::list& parameters() const { return parameters_; } Type& add_parameter(const Type& parameter) { parameters_.push_back(parameter); @@ -120,14 +125,6 @@ class Type { } return *this; } - // Returns true if "type" is of a known collection type (only a few for now) - bool IsCollection() const { - return name_ == "List" || name_ == "Iterable"; - } - // Returns true if this instance is a wildcard () - bool IsWildcard() const { - return kind_ == GENERIC && name_.empty(); - } protected: Type(Kind kind, const string& name, const string& package = "") @@ -137,7 +134,6 @@ class Type { Kind kind_; string name_; string package_; - string description_; std::list parameters_; std::list annotations_; std::list supertypes_; @@ -180,16 +176,11 @@ class Variable { const string& name() const { return name_; } const Type& type() const { return type_; } bool variadic() const { return variadic_; } - const string& description() const { return description_; } - Variable& description(const string& description) { - description_ = description; - return *this; - } + private: string name_; Type type_; bool variadic_; - string description_; Variable(const string& name, const Type& type, bool variadic) : name_(name), type_(type), variadic_(variadic) {} @@ -210,16 +201,6 @@ class Method { bool constructor() const { return constructor_; } const string& name() const { return name_; } const Type& return_type() const { return return_type_; } - const string& description() const { return description_; } - Method& description(const string& description) { - description_ = description; - return *this; - } - const string& return_description() const { return return_description_; } - Method& return_description(const string& description) { - return_description_ = description; - return *this; - } const std::list& arguments() const { return arguments_; } Method& add_argument(const Variable& var) { arguments_.push_back(var); @@ -235,8 +216,6 @@ class Method { string name_; Type return_type_; bool constructor_; - string description_; - string return_description_; std::list arguments_; std::list annotations_; @@ -244,6 +223,35 @@ class Method { : name_(name), return_type_(return_type), constructor_(constructor) {} }; +// A definition of a documentation bloc for a Java element (JavaDoc) +class Javadoc { + public: + static Javadoc Create(const string& brief = "") { + return Javadoc(brief); + } + const string& brief() const { return brief_; } + const string& details() const { return description_; } + Javadoc& details(const string description) { + description_ = description; + return *this; + } + const std::list> tags() const { return tags_; } + Javadoc& add_tag(const string& tag, const string& text) { + tags_.push_back(std::make_pair(tag, text)); + return *this; + } + Javadoc& add_param_tag(const string& name, const string& text) { + return add_tag("param", name + " " + text); + } + + private: + string brief_; + string description_; + std::list> tags_; + + explicit Javadoc(const string& brief) : brief_(brief) {} +}; + } // namespace java } // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index bea99f3d7f..015200023f 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -48,8 +48,11 @@ const char kUsageHeader[] = "through\n" "the 'org.tensorflow.op.Ops' API as a group until the generated classes " "are compiled using an appropriate annotation processor.\n\n" - "Finally, the '--base_package' overrides the default parent package " - "under which the generated subpackage and classes are to be located.\n\n"; + "The '--base_package' overrides the default parent package under which " + "the generated subpackage and classes are to be located.\n\n" + "Finally, a list of directories of API proto definitions can be provided " + "to override default values found in the ops definitions, ordered by\n" + "priority (the last having precedence over the first).\n\n"; } // namespace java } // namespace tensorflow @@ -60,7 +63,7 @@ int main(int argc, char* argv[]) { tensorflow::string base_package = "org.tensorflow.op"; std::vector flag_list = { tensorflow::Flag("output_dir", &output_dir, - "Root directory into which output files are generated"), + "Root directory into which output files are generated"), tensorflow::Flag( "lib_name", &lib_name, "A name, in snake_case, used to classify this set of operations"), @@ -72,12 +75,15 @@ int main(int argc, char* argv[]) { bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); tensorflow::port::InitMain(usage.c_str(), &argc, &argv); QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage; - - tensorflow::java::OpGenerator generator; + std::vector api_dirs; + if (argc > 1) { + api_dirs = tensorflow::str_util::Split(argv[1], ",", + tensorflow::str_util::SkipEmpty()); + } + tensorflow::java::OpGenerator generator(base_package, output_dir, api_dirs); tensorflow::OpList ops; - tensorflow::OpRegistry::Global()->Export(true, &ops); - tensorflow::Status status = - generator.Run(ops, lib_name, base_package, output_dir); + tensorflow::OpRegistry::Global()->Export(false, &ops); + tensorflow::Status status = generator.Run(ops, lib_name); TF_QCHECK_OK(status); return 0; diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index def06baf2d..c9b57f5706 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -14,53 +14,409 @@ limitations under the License. ==============================================================================*/ #include +#include +#include +#include +#include +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/java/src/gen/cc/java_defs.h" +#include "tensorflow/java/src/gen/cc/source_writer.h" +#include "tensorflow/java/src/gen/cc/op_parser.h" #include "tensorflow/java/src/gen/cc/op_generator.h" namespace tensorflow { namespace java { namespace { -string CamelCase(const string& str, char delimiter, bool upper) { - string result; - bool cap = upper; - for (string::const_iterator it = str.begin(); it != str.end(); ++it) { - const char c = *it; - if (c == delimiter) { - cap = true; - } else if (cap) { - result += toupper(c); - cap = false; +const char* kLicenseSnippet = + "tensorflow/java/src/gen/resources/license.snippet.java"; + +const std::map kPrimitiveAttrTypes = { + { "Boolean", Type::Boolean() }, + { "Byte", Type::Byte() }, + { "Character", Type::Byte() }, + { "Float", Type::Float() }, + { "Integer", Type::Long() }, + { "Long", Type::Long() }, + { "Short", Type::Long() }, + { "Double", Type::Float() }, +}; + +enum RenderMode { + DEFAULT, + SINGLE_OUTPUT, + SINGLE_LIST_OUTPUT +}; + +void CollectOpDependencies(const OpSpec& op, RenderMode mode, + std::list* out) { + out->push_back(Type::Class("Operation", "org.tensorflow")); + out->push_back(Type::Class("OperationBuilder", "org.tensorflow")); + out->push_back(Type::Class("Scope", "org.tensorflow.op")); + if (mode == SINGLE_OUTPUT) { + out->push_back(Type::Class("Output", "org.tensorflow")); + } else if (mode == SINGLE_LIST_OUTPUT) { + out->push_back(Type::Interface("Iterator", "java.util")); + } + // Don't pay attention to duplicate types in the dependency list, they will + // be filtered out by the SourceWriter. + for (const OpSpec::Operand& input : op.inputs()) { + out->push_back(input.var().type()); + if (input.iterable()) { + out->push_back(Type::Class("Operands", "org.tensorflow.op")); + } + } + for (const OpSpec::Operand& output : op.outputs()) { + out->push_back(output.var().type()); + if (output.iterable()) { + out->push_back(Type::Class("Arrays", "java.util")); + } + } + for (const OpSpec::Operand& attribute : op.attributes()) { + out->push_back(attribute.var().type()); + if (attribute.var().type().name() == "Class") { + out->push_back(Type::Enum("DataType", "org.tensorflow")); + } + } + for (const OpSpec::Operand& option : op.options()) { + out->push_back(option.var().type()); + } +} + +void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional, + SourceWriter* writer) { + string var = optional ? "opts." + attr.var().name() : attr.var().name(); + if (attr.iterable()) { + const Type& type = attr.data_type(); + std::map::const_iterator it = + kPrimitiveAttrTypes.find(type.name()); + if (it != kPrimitiveAttrTypes.end()) { + string array = attr.var().name() + "Array"; + writer->AppendType(it->second) + .Append("[] " + array + " = new ") + .AppendType(it->second) + .Append("[" + var + ".size()];") + .EndLine(); + writer->BeginBlock("for (int i = 0; i < " + array + ".length; ++i)") + .Append(array + "[i] = " + var + ".get(i);") + .EndLine() + .EndBlock() + .Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + array) + .Append(");") + .EndLine(); } else { - result += c; + writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + var) + .Append(".toArray(new ") + .AppendType(type) + .Append("[" + var + ".size()]));") + .EndLine(); } + } else { + Type type = attr.var().type(); + writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", "); + if (type.name() == "Class") { + writer->Append("DataType.fromClass(" + attr.var().name() + "));"); + } else { + writer->Append(var + ");"); + } + writer->EndLine(); } - return result; } -} // namespace +void RenderFactoryMethod(const OpSpec& op, const Type& op_class, + SourceWriter* writer) { + Method factory = Method::Create("create", op_class); + Javadoc factory_doc = Javadoc::Create( + "Factory method to create a class to wrap a new " + op_class.name() + + " operation to the graph."); + Variable scope = + Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op")); + factory.add_argument(scope); + factory_doc.add_param_tag(scope.name(), "Current graph scope"); + for (const OpSpec::Operand& input : op.inputs()) { + factory.add_argument(input.var()); + factory_doc.add_param_tag(input.var().name(), input.description()); + } + for (const OpSpec::Operand& attribute : op.attributes()) { + factory.add_argument(attribute.var()); + factory_doc.add_param_tag(attribute.var().name(), attribute.description()); + } + if (!op.options().empty()) { + factory.add_argument(Variable::Varargs("options", Type::Class("Options"))); + factory_doc.add_param_tag("options", "carries optional attributes values"); + } + factory_doc.add_tag("return", "a new instance of " + op_class.name()); + writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc); + writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" + + op.graph_name() + "\", scope.makeOpName(\"" + + op_class.name() + "\"));"); + writer->EndLine(); -OpGenerator::OpGenerator() : env(Env::Default()) {} + for (const OpSpec::Operand& input : op.inputs()) { + if (input.iterable()) { + writer->Append("opBuilder.addInputList(Operands.asOutputs(" + + input.var().name() + "));"); + writer->EndLine(); + } else { + writer->Append("opBuilder.addInput(" + input.var().name() + + ".asOutput());"); + writer->EndLine(); + } + } + for (const OpSpec::Operand& attribute : op.attributes()) { + WriteSetAttrDirective(attribute, false, writer); + } + if (!op.options().empty()) { + writer->BeginBlock("if (options != null)") + .BeginBlock("for (Options opts : options)"); + for (const OpSpec::Operand& option : op.options()) { + writer->BeginBlock("if (opts." + option.var().name() + " != null)"); + WriteSetAttrDirective(option, true, writer); + writer->EndBlock(); + } + writer->EndBlock().EndBlock(); + } + writer->Append("return new ") + .AppendType(op_class) + .Append("(opBuilder.build());") + .EndLine(); + writer->EndMethod(); +} -OpGenerator::~OpGenerator() {} +void RenderConstructor(const OpSpec& op, const Type& op_class, + SourceWriter* writer) { + Method constructor = Method::ConstructorFor(op_class) + .add_argument( + Variable::Create("operation", + Type::Class("Operation", "org.tensorflow"))); + for (const OpSpec::Operand& output : op.outputs()) { + if (output.iterable() && !output.data_type().unknown()) { + constructor.add_annotation( + Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); + break; + } + } + writer->BeginMethod(constructor, PRIVATE) + .Append("super(operation);") + .EndLine(); + if (op.outputs().size() > 0) { + writer->Append("int outputIdx = 0;") + .EndLine(); + for (const OpSpec::Operand& output : op.outputs()) { + if (output.iterable()) { + string var_length = output.var().name() + "Length"; + writer->Append("int " + var_length) + .Append(" = operation.outputListLength(\"" + output.graph_name() + + "\");") + .EndLine() + .Append(output.var().name() + " = Arrays.asList("); + if (!output.data_type().unknown()) { + writer->Append("(") + .AppendType(output.var().type().parameters().front()) + .Append("[])"); + } + writer->Append("operation.outputList(outputIdx, " + var_length + "));") + .EndLine() + .Append("outputIdx += " + var_length + ";") + .EndLine(); + } else { + writer->Append(output.var().name() + + " = operation.output(outputIdx++);") + .EndLine(); + } + } + } + writer->EndMethod(); +} -Status OpGenerator::Run(const OpList& ops, const string& lib_name, - const string& base_package, const string& output_dir) { - const string package = - base_package + '.' + str_util::StringReplace(lib_name, "_", "", true); - const string package_path = - output_dir + '/' + str_util::StringReplace(package, ".", "/", true); - const string group = CamelCase(lib_name, '_', false); +void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) { + for (const OpSpec::Operand& option : op.options()) { + Method setter = Method::Create(option.var().name(), Type::Class("Options")) + .add_argument(option.var()); + Javadoc setter_doc = Javadoc::Create() + .add_param_tag(option.var().name(), option.description()); + writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc) + .Append("return new Options()." + option.var().name() + "(" + + option.var().name() + ");") + .EndLine() + .EndMethod(); + } + for (const OpSpec::Operand& output : op.outputs()) { + Method getter = Method::Create(output.var().name(), output.var().type()); + Javadoc getter_doc = Javadoc::Create(output.description()); + writer->BeginMethod(getter, PUBLIC, &getter_doc) + .Append("return " + output.var().name() + ";") + .EndLine() + .EndMethod(); + } +} + +void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, + SourceWriter* writer) { + OpSpec::Operand output = op.outputs().front(); + + if (mode == SINGLE_OUTPUT) { + bool cast2obj = output.data_type().unknown(); + Type return_type = Type::Class("Output", "org.tensorflow") + .add_parameter(cast2obj ? Type::Class("Object") : output.data_type()); + Method as_output = Method::Create("asOutput", return_type) + .add_annotation(Annotation::Create("Override")); + if (cast2obj) { + as_output.add_annotation( + Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); + } + writer->BeginMethod(as_output, PUBLIC); + if (cast2obj) { + writer->Append("return (").AppendType(return_type).Append(") "); + } else { + writer->Append("return "); + } + writer->Append(output.var().name() + ";") + .EndLine() + .EndMethod(); + + } else if (mode == SINGLE_LIST_OUTPUT) { + Type operand = Type::Interface("Operand", "org.tensorflow"); + if (output.data_type().unknown()) { + operand.add_parameter(Type::Class("Object")); + } else { + operand.add_parameter(output.data_type()); + } + Type return_type = Type::Interface("Iterator", "java.util") + .add_parameter(operand); + Method iterator = Method::Create("iterator", return_type) + .add_annotation(Annotation::Create("Override")) + .add_annotation(Annotation::Create("SuppressWarnings") + .attributes("{\"rawtypes\", \"unchecked\"}")); + // cast the output list using a raw List + writer->BeginMethod(iterator, PUBLIC) + .Append("return (" + return_type.name() + ") ") + .Append(output.var().name() + ".iterator();") + .EndLine() + .EndMethod(); + } +} + +void RenderOptionsClass(const OpSpec& op, SourceWriter* writer) { + Type options_class = Type::Class("Options"); + Javadoc options_doc = Javadoc::Create( + "Class holding optional attributes of this operation"); + writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc); + for (const OpSpec::Operand& option : op.options()) { + Method setter = Method::Create(option.var().name(), options_class) + .add_argument(option.var()); + Javadoc setter_doc = Javadoc::Create() + .add_param_tag(option.var().name(), option.description()); + writer->BeginMethod(setter, PUBLIC, &setter_doc) + .Append("this." + option.var().name() + " = " + option.var().name() + + ";") + .EndLine() + .Append("return this;") + .EndLine() + .EndMethod(); + } + writer->EndLine(); + for (const OpSpec::Operand& option : op.options()) { + writer->WriteField(option.var(), PRIVATE); + } + Method constructor = Method::ConstructorFor(options_class); + writer->BeginMethod(constructor, PRIVATE).EndMethod(); + writer->EndType(); +} - if (!env->FileExists(package_path).ok()) { - TF_CHECK_OK(env->RecursivelyCreateDir(package_path)); +void RenderEndpoint(const OpSpec& op, const OpSpec::Endpoint& endpoint, + SourceWriter* writer) { + RenderMode mode = DEFAULT; + if (op.outputs().size() == 1) { + mode = op.outputs().front().iterable() ? SINGLE_LIST_OUTPUT : SINGLE_OUTPUT; + } + std::list dependencies; + CollectOpDependencies(op, mode, &dependencies); + const Type& op_class = endpoint.type(); + writer->WriteFromFile(kLicenseSnippet) + .EndLine() + .Append("// This file is machine generated, DO NOT EDIT!") + .EndLine() + .EndLine() + .BeginType(op_class, PUBLIC|FINAL, &dependencies, &endpoint.javadoc()); + if (!op.options().empty()) { + RenderOptionsClass(op, writer); } + RenderFactoryMethod(op, op_class, writer); + RenderGettersAndSetters(op, writer); + if (mode != DEFAULT) { + RenderInterfaceImpl(op, mode, writer); + } + writer->EndLine(); + for (const OpSpec::Operand& output : op.outputs()) { + writer->WriteField(output.var(), PRIVATE); + } + RenderConstructor(op, op_class, writer); + writer->EndType(); +} + +} // namespace + +OpGenerator::OpGenerator(const string& base_package, const string& output_dir, + const std::vector& api_dirs, Env* env) + : base_package_(base_package), output_dir_(output_dir), api_dirs_(api_dirs), + env_(env) { +} +Status OpGenerator::Run(const OpList& op_list, const string& lib_name) { LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations"; - // TODO(karllessard) generate wrappers from list of ops + ApiDefMap api_map(op_list); + if (!api_dirs_.empty()) { + // Only load api files that correspond to the requested "op_list" + for (const auto& op : op_list.op()) { + for (const auto& api_def_dir : api_dirs_) { + const std::string api_def_file_pattern = + io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt"); + if (env_->FileExists(api_def_file_pattern).ok()) { + TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern)); + } + } + } + } + api_map.UpdateDocs(); + for (const auto& op_def : op_list.op()) { + const ApiDef* api_def = api_map.GetApiDef(op_def.name()); + if (api_def->visibility() != ApiDef::SKIP) { + Status status = GenerateOp(op_def, *api_def, lib_name); + if (status != Status::OK()) { + LOG(ERROR) << "Fail to generate Java wrapper for operation \"" + << op_def.name() << "\""; + } + } + } + return Status::OK(); +} + +Status OpGenerator::GenerateOp(const OpDef& op_def, const ApiDef& api_def, + const string& lib_name) { + std::unique_ptr op; + OpParser op_parser(op_def, api_def, lib_name, base_package_); + op_parser.Parse(&op); + for (const OpSpec::Endpoint& endpoint : op->endpoints()) { + string package_path = io::JoinPath(output_dir_, + str_util::StringReplace(endpoint.type().package(), ".", "/", true)); + if (!env_->FileExists(package_path).ok()) { + TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(package_path)); + } + string file_path = + io::JoinPath(package_path, endpoint.type().name() + ".java"); + std::unique_ptr file; + TF_CHECK_OK(env_->NewWritableFile(file_path, &file)); + SourceFileWriter writer(file.get()); + RenderEndpoint(*op, endpoint, &writer); + } return Status::OK(); } diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index 4b55ed3ed9..19d8db95fb 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -17,34 +17,42 @@ limitations under the License. #define TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_ #include +#include -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/lib/core/status.h" namespace tensorflow { namespace java { -/// \brief A generator of Java operation wrappers. -/// -/// Such generator is normally ran only once per executable, outputting -/// wrappers for the all registered operations it has been compiled with. -/// Nonetheless, it is designed to support multiple runs, giving a different -/// list of operations on each cycle. +// A generator of Java operation wrappers. +// +// Such generator is normally ran only once per executable, outputting +// wrappers for the all registered operations it has been compiled with. +// Nonetheless, it is designed to support multiple runs, giving a different +// list of operations on each cycle. class OpGenerator { public: - OpGenerator(); - virtual ~OpGenerator(); + OpGenerator(const string& base_package, const string& output_dir, + const std::vector& api_dirs, Env* env = Env::Default()); + virtual ~OpGenerator() = default; - /// \brief Generates wrappers for the given list of 'ops'. - /// - /// Output files are generated in //, - /// where 'lib_package' is derived from 'lib_name'. - Status Run(const OpList& ops, const string& lib_name, - const string& base_package, const string& output_dir); + // Generates wrappers for the given list of 'ops'. + // + // Output files are generated in //, + // where 'lib_package' is derived from 'lib_name'. + Status Run(const OpList& op_list, const string& lib_name); private: - Env* env; + string base_package_; + string output_dir_; + std::vector api_dirs_; + Env* env_; + + Status GenerateOp(const OpDef& op_def, const ApiDef& api_def, + const string& lib_name); }; } // namespace java diff --git a/tensorflow/java/src/gen/cc/op_parser.cc b/tensorflow/java/src/gen/cc/op_parser.cc new file mode 100644 index 0000000000..0541e343d8 --- /dev/null +++ b/tensorflow/java/src/gen/cc/op_parser.cc @@ -0,0 +1,417 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/java/src/gen/cc/op_parser.h" + +namespace tensorflow { +namespace java { +namespace { + +string SnakeToCamelCase(const string& str, bool upper = false) { + string result; + bool cap = upper; + for (string::const_iterator it = str.begin(); it != str.end(); ++it) { + const char c = *it; + if (c == '_') { + cap = true; + } else if (cap) { + result += toupper(c); + cap = false; + } else { + result += c; + } + } + return result; +} + +bool IsRealNumber(DataType type) { + for (DataType dt : RealNumberTypes()) { + if (type == dt) { + return true; + } + } + return false; +} + +bool IsRealNumbers(const AttrValue& values) { + if (values.has_list()) { + for (int i = 0; i < values.list().type_size(); ++i) { + if (!IsRealNumber(values.list().type(i))) { + return false; + } + } + return true; + } + return IsRealNumber(values.type()); +} + +string ParseDocumentation(const string& text) { + std::stringstream javadoc_text; + string::const_iterator c_iter = text.cbegin(); + bool code = false; + bool emphasis = false; + bool list = false; + while (c_iter != text.cend()) { + char c = *c_iter++; + int count = 1; + switch (c) { + case '\n': + if (!code) { + // consumes all subsequent newlines, if there are more than one, + // then there are two choices: + // - if the next line starts with an asterisk, we are enumerating + // a list of items + // - otherwise, we are starting a new paragraph + for (; c_iter != text.cend() && *c_iter == '\n'; ++count, ++c_iter) {} + if (c_iter != text.cend()) { + if (count > 1) { + if (*c_iter != '*' && list) { + javadoc_text << "\n\n"; + list = false; + } else if (*c_iter == '*' && !list) { + javadoc_text << "\n
    \n
  • "; + list = true; + c_iter++; + } else { + javadoc_text << "\n

    \n"; + } + } else if (list && *c_iter == '*') { + javadoc_text << "

  • \n
  • "; + c_iter++; + } else { + javadoc_text << '\n'; + } + } + } + break; + case '`': + // consumes all subsequent backquotes, those are use enclose code. + // if there are more than 3, we are dealing with a pre-formatted block, + // otherwise it is a single-line code snippet + for (; c_iter != text.cend() && *c_iter == '`'; ++count, ++c_iter) {} + if (count >= 3) { + javadoc_text << (code ? "\n}" : "
    {@code\n");
    +      } else {
    +        javadoc_text << (code ? "}" : "{@code ");
    +      }
    +      code = !code;
    +      break;
    +    case '*':
    +      if (!code) {
    +        // consumes all subsequent asterisks, if there are more than one, then
    +        // we put the text in bold, otherwise in italic
    +        for (; c_iter != text.cend() && *c_iter == '*'; ++count, ++c_iter) {}
    +        if (count > 1) {
    +          javadoc_text << (emphasis ? "" : "");
    +        } else {
    +          javadoc_text << (emphasis ? "" : "");
    +        }
    +        emphasis = !emphasis;
    +      } else {
    +        javadoc_text << '*';
    +      }
    +      break;
    +    default:
    +      javadoc_text << c;
    +      break;
    +    }
    +  }
    +  return javadoc_text.str();
    +}
    +
    +}  // namespace
    +
    +OpParser::OpParser(const OpDef& op_def, const ApiDef& api_def,
    +    const string& lib_name, const string& base_package)
    +  : op_def_(op_def), op_api_(api_def), lib_name_(lib_name),
    +    base_package_(base_package) {
    +}
    +
    +void OpParser::Parse(std::unique_ptr* op_ptr) {
    +  visited_attrs_.clear();
    +  next_generic_ = 'T';
    +  op_ptr->reset(new OpSpec(op_api_.graph_op_name()));
    +  for (const string& next_input_name : op_api_.arg_order()) {
    +    for (int i = 0; i < op_def_.input_arg().size(); ++i) {
    +      if (op_def_.input_arg(i).name() == next_input_name) {
    +        ParseInput(op_def_.input_arg(i), op_api_.in_arg(i), op_ptr->get());
    +        break;
    +      }
    +    }
    +  }
    +  for (int i = 0; i < op_def_.attr().size(); ++i) {
    +    ParseAttribute(op_def_.attr(i), op_api_.attr(i), op_ptr->get());
    +  }
    +  for (int i = 0; i < op_def_.output_arg().size(); ++i) {
    +    ParseOutput(op_def_.output_arg(i), op_api_.out_arg(i), op_ptr->get());
    +  }
    +  BuildEndpoints(op_ptr->get());
    +}
    +
    +void OpParser::BuildEndpoints(OpSpec* op) {
    +  Javadoc op_doc = Javadoc::Create(ParseDocumentation(op_api_.summary()))
    +    .details(ParseDocumentation(op_api_.description()));
    +  std::vector op_supertypes;
    +  op_supertypes.push_back(Type::Class("PrimitiveOp", "org.tensorflow.op"));
    +  std::map op_generics;
    +  for (const OpSpec::Operand& output : op->outputs()) {
    +    // declare generic output parameters at the Op class level
    +    const Type& data_type = output.data_type();
    +    if (data_type.kind() == Type::GENERIC && !data_type.unknown()
    +        && op_generics.find(data_type.name()) == op_generics.end()) {
    +      op_generics.insert(std::make_pair(data_type.name(), &data_type));
    +      op_doc.add_param_tag("<" + data_type.name() + ">",
    +          "data type of output '" + output.var().name() + "'");
    +    }
    +    // implement the Op as an (iteration of) Operand if it has only one output
    +    if (op->outputs().size() == 1) {
    +      Type operand_inf(Type::Interface("Operand", "org.tensorflow"));
    +      operand_inf.add_parameter(data_type.unknown() ?
    +          Type::Class("Object") : data_type);
    +      op_supertypes.push_back(output.iterable() ?
    +          Type::IterableOf(operand_inf) : operand_inf);
    +    }
    +  }
    +  for (const auto& endpoint_def : op_api_.endpoint()) {
    +    std::vector name_tokens = str_util::Split(endpoint_def.name(), ".");
    +    // if the endpoint specifies a package, use it, otherwise derive it from the
    +    // op library name.
    +    string name;
    +    string package;
    +    if (name_tokens.size() > 1) {
    +      package = str_util::Lowercase(name_tokens.at(0));
    +      name = name_tokens.at(1);
    +    } else {
    +      package = str_util::StringReplace(lib_name_, "_", "", true);
    +      name = name_tokens.at(0);
    +    }
    +    Type endpoint(Type::Class(name, base_package_ + "." + package));
    +    Javadoc endpoint_doc(op_doc);
    +    for (const auto& parameter : op_generics) {
    +      endpoint.add_parameter(*parameter.second);
    +    }
    +    for (const Type& supertype : op_supertypes) {
    +      endpoint.add_supertype(supertype);
    +    }
    +    if (endpoint_def.deprecation_version() > 0) {
    +      string explanation;
    +      if (op_api_.endpoint(0).deprecation_version() == 0) {
    +        explanation = ", use {@link "
    +            + op->endpoints().at(0).type().full_name()
    +            + "} instead";
    +      } else {
    +        explanation = op_def_.deprecation().explanation();
    +      }
    +      endpoint_doc.add_tag("deprecated", explanation);
    +      endpoint.add_annotation(Annotation::Create("Deprecated"));
    +    }
    +    // only visible ops should be annotated for exposure in the Ops Graph API
    +    if (op_api_.visibility() != ApiDef::HIDDEN) {
    +      string group_name = SnakeToCamelCase(lib_name_);
    +      endpoint.add_annotation(
    +          Annotation::Create("Operator", "org.tensorflow.op.annotation")
    +            .attributes("group = \"" + group_name + "\""));
    +    }
    +    op->add_endpoint(endpoint, endpoint_doc);
    +  }
    +}
    +
    +void OpParser::ParseInput(const OpDef_ArgDef& input_def,
    +    const ApiDef::Arg& input_api, OpSpec* op) {
    +  bool iterable = false;
    +  Type data_type = DataTypeOf(input_def, &iterable);
    +  Type type = Type::Interface("Operand", "org.tensorflow")
    +    .add_parameter(data_type);
    +  if (iterable) {
    +    type = Type::IterableOf(type);
    +  }
    +  op->add_input(OpSpec::Operand(input_api.name(),
    +      Variable::Create(SnakeToCamelCase(input_api.rename_to()), type),
    +      data_type,
    +      ParseDocumentation(input_api.description()),
    +      iterable));
    +}
    +
    +void OpParser::ParseOutput(const OpDef_ArgDef& output_def,
    +    const ApiDef::Arg& output_api, OpSpec* op) {
    +  bool iterable = false;
    +  Type data_type = DataTypeOf(output_def, &iterable);
    +  Type type = Type::Class("Output", "org.tensorflow")
    +    .add_parameter(data_type);
    +  if (iterable) {
    +    type = Type::ListOf(type);
    +  }
    +  op->add_output(OpSpec::Operand(output_api.name(),
    +      Variable::Create(SnakeToCamelCase(output_api.rename_to()), type),
    +      data_type,
    +      ParseDocumentation(output_api.description()),
    +      iterable));
    +}
    +
    +void OpParser::ParseAttribute(const OpDef_AttrDef& attr_def,
    +    const ApiDef::Attr& attr_api, OpSpec* op) {
    +  // do not parse attributes already visited, they have probably been inferred
    +  // before as an input argument type
    +  if (visited_attrs_.find(attr_def.name()) != visited_attrs_.cend()) {
    +    return;
    +  }
    +  bool iterable = false;
    +  Type data_type = DataTypeOf(attr_def, &iterable);
    +  // generic attributes should be passed as an explicit type
    +  bool explicit_type = data_type.kind() == Type::GENERIC && !iterable;
    +  Type type = explicit_type ?
    +      Type::Class("Class").add_parameter(data_type) : data_type;
    +  if (iterable) {
    +    type = Type::ListOf(data_type);
    +  }
    +  OpSpec::Operand attr(attr_api.name(),
    +      Variable::Create(SnakeToCamelCase(attr_api.rename_to()), type),
    +      data_type,
    +      ParseDocumentation(attr_api.description()),
    +      iterable);
    +  // attributes with a default value are optional
    +  if (attr_api.has_default_value() && !explicit_type) {
    +    op->add_option(attr);
    +  } else {
    +    op->add_attribute(attr);
    +  }
    +  visited_attrs_.insert(std::make_pair(attr_api.name(), data_type));
    +}
    +
    +Type OpParser::DataTypeOf(const OpDef_ArgDef& arg, bool* iterable_out) {
    +  if (!arg.number_attr().empty()) {
    +    visited_attrs_.insert(std::make_pair(arg.number_attr(), Type::Int()));
    +    *iterable_out = true;
    +  }
    +  if (arg.type() != DataType::DT_INVALID) {
    +    // resolve type from DataType
    +    switch (arg.type()) {
    +      case DataType::DT_BOOL:
    +        return Type::Class("Boolean");
    +
    +      case DataType::DT_STRING:
    +        return Type::Class("String");
    +
    +      case DataType::DT_FLOAT:
    +        return Type::Class("Float");
    +
    +      case DataType::DT_DOUBLE:
    +        return Type::Class("Double");
    +
    +      case DataType::DT_UINT8:
    +        return Type::Class("UInt8", "org.tensorflow.types");
    +
    +      case DataType::DT_INT32:
    +        return Type::Class("Integer");
    +
    +      case DataType::DT_INT64:
    +        return Type::Class("Long");
    +
    +      case DataType::DT_RESOURCE:
    +        // TODO(karllessard) create a Resource utility class that could be
    +        // used to store a resource and its type (passed in a second argument).
    +        // For now, we need to force a wildcard and we will unfortunately lose
    +        // track of the resource type.
    +        return Type::Wildcard();
    +
    +      default:
    +        break;
    +    }
    +  } else {
    +    // resolve type from type attribute
    +    string attr_name = arg.type_attr();
    +    if (attr_name.empty()) {
    +      attr_name = arg.type_list_attr();
    +      if (!attr_name.empty()) {
    +        *iterable_out = true;
    +        Type type = Type::Wildcard();
    +        visited_attrs_.insert(std::make_pair(attr_name, type));
    +        return type;
    +      }
    +    }
    +    for (const auto& attr : op_def_.attr()) {
    +      if (attr.name() == attr_name) {
    +        Type type = DataTypeOf(attr, iterable_out);
    +        visited_attrs_.insert(std::make_pair(attr_name, type));
    +        return type;
    +      }
    +    }
    +  }
    +  LOG(WARNING) << "Data type for arg \"" << arg.name() << "\" is unknown";
    +  return Type::Wildcard();
    +}
    +
    +Type OpParser::DataTypeOf(const OpDef_AttrDef& attr, bool* iterable_out) {
    +  std::map::const_iterator it = visited_attrs_.find(attr.name());
    +  if (it != visited_attrs_.cend()) {
    +    return it->second;
    +  }
    +  string attr_type = attr.type();
    +  if (attr.type().compare(0, 5, "list(") == 0) {
    +    attr_type = attr_type.substr(5, attr.type().find_last_of(')') - 5);
    +    *iterable_out = true;
    +  }
    +  if (attr_type == "type") {
    +    if (*iterable_out) {
    +      return Type::Enum("DataType", "org.tensorflow");
    +    }
    +    return GetNextGenericTensorType(attr.allowed_values());
    +  }
    +  if (attr_type == "string") {
    +    return Type::Class("String");
    +  }
    +  if (attr_type == "int") {
    +    return Type::Class("Integer");
    +  }
    +  if (attr_type == "float") {
    +    return Type::Class("Float");
    +  }
    +  if (attr_type == "bool") {
    +    return Type::Class("Boolean");
    +  }
    +  if (attr_type == "shape") {
    +    return Type::Class("Shape", "org.tensorflow");
    +  }
    +  if (attr_type == "tensor") {
    +    return Type::Class("Tensor", "org.tensorflow")
    +      .add_parameter(Type::Wildcard());
    +  }
    +  LOG(WARNING) << "Data type for attribute \"" << attr_type << "\" is unknown";
    +  return *iterable_out ? Type::Wildcard() : Type::Class("Object");
    +}
    +
    +Type OpParser::GetNextGenericTensorType(const AttrValue& allowed_values)  {
    +  Type generic = Type::Generic(string(1, next_generic_));
    +  next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1;
    +
    +  // when only real numbers are allowed, enforce that restriction in the Java by
    +  // extending the generic from java.lang.Number
    +  if (IsRealNumbers(allowed_values)) {
    +    generic.add_supertype(Type::Class("Number"));
    +  }
    +  return generic;
    +}
    +
    +}  // namespace java
    +}  // namespace tensorflow
    diff --git a/tensorflow/java/src/gen/cc/op_parser.h b/tensorflow/java/src/gen/cc/op_parser.h
    new file mode 100644
    index 0000000000..42855127cc
    --- /dev/null
    +++ b/tensorflow/java/src/gen/cc/op_parser.h
    @@ -0,0 +1,137 @@
    +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
    +
    +Licensed under the Apache License, Version 2.0 (the "License");
    +you may not use this file except in compliance with the License.
    +You may obtain a copy of the License at
    +
    +    http://www.apache.org/licenses/LICENSE-2.0
    +
    +Unless required by applicable law or agreed to in writing, software
    +distributed under the License is distributed on an "AS IS" BASIS,
    +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +See the License for the specific language governing permissions and
    +limitations under the License.
    +==============================================================================*/
    +
    +#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
    +#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
    +
    +#include 
    +#include 
    +#include 
    +#include 
    +
    +#include "tensorflow/core/framework/op_def.pb.h"
    +#include "tensorflow/core/framework/api_def.pb.h"
    +#include "tensorflow/java/src/gen/cc/java_defs.h"
    +
    +namespace tensorflow {
    +namespace java {
    +
    +// Specification of a TensorFlow operation to generate.
    +//
    +// This is the result of an operation definition parsing, see OpParser::Parse().
    +class OpSpec {
    + public:
    +  class Endpoint {
    +   public:
    +    Endpoint(const Type& type, const Javadoc& javadoc)
    +      : type_(type), javadoc_(javadoc) {}
    +    const Type& type() const { return type_; }
    +    const Javadoc& javadoc() const { return javadoc_; }
    +
    +   private:
    +    Type type_;
    +    Javadoc javadoc_;
    +  };
    +
    +  class Operand {
    +   public:
    +    Operand(const string& graph_name, const Variable& var,
    +        const Type& data_type, const string& description, bool iterable)
    +     : graph_name_(graph_name), var_(var), data_type_(data_type),
    +       description_(description), iterable_(iterable) {}
    +    const string& graph_name() const { return graph_name_; }
    +    const Variable& var() const { return var_; }
    +    Variable* var_ptr() { return &var_; }
    +    const Type& data_type() const { return data_type_; }
    +    const string& description() const { return description_; }
    +    bool iterable() const { return iterable_; }
    +
    +   private:
    +    string graph_name_;
    +    Variable var_;
    +    Type data_type_;
    +    string description_;
    +    bool iterable_;
    +  };
    +
    +  explicit OpSpec(const string& graph_name) : graph_name_(graph_name) {}
    +  const string& graph_name() const { return graph_name_; }
    +  const std::vector endpoints() const { return endpoints_; }
    +  void add_endpoint(const Type& type, const Javadoc& javadoc) {
    +    endpoints_.push_back(Endpoint(type, javadoc));
    +  }
    +  const std::vector& inputs() const { return inputs_; }
    +  void add_input(const Operand& input) {
    +    inputs_.push_back(input);
    +  }
    +  const std::vector& outputs() const { return outputs_; }
    +  void add_output(const Operand& output) {
    +    outputs_.push_back(output);
    +  }
    +  const std::vector& attributes() const { return attributes_; }
    +  void add_attribute(const Operand& attribute) {
    +    attributes_.push_back(attribute);
    +  }
    +  const std::vector& options() const { return options_; }
    +  void add_option(const Operand& option) {
    +    options_.push_back(option);
    +  }
    +
    + private:
    +  string graph_name_;
    +  std::vector endpoints_;
    +  std::vector inputs_;
    +  std::vector outputs_;
    +  std::vector attributes_;
    +  std::vector options_;
    +};
    +
    +// A parser of ops proto definitions.
    +//
    +// This object parses the definition and the api of an TensorFlow operation to
    +// produce a specification that can be used for Java source code rendering.
    +class OpParser {
    + public:
    +  OpParser(const OpDef& op_def, const ApiDef& api_def, const string& lib_name,
    +      const string& base_package);
    +  virtual ~OpParser() = default;
    +
    +  // Produces an operation specification from its proto definitions.
    +  void Parse(std::unique_ptr* op_ptr);
    +
    + private:
    +  OpDef op_def_;
    +  ApiDef op_api_;
    +  string lib_name_;
    +  string base_package_;
    +  std::map visited_attrs_;
    +  char next_generic_ = 0;
    +
    +  void BuildEndpoints(OpSpec* op);
    +  void ParseInput(const OpDef_ArgDef& input_def,
    +      const ApiDef::Arg& input_api, OpSpec* op);
    +  void ParseOutput(const OpDef_ArgDef& output_def,
    +      const ApiDef::Arg& output_api, OpSpec* op);
    +  void ParseAttribute(const OpDef_AttrDef& attr_def,
    +      const ApiDef::Attr& attr_api, OpSpec* op);
    +  Type DataTypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
    +  Type DataTypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out);
    +  Type GetNextGenericTensorType(const AttrValue& allowed_values);
    +};
    +
    +}  // namespace java
    +}  // namespace tensorflow
    +
    +#endif  // TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
    diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc
    index a02f75ad6e..b1de5af6ba 100644
    --- a/tensorflow/java/src/gen/cc/source_writer.cc
    +++ b/tensorflow/java/src/gen/cc/source_writer.cc
    @@ -15,7 +15,7 @@ limitations under the License.
     
     #include 
     #include 
    -#include 
    +#include 
     
     #include "tensorflow/java/src/gen/cc/source_writer.h"
     
    @@ -83,20 +83,20 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) {
     }
     
     SourceWriter& SourceWriter::AppendType(const Type& type) {
    -  if (type.kind() == Type::Kind::GENERIC && type.name().empty()) {
    +  if (type.unknown()) {
         Append("?");
       } else {
         Append(type.name());
    -  }
    -  if (!type.parameters().empty()) {
    -    Append("<");
    -    for (const Type& t : type.parameters()) {
    -      if (&t != &type.parameters().front()) {
    -        Append(", ");
    +    if (!type.parameters().empty()) {
    +      Append("<");
    +      for (const Type& t : type.parameters()) {
    +        if (&t != &type.parameters().front()) {
    +          Append(", ");
    +        }
    +        AppendType(t);
           }
    -      AppendType(t);
    +      Append(">");
         }
    -    Append(">");
       }
       return *this;
     }
    @@ -107,7 +107,21 @@ SourceWriter& SourceWriter::EndLine() {
       return *this;
     }
     
    -SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) {
    +SourceWriter& SourceWriter::BeginBlock(const string& expression) {
    +  if (!expression.empty()) {
    +    Append(expression + " {");
    +  } else {
    +    Append(newline_ ? "{" : " {");
    +  }
    +  return EndLine().Indent(2);
    +}
    +
    +SourceWriter& SourceWriter::EndBlock() {
    +  return Indent(-2).Append("}").EndLine();
    +}
    +
    +SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers,
    +    const Javadoc* javadoc) {
       GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
       if (!method.constructor()) {
         generic_namespace->Visit(method.return_type());
    @@ -116,8 +130,9 @@ SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) {
         generic_namespace->Visit(v.type());
       }
       EndLine();
    -  WriteDoc(method.description(), method.return_description(),
    -      &method.arguments());
    +  if (javadoc != nullptr) {
    +    WriteJavadoc(*javadoc);
    +  }
       if (!method.annotations().empty()) {
         WriteAnnotations(method.annotations());
       }
    @@ -145,29 +160,35 @@ SourceWriter& SourceWriter::EndMethod() {
       return *this;
     }
     
    -SourceWriter& SourceWriter::BeginType(const Type& type,
    -    const std::list* dependencies, int modifiers) {
    +SourceWriter& SourceWriter::BeginType(const Type& type, int modifiers,
    +    const std::list* extra_dependencies, const Javadoc* javadoc) {
       if (!type.package().empty()) {
         Append("package ").Append(type.package()).Append(";").EndLine();
       }
    -  if (dependencies != nullptr && !dependencies->empty()) {
    -    TypeImporter type_importer(type.package());
    -    for (const Type& t : *dependencies) {
    +  TypeImporter type_importer(type.package());
    +  type_importer.Visit(type);
    +  if (extra_dependencies != nullptr) {
    +    for (const Type& t : *extra_dependencies) {
           type_importer.Visit(t);
         }
    +  }
    +  if (!type_importer.imports().empty()) {
         EndLine();
         for (const string& s : type_importer.imports()) {
           Append("import ").Append(s).Append(";").EndLine();
         }
       }
    -  return BeginInnerType(type, modifiers);
    +  return BeginInnerType(type, modifiers, javadoc);
     }
     
    -SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers) {
    +SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers,
    +    const Javadoc* javadoc) {
       GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
       generic_namespace->Visit(type);
       EndLine();
    -  WriteDoc(type.description());
    +  if (javadoc != nullptr) {
    +    WriteJavadoc(*javadoc);
    +  }
       if (!type.annotations().empty()) {
         WriteAnnotations(type.annotations());
       }
    @@ -200,14 +221,15 @@ SourceWriter& SourceWriter::EndType() {
       return *this;
     }
     
    -SourceWriter& SourceWriter::WriteFields(const std::list& fields,
    -    int modifiers) {
    -  EndLine();
    -  for (const Variable& v : fields) {
    -    WriteModifiers(modifiers);
    -    AppendType(v.type()).Append(" ").Append(v.name()).Append(";");
    -    EndLine();
    +SourceWriter& SourceWriter::WriteField(const Variable& field, int modifiers,
    +    const Javadoc* javadoc) {
    +  // If present, write field javadoc only as one brief line
    +  if (javadoc != nullptr && !javadoc->brief().empty()) {
    +    Append("/** ").Append(javadoc->brief()).Append(" */").EndLine();
       }
    +  WriteModifiers(modifiers);
    +  AppendType(field.type()).Append(" ").Append(field.name()).Append(";");
    +  EndLine();
       return *this;
     }
     
    @@ -228,39 +250,33 @@ SourceWriter& SourceWriter::WriteModifiers(int modifiers) {
       return *this;
     }
     
    -SourceWriter& SourceWriter::WriteDoc(const string& description,
    -    const string& return_description, const std::list* parameters) {
    -  if (description.empty() && return_description.empty()
    -      && (parameters == nullptr || parameters->empty())) {
    -    return *this;  // no doc to write
    -  }
    +SourceWriter& SourceWriter::WriteJavadoc(const Javadoc& javadoc) {
    +  Append("/**").Prefix(" * ").EndLine();
       bool do_line_break = false;
    -  Append("/**").EndLine().Prefix(" * ");
    -  if (!description.empty()) {
    -    Write(description).EndLine();
    +  if (!javadoc.brief().empty()) {
    +    Write(javadoc.brief()).EndLine();
         do_line_break = true;
       }
    -  if (parameters != nullptr && !parameters->empty()) {
    +  if (!javadoc.details().empty()) {
         if (do_line_break) {
    -      EndLine();
    -      do_line_break = false;
    -    }
    -    for (const Variable& v : *parameters) {
    -      Append("@param ").Append(v.name());
    -      if (!v.description().empty()) {
    -        Append(" ").Write(v.description());
    -      }
    -      EndLine();
    +      Append("

    ").EndLine(); } + Write(javadoc.details()).EndLine(); + do_line_break = true; } - if (!return_description.empty()) { + if (!javadoc.tags().empty()) { if (do_line_break) { EndLine(); - do_line_break = false; } - Append("@return ").Write(return_description).EndLine(); + for (const auto& p : javadoc.tags()) { + Append("@" + p.first); + if (!p.second.empty()) { + Append(" ").Write(p.second); + } + EndLine(); + } } - return Prefix("").Append(" **/").EndLine(); + return Prefix("").Append(" */").EndLine(); } SourceWriter& SourceWriter::WriteAnnotations( @@ -311,20 +327,19 @@ void SourceWriter::PopGenericNamespace() { void SourceWriter::TypeVisitor::Visit(const Type& type) { DoVisit(type); for (const Type& t : type.parameters()) { - DoVisit(t); + Visit(t); } for (const Annotation& t : type.annotations()) { DoVisit(t); } for (const Type& t : type.supertypes()) { - DoVisit(t); + Visit(t); } } void SourceWriter::GenericNamespace::DoVisit(const Type& type) { // ignore non-generic parameters, wildcards and generics already declared - if (type.kind() == Type::GENERIC - && !type.IsWildcard() + if (type.kind() == Type::GENERIC && !type.unknown() && generic_names_.find(type.name()) == generic_names_.end()) { declared_types_.push_back(&type); generic_names_.insert(type.name()); @@ -333,7 +348,7 @@ void SourceWriter::GenericNamespace::DoVisit(const Type& type) { void SourceWriter::TypeImporter::DoVisit(const Type& type) { if (!type.package().empty() && type.package() != current_package_) { - imports_.insert(type.package() + '.' + type.name()); + imports_.insert(type.full_name()); } } diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h index f011acd30a..1f0febe9a3 100644 --- a/tensorflow/java/src/gen/cc/source_writer.h +++ b/tensorflow/java/src/gen/cc/source_writer.h @@ -93,25 +93,22 @@ class SourceWriter { // This method appends a new opening brace to the current data and indent the // next lines according to Google Java Style Guide. The block can optionally // be preceded by an expression (e.g. Append("if(true)").BeginBlock();) - SourceWriter& BeginBlock() { - return Append(newline_ ? "{" : " {").EndLine().Indent(2); - } + SourceWriter& BeginBlock(const string& expr = ""); // Ends the current block of source code. // // This method appends a new closing brace to the current data and outdent the // next lines back to the margin used before BeginBlock() was invoked. - SourceWriter& EndBlock() { - return Indent(-2).Append("}").EndLine(); - } + SourceWriter& EndBlock(); // Begins to write a method. // // This method outputs the signature of the Java method from the data passed - // in the 'method' parameter and starts a new block. Additionnal modifiers can - // also be passed in parameter to define the accesses and the scope of this - // method. - SourceWriter& BeginMethod(const Method& method, int modifiers = 0); + // in the 'method' parameter and starts a new block. Modifiers are also passed + // in parameter to define the access scope of this method and, optionally, + // a Javadoc. + SourceWriter& BeginMethod(const Method& method, int modifiers, + const Javadoc* javadoc = nullptr); // Ends the current method. // @@ -122,22 +119,24 @@ class SourceWriter { // Begins to write the main type of a source file. // // This method outputs the declaration of the Java type from the data passed - // in the 'type' parameter and starts a new block. Additionnal modifiers can - // also be passed in parameter to define the accesses and the scope of this - // type. + // in the 'type' parameter and starts a new block. Modifiers are also passed + // in parameter to define the access scope of this type and, optionally, + // a Javadoc. // - // If not null, all types found in the 'dependencies' list will be imported - // before declaring the new type. - SourceWriter& BeginType(const Type& clazz, - const std::list* dependencies, int modifiers = 0); + // If not null, all types found in the 'extra_dependencies' list will be + // imported before declaring the new type. + SourceWriter& BeginType(const Type& clazz, int modifiers, + const std::list* extra_dependencies = nullptr, + const Javadoc* javadoc = nullptr); // Begins to write a new inner type. // // This method outputs the declaration of the Java type from the data passed - // in the 'type' parameter and starts a new block. Additionnal modifiers can - // also be passed in parameter to define the accesses and the scope of this - // type. - SourceWriter& BeginInnerType(const Type& type, int modifiers = 0); + // in the 'type' parameter and starts a new block. Modifiers are also passed + // in parameter to define the accesses and the scope of this type and, + // optionally, a Javadoc. + SourceWriter& BeginInnerType(const Type& type, int modifiers, + const Javadoc* javadoc = nullptr); // Ends the current type. // @@ -145,13 +144,13 @@ class SourceWriter { // BeginType() or BeginInnerType() prior to this. SourceWriter& EndType(); - // Writes a list of variables as fields of a type. + // Writes a variable as fields of a type. // // This method must be called within the definition of a type (see BeginType() - // or BeginInnerType()). Additional modifiers can also be passed in parameter - // to define the accesses and the scope of those fields. - SourceWriter& WriteFields(const std::list& fields, - int modifiers = 0); + // or BeginInnerType()). Modifiers are also be passed in parameter to define + // the accesses and the scope of this field and, optionally, a Javadoc. + SourceWriter& WriteField(const Variable& field, int modifiers, + const Javadoc* javadoc = nullptr); protected: virtual void DoAppend(const StringPiece& str) = 0; @@ -207,9 +206,7 @@ class SourceWriter { std::stack generic_namespaces_; SourceWriter& WriteModifiers(int modifiers); - SourceWriter& WriteDoc(const string& description, - const string& return_description = "", - const std::list* parameters = nullptr); + SourceWriter& WriteJavadoc(const Javadoc& javadoc); SourceWriter& WriteAnnotations(const std::list& annotations); SourceWriter& WriteGenerics(const std::list& generics); GenericNamespace* PushGenericNamespace(int modifiers); diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc index 4bce2fea70..8bd42d9d0e 100644 --- a/tensorflow/java/src/gen/cc/source_writer_test.cc +++ b/tensorflow/java/src/gen/cc/source_writer_test.cc @@ -250,7 +250,7 @@ TEST(StreamTest, Types) { .AppendType(generic).Append(", ") .AppendType(Type::ListOf(generic)).Append(", ") .AppendType(Type::ListOf(Type::IterableOf(generic))).Append(", ") - .AppendType(Type::ListOf(Type::Generic())); + .AppendType(Type::ListOf(Type::Wildcard())); const char* expected = "int, String, T, List, List>, List"; @@ -282,7 +282,7 @@ TEST(WriteType, SimpleClass) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -300,7 +300,7 @@ TEST(WriteType, SimpleClassWithDependencies) { deps.push_back(Type::Class("SamePackageType", "org.tensorflow")); deps.push_back(Type::Class("NoPackageType")); - writer.BeginType(clazz, &deps, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC, &deps).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -313,18 +313,21 @@ TEST(WriteType, SimpleClassWithDependencies) { TEST(WriteType, AnnotatedAndDocumentedClass) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); - clazz.description("This class has a\n

    \nmultiline description."); + Javadoc clazz_doc; + clazz_doc.brief("Javadoc test") + .details("This is a\nmultiline description."); clazz.add_annotation(Annotation::Create("Bean")); clazz.add_annotation(Annotation::Create("SuppressWarnings") .attributes("\"rawtypes\"")); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC, nullptr, &clazz_doc).EndType(); const char* expected = "package org.tensorflow;\n\n" "/**\n" - " * This class has a\n" + " * Javadoc test\n" " *

    \n" + " * This is a\n" " * multiline description.\n" " **/\n" "@Bean\n" @@ -339,7 +342,7 @@ TEST(WriteType, ParameterizedClass) { clazz.add_parameter(Type::Generic("T")); clazz.add_parameter(Type::Generic("U").add_supertype(Type::Class("Number"))); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -358,7 +361,7 @@ TEST(WriteType, ParameterizedClassAndSupertypes) { clazz.add_supertype(Type::Interface("Runnable")); clazz.add_supertype(Type::Class("SuperTest").add_parameter(type_t)); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -372,24 +375,24 @@ TEST(WriteType, ParameterizedClassFields) { Type clazz = Type::Class("Test", "org.tensorflow"); Type type_t = Type::Generic("T").add_supertype(Type::Class("Number")); clazz.add_parameter(type_t); - std::list static_fields; - static_fields.push_back(Variable::Create("field1", Type::Class("String"))); - std::list member_fields; - member_fields.push_back(Variable::Create("field2", Type::Class("String"))); - member_fields.push_back(Variable::Create("field3", type_t)); - - writer.BeginType(clazz, nullptr, PUBLIC) - .WriteFields(static_fields, STATIC | PUBLIC | FINAL) - .WriteFields(member_fields, PRIVATE) + Variable field1 = Variable::Create("field1", Type::Class("String")); + Variable field2 = Variable::Create("field2", Type::Class("String")); + Variable field3 = Variable::Create("field3", type_t); + Javadoc field3_doc; + field3_doc.brief("This variable is documented"); + + writer.BeginType(clazz, PUBLIC) + .WriteField(field1, STATIC | PUBLIC | FINAL) + .WriteField(field2, PRIVATE) + .WriteField(field3, PRIVATE, &field3_doc) .EndType(); const char* expected = "package org.tensorflow;\n\n" "public class Test {\n" - " \n" " public static final String field1;\n" - " \n" " private String field2;\n" + " /** This variable is documented */\n" " private T field3;\n" "}\n"; ASSERT_STREQ(expected, writer.str().data()); @@ -400,7 +403,7 @@ TEST(WriteType, SimpleInnerClass) { Type clazz = Type::Class("Test", "org.tensorflow"); Type inner_class = Type::Class("InnerTest"); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginInnerType(inner_class, PUBLIC) .EndType() .EndType(); @@ -423,7 +426,7 @@ TEST(WriteType, StaticParameterizedInnerClass) { Type inner_class = Type::Class("InnerTest"); inner_class.add_parameter(type_t); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginInnerType(inner_class, PUBLIC | STATIC) .EndType() .EndType(); @@ -443,7 +446,7 @@ TEST(WriteMethod, SimpleMethod) { Type clazz = Type::Class("Test", "org.tensorflow"); Method method = Method::Create("doNothing", Type::Void()); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginMethod(method, PUBLIC).EndMethod() .EndType(); @@ -461,13 +464,15 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); Method method = Method::Create("doNothing", Type::Void()); - method.description("This method has a\n

    \nmultiline description."); + Javadoc method_doc; + method_doc.brief("Javadoc test") + .details("This method has a\nmultiline description."); method.add_annotation(Annotation::Create("Override")); method.add_annotation(Annotation::Create("SuppressWarnings") .attributes("\"rawtypes\"")); - writer.BeginType(clazz, nullptr, PUBLIC) - .BeginMethod(method, PUBLIC).EndMethod() + writer.BeginType(clazz, PUBLIC) + .BeginMethod(method, PUBLIC, &method_doc).EndMethod() .EndType(); const char* expected = @@ -475,8 +480,9 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { "public class Test {\n" " \n" " /**\n" - " * This method has a\n" + " * Javadoc test\n" " *

    \n" + " * This method has a\n" " * multiline description.\n" " **/\n" " @Override\n" @@ -490,16 +496,18 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { TEST(WriteMethod, DocumentedMethodWithArguments) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); + Variable reverse = Variable::Create("reverse", Type::Boolean()); Method method = Method::Create("boolToInt", Type::Int()); - method.description("Converts a boolean to an int"); - method.return_description("int value for this boolean"); method.add_argument(Variable::Create("b", Type::Boolean())); - Variable reverse = Variable::Create("reverse", Type::Boolean()); - reverse.description("if true, value is reversed"); method.add_argument(reverse); - - writer.BeginType(clazz, nullptr, PUBLIC) - .BeginMethod(method, PUBLIC) + Javadoc method_doc; + method_doc.brief("Converts a boolean to an int") + .details("This method will convert\na boolean to an int") + .add_param_tag(reverse.name(), "if true, value is reversed") + .add_tag("return", "int value for this boolean"); + + writer.BeginType(clazz, PUBLIC) + .BeginMethod(method, PUBLIC, &method_doc) .Append("if (b && !reverse)") .BeginBlock() .Append("return 1;").EndLine() @@ -514,8 +522,10 @@ TEST(WriteMethod, DocumentedMethodWithArguments) { " \n" " /**\n" " * Converts a boolean to an int\n" + " *

    \n" + " * This method will convert\n" + " * a boolean to an int\n" " * \n" - " * @param b\n" " * @param reverse if true, value is reversed\n" " * @return int value for this boolean\n" " **/\n" @@ -536,7 +546,7 @@ TEST(WriteMethod, ParameterizedMethod) { clazz.add_parameter(type_t); Method method = Method::Create("doNothing", type_t); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginMethod(method, PUBLIC) .Append("return null;").EndLine() .EndMethod() @@ -560,7 +570,7 @@ TEST(WriteMethod, StaticParameterizedMethod) { clazz.add_parameter(type_t); Method method = Method::Create("doNothing", type_t); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginMethod(method, PUBLIC | STATIC) .Append("return null;").EndLine() .EndMethod() diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl index a6650fc4ea..1e7899cf7a 100644 --- a/tensorflow/java/src/gen/gen_ops.bzl +++ b/tensorflow/java/src/gen/gen_ops.bzl @@ -1,9 +1,11 @@ # -*- Python -*- -load("//tensorflow:tensorflow.bzl", - "tf_binary_additional_srcs", - "tf_cc_binary", - "tf_copts") +load( + "//tensorflow:tensorflow.bzl", + "tf_binary_additional_srcs", + "tf_cc_binary", + "tf_copts", +) # Given a list of "ops_libs" (a list of files in the core/ops directory # without their .cc extensions), generate Java wrapper code for all operations @@ -27,16 +29,31 @@ def tf_java_op_gen_srcjar(name, ops_libs_pkg="//tensorflow/core", out_dir="ops/", out_src_dir="src/main/java/", + api_def_srcs=[], visibility=["//tensorflow/java:__pkg__"]): gen_tools = [] gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files + srcs = api_def_srcs[:] # Construct an op generator binary for each ops library. for ops_lib in ops_libs: gen_lib = ops_lib[:ops_lib.rfind("_")] out_gen_tool = out_dir + ops_lib + "_gen_tool" + if not api_def_srcs: + api_def_args_str = "," + else: + api_def_args = [] + for api_def_src in api_def_srcs: + # Add directory of the first ApiDef source to args. + # We are assuming all ApiDefs in a single api_def_src are in the + # same directory. + api_def_args.append( + " $$(dirname $$(echo $(locations " + api_def_src + + ") | cut -d\" \" -f1))") + api_def_args_str = ",".join(api_def_args) + tf_cc_binary( name=out_gen_tool, copts=tf_copts(), @@ -48,7 +65,8 @@ def tf_java_op_gen_srcjar(name, gen_cmds += ["$(location :" + out_gen_tool + ")" + " --output_dir=$(@D)/" + out_src_dir + " --lib_name=" + gen_lib + - " --base_package=" + gen_base_package] + " --base_package=" + gen_base_package + + " " + api_def_args_str] # Generate a source archive containing generated code for these ops. gen_srcjar = out_dir + name + ".srcjar" @@ -57,6 +75,7 @@ def tf_java_op_gen_srcjar(name, gen_tools += tf_binary_additional_srcs() native.genrule( name=name, + srcs=srcs, outs=[gen_srcjar], tools=gen_tools, cmd="&&".join(gen_cmds)) diff --git a/tensorflow/java/src/gen/resources/license.snippet.java b/tensorflow/java/src/gen/resources/license.snippet.java new file mode 100644 index 0000000000..90285ec669 --- /dev/null +++ b/tensorflow/java/src/gen/resources/license.snippet.java @@ -0,0 +1,14 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ -- GitLab From 7e80197f020895fea41eda36b08135b747a9a4f1 Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Fri, 6 Apr 2018 08:56:54 -0400 Subject: [PATCH 090/839] Improve Javadoc and include first code review --- tensorflow/java/BUILD | 23 +- tensorflow/java/src/gen/cc/java_defs.h | 12 +- tensorflow/java/src/gen/cc/op_gen_main.cc | 48 +- tensorflow/java/src/gen/cc/op_generator.cc | 224 ++++++---- tensorflow/java/src/gen/cc/op_generator.h | 25 +- tensorflow/java/src/gen/cc/op_parser.cc | 417 ------------------ tensorflow/java/src/gen/cc/op_parser.h | 137 ------ tensorflow/java/src/gen/cc/op_specs.cc | 390 ++++++++++++++++ tensorflow/java/src/gen/cc/op_specs.h | 152 +++++++ tensorflow/java/src/gen/cc/source_writer.cc | 2 +- tensorflow/java/src/gen/cc/source_writer.h | 2 +- .../java/src/gen/cc/source_writer_test.cc | 20 +- tensorflow/java/src/gen/gen_ops.bzl | 68 +-- ...ense.snippet.java => license.java.snippet} | 0 14 files changed, 760 insertions(+), 760 deletions(-) delete mode 100644 tensorflow/java/src/gen/cc/op_parser.cc delete mode 100644 tensorflow/java/src/gen/cc/op_parser.h create mode 100644 tensorflow/java/src/gen/cc/op_specs.cc create mode 100644 tensorflow/java/src/gen/cc/op_specs.h rename tensorflow/java/src/gen/resources/{license.snippet.java => license.java.snippet} (100%) diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 635a4e807d..17566e1a9c 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -68,9 +68,13 @@ filegroup( ], ) +# Build the gen tool as a library, as it will be linked to a core/ops binary +# files before making it an executable. tf_java_op_gen_srcjar( name = "java_op_gen_sources", - api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], + api_def_srcs = [ + "//tensorflow/core/api_def:base_api_def", + ], gen_base_package = "org.tensorflow.op", gen_tool = "java_op_gen_tool", ops_libs = [ @@ -95,30 +99,17 @@ tf_java_op_gen_srcjar( ], ) -# Build the gen tool as a library, as it will be linked to a core/ops binary -# file before making it an executable. See tf_java_op_gen_srcjar(). -cc_library( - name = "java_op_gen_tool", - srcs = [ - "src/gen/cc/op_gen_main.cc", - ], - copts = tf_copts(), - deps = [ - ":java_op_gen_lib", - ], -) - cc_library( name = "java_op_gen_lib", srcs = [ "src/gen/cc/op_generator.cc", - "src/gen/cc/op_parser.cc", + "src/gen/cc/op_specs.cc", "src/gen/cc/source_writer.cc", ], hdrs = [ "src/gen/cc/java_defs.h", "src/gen/cc/op_generator.h", - "src/gen/cc/op_parser.h", + "src/gen/cc/op_specs.h", "src/gen/cc/source_writer.h", ], copts = tf_copts(), diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index 2065477f58..81ac67eb2f 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -230,12 +230,12 @@ class Javadoc { return Javadoc(brief); } const string& brief() const { return brief_; } - const string& details() const { return description_; } - Javadoc& details(const string description) { - description_ = description; + const string& details() const { return details_; } + Javadoc& details(const string& details) { + details_ = details; return *this; } - const std::list> tags() const { return tags_; } + const std::list>& tags() const { return tags_; } Javadoc& add_tag(const string& tag, const string& text) { tags_.push_back(std::make_pair(tag, text)); return *this; @@ -246,7 +246,7 @@ class Javadoc { private: string brief_; - string description_; + string details_; std::list> tags_; explicit Javadoc(const string& brief) : brief_(brief) {} diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index 015200023f..458141b877 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,55 +36,41 @@ const char kUsageHeader[] = "Operation wrappers are generated under the path specified by the " "'--output_dir' argument. This path can be absolute or relative to the\n" "current working directory and will be created if it does not exists.\n\n" - "The '--lib_name' argument is used to classify the set of operations. If " - "the chosen name contains more than one word, it must be provided in \n" - "snake_case. This value is declined into other meaningful names, such as " - "the group and package of the generated operations. For example,\n" - "'--lib_name=my_lib' generates the operations under the " - "'org.tensorflow.op.mylib' package and add them to the 'myLib()' operator\n" - "group.\n\n" - "Note that the operator group assigned to the generated wrappers is just " - "an annotation tag at this stage. Operations will not be available " - "through\n" - "the 'org.tensorflow.op.Ops' API as a group until the generated classes " - "are compiled using an appropriate annotation processor.\n\n" + "Note that the operations will not be available through the " + "'org.tensorflow.op.Ops' API until the generated classes are compiled\n" + "using an appropriate annotation processor.\n\n" "The '--base_package' overrides the default parent package under which " "the generated subpackage and classes are to be located.\n\n" - "Finally, a list of directories of API proto definitions can be provided " - "to override default values found in the ops definitions, ordered by\n" - "priority (the last having precedence over the first).\n\n"; + "Finally, the `--api_dirs` argument takes a list of comma-seperated " + "directories of API definitions can be provided to override default\n" + "values found in the ops definitions. Directories are ordered by priority " + "(the last having precedence over the first).\n\n"; } // namespace java } // namespace tensorflow int main(int argc, char* argv[]) { - tensorflow::string lib_name; tensorflow::string output_dir; tensorflow::string base_package = "org.tensorflow.op"; + tensorflow::string api_dirs_str; std::vector flag_list = { tensorflow::Flag("output_dir", &output_dir, "Root directory into which output files are generated"), - tensorflow::Flag( - "lib_name", &lib_name, - "A name, in snake_case, used to classify this set of operations"), - tensorflow::Flag( - "base_package", &base_package, - "Package parent to the generated subpackage and classes")}; + tensorflow::Flag("base_package", &base_package, + "Package parent to the generated subpackage and classes"), + tensorflow::Flag("api_dirs", &api_dirs_str, + "List of directories that contains the ops api definitions")}; tensorflow::string usage = tensorflow::java::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); tensorflow::port::InitMain(usage.c_str(), &argc, &argv); - QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage; - std::vector api_dirs; - if (argc > 1) { - api_dirs = tensorflow::str_util::Split(argv[1], ",", - tensorflow::str_util::SkipEmpty()); - } + QCHECK(parsed_flags_ok && !output_dir.empty()) << usage; + std::vector api_dirs = tensorflow::str_util::Split( + api_dirs_str, ",", tensorflow::str_util::SkipEmpty()); tensorflow::java::OpGenerator generator(base_package, output_dir, api_dirs); tensorflow::OpList ops; tensorflow::OpRegistry::Global()->Export(false, &ops); - tensorflow::Status status = generator.Run(ops, lib_name); - TF_QCHECK_OK(status); + TF_CHECK_OK(generator.Run(ops)); return 0; } diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index c9b57f5706..c32ad3b109 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -27,15 +28,15 @@ limitations under the License. #include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/java/src/gen/cc/java_defs.h" #include "tensorflow/java/src/gen/cc/source_writer.h" -#include "tensorflow/java/src/gen/cc/op_parser.h" #include "tensorflow/java/src/gen/cc/op_generator.h" +#include "tensorflow/java/src/gen/cc/op_specs.h" namespace tensorflow { namespace java { namespace { const char* kLicenseSnippet = - "tensorflow/java/src/gen/resources/license.snippet.java"; + "tensorflow/java/src/gen/resources/license.java.snippet"; const std::map kPrimitiveAttrTypes = { { "Boolean", Type::Boolean() }, @@ -66,34 +67,34 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode, } // Don't pay attention to duplicate types in the dependency list, they will // be filtered out by the SourceWriter. - for (const OpSpec::Operand& input : op.inputs()) { + for (const ArgumentSpec& input : op.inputs()) { out->push_back(input.var().type()); if (input.iterable()) { out->push_back(Type::Class("Operands", "org.tensorflow.op")); } } - for (const OpSpec::Operand& output : op.outputs()) { + for (const ArgumentSpec& output : op.outputs()) { out->push_back(output.var().type()); if (output.iterable()) { out->push_back(Type::Class("Arrays", "java.util")); } } - for (const OpSpec::Operand& attribute : op.attributes()) { + for (const AttributeSpec& attribute : op.attributes()) { out->push_back(attribute.var().type()); if (attribute.var().type().name() == "Class") { out->push_back(Type::Enum("DataType", "org.tensorflow")); } } - for (const OpSpec::Operand& option : op.options()) { - out->push_back(option.var().type()); + for (const AttributeSpec& optional_attribute : op.optional_attributes()) { + out->push_back(optional_attribute.var().type()); } } -void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional, +void WriteSetAttrDirective(const AttributeSpec& attr, bool optional, SourceWriter* writer) { string var = optional ? "opts." + attr.var().name() : attr.var().name(); if (attr.iterable()) { - const Type& type = attr.data_type(); + const Type& type = attr.type(); std::map::const_iterator it = kPrimitiveAttrTypes.find(type.name()); if (it != kPrimitiveAttrTypes.end()) { @@ -107,11 +108,11 @@ void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional, .Append(array + "[i] = " + var + ".get(i);") .EndLine() .EndBlock() - .Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + array) + .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + array) .Append(");") .EndLine(); } else { - writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + var) + writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + var) .Append(".toArray(new ") .AppendType(type) .Append("[" + var + ".size()]));") @@ -119,7 +120,7 @@ void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional, } } else { Type type = attr.var().type(); - writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", "); + writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", "); if (type.name() == "Class") { writer->Append("DataType.fromClass(" + attr.var().name() + "));"); } else { @@ -139,26 +140,26 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class, Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op")); factory.add_argument(scope); factory_doc.add_param_tag(scope.name(), "Current graph scope"); - for (const OpSpec::Operand& input : op.inputs()) { + for (const ArgumentSpec& input : op.inputs()) { factory.add_argument(input.var()); factory_doc.add_param_tag(input.var().name(), input.description()); } - for (const OpSpec::Operand& attribute : op.attributes()) { + for (const AttributeSpec& attribute : op.attributes()) { factory.add_argument(attribute.var()); factory_doc.add_param_tag(attribute.var().name(), attribute.description()); } - if (!op.options().empty()) { + if (!op.optional_attributes().empty()) { factory.add_argument(Variable::Varargs("options", Type::Class("Options"))); factory_doc.add_param_tag("options", "carries optional attributes values"); } factory_doc.add_tag("return", "a new instance of " + op_class.name()); writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc); writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" - + op.graph_name() + "\", scope.makeOpName(\"" + + op.graph_op_name() + "\", scope.makeOpName(\"" + op_class.name() + "\"));"); writer->EndLine(); - for (const OpSpec::Operand& input : op.inputs()) { + for (const ArgumentSpec& input : op.inputs()) { if (input.iterable()) { writer->Append("opBuilder.addInputList(Operands.asOutputs(" + input.var().name() + "));"); @@ -169,15 +170,15 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class, writer->EndLine(); } } - for (const OpSpec::Operand& attribute : op.attributes()) { + for (const AttributeSpec& attribute : op.attributes()) { WriteSetAttrDirective(attribute, false, writer); } - if (!op.options().empty()) { + if (!op.optional_attributes().empty()) { writer->BeginBlock("if (options != null)") .BeginBlock("for (Options opts : options)"); - for (const OpSpec::Operand& option : op.options()) { - writer->BeginBlock("if (opts." + option.var().name() + " != null)"); - WriteSetAttrDirective(option, true, writer); + for (const AttributeSpec& attribute : op.optional_attributes()) { + writer->BeginBlock("if (opts." + attribute.var().name() + " != null)"); + WriteSetAttrDirective(attribute, true, writer); writer->EndBlock(); } writer->EndBlock().EndBlock(); @@ -195,8 +196,8 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, .add_argument( Variable::Create("operation", Type::Class("Operation", "org.tensorflow"))); - for (const OpSpec::Operand& output : op.outputs()) { - if (output.iterable() && !output.data_type().unknown()) { + for (const ArgumentSpec& output : op.outputs()) { + if (output.iterable() && !output.type().unknown()) { constructor.add_annotation( Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); break; @@ -208,15 +209,15 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, if (op.outputs().size() > 0) { writer->Append("int outputIdx = 0;") .EndLine(); - for (const OpSpec::Operand& output : op.outputs()) { + for (const ArgumentSpec& output : op.outputs()) { if (output.iterable()) { string var_length = output.var().name() + "Length"; writer->Append("int " + var_length) - .Append(" = operation.outputListLength(\"" + output.graph_name() + .Append(" = operation.outputListLength(\"" + output.op_def_name() + "\");") .EndLine() .Append(output.var().name() + " = Arrays.asList("); - if (!output.data_type().unknown()) { + if (!output.type().unknown()) { writer->Append("(") .AppendType(output.var().type().parameters().front()) .Append("[])"); @@ -236,18 +237,19 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, } void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) { - for (const OpSpec::Operand& option : op.options()) { - Method setter = Method::Create(option.var().name(), Type::Class("Options")) - .add_argument(option.var()); + for (const AttributeSpec& attribute : op.optional_attributes()) { + Method setter = + Method::Create(attribute.var().name(), Type::Class("Options")) + .add_argument(attribute.var()); Javadoc setter_doc = Javadoc::Create() - .add_param_tag(option.var().name(), option.description()); + .add_param_tag(attribute.var().name(), attribute.description()); writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc) - .Append("return new Options()." + option.var().name() + "(" - + option.var().name() + ");") + .Append("return new Options()." + attribute.var().name() + "(" + + attribute.var().name() + ");") .EndLine() .EndMethod(); } - for (const OpSpec::Operand& output : op.outputs()) { + for (const ArgumentSpec& output : op.outputs()) { Method getter = Method::Create(output.var().name(), output.var().type()); Javadoc getter_doc = Javadoc::Create(output.description()); writer->BeginMethod(getter, PUBLIC, &getter_doc) @@ -259,12 +261,12 @@ void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) { void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, SourceWriter* writer) { - OpSpec::Operand output = op.outputs().front(); + ArgumentSpec output = op.outputs().front(); if (mode == SINGLE_OUTPUT) { - bool cast2obj = output.data_type().unknown(); + bool cast2obj = output.type().unknown(); Type return_type = Type::Class("Output", "org.tensorflow") - .add_parameter(cast2obj ? Type::Class("Object") : output.data_type()); + .add_parameter(cast2obj ? Type::Class("Object") : output.type()); Method as_output = Method::Create("asOutput", return_type) .add_annotation(Annotation::Create("Override")); if (cast2obj) { @@ -283,10 +285,10 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, } else if (mode == SINGLE_LIST_OUTPUT) { Type operand = Type::Interface("Operand", "org.tensorflow"); - if (output.data_type().unknown()) { + if (output.type().unknown()) { operand.add_parameter(Type::Class("Object")); } else { - operand.add_parameter(output.data_type()); + operand.add_parameter(output.type()); } Type return_type = Type::Interface("Iterator", "java.util") .add_parameter(operand); @@ -308,57 +310,119 @@ void RenderOptionsClass(const OpSpec& op, SourceWriter* writer) { Javadoc options_doc = Javadoc::Create( "Class holding optional attributes of this operation"); writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc); - for (const OpSpec::Operand& option : op.options()) { - Method setter = Method::Create(option.var().name(), options_class) - .add_argument(option.var()); + for (const AttributeSpec& attribute : op.optional_attributes()) { + Method setter = Method::Create(attribute.var().name(), options_class) + .add_argument(attribute.var()); Javadoc setter_doc = Javadoc::Create() - .add_param_tag(option.var().name(), option.description()); + .add_param_tag(attribute.var().name(), attribute.description()); writer->BeginMethod(setter, PUBLIC, &setter_doc) - .Append("this." + option.var().name() + " = " + option.var().name() - + ";") + .Append("this." + attribute.var().name() + " = " + + attribute.var().name() + ";") .EndLine() .Append("return this;") .EndLine() .EndMethod(); } writer->EndLine(); - for (const OpSpec::Operand& option : op.options()) { - writer->WriteField(option.var(), PRIVATE); + for (const AttributeSpec& optional_attribute : op.optional_attributes()) { + writer->WriteField(optional_attribute.var(), PRIVATE); } Method constructor = Method::ConstructorFor(options_class); writer->BeginMethod(constructor, PRIVATE).EndMethod(); writer->EndType(); } -void RenderEndpoint(const OpSpec& op, const OpSpec::Endpoint& endpoint, - SourceWriter* writer) { +inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) { + return Type::Class(endpoint.name(), + base_package + "." + str_util::Lowercase(endpoint.package())); +} + +void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, + const string& base_package, const string& output_dir, Env* env) { + Type op_class(ClassOf(endpoint, base_package) + .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op"))); + Javadoc op_javadoc(endpoint.javadoc()); + + // implement Operand (or Iterable) if the op has only one output RenderMode mode = DEFAULT; if (op.outputs().size() == 1) { - mode = op.outputs().front().iterable() ? SINGLE_LIST_OUTPUT : SINGLE_OUTPUT; + const ArgumentSpec& output = op.outputs().front(); + Type operand_type(output.type().unknown() ? + Type::Class("Object") : output.type()); + Type operand_inf(Type::Interface("Operand", "org.tensorflow") + .add_parameter(operand_type)); + if (output.iterable()) { + mode = SINGLE_LIST_OUTPUT; + op_class.add_supertype(Type::IterableOf(operand_inf)); + } else { + mode = SINGLE_OUTPUT; + op_class.add_supertype(operand_inf); + } + } + // declare all outputs generics at the op class level + std::set generics; + for (const ArgumentSpec& output : op.outputs()) { + if (output.type().kind() == Type::GENERIC && !output.type().unknown() + && generics.find(output.type().name()) == generics.end()) { + op_class.add_parameter(output.type()); + op_javadoc.add_param_tag("<" + output.type().name() + ">", + "data type of output {@code " + output.var().name() + "}"); + generics.insert(output.type().name()); + } + } + // handle endpoint deprecation + if (endpoint.deprecated()) { + op_class.add_annotation(Annotation::Create("Deprecated")); + string explanation; + if (!op.endpoints().front().deprecated()) { + explanation = "use {@link " + + ClassOf(op.endpoints().front(), base_package).full_name() + + "} instead"; + } else { + explanation = op.deprecation_explanation(); + } + op_javadoc.add_tag("deprecated", explanation); } + // expose the op in the Ops Graph API only if it is visible + if (!op.hidden()) { + op_class.add_annotation( + Annotation::Create("Operator", "org.tensorflow.op.annotation") + .attributes("group = \"" + endpoint.package() + "\"")); + } + // create op class file + string op_dir = io::JoinPath(output_dir, + str_util::StringReplace(op_class.package(), ".", "/", true)); + if (!env->FileExists(op_dir).ok()) { + TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir)); + } + std::unique_ptr op_file; + TF_CHECK_OK(env->NewWritableFile( + io::JoinPath(op_dir, op_class.name() + ".java"), &op_file)); + + // render endpoint source code + SourceFileWriter writer(op_file.get()); std::list dependencies; CollectOpDependencies(op, mode, &dependencies); - const Type& op_class = endpoint.type(); - writer->WriteFromFile(kLicenseSnippet) + writer.WriteFromFile(kLicenseSnippet) .EndLine() .Append("// This file is machine generated, DO NOT EDIT!") .EndLine() .EndLine() - .BeginType(op_class, PUBLIC|FINAL, &dependencies, &endpoint.javadoc()); - if (!op.options().empty()) { - RenderOptionsClass(op, writer); + .BeginType(op_class, PUBLIC|FINAL, &dependencies, &op_javadoc); + if (!op.optional_attributes().empty()) { + RenderOptionsClass(op, &writer); } - RenderFactoryMethod(op, op_class, writer); - RenderGettersAndSetters(op, writer); + RenderFactoryMethod(op, op_class, &writer); + RenderGettersAndSetters(op, &writer); if (mode != DEFAULT) { - RenderInterfaceImpl(op, mode, writer); + RenderInterfaceImpl(op, mode, &writer); } - writer->EndLine(); - for (const OpSpec::Operand& output : op.outputs()) { - writer->WriteField(output.var(), PRIVATE); + writer.EndLine(); + for (const ArgumentSpec& output : op.outputs()) { + writer.WriteField(output.var(), PRIVATE); } - RenderConstructor(op, op_class, writer); - writer->EndType(); + RenderConstructor(op, op_class, &writer); + writer.EndType(); } } // namespace @@ -369,8 +433,7 @@ OpGenerator::OpGenerator(const string& base_package, const string& output_dir, env_(env) { } -Status OpGenerator::Run(const OpList& op_list, const string& lib_name) { - LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations"; +Status OpGenerator::Run(const OpList& op_list) { ApiDefMap api_map(op_list); if (!api_dirs_.empty()) { // Only load api files that correspond to the requested "op_list" @@ -388,37 +451,14 @@ Status OpGenerator::Run(const OpList& op_list, const string& lib_name) { for (const auto& op_def : op_list.op()) { const ApiDef* api_def = api_map.GetApiDef(op_def.name()); if (api_def->visibility() != ApiDef::SKIP) { - Status status = GenerateOp(op_def, *api_def, lib_name); - if (status != Status::OK()) { - LOG(ERROR) << "Fail to generate Java wrapper for operation \"" - << op_def.name() << "\""; + OpSpec op(OpSpec::Create(op_def, *api_def)); + for (const EndpointSpec& endpoint : op.endpoints()) { + GenerateOp(op, endpoint, base_package_, output_dir_, env_); } } } return Status::OK(); } -Status OpGenerator::GenerateOp(const OpDef& op_def, const ApiDef& api_def, - const string& lib_name) { - std::unique_ptr op; - OpParser op_parser(op_def, api_def, lib_name, base_package_); - op_parser.Parse(&op); - for (const OpSpec::Endpoint& endpoint : op->endpoints()) { - string package_path = io::JoinPath(output_dir_, - str_util::StringReplace(endpoint.type().package(), ".", "/", true)); - if (!env_->FileExists(package_path).ok()) { - TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(package_path)); - } - string file_path = - io::JoinPath(package_path, endpoint.type().name() + ".java"); - std::unique_ptr file; - TF_CHECK_OK(env_->NewWritableFile(file_path, &file)); - - SourceFileWriter writer(file.get()); - RenderEndpoint(*op, endpoint, &writer); - } - return Status::OK(); -} - } // namespace java } // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index 19d8db95fb..06b08e852a 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,36 +23,33 @@ limitations under the License. #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/java/src/gen/cc/op_specs.h" namespace tensorflow { namespace java { // A generator of Java operation wrappers. // -// Such generator is normally ran only once per executable, outputting -// wrappers for the all registered operations it has been compiled with. -// Nonetheless, it is designed to support multiple runs, giving a different -// list of operations on each cycle. +// This generator takes a list of ops definitions in input and outputs +// a Java Op wrapper for each of them in the provided directory. The same +// generator instance can be invoked multiple times with a different list of +// ops definitions. class OpGenerator { public: OpGenerator(const string& base_package, const string& output_dir, const std::vector& api_dirs, Env* env = Env::Default()); - virtual ~OpGenerator() = default; // Generates wrappers for the given list of 'ops'. // // Output files are generated in //, - // where 'lib_package' is derived from 'lib_name'. - Status Run(const OpList& op_list, const string& lib_name); + // where 'lib_package' is derived from ops endpoints. + Status Run(const OpList& op_list); private: - string base_package_; - string output_dir_; - std::vector api_dirs_; + const string base_package_; + const string output_dir_; + const std::vector api_dirs_; Env* env_; - - Status GenerateOp(const OpDef& op_def, const ApiDef& api_def, - const string& lib_name); }; } // namespace java diff --git a/tensorflow/java/src/gen/cc/op_parser.cc b/tensorflow/java/src/gen/cc/op_parser.cc deleted file mode 100644 index 0541e343d8..0000000000 --- a/tensorflow/java/src/gen/cc/op_parser.cc +++ /dev/null @@ -1,417 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/java/src/gen/cc/op_parser.h" - -namespace tensorflow { -namespace java { -namespace { - -string SnakeToCamelCase(const string& str, bool upper = false) { - string result; - bool cap = upper; - for (string::const_iterator it = str.begin(); it != str.end(); ++it) { - const char c = *it; - if (c == '_') { - cap = true; - } else if (cap) { - result += toupper(c); - cap = false; - } else { - result += c; - } - } - return result; -} - -bool IsRealNumber(DataType type) { - for (DataType dt : RealNumberTypes()) { - if (type == dt) { - return true; - } - } - return false; -} - -bool IsRealNumbers(const AttrValue& values) { - if (values.has_list()) { - for (int i = 0; i < values.list().type_size(); ++i) { - if (!IsRealNumber(values.list().type(i))) { - return false; - } - } - return true; - } - return IsRealNumber(values.type()); -} - -string ParseDocumentation(const string& text) { - std::stringstream javadoc_text; - string::const_iterator c_iter = text.cbegin(); - bool code = false; - bool emphasis = false; - bool list = false; - while (c_iter != text.cend()) { - char c = *c_iter++; - int count = 1; - switch (c) { - case '\n': - if (!code) { - // consumes all subsequent newlines, if there are more than one, - // then there are two choices: - // - if the next line starts with an asterisk, we are enumerating - // a list of items - // - otherwise, we are starting a new paragraph - for (; c_iter != text.cend() && *c_iter == '\n'; ++count, ++c_iter) {} - if (c_iter != text.cend()) { - if (count > 1) { - if (*c_iter != '*' && list) { - javadoc_text << "

  • \n
\n"; - list = false; - } else if (*c_iter == '*' && !list) { - javadoc_text << "\n
\n"; + in_list = false; + } else if (!input.starts_with("```")) { + // new paragraph (not required if a
 block follows)
+        javadoc_text << "

\n"; } - } else if (markup.starts_with("```") && text.empty()) { - // create a multiline code block - re2::StringPiece language; - RE2::Consume(&input, "[\\w\\+]+", &language); - if (FindAndCut(&input, markup.ToString() + "\n*", &text)) { - javadoc_text << "

\n{@code" << text << "}\n
\n"; + } else if (markup.starts_with("```")) { + // code blocks + if (FindAndCut(&input, "```\\s*\n*", &text)) { + javadoc_text << "
{@code\n" << text << "}
\n"; } else { - javadoc_text << markup << language; + javadoc_text << markup; } } else if (markup.starts_with("`")) { - // write inlined code + // inlined code if (FindAndCut(&input, markup, &text)) { javadoc_text << "{@code " << text << "}"; } else { javadoc_text << markup; } } else if (markup == "**") { - // emphase text (strong) + // text emphasis (strong) if (FindAndCut(&input, "\\b\\*{2}", &text)) { javadoc_text << "" << ParseDocumentation(text) << ""; } else { javadoc_text << markup; } } else if (markup == "*") { - // emphase text (light) + // text emphasis (normal) if (FindAndCut(&input, "\\b\\*{1}", &text)) { javadoc_text << "" << ParseDocumentation(text) << ""; } else { javadoc_text << markup; } - } else if (markup == "[") { - // add an external link + } else if (markup.starts_with("[")) { + // hyperlinks string label; string link; if (RE2::Consume(&input, "([^\\[]+)\\]\\((http.+)\\)", &label, &link)) { @@ -277,6 +281,7 @@ string ParseDocumentation(re2::StringPiece input) { javadoc_text << markup; } } else { + // safe fallback javadoc_text << markup; } } -- GitLab From eac1479f04181fb107c85af29a709eb369831972 Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Mon, 30 Apr 2018 07:38:48 -0400 Subject: [PATCH 093/839] Simplify and improve generics handling in generator --- tensorflow/java/build_defs.bzl | 1 + tensorflow/java/src/gen/cc/op_gen_main.cc | 4 +- tensorflow/java/src/gen/cc/op_generator.cc | 155 +++++++++------------ tensorflow/java/src/gen/cc/op_generator.h | 13 +- tensorflow/java/src/gen/cc/op_specs.cc | 81 ++++++----- tensorflow/java/src/gen/cc/op_specs.h | 16 ++- 6 files changed, 132 insertions(+), 138 deletions(-) diff --git a/tensorflow/java/build_defs.bzl b/tensorflow/java/build_defs.bzl index ab7f60d03d..e1916ca4d9 100644 --- a/tensorflow/java/build_defs.bzl +++ b/tensorflow/java/build_defs.bzl @@ -15,6 +15,7 @@ JAVA_VERSION_OPTS = [ XLINT_OPTS = [ "-Werror", "-Xlint:all", + "-Xlint:-processing", "-Xlint:-serial", "-Xlint:-try", "-Xlint:-classfile", # see b/32750402, go/javac-warnings#classfile diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index 458141b877..a508c96516 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -67,10 +67,10 @@ int main(int argc, char* argv[]) { QCHECK(parsed_flags_ok && !output_dir.empty()) << usage; std::vector api_dirs = tensorflow::str_util::Split( api_dirs_str, ",", tensorflow::str_util::SkipEmpty()); - tensorflow::java::OpGenerator generator(base_package, output_dir, api_dirs); + tensorflow::java::OpGenerator generator(api_dirs); tensorflow::OpList ops; tensorflow::OpRegistry::Global()->Export(false, &ops); - TF_CHECK_OK(generator.Run(ops)); + TF_CHECK_OK(generator.Run(ops, base_package, output_dir)); return 0; } diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 00f84bc9cd..2327a4daf1 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -38,23 +39,18 @@ namespace { const char* kLicenseSnippet = "tensorflow/java/src/gen/resources/license.java.snippet"; -const std::map kPrimitiveAttrTypes = { - { "Boolean", Type::Boolean() }, - { "Byte", Type::Byte() }, - { "Character", Type::Byte() }, - { "Float", Type::Float() }, - { "Integer", Type::Long() }, - { "Long", Type::Long() }, - { "Short", Type::Long() }, - { "Double", Type::Float() }, -}; - enum RenderMode { DEFAULT, SINGLE_OUTPUT, SINGLE_LIST_OUTPUT }; +inline void AddArgument(const Variable& var, const string& description, + Method* method_out, Javadoc* javadoc_out) { + method_out->add_argument(var); + javadoc_out->add_param_tag(var.name(), description); +} + void CollectOpDependencies(const OpSpec& op, RenderMode mode, std::list* out) { out->push_back(Type::Class("Operation", "org.tensorflow")); @@ -81,9 +77,7 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode, } for (const AttributeSpec& attribute : op.attributes()) { out->push_back(attribute.var().type()); - if (attribute.var().type().name() == "Class") { - out->push_back(Type::Enum("DataType", "org.tensorflow")); - } + out->push_back(attribute.jni_type()); } for (const AttributeSpec& optional_attribute : op.optional_attributes()) { out->push_back(optional_attribute.var().type()); @@ -92,45 +86,38 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode, void WriteSetAttrDirective(const AttributeSpec& attr, bool optional, SourceWriter* writer) { - string var = optional ? "opts." + attr.var().name() : attr.var().name(); + string var_name = optional ? "opts." + attr.var().name() : attr.var().name(); if (attr.iterable()) { - const Type& type = attr.type(); - std::map::const_iterator it = - kPrimitiveAttrTypes.find(type.name()); - if (it != kPrimitiveAttrTypes.end()) { - string array = attr.var().name() + "Array"; - writer->AppendType(it->second) - .Append("[] " + array + " = new ") - .AppendType(it->second) - .Append("[" + var + ".size()];") - .EndLine(); - writer->BeginBlock("for (int i = 0; i < " + array + ".length; ++i)") - .Append(array + "[i] = " + var + ".get(i);") - .EndLine() - .EndBlock() - .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + array) - .Append(");") - .EndLine(); + string array_name = attr.var().name() + "Array"; + writer->AppendType(attr.jni_type()) + .Append("[] " + array_name + " = new ") + .AppendType(attr.jni_type()) + .Append("[" + var_name + ".size()];") + .EndLine() + .BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)") + .Append(array_name + "[i] = "); + if (attr.type().kind() == Type::GENERIC) { + writer->Append("DataType.fromClass(" + var_name + ".get(i));"); } else { - writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + var) - .Append(".toArray(new ") - .AppendType(type) - .Append("[" + var + ".size()]));") - .EndLine(); + writer->Append(var_name + ".get(i);"); } + writer->EndLine() + .EndBlock() + .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ") + .Append(array_name + ");") + .EndLine(); } else { - Type type = attr.var().type(); writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", "); - if (type.name() == "Class") { - writer->Append("DataType.fromClass(" + attr.var().name() + "));"); + if (attr.var().type().name() == "Class") { + writer->Append("DataType.fromClass(" + var_name + "));"); } else { - writer->Append(var + ");"); + writer->Append(var_name + ");"); } writer->EndLine(); } } -void RenderFactoryMethod(const OpSpec& op, const Type& op_class, +void RenderFactoryMethods(const OpSpec& op, const Type& op_class, SourceWriter* writer) { Method factory = Method::Create("create", op_class); Javadoc factory_doc = Javadoc::Create( @@ -138,27 +125,24 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class, + " operation to the graph."); Variable scope = Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op")); - factory.add_argument(scope); - factory_doc.add_param_tag(scope.name(), "Current graph scope"); + AddArgument(scope, "current graph scope", &factory, &factory_doc); for (const ArgumentSpec& input : op.inputs()) { - factory.add_argument(input.var()); - factory_doc.add_param_tag(input.var().name(), input.description()); + AddArgument(input.var(), input.description(), &factory, &factory_doc); } - for (const AttributeSpec& attribute : op.attributes()) { - factory.add_argument(attribute.var()); - factory_doc.add_param_tag(attribute.var().name(), attribute.description()); + for (const AttributeSpec& attr : op.attributes()) { + AddArgument(attr.var(), attr.description(), &factory, &factory_doc); } if (!op.optional_attributes().empty()) { - factory.add_argument(Variable::Varargs("options", Type::Class("Options"))); - factory_doc.add_param_tag("options", "carries optional attributes values"); + AddArgument(Variable::Varargs("options", Type::Class("Options")), + "carries optional attributes values", &factory, &factory_doc); } factory_doc.add_tag("return", "a new instance of " + op_class.name()); + writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc); writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" + op.graph_op_name() + "\", scope.makeOpName(\"" + op_class.name() + "\"));"); writer->EndLine(); - for (const ArgumentSpec& input : op.inputs()) { if (input.iterable()) { writer->Append("opBuilder.addInputList(Operands.asOutputs(" @@ -192,10 +176,9 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class, void RenderConstructor(const OpSpec& op, const Type& op_class, SourceWriter* writer) { - Method constructor = Method::ConstructorFor(op_class) - .add_argument( - Variable::Create("operation", - Type::Class("Operation", "org.tensorflow"))); + Variable operation = + Variable::Create("operation", Type::Class("Operation", "org.tensorflow")); + Method constructor = Method::ConstructorFor(op_class).add_argument(operation); for (const ArgumentSpec& output : op.outputs()) { if (output.iterable() && !output.type().unknown()) { constructor.add_annotation( @@ -237,15 +220,14 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, } void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) { - for (const AttributeSpec& attribute : op.optional_attributes()) { + for (const AttributeSpec& attr : op.optional_attributes()) { Method setter = - Method::Create(attribute.var().name(), Type::Class("Options")) - .add_argument(attribute.var()); - Javadoc setter_doc = Javadoc::Create() - .add_param_tag(attribute.var().name(), attribute.description()); + Method::Create(attr.var().name(), Type::Class("Options")); + Javadoc setter_doc = Javadoc::Create(); + AddArgument(attr.var(), attr.description(), &setter, &setter_doc); writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc) - .Append("return new Options()." + attribute.var().name() + "(" - + attribute.var().name() + ");") + .Append("return new Options()." + attr.var().name() + "(" + + attr.var().name() + ");") .EndLine() .EndMethod(); } @@ -311,14 +293,12 @@ void RenderOptionsClass(const OpSpec& op, const Type& op_class, Javadoc options_doc = Javadoc::Create( "Optional attributes for {@link " + op_class.full_name() + "}"); writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc); - for (const AttributeSpec& attribute : op.optional_attributes()) { - Method setter = Method::Create(attribute.var().name(), options_class) - .add_argument(attribute.var()); - Javadoc setter_doc = Javadoc::Create() - .add_param_tag(attribute.var().name(), attribute.description()); + for (const AttributeSpec& attr : op.optional_attributes()) { + Method setter = Method::Create(attr.var().name(), options_class); + Javadoc setter_doc = Javadoc::Create(); + AddArgument(attr.var(), attr.description(), &setter, &setter_doc); writer->BeginMethod(setter, PUBLIC, &setter_doc) - .Append("this." + attribute.var().name() + " = " - + attribute.var().name() + ";") + .Append("this." + attr.var().name() + " = " + attr.var().name() + ";") .EndLine() .Append("return this;") .EndLine() @@ -339,12 +319,13 @@ inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) { } void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, - const string& base_package, const string& output_dir, Env* env) { + const string& base_package, const string& output_dir, Env* env, + const std::tm* timestamp) { Type op_class(ClassOf(endpoint, base_package) .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op"))); Javadoc op_javadoc(endpoint.javadoc()); - // implement Operand (or Iterable) if the op has only one output + // op interfaces RenderMode mode = DEFAULT; if (op.outputs().size() == 1) { const ArgumentSpec& output = op.outputs().front(); @@ -360,18 +341,22 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, op_class.add_supertype(operand_inf); } } - // declare all outputs generics at the op class level + // op generic parameters std::set generics; for (const ArgumentSpec& output : op.outputs()) { if (output.type().kind() == Type::GENERIC && !output.type().unknown() && generics.find(output.type().name()) == generics.end()) { op_class.add_parameter(output.type()); op_javadoc.add_param_tag("<" + output.type().name() + ">", - "data type of output {@code " + output.var().name() + "}"); + "data type for {@code " + output.var().name() + "()} output"); generics.insert(output.type().name()); } } - // handle endpoint deprecation + // op annotations + char date[20]; + strftime(date, sizeof date, "%FT%TZ", timestamp); + op_class.add_annotation(Annotation::Create("Generated", "javax.annotation") + .attributes(string("value = \"op_generator\", date = \"") + date + "\"")); if (endpoint.deprecated()) { op_class.add_annotation(Annotation::Create("Deprecated")); string explanation; @@ -384,8 +369,8 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, } op_javadoc.add_tag("deprecated", explanation); } - // expose the op in the Ops Graph API only if it is visible if (!op.hidden()) { + // expose the op in the Ops Graph API only if it is visible op_class.add_annotation( Annotation::Create("Operator", "org.tensorflow.op.annotation") .attributes("group = \"" + endpoint.package() + "\"")); @@ -405,15 +390,12 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, std::list dependencies; CollectOpDependencies(op, mode, &dependencies); writer.WriteFromFile(kLicenseSnippet) - .EndLine() - .Append("// This file is machine generated, DO NOT EDIT!") - .EndLine() .EndLine() .BeginType(op_class, PUBLIC|FINAL, &dependencies, &op_javadoc); if (!op.optional_attributes().empty()) { RenderOptionsClass(op, op_class, &writer); } - RenderFactoryMethod(op, op_class, &writer); + RenderFactoryMethods(op, op_class, &writer); RenderGettersAndSetters(op, &writer); if (mode != DEFAULT) { RenderInterfaceImpl(op, mode, &writer); @@ -428,13 +410,8 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, } // namespace -OpGenerator::OpGenerator(const string& base_package, const string& output_dir, - const std::vector& api_dirs, Env* env) - : base_package_(base_package), output_dir_(output_dir), api_dirs_(api_dirs), - env_(env) { -} - -Status OpGenerator::Run(const OpList& op_list) { +Status OpGenerator::Run(const OpList& op_list, const string& base_package, + const string& output_dir) { ApiDefMap api_map(op_list); if (!api_dirs_.empty()) { // Only load api files that correspond to the requested "op_list" @@ -449,12 +426,14 @@ Status OpGenerator::Run(const OpList& op_list) { } } api_map.UpdateDocs(); + time_t now; + time(&now); for (const auto& op_def : op_list.op()) { const ApiDef* api_def = api_map.GetApiDef(op_def.name()); if (api_def->visibility() != ApiDef::SKIP) { OpSpec op(OpSpec::Create(op_def, *api_def)); for (const EndpointSpec& endpoint : op.endpoints()) { - GenerateOp(op, endpoint, base_package_, output_dir_, env_); + GenerateOp(op, endpoint, base_package, output_dir, env_, gmtime(&now)); } } } diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index 06b08e852a..b789e11fa9 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -36,18 +36,17 @@ namespace java { // ops definitions. class OpGenerator { public: - OpGenerator(const string& base_package, const string& output_dir, - const std::vector& api_dirs, Env* env = Env::Default()); + explicit OpGenerator(const std::vector& api_dirs, + Env* env = Env::Default()) : api_dirs_(api_dirs), env_(env) {} // Generates wrappers for the given list of 'ops'. // - // Output files are generated in //, - // where 'lib_package' is derived from ops endpoints. - Status Run(const OpList& op_list); + // Output files are generated in //, + // where 'op_package' is derived from ops endpoints. + Status Run(const OpList& op_list, const string& base_package, + const string& output_dir); private: - const string base_package_; - const string output_dir_; const std::vector api_dirs_; Env* env_; }; diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index a0e7a180f2..dcc6388614 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -46,14 +46,30 @@ class TypeResolver { explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {} Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out); - Type TypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out); + std::pair TypeOf(const OpDef_AttrDef& attr_def, + bool *iterable_out); bool IsAttributeVisited(const string& attr_name) { return visited_attrs_.find(attr_name) != visited_attrs_.cend(); } + private: const OpDef op_def_; std::map visited_attrs_; - char next_generic_ = 'T'; + char next_generic_letter_ = 'T'; + + std::pair MakeTypePair(const Type& type, const Type& jni_type) { + return std::make_pair(type, jni_type); + } + std::pair MakeTypePair(const Type& type) { + return std::make_pair(type, type); + } + Type NextGeneric() { + char generic_letter = next_generic_letter_++; + if (next_generic_letter_ > 'Z') { + next_generic_letter_ = 'A'; + } + return Type::Generic(string(1, generic_letter)); + } }; Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, @@ -107,7 +123,7 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, } else { for (const auto& attr_def : op_def_.attr()) { if (attr_def.name() == arg_def.type_attr()) { - type = TypeOf(attr_def, iterable_out); + type = TypeOf(attr_def, iterable_out).first; break; } } @@ -125,51 +141,47 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, return type; } -Type TypeResolver::TypeOf(const OpDef_AttrDef& attr_def, +std::pair TypeResolver::TypeOf(const OpDef_AttrDef& attr_def, bool* iterable_out) { + std::pair types = MakeTypePair(Type::Wildcard()); *iterable_out = false; StringPiece attr_type = attr_def.type(); if (str_util::ConsumePrefix(&attr_type, "list(")) { attr_type.remove_suffix(1); // remove closing brace *iterable_out = true; } - Type type = *iterable_out ? Type::Wildcard() : Type::Class("Object"); - if (attr_type == "type") { - if (*iterable_out) { - type = Type::Enum("DataType", "org.tensorflow"); - } else { - type = Type::Generic(string(1, next_generic_)); - next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1; - if (IsRealNumbers(attr_def.allowed_values())) { - // enforce real numbers datasets by extending java.lang.Number - type.add_supertype(Type::Class("Number")); - } - } - } else if (attr_type == "string") { - type = Type::Class("String"); + if (attr_type == "string") { + types = MakeTypePair(Type::Class("String")); } else if (attr_type == "int") { - type = Type::Class("Integer"); + types = MakeTypePair(Type::Class("Long"), Type::Long()); } else if (attr_type == "float") { - type = Type::Class("Float"); + types = MakeTypePair(Type::Class("Float"), Type::Float()); } else if (attr_type == "bool") { - type = Type::Class("Boolean"); + types = MakeTypePair(Type::Class("Boolean"), Type::Boolean()); } else if (attr_type == "shape") { - type = Type::Class("Shape", "org.tensorflow"); + types = MakeTypePair(Type::Class("Shape", "org.tensorflow")); } else if (attr_type == "tensor") { - type = Type::Class("Tensor", "org.tensorflow") - .add_parameter(Type::Wildcard()); + types = MakeTypePair(Type::Class("Tensor", "org.tensorflow") + .add_parameter(Type::Wildcard())); + + } else if (attr_type == "type") { + Type type = *iterable_out ? Type::Wildcard() : NextGeneric(); + if (IsRealNumbers(attr_def.allowed_values())) { + type.add_supertype(Type::Class("Number")); + } + types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow")); } else { LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type << "\" in operation \"" << op_def_.name() << "\""; } - visited_attrs_.insert(std::make_pair(attr_def.name(), type)); - return type; + visited_attrs_.insert(std::make_pair(attr_def.name(), types.first)); + return types; } string SnakeToCamelCase(const string& str, bool upper = false) { @@ -307,19 +319,19 @@ ArgumentSpec CreateInput(const OpDef_ArgDef& input_def, AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def, const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) { bool iterable = false; - Type type = type_resolver->TypeOf(attr_def, &iterable); - // type attributes must be passed explicitly in methods as a Class<> parameter - bool is_explicit = type.kind() == Type::GENERIC && !iterable; - Type var_type = is_explicit ? Type::Class("Class").add_parameter(type) : type; + std::pair types = type_resolver->TypeOf(attr_def, &iterable); + Type var_type = types.first.kind() == Type::GENERIC ? + Type::Class("Class").add_parameter(types.first) : types.first; if (iterable) { - var_type = Type::ListOf(type); + var_type = Type::ListOf(var_type); } return AttributeSpec(attr_api_def.name(), Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type), - type, + types.first, + types.second, ParseDocumentation(attr_api_def.description()), iterable, - attr_api_def.has_default_value() && !is_explicit); + attr_api_def.has_default_value()); } ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def, @@ -340,7 +352,6 @@ ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def, EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def, const ApiDef_Endpoint& endpoint_def) { - std::vector name_tokens = str_util::Split(endpoint_def.name(), "."); string package; string name; @@ -381,7 +392,7 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) { AttributeSpec attr = CreateAttribute(op_def.attr(i), api_def.attr(i), &type_resolver); // attributes with a default value are optional - if (attr.optional()) { + if (attr.has_default_value() && attr.type().kind() != Type::GENERIC) { op.optional_attributes_.push_back(attr); } else { op.attributes_.push_back(attr); diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h index 55c2c3f307..7d64391446 100644 --- a/tensorflow/java/src/gen/cc/op_specs.h +++ b/tensorflow/java/src/gen/cc/op_specs.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/java/src/gen/cc/java_defs.h" namespace tensorflow { @@ -87,20 +88,23 @@ class AttributeSpec : public ArgumentSpec { // op_def_name: attribute name, as known by TensorFlow core // var: a variable to represent this attribute in Java // type: the type of this attribute + // jni_type: the type of this attribute in JNI layer (see OperationBuilder) // description: a description of this attribute, in javadoc // iterable: true if this attribute is a list - // optional: true if this attribute does not require to be set explicitly + // has_default_value: true if this attribute has a default value if not set AttributeSpec(const string& op_def_name, const Variable& var, - const Type& type, const string& description, bool iterable, - bool optional) + const Type& type, const Type& jni_type, const string& description, + bool iterable, bool has_default_value) : ArgumentSpec(op_def_name, var, type, description, iterable), - optional_(optional) {} + jni_type_(jni_type), has_default_value_(has_default_value) {} virtual ~AttributeSpec() = default; - bool optional() const { return optional_; } + const Type& jni_type() const { return jni_type_; } + bool has_default_value() const { return has_default_value_; } private: - const bool optional_; + const Type jni_type_; + const bool has_default_value_; }; class OpSpec { -- GitLab From dd1ef8fa8f6861e53e8a7953c171b3e9253043ed Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Thu, 3 May 2018 22:39:35 -0400 Subject: [PATCH 094/839] Second code review --- tensorflow/core/api_def/BUILD | 7 ++ .../java_api/api_def_FilterDataset.pbtxt | 4 + .../java_api/api_def_FlatMapDataset.pbtxt | 4 + .../core/api_def/java_api/api_def_For.pbtxt | 4 + .../java_api/api_def_GeneratorDataset.pbtxt | 4 + .../api_def_GroupByWindowDataset.pbtxt | 4 + .../core/api_def/java_api/api_def_If.pbtxt | 4 + .../java_api/api_def_InterleaveDataset.pbtxt | 4 + .../java_api/api_def_MapAndBatchDataset.pbtxt | 4 + .../api_def/java_api/api_def_MapDataset.pbtxt | 4 + .../java_api/api_def_OneShotIterator.pbtxt | 4 + .../api_def_ParallelInterleaveDataset.pbtxt | 4 + .../java_api/api_def_ParallelMapDataset.pbtxt | 4 + .../api_def/java_api/api_def_RemoteCall.pbtxt | 4 + .../java_api/api_def_ScanDataset.pbtxt | 4 + .../java_api/api_def_SymbolicGradient.pbtxt | 4 + .../core/api_def/java_api/api_def_While.pbtxt | 4 + tensorflow/java/BUILD | 39 ++++------ tensorflow/java/src/gen/cc/java_defs.h | 6 +- tensorflow/java/src/gen/cc/op_gen_main.cc | 2 +- tensorflow/java/src/gen/cc/op_generator.cc | 77 +++++++++++-------- tensorflow/java/src/gen/cc/op_generator.h | 2 +- tensorflow/java/src/gen/cc/op_specs.cc | 25 +++++- tensorflow/java/src/gen/cc/op_specs.h | 17 +++- tensorflow/java/src/gen/cc/source_writer.cc | 20 +++-- tensorflow/java/src/gen/cc/source_writer.h | 2 +- .../java/src/gen/cc/source_writer_test.cc | 2 +- tensorflow/java/src/gen/gen_ops.bzl | 41 +++------- 28 files changed, 195 insertions(+), 109 deletions(-) create mode 100644 tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_For.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_If.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_While.pbtxt diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 19d6438809..06b797e32e 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -4,6 +4,7 @@ # The following targets can be used to access ApiDefs: # :base_api_def # :python_api_def +# :java_api_def package( default_visibility = ["//visibility:private"], @@ -29,6 +30,12 @@ filegroup( visibility = ["//tensorflow:internal"], ) +filegroup( + name = "java_api_def", + srcs = glob(["java_api/*"]), + visibility = ["//tensorflow:internal"], +) + cc_library( name = "excluded_ops_lib", srcs = ["excluded_ops.cc"], diff --git a/tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt new file mode 100644 index 0000000000..debd7e5709 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "FilterDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt new file mode 100644 index 0000000000..329ab15ef5 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "FlatMapDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_For.pbtxt b/tensorflow/core/api_def/java_api/api_def_For.pbtxt new file mode 100644 index 0000000000..caabc947bb --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_For.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "For" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt new file mode 100644 index 0000000000..a6e5167c30 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "GeneratorDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt new file mode 100644 index 0000000000..4c0b2084a8 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "GroupByWindowDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_If.pbtxt b/tensorflow/core/api_def/java_api/api_def_If.pbtxt new file mode 100644 index 0000000000..13b8635ca7 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_If.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "If" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt new file mode 100644 index 0000000000..ed748d4d2a --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "InterleaveDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt new file mode 100644 index 0000000000..cb96bf63d8 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "MapAndBatchDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt new file mode 100644 index 0000000000..e0ab8dd9db --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "MapDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt b/tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt new file mode 100644 index 0000000000..13130e6882 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "OneShotIterator" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt new file mode 100644 index 0000000000..6a985d24fa --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ParallelInterleaveDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt new file mode 100644 index 0000000000..64f25b9e5e --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ParallelMapDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt b/tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt new file mode 100644 index 0000000000..2ccb5c8cf3 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "RemoteCall" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt new file mode 100644 index 0000000000..3463e60049 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ScanDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt b/tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt new file mode 100644 index 0000000000..88c3acea74 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "SymbolicGradient" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_While.pbtxt b/tensorflow/core/api_def/java_api/api_def_While.pbtxt new file mode 100644 index 0000000000..33756682c3 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_While.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "While" + visibility: SKIP +} \ No newline at end of file diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 17566e1a9c..7cd0208dbf 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -68,34 +68,27 @@ filegroup( ], ) -# Build the gen tool as a library, as it will be linked to a core/ops binary -# files before making it an executable. tf_java_op_gen_srcjar( name = "java_op_gen_sources", api_def_srcs = [ "//tensorflow/core/api_def:base_api_def", + "//tensorflow/core/api_def:java_api_def", ], - gen_base_package = "org.tensorflow.op", - gen_tool = "java_op_gen_tool", - ops_libs = [ - "array_ops", - "candidate_sampling_ops", - "control_flow_ops", - "data_flow_ops", - "image_ops", - "io_ops", - "linalg_ops", - "logging_ops", - "math_ops", - "nn_ops", - "no_op", - "parsing_ops", - "random_ops", - "sparse_ops", - "state_ops", - "string_ops", - "training_ops", - "user_ops", + base_package = "org.tensorflow.op", + gen_tool = ":java_op_gen_tool", +) + +tf_cc_binary( + name = "java_op_gen_tool", + srcs = [ + "src/gen/cc/op_gen_main.cc", + ], + copts = tf_copts(), + linkopts = ["-lm"], + linkstatic = 1, + deps = [ + ":java_op_gen_lib", + "//tensorflow/core:ops", ], ) diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index 81ac67eb2f..62575f6683 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -102,10 +102,10 @@ class Type { const Kind& kind() const { return kind_; } const string& name() const { return name_; } const string& package() const { return package_; } - const string full_name() const { + const string canonical_name() const { return package_.empty() ? name_ : package_ + "." + name_; } - bool unknown() const { return name_.empty(); } // only wildcards has no name + bool wildcard() const { return name_.empty(); } // only wildcards has no name const std::list& parameters() const { return parameters_; } Type& add_parameter(const Type& parameter) { parameters_.push_back(parameter); diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index a508c96516..6c35cd9595 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 2327a4daf1..7355b3a395 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -39,13 +38,26 @@ namespace { const char* kLicenseSnippet = "tensorflow/java/src/gen/resources/license.java.snippet"; +// There is three different modes to render an op class, depending on the +// number and type of outputs it has: +// +// DEFAULT: This mode does not provide any specialization for the op class, it +// is applied when the operation does not comply with any other mode +// +// OPERAND: The op class implements the Operand interface, allowing an +// instance to be passed directly in input to another operation +// +// LIST_OPERAND: The op class implements the Iterable> interface, +// allowing an instance to be passed directly as a list input to +// another operation +// enum RenderMode { DEFAULT, - SINGLE_OUTPUT, - SINGLE_LIST_OUTPUT + OPERAND, + LIST_OPERAND }; -inline void AddArgument(const Variable& var, const string& description, +void AddArgument(const Variable& var, const string& description, Method* method_out, Javadoc* javadoc_out) { method_out->add_argument(var); javadoc_out->add_param_tag(var.name(), description); @@ -56,9 +68,9 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode, out->push_back(Type::Class("Operation", "org.tensorflow")); out->push_back(Type::Class("OperationBuilder", "org.tensorflow")); out->push_back(Type::Class("Scope", "org.tensorflow.op")); - if (mode == SINGLE_OUTPUT) { + if (mode == OPERAND) { out->push_back(Type::Class("Output", "org.tensorflow")); - } else if (mode == SINGLE_LIST_OUTPUT) { + } else if (mode == LIST_OPERAND) { out->push_back(Type::Interface("Iterator", "java.util")); } // Don't pay attention to duplicate types in the dependency list, they will @@ -180,7 +192,7 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, Variable::Create("operation", Type::Class("Operation", "org.tensorflow")); Method constructor = Method::ConstructorFor(op_class).add_argument(operation); for (const ArgumentSpec& output : op.outputs()) { - if (output.iterable() && !output.type().unknown()) { + if (output.iterable() && !output.type().wildcard()) { constructor.add_annotation( Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); break; @@ -200,7 +212,7 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, + "\");") .EndLine() .Append(output.var().name() + " = Arrays.asList("); - if (!output.type().unknown()) { + if (!output.type().wildcard()) { writer->Append("(") .AppendType(output.var().type().parameters().front()) .Append("[])"); @@ -245,8 +257,8 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, SourceWriter* writer) { ArgumentSpec output = op.outputs().front(); - if (mode == SINGLE_OUTPUT) { - bool cast2obj = output.type().unknown(); + if (mode == OPERAND) { + bool cast2obj = output.type().wildcard(); Type return_type = Type::Class("Output", "org.tensorflow") .add_parameter(cast2obj ? Type::Class("Object") : output.type()); Method as_output = Method::Create("asOutput", return_type) @@ -265,9 +277,9 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, .EndLine() .EndMethod(); - } else if (mode == SINGLE_LIST_OUTPUT) { + } else if (mode == LIST_OPERAND) { Type operand = Type::Interface("Operand", "org.tensorflow"); - if (output.type().unknown()) { + if (output.type().wildcard()) { operand.add_parameter(Type::Class("Object")); } else { operand.add_parameter(output.type()); @@ -291,7 +303,7 @@ void RenderOptionsClass(const OpSpec& op, const Type& op_class, SourceWriter* writer) { Type options_class = Type::Class("Options"); Javadoc options_doc = Javadoc::Create( - "Optional attributes for {@link " + op_class.full_name() + "}"); + "Optional attributes for {@link " + op_class.canonical_name() + "}"); writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc); for (const AttributeSpec& attr : op.optional_attributes()) { Method setter = Method::Create(attr.var().name(), options_class); @@ -319,8 +331,7 @@ inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) { } void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, - const string& base_package, const string& output_dir, Env* env, - const std::tm* timestamp) { + const string& base_package, const string& output_dir, Env* env) { Type op_class(ClassOf(endpoint, base_package) .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op"))); Javadoc op_javadoc(endpoint.javadoc()); @@ -329,22 +340,22 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, RenderMode mode = DEFAULT; if (op.outputs().size() == 1) { const ArgumentSpec& output = op.outputs().front(); - Type operand_type(output.type().unknown() ? + Type operand_type(output.type().wildcard() ? Type::Class("Object") : output.type()); Type operand_inf(Type::Interface("Operand", "org.tensorflow") .add_parameter(operand_type)); if (output.iterable()) { - mode = SINGLE_LIST_OUTPUT; + mode = LIST_OPERAND; op_class.add_supertype(Type::IterableOf(operand_inf)); } else { - mode = SINGLE_OUTPUT; + mode = OPERAND; op_class.add_supertype(operand_inf); } } // op generic parameters std::set generics; for (const ArgumentSpec& output : op.outputs()) { - if (output.type().kind() == Type::GENERIC && !output.type().unknown() + if (output.type().kind() == Type::GENERIC && !output.type().wildcard() && generics.find(output.type().name()) == generics.end()) { op_class.add_parameter(output.type()); op_javadoc.add_param_tag("<" + output.type().name() + ">", @@ -353,16 +364,15 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, } } // op annotations - char date[20]; - strftime(date, sizeof date, "%FT%TZ", timestamp); - op_class.add_annotation(Annotation::Create("Generated", "javax.annotation") - .attributes(string("value = \"op_generator\", date = \"") + date + "\"")); + op_class.add_annotation( + Annotation::Create("Generated", "javax.annotation") + .attributes("value = \"TensorFlow Java Op Generator\"")); if (endpoint.deprecated()) { op_class.add_annotation(Annotation::Create("Deprecated")); string explanation; if (!op.endpoints().front().deprecated()) { explanation = "use {@link " + - ClassOf(op.endpoints().front(), base_package).full_name() + ClassOf(op.endpoints().front(), base_package).canonical_name() + "} instead"; } else { explanation = op.deprecation_explanation(); @@ -376,14 +386,16 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, .attributes("group = \"" + endpoint.package() + "\"")); } // create op class file - string op_dir = io::JoinPath(output_dir, + const string op_dir_name = io::JoinPath(output_dir, str_util::StringReplace(op_class.package(), ".", "/", true)); - if (!env->FileExists(op_dir).ok()) { - TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir)); + if (!env->FileExists(op_dir_name).ok()) { + TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir_name)) + << op_dir_name; } + const string op_file_name = op_class.name() + ".java"; std::unique_ptr op_file; TF_CHECK_OK(env->NewWritableFile( - io::JoinPath(op_dir, op_class.name() + ".java"), &op_file)); + io::JoinPath(op_dir_name, op_file_name), &op_file)) << op_file_name; // render endpoint source code SourceFileWriter writer(op_file.get()); @@ -420,20 +432,19 @@ Status OpGenerator::Run(const OpList& op_list, const string& base_package, const std::string api_def_file_pattern = io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt"); if (env_->FileExists(api_def_file_pattern).ok()) { - TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern)); + TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern)) + << api_def_file_pattern; } } } } api_map.UpdateDocs(); - time_t now; - time(&now); for (const auto& op_def : op_list.op()) { const ApiDef* api_def = api_map.GetApiDef(op_def.name()); if (api_def->visibility() != ApiDef::SKIP) { OpSpec op(OpSpec::Create(op_def, *api_def)); for (const EndpointSpec& endpoint : op.endpoints()) { - GenerateOp(op, endpoint, base_package, output_dir, env_, gmtime(&now)); + GenerateOp(op, endpoint, base_package, output_dir, env_); } } } diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index b789e11fa9..cfe842070a 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index dcc6388614..081062ceaf 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -45,9 +45,26 @@ class TypeResolver { public: explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {} + // Returns the class type of an input/output argument + // + // For example, if the argument's datatype is DT_STRING, this method will + // return "java.lang.String", so the argument can become "Operand" + // in the Ops API Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out); - std::pair TypeOf(const OpDef_AttrDef& attr_def, + + // Returns types of an input attribute + // + // The first element of the pair is the class type of this attribute while + // the second is its JNI/primitive type equivalent, required for explicit + // unboxing. + // + // For example, if the attribute is of type "float", this method will return + // , so the attribute can be used as a "Float" object + // in the Ops API and casted to a "float" when passing through the JNI layer. + std::pair TypesOf(const OpDef_AttrDef& attr_def, bool *iterable_out); + + // Returns true if the type of this attribute has already been resolved bool IsAttributeVisited(const string& attr_name) { return visited_attrs_.find(attr_name) != visited_attrs_.cend(); } @@ -123,7 +140,7 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, } else { for (const auto& attr_def : op_def_.attr()) { if (attr_def.name() == arg_def.type_attr()) { - type = TypeOf(attr_def, iterable_out).first; + type = TypesOf(attr_def, iterable_out).first; break; } } @@ -141,7 +158,7 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, return type; } -std::pair TypeResolver::TypeOf(const OpDef_AttrDef& attr_def, +std::pair TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, bool* iterable_out) { std::pair types = MakeTypePair(Type::Wildcard()); *iterable_out = false; @@ -319,7 +336,7 @@ ArgumentSpec CreateInput(const OpDef_ArgDef& input_def, AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def, const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) { bool iterable = false; - std::pair types = type_resolver->TypeOf(attr_def, &iterable); + std::pair types = type_resolver->TypesOf(attr_def, &iterable); Type var_type = types.first.kind() == Type::GENERIC ? Type::Class("Class").add_parameter(types.first) : types.first; if (iterable) { diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h index 7d64391446..81582ea207 100644 --- a/tensorflow/java/src/gen/cc/op_specs.h +++ b/tensorflow/java/src/gen/cc/op_specs.h @@ -65,7 +65,6 @@ class ArgumentSpec { const Type& type, const string& description, bool iterable) : op_def_name_(op_def_name), var_(var), type_(type), description_(description), iterable_(iterable) {} - virtual ~ArgumentSpec() = default; const string& op_def_name() const { return op_def_name_; } const Variable& var() const { return var_; } @@ -81,7 +80,7 @@ class ArgumentSpec { const bool iterable_; }; -class AttributeSpec : public ArgumentSpec { +class AttributeSpec { public: // A specification for an operation attribute // @@ -95,14 +94,24 @@ class AttributeSpec : public ArgumentSpec { AttributeSpec(const string& op_def_name, const Variable& var, const Type& type, const Type& jni_type, const string& description, bool iterable, bool has_default_value) - : ArgumentSpec(op_def_name, var, type, description, iterable), + : op_def_name_(op_def_name), var_(var), type_(type), + description_(description), iterable_(iterable), jni_type_(jni_type), has_default_value_(has_default_value) {} - virtual ~AttributeSpec() = default; + const string& op_def_name() const { return op_def_name_; } + const Variable& var() const { return var_; } + const Type& type() const { return type_; } + const string& description() const { return description_; } + bool iterable() const { return iterable_; } const Type& jni_type() const { return jni_type_; } bool has_default_value() const { return has_default_value_; } private: + const string op_def_name_; + const Variable var_; + const Type type_; + const string description_; + const bool iterable_; const Type jni_type_; const bool has_default_value_; }; diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc index 7e427787f9..56806cbb6d 100644 --- a/tensorflow/java/src/gen/cc/source_writer.cc +++ b/tensorflow/java/src/gen/cc/source_writer.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -83,17 +83,19 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) { } SourceWriter& SourceWriter::AppendType(const Type& type) { - if (type.unknown()) { + if (type.wildcard()) { Append("?"); } else { Append(type.name()); if (!type.parameters().empty()) { Append("<"); + bool first = true; for (const Type& t : type.parameters()) { - if (&t != &type.parameters().front()) { + if (!first) { Append(", "); } AppendType(t); + first = false; } Append(">"); } @@ -145,11 +147,13 @@ SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers, AppendType(method.return_type()).Append(" "); } Append(method.name()).Append("("); + bool first = true; for (const Variable& v : method.arguments()) { - if (&v != &method.arguments().front()) { + if (!first) { Append(", "); } AppendType(v.type()).Append(v.variadic() ? "... " : " ").Append(v.name()); + first = false; } return Append(")").BeginBlock(); } @@ -294,14 +298,16 @@ SourceWriter& SourceWriter::WriteAnnotations( SourceWriter& SourceWriter::WriteGenerics( const std::list& generics) { Append("<"); + bool first = true; for (const Type* pt : generics) { - if (pt != generics.front()) { + if (!first) { Append(", "); } Append(pt->name()); if (!pt->supertypes().empty()) { Append(" extends ").AppendType(pt->supertypes().front()); } + first = false; } return Append(">"); } @@ -339,7 +345,7 @@ void SourceWriter::TypeVisitor::Visit(const Type& type) { void SourceWriter::GenericNamespace::DoVisit(const Type& type) { // ignore non-generic parameters, wildcards and generics already declared - if (type.kind() == Type::GENERIC && !type.unknown() + if (type.kind() == Type::GENERIC && !type.wildcard() && generic_names_.find(type.name()) == generic_names_.end()) { declared_types_.push_back(&type); generic_names_.insert(type.name()); @@ -348,7 +354,7 @@ void SourceWriter::GenericNamespace::DoVisit(const Type& type) { void SourceWriter::TypeImporter::DoVisit(const Type& type) { if (!type.package().empty() && type.package() != current_package_) { - imports_.insert(type.full_name()); + imports_.insert(type.canonical_name()); } } diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h index bcae33ccce..1f0febe9a3 100644 --- a/tensorflow/java/src/gen/cc/source_writer.h +++ b/tensorflow/java/src/gen/cc/source_writer.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc index 875ad99ae2..b9a5fee9be 100644 --- a/tensorflow/java/src/gen/cc/source_writer_test.cc +++ b/tensorflow/java/src/gen/cc/source_writer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl index 7017b52649..f4ff34ea03 100644 --- a/tensorflow/java/src/gen/gen_ops.bzl +++ b/tensorflow/java/src/gen/gen_ops.bzl @@ -3,33 +3,26 @@ load( "//tensorflow:tensorflow.bzl", "tf_binary_additional_srcs", - "tf_cc_binary", - "tf_copts", ) -# Given a list of "ops_libs" (a list of files in the core/ops directory -# without their .cc extensions), generate Java wrapper code for all operations -# found in the ops files. -# Then, combine all those source files into a single archive (.srcjar). +# Generate Java wrapper classes for all registered core operations and package +# them into a single source archive (.srcjar). # # For example: -# tf_java_op_gen_srcjar("gen_sources", "gen_tool", "my.package", [ "array_ops", "math_ops" ]) +# tf_java_op_gen_srcjar("gen_sources", ":gen_tool", "my.package") # -# will create a genrule named "gen_sources" that first generate source files: -# ops/src/main/java/my/package/array/*.java -# ops/src/main/java/my/package/math/*.java +# will create a genrule named "gen_sources" that generates source files under +# ops/src/main/java/my/package/**/*.java # -# and then archive those source files in: +# and then archive those source files into # ops/gen_sources.srcjar # def tf_java_op_gen_srcjar(name, gen_tool, - gen_base_package, - ops_libs=[], - ops_libs_pkg="//tensorflow/core", + base_package, + api_def_srcs=[], out_dir="ops/", out_src_dir="src/main/java/", - api_def_srcs=[], visibility=["//tensorflow/java:__pkg__"]): gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files @@ -48,23 +41,9 @@ def tf_java_op_gen_srcjar(name, ") | cut -d\" \" -f1))") api_def_args_str = ",".join(api_def_args) - gen_tool_deps = [":java_op_gen_lib"] - for ops_lib in ops_libs: - gen_tool_deps.append(ops_libs_pkg + ":" + ops_lib + "_op_lib") - - tf_cc_binary( - name=gen_tool, - srcs=[ - "src/gen/cc/op_gen_main.cc", - ], - copts=tf_copts(), - linkopts=["-lm"], - linkstatic=1, # Faster to link this one-time-use binary dynamically - deps = gen_tool_deps) - - gen_cmds += ["$(location :" + gen_tool + ")" + + gen_cmds += ["$(location " + gen_tool + ")" + " --output_dir=$(@D)/" + out_src_dir + - " --base_package=" + gen_base_package + + " --base_package=" + base_package + " --api_dirs=" + api_def_args_str] # Generate a source archive containing generated code for these ops. -- GitLab From aaa345f5a662aab524bbee3912c605919239bef6 Mon Sep 17 00:00:00 2001 From: wangsiyu Date: Fri, 4 May 2018 10:52:26 +0800 Subject: [PATCH 095/839] refine by using iterator of partitioned variable --- tensorflow/python/layers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index c050e6be04..f7b2e471b2 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -358,7 +358,7 @@ def _add_elements_to_collection(elements, collection_list): def _should_add_regularizer(variable, existing_variable_set): result = True if isinstance(variable, tf_variables.PartitionedVariable): - for var in variable._get_variable_list(): + for var in variable: if var in existing_variable_set: result = False break -- GitLab From de9256f61a9d71a30b175e46116fc5d87063ceaa Mon Sep 17 00:00:00 2001 From: "William D. Irons" Date: Fri, 4 May 2018 08:19:03 -0500 Subject: [PATCH 096/839] Add conditions:default to mkl build (#19008) If building on a system that is not darwin, linux_x86_64, or windows, the select statement in third_party/mkl/BUILD fails to find a match and fails. Need to use no mkl libraries for non-x86 systems Fixes #18084 --- third_party/mkl/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/third_party/mkl/BUILD b/third_party/mkl/BUILD index c2adf578c7..017613abb0 100644 --- a/third_party/mkl/BUILD +++ b/third_party/mkl/BUILD @@ -34,6 +34,7 @@ filegroup( "@org_tensorflow//tensorflow:windows": [ "@mkl_windows//:LICENSE", ], + "//conditions:default": [] }), visibility = ["//visibility:public"], ) @@ -54,5 +55,6 @@ cc_library( "@mkl_windows//:mkl_headers", "@mkl_windows//:mkl_libs_windows", ], + "//conditions:default": [] }), ) -- GitLab From 7f0d43c1b7462645767712cd5942d754a5f7adb7 Mon Sep 17 00:00:00 2001 From: manhyuk Date: Sat, 5 May 2018 01:53:35 +0900 Subject: [PATCH 097/839] fix typo --- tensorflow/contrib/tpu/python/tpu/tpu_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index fbc1173e49..4d7bc6a5a6 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # =================================================================== -"""TPU system metdata and associated tooling.""" +"""TPU system metadata and associated tooling.""" from __future__ import absolute_import from __future__ import division -- GitLab From 9b43bd6459e410fc8d3dd1beba9f9a6a254096ba Mon Sep 17 00:00:00 2001 From: Akshay Agrawal Date: Thu, 3 May 2018 13:40:20 -0700 Subject: [PATCH 098/839] Documentation for tf.contrib.eager.py_func PiperOrigin-RevId: 195303454 --- tensorflow/python/ops/script_ops.py | 66 ++++++++++++++++++++++++++--- 1 file changed, 60 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 9f1dd2c4fd..f87c5dc5e3 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -243,14 +243,68 @@ def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None): def eager_py_func(func, inp, Tout, name=None): - """Wraps a python function into a TensorFlow op. + """Wraps a python function into a TensorFlow op that executes it eagerly. - When the returned op is executed, `func` is invoked with eager execution - enabled. Inputs are Tensor objects and func must return None or objects - that may be converted to Tensor objects. + This function allows expressing computations in a TensorFlow graph as + Python functions. In particular, it wraps a Python function `func` + in a TensorFlow operation that executes it with eager exeuction enabled. As a + consequence, `tf.contrib.eager.py_func` makes it possible to express control + flow using Python constructs (`if`, `while`, `for`, etc.), instead of + TensorFlow control flow constructs (@{tf.cond}, @{tf.while_loop}). For + example, you might use `tf.contrib.eager.py_func` to implement the log huber + function: + + ```python + def log_huber(x, m): + if tf.abs(x) <= m: + return x ** 2 + else: + return m ** 2 * (1 - 2 * tf.log(m) + tf.log(x ** 2)) + + x = tf.placeholder(tf.float32) + m = tf.placeholder(tf.float32) + + y = tf.contrib.eager.py_func(func=log_huber, inp=[x, m], Tout=tf.float32) + + with tf.Session() as sess: + # The session executes `log_huber` eagerly. Given the feed values below, + # it will take the second branch, so `output` evaluates to 7.24372. + output = sess.run(y, feed_dict={x: 3.0, m: 2.0}) + ``` + + You can also use `tf.contrib.eager.py_func` to debug your models at runtime + using Python tools, i.e., you can isolate portions of your code that + you want to debug, wrap them in Python functions and insert `pdb` tracepoints + or print statements as desired, and wrap those functions in + `tf.contrib.eager.py_func`. + + For more information on eager execution, see @{$programmers_guide/eager}. + + `tf.contrib.eager.py_func` is similar in spirit to @{tf.py_func}, but unlike + the latter, the former lets you use TensorFlow operations in the wrapped + Python function. In particular, while @{tf.py_func} only runs on CPUs and + wraps functions that take NumPy arrays as inputs and return NumPy arrays as + outputs, `tf.contrib.eager.py_func` can be placed on GPUs and wraps functions + that take Tensors as inputs, execute TensorFlow operations in their bodies, + and return Tensors as outputs. + + `tf.contrib.eager.py_func` is not differentiable, though a gradient may be + implemented in the future; if you would like to differentiate through it, + please file an issue on Github. + + Like @{tf.py_func}, `tf.contrib.eager.py_func` has the following limitations + with respect to serialization and distribution: + + * The body of the function (i.e. `func`) will not be serialized in a + `GraphDef`. Therefore, you should not use this function if you need to + serialize your model and restore it in a different environment. + + * The operation must run in the same address space as the Python program + that calls `tf.contrib.eager.py_func()`. If you are using distributed + TensorFlow, you must run a `tf.train.Server` in the same process as the + program that calls `tf.contrib.eager.py_func()` and you must pin the created + operation to a device in that server (e.g. using `with tf.device():`). - This function has the same limitations as `py_func` with respect to - serialization and distribution. Args: func: A Python function which accepts a list of `Tensor` objects -- GitLab From 86d3435503e20e44ab37c87613481f7a35d0c14e Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Thu, 3 May 2018 13:54:29 -0700 Subject: [PATCH 099/839] Fix a typo. PiperOrigin-RevId: 195305770 --- tensorflow/python/estimator/training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 41ffa371aa..2f14a6f560 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -657,7 +657,7 @@ class _TrainingExecutor(object): hooks=train_hooks) if not self._continuous_eval_listener.before_eval(): - logging.info('Exiting training and evaluation lopp, as requested by ' + logging.info('Exiting training and evaluation loop, as requested by ' '_ContinuousEvalListener.before_eval.') break -- GitLab From 518dfea0d6d45448a360a49635fe815a28730c46 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Thu, 3 May 2018 14:00:56 -0700 Subject: [PATCH 100/839] [XLA:CPU] Remove dead function + DCHECK, NFC There isn't a lot of benefit to fixing the function to do that it says it does since I'm adding support for lowering batch matmul which will break this precondition anyway. PiperOrigin-RevId: 195306803 --- tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc | 4 ---- tensorflow/compiler/xla/service/cpu/dot_op_emitter.h | 4 ---- 2 files changed, 8 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 495fecc4aa..801c523908 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -557,8 +557,6 @@ DotOpEmitter::DotOpEmitter( return dot_emitter.Emit(); } -bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; } - bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (dot_.shape().dimensions_size() != 2) { return false; @@ -908,8 +906,6 @@ tensorflow::Status DotOpEmitter::EmitScalarDot() { } tensorflow::Status DotOpEmitter::EmitCallToRuntime() { - DCHECK(ShapesAreLegalForRuntimeDot()); - // The signature of the Eigen runtime matmul function is: // // (void)(void* run_options, float* out, float* lhs, float* rhs, diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 9d748eb81f..47e0924334 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -99,10 +99,6 @@ class DotOpEmitter { llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, int64 reduction_dimension, tensorflow::StringPiece name_suffix); - // Our runtime operation requires that all arrays have the same layout, - // no padding, and a rank of two. - bool ShapesAreLegalForRuntimeDot() const; - // Represents the dimensions of a matrix-matrix multiply operation. struct MatMultDims { // The number of rows in the LHS. -- GitLab From a4a9e372f6af694e91ef7aaae9f23867d0ec0fc2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 May 2018 14:11:13 -0700 Subject: [PATCH 101/839] Optimize idempotent ops, e.g., Snapshot(Snapshot(x)) => Snapshot(x) PiperOrigin-RevId: 195308675 --- tensorflow/core/grappler/op_types.cc | 15 ++-- tensorflow/core/grappler/op_types.h | 5 ++ .../optimizers/arithmetic_optimizer.cc | 30 ++++++++ .../optimizers/arithmetic_optimizer.h | 1 + .../optimizers/arithmetic_optimizer_test.cc | 68 +++++++++++++++++++ tensorflow/python/grappler/cluster_test.py | 4 +- .../profiler/internal/run_metadata_test.py | 6 +- 7 files changed, 121 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index c48dc00941..e633ecf789 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -323,6 +323,8 @@ bool IsSize(const NodeDef& node) { return node.op() == "Size"; } bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; } +bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot"; } + bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad"; } bool IsSoftsignGrad(const NodeDef& node) { return node.op() == "SoftsignGrad"; } @@ -488,14 +490,13 @@ bool IsValueAndOrderAndShapePreserving(const NodeDef& node) { "DeepCopy" "Enter", "Exit", - "Identity", - "IdentityN", "PreventGradient", "Print", "Snapshot", "StopGradient", })); - return value_and_order_and_shape_preserving_ops->count(node.op()) > 0; + return value_and_order_and_shape_preserving_ops->count(node.op()) > 0 || + IsIdentity(node); } bool IsValueAndOrderPreserving(const NodeDef& node) { @@ -505,7 +506,7 @@ bool IsValueAndOrderPreserving(const NodeDef& node) { static const std::unordered_set* value_and_order_preserving_ops = CHECK_NOTNULL((new const std::unordered_set{ "ExpandDims", - "Snapshot", + "Reshape", "Squeeze", })); return value_and_order_preserving_ops->count(node.op()) > 0 || @@ -576,7 +577,7 @@ bool IsUnaryElementWise(const NodeDef& node) { "Tanh", })); return element_wise_ops->count(node.op()) > 0 || - (!IsIdentityN(node) && IsValueAndOrderAndShapePreserving(node)); + IsValueAndOrderAndShapePreserving(node); } bool HasOpDef(const NodeDef& node) { @@ -584,5 +585,9 @@ bool HasOpDef(const NodeDef& node) { return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok(); } +bool IsIdempotent(const NodeDef& node) { + return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node); +} + } // namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index e33dd21538..f6105d710e 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -123,6 +123,7 @@ bool IsShape(const NodeDef& node); bool IsShapeN(const NodeDef& node); bool IsShuffle(const NodeDef& node); bool IsSigmoidGrad(const NodeDef& node); +bool IsSnapshot(const NodeDef& node); bool IsSoftplusGrad(const NodeDef& node); bool IsSoftsignGrad(const NodeDef& node); bool IsSplit(const NodeDef& node); @@ -187,6 +188,10 @@ bool IsValueAndOrderPreserving(const NodeDef& node); // function returns true if the op commutes with all element-wise operations. bool IsValuePreserving(const NodeDef& node); +// Returns true if node is idempotent w.r.t. its first input, i.e. if +// Op(Op(x, y, z), y, z) = Op(x, y, z). +bool IsIdempotent(const NodeDef& node); + bool IsUnaryElementWise(const NodeDef& node); // Returns true if we can find an opdef corresponding to the op of the node. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 2a5654f752..29f49079c4 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -295,6 +295,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage { } } } + DedupControlInputs(target_node); } bool IsInPreserveSet(const NodeDef& node) const { @@ -1690,6 +1691,32 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { std::unordered_set optimized_nodes_; }; +class RemoveIdempotentStage : public ArithmeticOptimizerStage { + public: + explicit RemoveIdempotentStage(const GraphOptimizerContext& ctx, + const ArithmeticOptimizerContext& ctx_ext) + : ArithmeticOptimizerStage("RemoveIdempotent", ctx, ctx_ext) {} + ~RemoveIdempotentStage() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsIdempotent(*node) && !IsInPreserveSet(*node); + } + + Status TrySimplify(NodeDef* node, string* simplified_node_name) override { + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input)); + auto root_scope_and_name = ParseNodeScopeAndName(node->name()); + const string new_name = OptimizedNodeName(root_scope_and_name); + if (input->op() == node->op() && input->device() == node->device() && + IsIdempotent(*input) && !ctx().node_map->NodeExists(new_name)) { + NodeDef* new_input_node = AddCopyNode(new_name, input); + ForwardControlDependencies(new_input_node, {node}); + *simplified_node_name = new_input_node->name(); + } + return Status::OK(); + } +}; + // Performs the conversion: // Div(x, Sqrt(y)) => Mul(x, Rsqrt(y)) // TODO(srjoglekar): Generalize to optimize cases like (x / pow(y, z)). @@ -1975,6 +2002,7 @@ void ArithmeticOptimizer::ForwardControlDependencies( } } } + DedupControlInputs(target_node); } // TODO(ezhulenev): extract each individual simplify rewrite into separate @@ -2381,6 +2409,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage(ctx, ctx_ext); if (options_.convert_sqrt_div_to_rsqrt_mul) pipeline.AddStage(ctx, ctx_ext); + if (options_.remove_idempotent) + pipeline.AddStage(ctx, ctx_ext); VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: " << str_util::Join(pipeline.StageNames(), ", "); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 6309dc1a33..3f9feac55f 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -67,6 +67,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool remove_negation = true; bool hoist_cwise_unary_chains = true; bool convert_sqrt_div_to_rsqrt_mul = false; + bool remove_idempotent = true; // Choose which arithmetic optimizer stages will be enabled for a given // optimization level by default. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index d32743f3f2..e109e66633 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -83,6 +83,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { GraphDef* output) { TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); item->graph.Swap(output); + output->Clear(); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); } @@ -91,6 +92,7 @@ class ArithmeticOptimizerTest : public GrapplerTest { GraphDef* output) { TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); item->graph.Swap(output); + output->Clear(); TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); } @@ -99,8 +101,10 @@ class ArithmeticOptimizerTest : public GrapplerTest { GraphDef* output) { TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); item->graph.Swap(output); + output->Clear(); TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); item->graph.Swap(output); + output->Clear(); TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); } @@ -168,6 +172,11 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true; } + + void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_idempotent = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -2390,5 +2399,64 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) { } } +TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 3.14f, {32}); + Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {}); + Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {}); + Output sn1 = + ops::Snapshot(s.WithOpName("sn1").WithControlDependencies(ctrl1), a); + Output sn2 = + ops::Snapshot(s.WithOpName("sn2").WithControlDependencies(ctrl2), sn1); + Output out1 = ops::Identity(s.WithOpName("out1"), sn2); + Output id1 = ops::Identity(s.WithOpName("id1"), a); + Output id2 = ops::Identity(s.WithOpName("id2"), id1); + Output out2 = ops::Identity(s.WithOpName("out2"), id2); + GrapplerItem item; + item.fetch = {"out1", "out2"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveIdempotent(&optimizer); + OptimizeTwice(&optimizer, &item, &output); + + EXPECT_EQ(11, output.node_size()); + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "out1") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_sn2", node.input(0)); + found++; + } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_sn2") { + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("Snapshot", node.op()); + EXPECT_EQ("a", node.input(0)); + EXPECT_EQ("^ctrl1", node.input(1)); + EXPECT_EQ("^ctrl2", node.input(2)); + found++; + } else if (node.name() == "out2") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_id2", node.input(0)); + found++; + } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_id2") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("a", node.input(0)); + found++; + } + } + EXPECT_EQ(4, found); + + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(tensors.size(), tensors_expected.size()); + EXPECT_EQ(tensors.size(), item.fetch.size()); + for (int i = 0; i < item.fetch.size(); ++i) { + test::ExpectTensorNear(tensors_expected[i], tensors[i], 1e-6); + } +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/python/grappler/cluster_test.py b/tensorflow/python/grappler/cluster_test.py index 26c6f22d34..541747867f 100644 --- a/tensorflow/python/grappler/cluster_test.py +++ b/tensorflow/python/grappler/cluster_test.py @@ -45,7 +45,7 @@ class ClusterTest(test.TestCase): op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts( grappler_item) self.assertTrue(run_time > 0) - self.assertEqual(len(op_perfs), 8) + self.assertEqual(len(op_perfs), 4) self.assertTrue(step_stats.dev_stats) def testNoDetailedStats(self): @@ -129,7 +129,7 @@ class ClusterTest(test.TestCase): disable_detailed_stats=False, disable_timeline=False) as gcluster: op_perfs, run_time, step_stats = gcluster.MeasureCosts(grappler_item) self.assertTrue(run_time > 0) - self.assertEqual(len(op_perfs), 8) + self.assertEqual(len(op_perfs), 4) self.assertTrue(step_stats.dev_stats) def testAvailableOps(self): diff --git a/tensorflow/python/profiler/internal/run_metadata_test.py b/tensorflow/python/profiler/internal/run_metadata_test.py index fd893d6cde..216cc3dd54 100644 --- a/tensorflow/python/profiler/internal/run_metadata_test.py +++ b/tensorflow/python/profiler/internal/run_metadata_test.py @@ -23,6 +23,7 @@ from collections import defaultdict import six from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops @@ -65,7 +66,10 @@ def _run_model(): w = random_ops.random_normal(shape=[SIZE, 2 * SIZE]) y = math_ops.matmul(x, w) - with session.Session() as sess: + config = config_pb2.ConfigProto() + config.graph_options.rewrite_options.arithmetic_optimization = ( + rewriter_config_pb2.RewriterConfig.OFF) + with session.Session(config=config) as sess: run_metadata = config_pb2.RunMetadata() opts = builder.time_and_memory() opts['min_micros'] = 0 -- GitLab From 05425f25ee1f8b83624127cf0f403b6751e7d70a Mon Sep 17 00:00:00 2001 From: Nick Desaulniers Date: Thu, 3 May 2018 14:16:27 -0700 Subject: [PATCH 102/839] [TF:XLA] clean up interface to xla::VerifyHloModule It seems that the first argument, platform, is unused. PiperOrigin-RevId: 195309504 --- tensorflow/compiler/xla/tests/test_utils.cc | 3 +-- tensorflow/compiler/xla/tests/test_utils.h | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 997a1d8273..810cc25f1b 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -339,8 +339,7 @@ StatusOr>> MakeFakeArguments( return std::move(arguments); } -Status VerifyHloModule(const se::Platform& platform, HloModule* const module, - bool allow_mixed_precision) { +Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision) { return HloVerifier(allow_mixed_precision).Run(module).status(); } diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 30c147910c..f483cdebea 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -68,7 +68,7 @@ StatusOr>> MakeFakeArguments( // Check that a given module satisfies various constraints before trying to // execute it. -Status VerifyHloModule(const se::Platform& platform, HloModule* const module, +Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision = false); } // namespace xla -- GitLab From 316e0bab900d2a513e4e9622940181414e0d0596 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 May 2018 14:18:07 -0700 Subject: [PATCH 103/839] Add separate get_read and get_updated helpers that work on code exceprts. Handle corner case for AugAssign. Fix bug in _node_sets_self_attribute. PiperOrigin-RevId: 195309809 --- .../pyct/static_analysis/activity.py | 79 ++++++++-- .../pyct/static_analysis/activity_test.py | 136 +++++++++++++++--- 2 files changed, 187 insertions(+), 28 deletions(-) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py index 2c14c2c8c2..4d7b0cbb7b 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py @@ -23,11 +23,12 @@ import copy import gast from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import transformer -from tensorflow.contrib.autograph.pyct.qual_names import QN from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno # TODO(mdan): Add support for PY3 (e.g. Param vs arg). +# TODO(alexbw): Ignore named literals (e.g. None) class Scope(object): @@ -43,16 +44,20 @@ class Scope(object): used: identifiers referenced in this scope """ - def __init__(self, parent, isolated=True): + def __init__(self, parent, isolated=True, add_unknown_symbols=False): """Create a new scope. Args: parent: A Scope or None. isolated: Whether the scope is isolated, that is, whether variables created in this scope should be visible to the parent scope. + add_unknown_symbols: Whether to handle attributed and subscripts + without having first seen the base name. + E.g., analyzing the statement 'x.y = z' without first having seen 'x'. """ self.isolated = isolated self.parent = parent + self.add_unknown_symbols = add_unknown_symbols self.modified = set() self.created = set() self.used = set() @@ -134,13 +139,17 @@ class Scope(object): self.params.add(name) def mark_creation(self, name, writes_create_symbol=False): + """Mark a qualified name as created.""" if name.is_composite(): parent = name.parent - if self.has(parent): - if not writes_create_symbol: - return + if not writes_create_symbol: + return else: - raise ValueError('Unknown symbol "%s".' % parent) + if not self.has(parent): + if self.add_unknown_symbols: + self.mark_read(parent) + else: + raise ValueError('Unknown symbol "%s".' % parent) self.created.add(name) def mark_write(self, name): @@ -163,17 +172,25 @@ class Scope(object): class ActivityAnalyzer(transformer.Base): - """Annotates nodes with local scope information. See Scope.""" + """Annotates nodes with local scope information. - def __init__(self, context, parent_scope): + See Scope. + + The use of this class requires that qual_names.resolve() has been called on + the node. This class will ignore nodes have not been + annotated with their qualified names. + """ + + def __init__(self, context, parent_scope=None, add_unknown_symbols=False): super(ActivityAnalyzer, self).__init__(context) - self.scope = Scope(parent_scope) + self.scope = Scope(parent_scope, None, add_unknown_symbols) self._in_return_statement = False + self._in_aug_assign = False @property def _in_constructor(self): - innermost = self.enclosing_entities[-1] if len(self.enclosing_entities) > 1: + innermost = self.enclosing_entities[-1] parent = self.enclosing_entities[-2] return isinstance(parent, gast.ClassDef) and innermost.name == '__init__' return False @@ -184,6 +201,7 @@ class ActivityAnalyzer(transformer.Base): # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'. if qn.has_attr and qn.parent.qn == ('self',): return True + return False def _track_symbol(self, node, @@ -201,12 +219,14 @@ class ActivityAnalyzer(transformer.Base): self.scope.mark_write(qn.parent) if writes_create_symbol: self.scope.mark_creation(qn, writes_create_symbol=True) + if self._in_aug_assign: + self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. - # TODO(mdan): This bay be incorrect with nested functions. + # TODO(mdan): This may be incorrect with nested functions. # For nested functions, we'll have to add the notion of hiding args from # the parent scope, not writing to them. self.scope.mark_creation(qn) @@ -222,6 +242,14 @@ class ActivityAnalyzer(transformer.Base): if self._in_return_statement: self.scope.mark_returned(qn) + def visit_AugAssign(self, node): + # Special rules for AugAssign. In Assign, the target is only written, + # but in AugAssig (e.g. a += b), the target is both read and written. + self._in_aug_assign = True + self.generic_visit(node) + self._in_aug_assign = False + return node + def visit_Name(self, node): self.generic_visit(node) self._track_symbol(node) @@ -295,7 +323,7 @@ class ActivityAnalyzer(transformer.Base): def visit_FunctionDef(self, node): if self.scope: - qn = QN(node.name) + qn = qual_names.QN(node.name) self.scope.mark_write(qn) current_scope = self.scope body_scope = Scope(current_scope, isolated=True) @@ -355,5 +383,32 @@ class ActivityAnalyzer(transformer.Base): return node +def get_read(node, context): + """Return the variable names as QNs (qual_names.py) read by this statement.""" + analyzer = ActivityAnalyzer(context, None, True) + analyzer.visit(node) + return analyzer.scope.used + + +def get_updated(node, context): + """Return the variable names created or mutated by this statement. + + This function considers assign statements, augmented assign statements, and + the targets of for loops, as well as function arguments. + For example, `x[0] = 2` will return `x`, `x, y = 3, 4` will return `x` and + `y`, `for i in range(x)` will return `i`, etc. + Args: + node: An AST node + context: An EntityContext instance + + Returns: + A set of variable names (QNs, see qual_names.py) of all the variables + created or mutated. + """ + analyzer = ActivityAnalyzer(context, None, True) + analyzer.visit(node) + return analyzer.scope.created | analyzer.scope.modified + + def resolve(node, context, parent_scope=None): return ActivityAnalyzer(context, parent_scope).visit(node) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py index ef79a295bf..fdbd349af9 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py @@ -123,7 +123,7 @@ class ActivityAnalyzerTest(test.TestCase): recursive=True) node = qual_names.resolve(node) node = activity.resolve(node, ctx) - return node + return node, ctx def test_local_markers(self): @@ -133,7 +133,7 @@ class ActivityAnalyzerTest(test.TestCase): b -= 1 return b - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) self.assertFalse( anno.getanno(node.body[0].body[0].value, NodeAnno.IS_LOCAL)) # c in b = c @@ -156,6 +156,7 @@ class ActivityAnalyzerTest(test.TestCase): expected - actual, actual - expected)) def assertScopeIsRmc(self, scope, used, modified, created): + """Assert the scope contains specific used, modified & created variables.""" self.assertSymbolSetsAre(used, scope.used, 'read') self.assertSymbolSetsAre(modified, scope.modified, 'modified') self.assertSymbolSetsAre(created, scope.created, 'created') @@ -168,7 +169,7 @@ class ActivityAnalyzerTest(test.TestCase): print(a, b) return c - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) print_node = node.body[0].body[2] if isinstance(print_node, gast.Print): # Python 2 @@ -191,7 +192,7 @@ class ActivityAnalyzerTest(test.TestCase): foo(a, b) # pylint:disable=undefined-variable return c - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) call_node = node.body[0].body[2].value # We basically need to detect which variables are captured by the call # arguments. @@ -208,7 +209,7 @@ class ActivityAnalyzerTest(test.TestCase): foo(a.b, a.c) return a.d - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) call_node = node.body[0].body[1].value self.assertScopeIsRmc( anno.getanno(call_node, NodeAnno.ARGS_SCOPE), @@ -234,7 +235,7 @@ class ActivityAnalyzerTest(test.TestCase): foo(a[0], a[b]) return a[c] - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) call_node = node.body[0].body[2].value self.assertScopeIsRmc( anno.getanno(call_node, NodeAnno.ARGS_SCOPE), @@ -258,7 +259,7 @@ class ActivityAnalyzerTest(test.TestCase): b -= 1 return b, c - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) while_node = node.body[0].body[1] self.assertScopeIsRmc( anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), @@ -278,7 +279,7 @@ class ActivityAnalyzerTest(test.TestCase): b -= 1 return b, c - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) for_node = node.body[0].body[1] self.assertScopeIsRmc( anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',)) @@ -299,7 +300,7 @@ class ActivityAnalyzerTest(test.TestCase): u = -y return z, u - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'), @@ -326,7 +327,7 @@ class ActivityAnalyzerTest(test.TestCase): d = 1 return d - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), @@ -358,7 +359,7 @@ class ActivityAnalyzerTest(test.TestCase): d = 1 return d - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), @@ -390,7 +391,7 @@ class ActivityAnalyzerTest(test.TestCase): a = b * b return a - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) inner_if_node = node.body[0].body[0].body[0] self.assertScopeIsRmc( anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',), @@ -413,7 +414,7 @@ class ActivityAnalyzerTest(test.TestCase): b -= f(i) return b, c - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) fn_def_node = node.body[0].body[0] self.assertScopeIsRmc( @@ -434,7 +435,7 @@ class ActivityAnalyzerTest(test.TestCase): self.b = a self.b.c = 1 - node = self._parse_and_analyze(TestClass) + node, _ = self._parse_and_analyze(TestClass) init_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(init_node, NodeAnno.BODY_SCOPE), @@ -448,15 +449,118 @@ class ActivityAnalyzerTest(test.TestCase): def test_fn(a): a[0] += 1 - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) fn_node = node.body[0] self.assertScopeIsRmc( anno.getanno(fn_node, NodeAnno.BODY_SCOPE), - ('a',), + ('a', 'a[0]'), ('a', 'a[0]'), ('a',), ) + def test_return_vars_are_read(self): + + def test_fn(a, b, c): # pylint: disable=unused-argument + return c + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), + ('c',), + (), + ( + 'a', + 'b', + 'c', + ), + ) + + def test_aug_assign(self): + + def test_fn(a, b): + a += b + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), + ('a', 'b'), + ('a'), + ('a', 'b'), + ) + + def test_aug_assign_rvalues(self): + + a = dict(bar=3) + + def foo(): + return a + + def test_fn(x): + foo()['bar'] += x + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), + ('foo', 'x'), + (), + ('x',), + ) + + def test_params_created(self): + + def test_fn(a, b): # pylint: disable=unused-argument + return b + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('b',), (('')), + (('a', 'b'))) + + def test_get_read(self): + + def test_fn(x, y): + z = test_fn(x, y) + return z + + node, ctx = self._parse_and_analyze(test_fn) + node = node.body[0].body[0] + read_vars = activity.get_read(node, ctx) + self.assertEqual(read_vars, set(map(qual_names.QN, ('test_fn', 'x', 'y')))) + + def test_fn2(x, y, z): + z += test_fn2(x, y, z) + return z + + node, ctx = self._parse_and_analyze(test_fn2) + node = node.body[0].body[0] + read_vars = activity.get_read(node, ctx) + self.assertEqual(read_vars, + set(map(qual_names.QN, ('test_fn2', 'x', 'y', 'z')))) + + def test_get_updated(self): + + def test_fn(x, y): + z = test_fn(x, y) + return z + + node, ctx = self._parse_and_analyze(test_fn) + node = node.body[0].body[0] + updated_vars = activity.get_updated(node, ctx) + self.assertEqual(updated_vars, set(map(qual_names.QN, ('z')))) + + def test_fn2(x, y, z): + z += test_fn2(x, y, z) + return z + + node, ctx = self._parse_and_analyze(test_fn2) + node = node.body[0].body[0] + updated_vars = activity.get_updated(node, ctx) + self.assertEqual(updated_vars, set(map(qual_names.QN, ('z')))) + if __name__ == '__main__': test.main() -- GitLab From 28d43e5ada3c1e16b81c64b08cbbc273407a0347 Mon Sep 17 00:00:00 2001 From: Zhixian Yan Date: Thu, 3 May 2018 14:34:27 -0700 Subject: [PATCH 104/839] Add tflite listed models with accuracy and performance numbers. PiperOrigin-RevId: 195312636 --- tensorflow/contrib/lite/g3doc/models.md | 87 +++++++++++++++++-------- 1 file changed, 61 insertions(+), 26 deletions(-) diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md index d8134d5a00..c1c8ef049f 100644 --- a/tensorflow/contrib/lite/g3doc/models.md +++ b/tensorflow/contrib/lite/g3doc/models.md @@ -1,28 +1,63 @@ # List of Hosted Models -* [NASNet large](https://storage.googleapis.com/download.tensorflow.org/models/tflite/nasnet_large_2018_03_27.zip) -* [NASNet mobile](https://storage.googleapis.com/download.tensorflow.org/models/tflite/nasnet_mobile_2018_03_27.zip) -* [ResNet v2 101](https://storage.googleapis.com/download.tensorflow.org/models/tflite/resnet_v2_101_2018_03_27.zip) -* [ResNet v2 50](https://storage.googleapis.com/download.tensorflow.org/models/tflite/resnet_v2_50_2018_03_27.zip) -* [Inception ResNet v2](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_resnet_v2_2018_03_27.zip) -* [Inception v4](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v4_2018_03_27.zip) -* [Inception v3 2015](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_2015_2017_11_10.zip) -* [Inception v3 Slim 2016](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip) -* [Mobilenet 0.25 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.25_128_float_2017_11_08.zip) -* [Mobilenet 0.25 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.25_160_float_2017_11_08.zip) -* [Mobilenet 0.25 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.25_192_float_2017_11_08.zip) -* [Mobilenet 0.25 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.25_224_float_2017_11_08.zip) -* [Mobilenet 0.50 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.50_128_float_2017_11_08.zip) -* [Mobilenet 0.50 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.50_160_float_2017_11_08.zip) -* [Mobilenet 0.50 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.50_192_float_2017_11_08.zip) -* [Mobilenet 0.50 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.50_224_float_2017_11_08.zip) -* [Mobilenet 0.75 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.75_128_float_2017_11_08.zip) -* [Mobilenet 0.75 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.75_160_float_2017_11_08.zip) -* [Mobilenet 0.75 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.75_192_float_2017_11_08.zip) -* [Mobilenet 0.75 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.75_224_float_2017_11_08.zip) -* [Mobilenet 1.0 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_128_float_2017_11_08.zip) -* [Mobilenet 1.0 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_160_float_2017_11_08.zip) -* [Mobilenet 1.0 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_192_float_2017_11_08.zip) -* [Mobilenet 1.0 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_float_2017_11_08.zip) -* [Mobilenet 1.0 224 Quant](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) -* [Smart Reply 1.0 Android ](https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip) +## Image classification (Float Models) + +Model Name | Paper_Model_Files^ | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance +------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------: +DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms +SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms +NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 72.2% | 90.6% | 261 ms | 389 ms +NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.1% | 95.8% | 6697 ms | 7940 ms +ResNet_V2_50 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_50_2018_04_27.tgz) | 102.3 Mb | 68.1% | 88.4% | 942 ms | 1008 ms +ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_101_2018_04_27.tgz) | 178.3 Mb | 70.4% | 89.6% | 1880 ms | 1970 ms +Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 76.9% | 93.5% | 1433 ms | 1522 ms +Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 79.6% | 94.6% | 2986 ms | 3139 ms +Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 76.8% | 93.5% | 2731 ms | 2926 ms +Mobilenet_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.5% | 66.3% | 6.2 ms | 13.0 ms +Mobilenet_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.5% | 70.3% | 8.6 ms | 19.5 ms +Mobilenet_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.7% | 72.3% | 12.1 ms | 27.8 ms +Mobilenet_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 49.8% | 74.2% | 16.2 ms | 37.3 ms +Mobilenet_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.3% | 79.4% | 18.1 ms | 29.9 ms +Mobilenet_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.1% | 81.9% | 26.8 ms | 45.9 ms +Mobilenet_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 61.7% | 83.6% | 35.6 ms | 65.3 ms +Mobilenet_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.3% | 84.9% | 47.6 ms | 164.2 ms +Mobilenet_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.1% | 83.9% | 34.6 ms | 48.7 ms +Mobilenet_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.3% | 86.0% | 51.3 ms | 75.2 ms +Mobilenet_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.2% | 87.3% | 71.7 ms | 107.0 ms +Mobilenet_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.4% | 88.2% | 95.7 ms | 143.4 ms +Mobilenet_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.2% | 85.8% | 57.4 ms | 76.8 ms +Mobilenet_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.0% | 87.7% | 86.0 ms | 117.7 ms +Mobilenet_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 70.0% | 89.2% | 118.6 ms | 167.3 ms +Mobilenet_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 70.9% | 89.9% | 160.1 ms | 224.3 ms + +^ The model files include both TF Lite FlatBuffer and Tensorflow frozen Graph. + +^^ The performance numbers are generated in the benchmark on Pixel-2 using +single thread large core. + +## Image classification (Quantized Models) + +Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance +------------------------ | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------: +Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.9% | 65.8% | 3.7 ms +Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.5% | 69.1% | 5.5 ms +Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 45.8% | 71.9% | 7.9 ms +Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.2% | 73.8% | 10.4 ms +Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.9% | 78.9% | 8.8 ms +Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.7% | 81.3% | 13.0 ms +Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.4% | 83.2% | 18.3 ms +Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 62.2% | 84.5% | 24.7 ms +Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 59.8% | 82.8% | 16.2 ms +Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 63.9% | 85.5% | 24.3 ms +Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.2% | 87.1% | 33.8 ms +Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 67.9% | 88.1% | 45.4 ms +Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 64.0% | 85.5% | 24.9 ms +Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.3% | 87.7% | 37.4 ms +Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.0% | 88.9% | 51.9 ms +Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 69.7% | 89.5% | 70.2 ms + +## Other models + +Model | TF Lite FlatBuffer +----------------------- | :----------------: +Smart Reply 1.0 Android | [reference](https://research.googleblog.com/2017/11/on-device-conversational-modeling-with.html), [tflite](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip) -- GitLab From 4f4b15cece96c6cfa749c3fcf3288f1f47986210 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 May 2018 15:20:05 -0700 Subject: [PATCH 105/839] Fix bug that disabled loop invariant node motion optimizer. Disable it options, since it is broken in the presence of gradient stacks. Get rid of an unnecessary copy of the graph. PiperOrigin-RevId: 195319766 --- .../core/grappler/optimizers/loop_optimizer.cc | 14 ++++++-------- .../core/grappler/optimizers/loop_optimizer.h | 2 +- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index f7994221bb..5adc5b9227 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -474,15 +474,13 @@ std::vector GetStackPushNodesToConvert( return nodes_to_convert; } -Status RemoveStackOps(const GrapplerItem& item, GraphDef* optimized_graph) { - const std::unordered_set nodes_to_preserve = item.NodesToPreserve(); - const GraphDef& graph = item.graph; - *optimized_graph = graph; +Status RemoveStackOps(const std::unordered_set& nodes_to_preserve, + GraphDef* optimized_graph) { NodeMap node_map(optimized_graph); SimpleGraphView graph_view; - TF_RETURN_IF_ERROR(graph_view.Initialize(graph)); - for (int node_idx = 0; node_idx < graph.node_size(); ++node_idx) { - if (IsStackOp(graph.node(node_idx))) { + TF_RETURN_IF_ERROR(graph_view.Initialize(*optimized_graph)); + for (int node_idx = 0; node_idx < optimized_graph->node_size(); ++node_idx) { + if (IsStackOp(optimized_graph->node(node_idx))) { for (int push_node_idx : GetStackPushNodesToConvert( graph_view, nodes_to_preserve, node_idx)) { // We found push nodes without corresponding pops. Convert them to @@ -517,7 +515,7 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, TF_RETURN_IF_ERROR(linm_optimizer.Optimize()); } if (options_.enable_stack_push_removal) { - TF_RETURN_IF_ERROR(RemoveStackOps(item, optimized_graph)); + TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph)); } return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.h b/tensorflow/core/grappler/optimizers/loop_optimizer.h index a422505d23..764506f7c1 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.h +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.h @@ -52,7 +52,7 @@ class LoopOptimizer : public GraphOptimizer { // Granular control for loop optimizer stages. struct LoopOptimizerOptions { - bool enable_loop_invariant_node_motion = true; + bool enable_loop_invariant_node_motion = false; bool enable_stack_push_removal = true; static LoopOptimizerOptions Default(RewriterConfig::Toggle opt_level) { -- GitLab From f25dd60858bc9ebe7b618aa966c2ddc1eef1f775 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Thu, 3 May 2018 15:39:46 -0700 Subject: [PATCH 106/839] Use tuple instead of list to reduce the chance of it being picked by the list conversions. PiperOrigin-RevId: 195322522 --- tensorflow/contrib/autograph/converters/asserts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/autograph/converters/asserts.py b/tensorflow/contrib/autograph/converters/asserts.py index 2d9e2c58e3..3b0db677ce 100644 --- a/tensorflow/contrib/autograph/converters/asserts.py +++ b/tensorflow/contrib/autograph/converters/asserts.py @@ -33,7 +33,7 @@ class AssertsTransformer(transformer.Base): # Note: The lone tf.Assert call will be wrapped with control_dependencies # by side_effect_guards. template = """ - tf.Assert(test, [msg]) + tf.Assert(test, (msg,)) """ if node.msg is None: -- GitLab From 549d63acd35872061ae42c36c94df6dbef18ee2b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 May 2018 15:42:23 -0700 Subject: [PATCH 107/839] Do not hoist nodes that modify frame info. PiperOrigin-RevId: 195322927 --- tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 29f49079c4..adfae2e1a3 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1541,7 +1541,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { const ChainLinkSet& ops) const { if (ops.empty()) return true; const NodeDef* op0 = ops.begin()->node; - if (!IsUnaryElementWise(*op0)) return false; + if (ModifiesFrameInfo(*op0) || !IsUnaryElementWise(*op0)) return false; for (const auto& link : ops) { const NodeDef* op = link.node; if (op->device() != root_node.device() || op->op() != op0->op() || -- GitLab From 200f4a2089cd4bef7832679cd121a2dbe85d6180 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Thu, 3 May 2018 15:58:43 -0700 Subject: [PATCH 108/839] Fix oom_test so that it doesn't try to allocate a giant host buffer when run without --config=cuda. Sadly the best way I could come up with is pretty hacky. PiperOrigin-RevId: 195325149 --- tensorflow/compiler/tests/oom_test.py | 29 ++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/tests/oom_test.py b/tensorflow/compiler/tests/oom_test.py index 1434e965e3..d68d32057a 100644 --- a/tensorflow/compiler/tests/oom_test.py +++ b/tensorflow/compiler/tests/oom_test.py @@ -22,6 +22,8 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest @@ -42,20 +44,33 @@ class OutOfMemoryTest(xla_test.XLATestCase): """ def test_loop(): - size = 2e8 + size = int(2e8) while True: with self.test_session(): - # Force the compiled code to not be constant by feeding in an addend. - p = array_ops.placeholder(dtypes.float32, shape=[]) + # Force the compiled code to not be constant by feeding in a + # parameter. + p = array_ops.placeholder(dtypes.float32, shape=[2, 1, 1]) with self.test_scope(): - # Create a large R1 tensor. - c = array_ops.zeros([size, 1]) + p + # Create a computation that produces a large R1 tensor as an + # intermediate result. Reduce it down so that if this file was + # compiled without --config=cuda, we don't force a D2H copy of a + # large tensor and potentially OOM the host. + # + # This is a bit tricky because XLA:GPU doesn't currently support RNG + # ops. Here we rely on the fact that XLA doesn't do algebraic + # simplifications on conv(, ). + c = math_ops.reduce_sum( + nn_ops.convolution( + array_ops.ones([1, size, 1]), + p, + padding='SAME', + data_format='NWC')) - c.eval(feed_dict={p: 1.0}) + c.eval(feed_dict={p: [[[1.0]], [[2.0]]]}) size *= 2 self.assertRaises(errors.ResourceExhaustedError, test_loop) -if __name__ == "__main__": +if __name__ == '__main__': googletest.main() -- GitLab From 04d5adbf848eba94e5c352cd0843a094b9fa0a4a Mon Sep 17 00:00:00 2001 From: Jeremy Lau Date: Thu, 3 May 2018 16:08:48 -0700 Subject: [PATCH 109/839] Fix bugs in LogicalBuffer::ToString and BufferValue::ToProto: these functions may be called before set_color(), but color() check fails when no color is set. PiperOrigin-RevId: 195327063 --- tensorflow/compiler/xla/service/buffer_value.cc | 4 +++- tensorflow/compiler/xla/service/logical_buffer.cc | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc index df1a5ca435..2bc556a9e2 100644 --- a/tensorflow/compiler/xla/service/buffer_value.cc +++ b/tensorflow/compiler/xla/service/buffer_value.cc @@ -59,7 +59,9 @@ LogicalBufferProto BufferValue::ToProto(const SizeFunction& size_fn) const { LogicalBufferProto::Location proto_location = ToLocationProto(*instruction(), index()); proto.mutable_defined_at()->Swap(&proto_location); - proto.set_color(color().value()); + if (has_color()) { + proto.set_color(color().value()); + } return proto; } diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc index 1b3de8ad17..c742d35a7b 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.cc +++ b/tensorflow/compiler/xla/service/logical_buffer.cc @@ -32,9 +32,13 @@ LogicalBuffer::LogicalBuffer(HloInstruction* instruction, LogicalBuffer::~LogicalBuffer() {} string LogicalBuffer::ToString() const { + string color_string; + if (has_color()) { + color_string = tensorflow::strings::StrCat(" @", color().value()); + } return tensorflow::strings::StrCat(instruction_->name(), "[", tensorflow::str_util::Join(index_, ","), - "](#", id(), " @", color().value(), ")"); + "](#", id(), color_string, ")"); } } // namespace xla -- GitLab From c9a92808ab8c1e19fd0a6bba5b9814c8c2c42511 Mon Sep 17 00:00:00 2001 From: Russell Power Date: Thu, 3 May 2018 16:16:05 -0700 Subject: [PATCH 110/839] Adjust worker shutdown hooks for TPUs PiperOrigin-RevId: 195328247 --- .../contrib/tpu/python/tpu/session_support.py | 47 +++++++++++++++---- .../contrib/tpu/python/tpu/tpu_estimator.py | 23 ++++++++- 2 files changed, 58 insertions(+), 12 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py index 7c25f6693c..3455e0b4a6 100644 --- a/tensorflow/contrib/tpu/python/tpu/session_support.py +++ b/tensorflow/contrib/tpu/python/tpu/session_support.py @@ -126,12 +126,21 @@ class WorkerHeartbeatManager(object): return WorkerHeartbeatManager(self._session, bad_devices, bad_ops, self._request_placeholder) + def __repr__(self): + return 'HeartbeatManager(%s)' % ','.join(self._devices) + def shutdown(self, timeout_ms=10000): """Shutdown all workers after `shutdown_timeout_secs`.""" + logging.info('Shutting down %s.', self) req = event_pb2.WorkerHeartbeatRequest( watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms)) self.configure(req) + # Wait for workers to shutdown. This isn't strictly required + # but it avoids triggering multiple checkpoints with the same lame worker. + logging.info('Waiting %dms for worker shutdown.', timeout_ms) + time.sleep(timeout_ms / 1000) + def all_worker_devices(session): """Return a list of devices for each worker in the system.""" @@ -250,6 +259,7 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook): ' in your model definition to allow checkpointing.') with self._graph.as_default(): + logging.info('Installing graceful shutdown hook.') self._session = session_lib.Session( target=training_session.sess_str, graph=self._graph) self._workers = WorkerHeartbeatManager.from_devices( @@ -296,16 +306,33 @@ class GracefulShutdownHook(session_run_hook.SessionRunHook): fn(run_context, self._workers, lame_workers) -def restart_computation(run_context, all_workers, lame_workers): - del run_context, lame_workers - logging.info('Shutting down all workers.') - all_workers.shutdown() +class RestartComputation(object): + """Restart the entire computation. + + This hook shuts down all workers and returns control to the top-level by + throwing a CoordinatorShutdownException. + """ + + def __init__(self, timeout_ms=10000): + self.timeout_ms = timeout_ms + + def __call__(self, run_context, all_workers, lame_workers): + del run_context, lame_workers + all_workers.shutdown(timeout_ms=self.timeout_ms) + + logging.info('Terminating coordinator.') + raise CoordinatorShutdownException() + - logging.info('Terminating coordinator.') - raise CoordinatorShutdownException() +class ShutdownLameWorkers(object): + """Shutdown lamed workers. + + Processing will continue normally (typically by waiting for the down + workers to be restarted). + """ + def __init__(self, timeout_ms=10000): + self.timeout_in_ms = timeout_ms -def shutdown_lame_workers(run_context, all_workers, lame_workers): - del run_context, all_workers - logging.info('Shutting down %s', lame_workers) - lame_workers.shutdown() + def __call__(self, run_context, all_workers, lame_workers): + lame_workers.shutdown(timeout_ms=self.timeout_in_ms) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 534042b42c..a69bfa9a20 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -2049,9 +2049,28 @@ class TPUEstimator(estimator_lib.Estimator): host_ops = host_call.create_tpu_hostcall() if host_ops is None: host_ops = [] + shutdown_hooks = [] - if os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN', '0') != '0': - shutdown_hooks.append(session_support.GracefulShutdownHook()) + shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE', + 'shutdown_worker') + if shutdown_mode: + if shutdown_mode == 'shutdown_worker': + finalizer_hooks = [ + session_support.ShutdownLameWorkers(timeout_ms=1000), + ] + elif shutdown_mode == 'shutdown_computation': + finalizer_hooks = [ + session_support.RestartComputation(timeout_ms=1000), + ] + else: + raise ValueError('Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' % + shutdown_mode) + + shutdown_hooks.append(session_support.GracefulShutdownHook( + checkpoint_prefix=self.model_dir + '/model.ckpt', + on_shutdown_hooks=finalizer_hooks + )) + with ops.control_dependencies([loss]): global_step = array_ops.identity(training.get_global_step()) hooks = input_hooks + shutdown_hooks -- GitLab From 4a74a5058f7cb3ac096fa582941d9ab801ba6d65 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 May 2018 16:34:11 -0700 Subject: [PATCH 111/839] Fix flaky test time-outs for dnn_test and rnn_test. PiperOrigin-RevId: 195331183 --- tensorflow/contrib/estimator/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 41a817673d..571e2e3a5d 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -77,6 +77,7 @@ py_test( tags = [ "no_pip", "notsan", + "optonly", # times out http://b/79220679 ], deps = [ ":dnn", @@ -450,6 +451,7 @@ py_test( "no_pip", "noasan", # times out "notsan", + "optonly", # times out http://b/79220679 ], deps = [ ":head", -- GitLab From 213a98d893105945540e0169faa124ac7e1200ba Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 May 2018 17:03:03 -0700 Subject: [PATCH 112/839] [XLA] Redesign: deprecate ComputationBuilder. PiperOrigin-RevId: 195335330 --- tensorflow/compiler/xla/client/computation.h | 2 + .../compiler/xla/client/computation_builder.h | 2 + tensorflow/compiler/xla/client/lib/BUILD | 5 +- .../compiler/xla/client/lib/arithmetic.cc | 90 +---------------- .../compiler/xla/client/lib/arithmetic.h | 55 +---------- tensorflow/compiler/xla/client/lib/testing.cc | 16 ++- tensorflow/compiler/xla/client/lib/testing.h | 1 - tensorflow/compiler/xla/service/BUILD | 10 +- tensorflow/compiler/xla/service/cpu/BUILD | 4 +- .../xla/service/cpu/sample_harness.cc | 10 +- .../xla/service/hlo_cost_analysis_test.cc | 73 ++++++-------- .../xla/service/hlo_evaluator_test.cc | 10 +- .../xla/service/hlo_tfgraph_builder_test.cc | 1 - .../xla/service/transpose_folding_test.cc | 10 +- .../zero_sized_hlo_elimination_test.cc | 1 - tensorflow/compiler/xla/tests/BUILD | 18 ---- .../xla/tests/local_client_aot_test_helper.cc | 19 ++-- .../xla/tests/set_return_value_test.cc | 98 ------------------- .../xla/tests/vector_ops_simple_test.cc | 3 +- 19 files changed, 85 insertions(+), 343 deletions(-) delete mode 100644 tensorflow/compiler/xla/tests/set_return_value_test.cc diff --git a/tensorflow/compiler/xla/client/computation.h b/tensorflow/compiler/xla/client/computation.h index a53fc9e9cf..9a1bcde763 100644 --- a/tensorflow/compiler/xla/client/computation.h +++ b/tensorflow/compiler/xla/client/computation.h @@ -30,6 +30,8 @@ namespace xla { // Wraps a ComputationHandle protobuf with a lifetime. Computation is // movable and not copyable to capture the same kind of unique // ownership that std::unique_ptr represents. +// +// TODO(b/74197823): Deprecated. Use XlaComputation instead. class Computation { public: // Creates a null Computation. diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 9431c2c459..ac1eb915cc 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -48,6 +48,8 @@ namespace xla { // deferred from being handled until Build() is called. // // Thread-compatible. +// +// TODO(b/74197823): Deprecated. Use XlaBuilder instead. class ComputationBuilder { public: // client: client in which to build the computation. diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 59c4a53c05..d49d959a6c 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -22,8 +22,6 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", @@ -43,9 +41,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 63df449e0b..a1d34796cc 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -17,7 +17,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -27,28 +28,6 @@ limitations under the License. namespace xla { namespace { -using InstructionGenerator = - ComputationDataHandle (*)(ComputationBuilder*, const ComputationDataHandle&, - const ComputationDataHandle&); - -Computation CreateScalarComputation(const string& name, PrimitiveType type, - ComputationBuilder* builder, - InstructionGenerator generator) { - std::unique_ptr b; - if (type == PRED) { - b = builder->CreateSubBuilder(name); - } else { - b = builder->CreateSubBuilder( - tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type))); - } - - const Shape scalar = ShapeUtil::MakeShape(type, {}); - auto lhs = b->Parameter(0, scalar, "lhs"); - auto rhs = b->Parameter(1, scalar, "rhs"); - generator(b.get(), lhs, rhs); - return b->BuildAndNoteError(); -} - using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&); XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, @@ -71,71 +50,6 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, } // namespace -Computation CreateScalarAddComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "add", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Add(lhs, rhs); }); -} - -Computation CreateScalarMultiplyComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "mul", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Mul(lhs, rhs); }); -} - -Computation CreateScalarGeComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "ge", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Ge(lhs, rhs); }); -} - -Computation CreateScalarMaxComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "max", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Max(lhs, rhs); }); -} - -Computation CreateScalarMinComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "min", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Min(lhs, rhs); }); -} - -Computation CreateScalarAndComputation(ComputationBuilder* builder) { - return CreateScalarComputation( - "and", PRED, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->And(lhs, rhs); }); -} - -Computation CreateScalarOrComputation(ComputationBuilder* builder) { - return CreateScalarComputation( - "or", PRED, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Or(lhs, rhs); }); -} - -StatusOr Any(const ComputationDataHandle& predicates, - ComputationBuilder* builder) { - auto f = builder->ConstantR0(false); - Computation logical_or = CreateScalarOrComputation(builder); - TF_ASSIGN_OR_RETURN(std::unique_ptr predicates_shape, - builder->GetShape(predicates)); - std::vector all_dimensions(ShapeUtil::Rank(*predicates_shape)); - std::iota(all_dimensions.begin(), all_dimensions.end(), 0); - return builder->Reduce(predicates, f, logical_or, all_dimensions); -} - XlaComputation CreateScalarAddComputation(PrimitiveType type, XlaBuilder* builder) { return CreateScalarComputation( diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index f4d3fc8015..64b6b7d633 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -18,83 +18,38 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { -// Creates a scalar add computation and returns it. -Computation CreateScalarAddComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar multiply computation and returns it. -Computation CreateScalarMultiplyComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar ge computation and returns it. -Computation CreateScalarGeComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar max computation and returns it. -Computation CreateScalarMaxComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar min computation and returns it. -Computation CreateScalarMinComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar logical AND computation and returns it. -Computation CreateScalarAndComputation(ComputationBuilder* builder); - -// Creates a scalar logical OR computation and returns it. -Computation CreateScalarOrComputation(ComputationBuilder* builder); - -// Returns whether any predicate in "predicates" is set. -// -// Note: if predicates is zero-sized, Any() vacuously returns false. -StatusOr Any(const ComputationDataHandle& predicates, - ComputationBuilder* builder); - -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// // Creates a scalar add computation and returns it. XlaComputation CreateScalarAddComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar multiply computation and returns it. XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar ge computation and returns it. XlaComputation CreateScalarGeComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar max computation and returns it. XlaComputation CreateScalarMaxComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar min computation and returns it. XlaComputation CreateScalarMinComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar logical AND computation and returns it. XlaComputation CreateScalarAndComputation(XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// // Creates a scalar logical OR computation and returns it. XlaComputation CreateScalarOrComputation(XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// // Returns whether any predicate in "predicates" is set. // // Note: if predicates is zero-sized, Any() vacuously returns false. diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 311dc4bdd7..9cd87f7473 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -46,16 +45,14 @@ int64 DataSizeOfShape(const Shape& shape) { return total_size; } -// Create a ComputationDataHandle for an op what generates fake data with the -// given shape. -ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape, - ComputationBuilder* builder) { +// Creates a XlaOp for an op what generates fake data with the given shape. +XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { if (ShapeUtil::IsArray(shape)) { return builder->Broadcast( builder->ConstantLiteral(Literal::One(shape.element_type())), AsInt64Slice(shape.dimensions())); } - std::vector parts; + std::vector parts; for (const Shape& s : shape.tuple_shapes()) { parts.push_back(BuildFakeDataOpOnDevice(s, builder)); } @@ -64,11 +61,10 @@ ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape, std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, Client* client) { - ComputationBuilder b( - client, + XlaBuilder b( tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); BuildFakeDataOpOnDevice(shape, &b); - Computation computation = b.Build().ConsumeValueOrDie(); + XlaComputation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); *execution_options.mutable_shape_with_output_layout() = shape; diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index 1dc2622972..9e06141b1f 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 0b8b22b44c..9c362d8cad 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -233,7 +233,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1669,10 +1669,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2406,7 +2406,6 @@ tf_cc_test( srcs = ["hlo_tfgraph_builder_test.cc"], deps = [ ":hlo_tfgraph_builder", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", @@ -2475,7 +2474,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -2512,6 +2511,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index cb81e413a3..7e6d58c7fa 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -365,10 +365,10 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index b3f4609d46..167aa4adda 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -19,10 +19,10 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -48,13 +48,13 @@ int main(int argc, char** argv) { client->TransferToServer(*param1_literal).ConsumeValueOrDie(); // Build computation. - xla::ComputationBuilder builder(client, ""); + xla::XlaBuilder builder(""); auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); auto add = builder.Add(p1, p0, {0}); - xla::StatusOr computation_status = builder.Build(); - xla::Computation computation = computation_status.ConsumeValueOrDie(); + xla::StatusOr computation_status = builder.Build(); + xla::XlaComputation computation = computation_status.ConsumeValueOrDie(); // Execute and transfer result of computation. xla::ExecutionProfile profile; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 81cc7c4bdc..16fdda8a8b 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -20,16 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/logging.h" @@ -58,11 +55,10 @@ class HloCostAnalysisTest : public ::testing::Test { // whitebox accesses to the user computation built from the client, // as shown in the BuildHloGraph functions below. service_(static_cast(ClientLibrary::GetXlaService( - static_cast(client_)->platform()))), - computation_tracker_(service_->computation_tracker()) { + static_cast(client_)->platform()))) { // Create a computation for a unary user function: x => exp(x + 0.5) { - ComputationBuilder builder(client_, "add_and_exp"); + XlaBuilder builder("add_and_exp"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto half = builder.ConstantR0(0.5); builder.Exp(builder.Add(x, half)); @@ -73,7 +69,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary user function: (x, y) => x + y { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -84,7 +80,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x)) { - ComputationBuilder builder(client_, "sigmoid"); + XlaBuilder builder("sigmoid"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto one = builder.ConstantR0(1.0); builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x)))); @@ -95,7 +91,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary max function: (x, y) => max (x, y) { - ComputationBuilder builder(client_, "max"); + XlaBuilder builder("max"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Max(x, y); @@ -106,7 +102,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary GT function: (x, y) => x > y { - ComputationBuilder builder(client_, "gt"); + XlaBuilder builder("gt"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Gt(x, y); @@ -117,35 +113,30 @@ class HloCostAnalysisTest : public ::testing::Test { } // Build HLO graph from the given builder and return the HLO module. - std::unique_ptr BuildHloGraph(ComputationBuilder* builder) { + std::unique_ptr BuildHloGraph(XlaBuilder* builder) { auto computation_status = builder->Build(); TF_CHECK_OK(computation_status.status()); auto computation = computation_status.ConsumeValueOrDie(); - auto user_computation_status = - computation_tracker_.Resolve(computation.handle()); - TF_CHECK_OK(user_computation_status.status()); - auto user_computation = user_computation_status.ConsumeValueOrDie(); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - return std::move( - computation_tracker_.BuildHloModule(versioned_handle, HloModuleConfig()) - .ValueOrDie()); + auto config = HloModule::CreateModuleConfigFromProto(computation.proto(), + DebugOptions()) + .ConsumeValueOrDie(); + return HloModule::CreateFromProto(computation.proto(), config) + .ConsumeValueOrDie(); } Client* client_; Service* service_; - const ComputationTracker& computation_tracker_; // User computations used for higher order operations (e.g., Map, Reduce). - Computation add_; - Computation add_and_exp_; - Computation sigmoid_; - Computation max_; - Computation gt_; + XlaComputation add_; + XlaComputation add_and_exp_; + XlaComputation sigmoid_; + XlaComputation max_; + XlaComputation gt_; }; TEST_F(HloCostAnalysisTest, MatrixMultiply) { - ComputationBuilder builder(client_, "matrix_multiply"); + XlaBuilder builder("matrix_multiply"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); auto result = builder.Dot(lhs, rhs); @@ -167,7 +158,7 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { } TEST_F(HloCostAnalysisTest, Map) { - ComputationBuilder builder(client_, "map"); + XlaBuilder builder("map"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in"); auto result = builder.Map({input}, add_and_exp_, {0}); @@ -184,7 +175,7 @@ TEST_F(HloCostAnalysisTest, Map) { } TEST_F(HloCostAnalysisTest, Convolution) { - ComputationBuilder builder(client_, "convolution"); + XlaBuilder builder("convolution"); auto input = builder.Parameter( 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, @@ -213,7 +204,7 @@ TEST_F(HloCostAnalysisTest, Convolution) { } TEST_F(HloCostAnalysisTest, Reduce) { - ComputationBuilder builder(client_, "reduce"); + XlaBuilder builder("reduce"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto result = @@ -231,7 +222,7 @@ TEST_F(HloCostAnalysisTest, Reduce) { } TEST_F(HloCostAnalysisTest, ReduceWindow) { - ComputationBuilder builder(client_, "reduce_window"); + XlaBuilder builder("reduce_window"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto result = builder.ReduceWindow(input, builder.ConstantR0(0), add_, @@ -248,7 +239,7 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) { } TEST_F(HloCostAnalysisTest, SelectAndScatter) { - ComputationBuilder builder(client_, "select_and_scatter"); + XlaBuilder builder("select_and_scatter"); auto operand = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto source = @@ -269,7 +260,7 @@ TEST_F(HloCostAnalysisTest, SelectAndScatter) { } TEST_F(HloCostAnalysisTest, Broadcast) { - ComputationBuilder b(client_, "broadcast"); + XlaBuilder b("broadcast"); b.Broadcast(b.ConstantR0(42), {10, 7}); auto hlo_module = BuildHloGraph(&b); HloCostAnalysis analysis(ShapeSize); @@ -280,7 +271,7 @@ TEST_F(HloCostAnalysisTest, Broadcast) { // Calculates the computation cost of a graph with more than one HLO node. TEST_F(HloCostAnalysisTest, FullyConnectedForward) { - ComputationBuilder builder(client_, "fully_connected_forward"); + XlaBuilder builder("fully_connected_forward"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input"); auto weight = @@ -305,7 +296,7 @@ TEST_F(HloCostAnalysisTest, FullyConnectedForward) { TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { HloCostAnalysis conv_analysis(ShapeSize); { - ComputationBuilder builder(client_, "conv_looking_matmul"); + XlaBuilder builder("conv_looking_matmul"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), "input"); auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), @@ -318,7 +309,7 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { HloCostAnalysis matmul_analysis(ShapeSize); { - ComputationBuilder builder(client_, "matmul"); + XlaBuilder builder("matmul"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input"); auto rhs = @@ -427,7 +418,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { - ComputationBuilder builder(client_, "matmul"); + XlaBuilder builder("matmul"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {123}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {42}), "y"); auto tuple = builder.Tuple({x, y}); @@ -443,7 +434,7 @@ TEST_F(HloCostAnalysisTest, TupleCost) { } TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { - ComputationBuilder builder(client_, "BaseDilatedConvolution"); + XlaBuilder builder("BaseDilatedConvolution"); auto input = builder.Parameter( 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, @@ -458,7 +449,7 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { auto result = builder.ConvGeneralDilated( input, kernel, /*window_strides=*/{1, 1}, /*padding=*/{{1, 1}, {1, 1}}, /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11}, - ComputationBuilder::CreateDefaultConvDimensionNumbers(2)); + XlaBuilder::CreateDefaultConvDimensionNumbers(2)); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 230147abfe..cc16446778 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -827,7 +827,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -1046,7 +1046,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -1109,7 +1109,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -1180,7 +1180,7 @@ TEST_P(HloEvaluatorTest, *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index f8d98f0678..be156d765d 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index c7c4160345..0319109f7f 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -222,7 +222,7 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 0, 2, 3})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -275,7 +275,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 3, 0, 2})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -334,7 +334,7 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { HloInstruction* transpose_x = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 2, 3})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -398,7 +398,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { HloInstruction* transpose_x = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index a4e67cc9d9..f5331280ee 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 54cf0543b8..0571ff5055 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1933,24 +1933,6 @@ xla_test( ], ) -xla_test( - name = "set_return_value_test", - srcs = ["set_return_value_test.cc"], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/client/xla_client:xla_computation", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - xla_test( name = "reshape_motion_test", srcs = ["reshape_motion_test.cc"], diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 3704ddd801..a366afe826 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -21,7 +21,8 @@ limitations under the License. #include "llvm/ADT/Triple.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/types.h" @@ -29,27 +30,31 @@ limitations under the License. #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" +namespace { + using xla::string; -xla::Computation Doubler(xla::Client* client) { - xla::ComputationBuilder builder(client, "doubler"); +xla::XlaComputation Doubler() { + xla::XlaBuilder builder("doubler"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); auto x = builder.Parameter(0, r0f32, "x"); builder.Mul(x, builder.ConstantR0(2.0)); return std::move(builder.Build().ValueOrDie()); } +} // namespace + int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient().ValueOrDie(); - xla::ComputationBuilder builder(client, "aot_test_helper"); + xla::XlaBuilder builder("aot_test_helper"); auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); auto opaque_param = builder.Parameter(0, opaque_shape, "x"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); auto sum = builder.CustomCall("SumStructElements", {opaque_param}, r0f32); - builder.Call(Doubler(client), {sum}); + builder.Call(Doubler(), {sum}); if (argc != 2) { LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU"; @@ -71,8 +76,8 @@ int main(int argc, char** argv) { llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string)); - xla::Computation computation = builder.Build().ConsumeValueOrDie(); - xla::CompileOnlyClient::AotComputationInstance instance{ + xla::XlaComputation computation = builder.Build().ConsumeValueOrDie(); + xla::CompileOnlyClient::AotXlaComputationInstance instance{ &computation, /*argument_layouts=*/{&opaque_shape}, &r0f32}; xla::cpu::CpuAotCompilationOptions options( diff --git a/tensorflow/compiler/xla/tests/set_return_value_test.cc b/tensorflow/compiler/xla/tests/set_return_value_test.cc deleted file mode 100644 index 29f79ec28a..0000000000 --- a/tensorflow/compiler/xla/tests/set_return_value_test.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace { - -class SetReturnValueTest : public ClientLibraryTestBase {}; - -TEST_F(SetReturnValueTest, NoSetValue) { - ComputationBuilder builder(client_, "no_set_value"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - - std::vector expected = {1.0, 3.0, 4.0, 0.0, -1.0, - 5.0, 6.0, -2.0, -3.0, 7.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(SetReturnValueTest, SetValue) { - ComputationBuilder builder(client_, "set_value"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - auto builder_status = builder.SetReturnValue(ax); - EXPECT_TRUE(builder_status.ok()); - - std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, - 4.0, 5.0, -3.0, -4.0, 6.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(SetReturnValueTest, SetValueAndModify) { - ComputationBuilder builder(client_, "set_value_and_modify"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - auto builder_status = builder.SetReturnValue(ax); - EXPECT_TRUE(builder_status.ok()); - auto aaax = builder.Add(alpha, aax); - - std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, - 4.0, 5.0, -3.0, -4.0, 6.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(SetReturnValueTest, SetValueMultipleTimesAndModify) { - ComputationBuilder builder(client_, "set_value_multiple_times_and_modify"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - auto builder_status = builder.SetReturnValue(aax); - EXPECT_TRUE(builder_status.ok()); - auto aaax = builder.Add(alpha, aax); - builder_status = builder.SetReturnValue(ax); - EXPECT_TRUE(builder_status.ok()); - auto aaaax = builder.Add(alpha, aaax); - - std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, - 4.0, 5.0, -3.0, -4.0, 6.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 3dded3f715..5cce7a2bf8 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -350,7 +349,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { } XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto zero = builder.ConstantR0(0); auto one = builder.ConstantR0(10); auto x = builder.ConstantR1({-3, 3, 9, 13}); -- GitLab From fc7b593cda65f4a3a3de0cc733270f0864f820e2 Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Thu, 3 May 2018 17:21:26 -0700 Subject: [PATCH 113/839] Clear the stat cache of the target when renaming the file. PiperOrigin-RevId: 195337886 --- .../core/platform/cloud/gcs_file_system.cc | 4 +- .../platform/cloud/gcs_file_system_test.cc | 72 +++++++++++++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 488f9cc75d..e44e897434 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -1375,9 +1375,9 @@ Status GcsFileSystem::RenameObject(const string& src, const string& target) { request->SetResultBuffer(&output_buffer); TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when renaming ", src, " to ", target); - // Flush the target from the block cache. The source will be flushed in the + // Flush the target from the caches. The source will be flushed in the // DeleteFile call below. - file_block_cache_->RemoveFile(target); + ClearFileCaches(target); Json::Value root; TF_RETURN_IF_ERROR(ParseJson(output_buffer, &root)); bool done; diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc index c639299954..28be13869b 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -1902,6 +1902,78 @@ TEST(GcsFileSystemTest, RenameFile_Object) { EXPECT_EQ("fedcba98", result); } +TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { + std::vector requests( + {// Stat the target file. + new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "path%2Fdst.txt?fields=size%2Cupdated\n" + "Auth Token: fake_token\n" + "Timeouts: 5 1 10\n", + strings::StrCat("{\"size\": \"1000\"," + "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), + // IsDirectory is checking whether there are children objects. + new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?" + "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsrc.txt%2F" + "&maxResults=1\n" + "Auth Token: fake_token\n" + "Timeouts: 5 1 10\n", + "{}"), + // IsDirectory is checking if the path exists as an object. + new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "path%2Fsrc.txt?fields=size%2Cupdated\n" + "Auth Token: fake_token\n" + "Timeouts: 5 1 10\n", + strings::StrCat("{\"size\": \"1010\"," + "\"updated\": \"2016-04-29T23:15:24.896Z\"}")), + // Copying to the new location. + new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "path%2Fsrc.txt/rewriteTo/b/bucket/o/path%2Fdst.txt\n" + "Auth Token: fake_token\n" + "Post: yes\n" + "Timeouts: 5 1 10\n", + "{\"done\": true}"), + // Deleting the original file. + new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "path%2Fsrc.txt\n" + "Auth Token: fake_token\n" + "Timeouts: 5 1 10\n" + "Delete: yes\n", + ""), + new FakeHttpRequest( + "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" + "path%2Fdst.txt?fields=size%2Cupdated\n" + "Auth Token: fake_token\n" + "Timeouts: 5 1 10\n", + strings::StrCat("{\"size\": \"1010\"," + "\"updated\": \"2016-04-29T23:15:24.896Z\"}"))}); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 3600 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, + kTestTimeoutConfig, nullptr /* gcs additional header */); + // Do an initial stat of the destination file to load their contents into the + // stat cache. + FileStatistics stat_before_renaming; + TF_EXPECT_OK(fs.Stat("gs://bucket/path/dst.txt", &stat_before_renaming)); + EXPECT_EQ(1000, stat_before_renaming.length); + + TF_EXPECT_OK( + fs.RenameFile("gs://bucket/path/src.txt", "gs://bucket/path/dst.txt")); + + FileStatistics stat_after_renaming; + TF_EXPECT_OK(fs.Stat("gs://bucket/path/dst.txt", &stat_after_renaming)); + EXPECT_EQ(1010, stat_after_renaming.length); +} + /// Tests the scenario when deletion returns a failure, but actually succeeds. TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { std::vector requests( -- GitLab From fa7b5a9d1ab654bbd466487e39a8b3f83c17f3f0 Mon Sep 17 00:00:00 2001 From: Chris Leary Date: Thu, 3 May 2018 18:08:34 -0700 Subject: [PATCH 114/839] [XLA] Make LocalShapedBuffer::FromLiteral fallible by passing StatusOr wrapper. PiperOrigin-RevId: 195345724 --- .../xla/python/local_computation_builder.cc | 16 ++++++++-------- .../xla/python/local_computation_builder.h | 6 ++++-- .../xla/python/local_computation_builder.i | 13 +++++++++++++ 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 7102f46737..044458164f 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -104,25 +104,25 @@ static StatusOr ToBuffer(LocalClient* client, } /* static */ -LocalShapedBuffer* LocalShapedBuffer::FromLiteral( +StatusOr LocalShapedBuffer::FromLiteral( const Literal& argument, const tensorflow::gtl::optional& shape_with_layout) { LocalClient* client = GetOrCreateLocalClient(); - ScopedShapedBuffer buf = [&] { + StatusOr buf = [&] { if (shape_with_layout) { std::unique_ptr relaid = argument.Relayout(shape_with_layout.value()); - return ToBuffer(client, /*device_ordinal=*/0, *relaid) - .ConsumeValueOrDie(); + return ToBuffer(client, /*device_ordinal=*/0, *relaid); } - return ToBuffer(client, /*device_ordinal=*/0, argument).ConsumeValueOrDie(); + return ToBuffer(client, /*device_ordinal=*/0, argument); }(); - return new LocalShapedBuffer(std::move(buf)); + TF_RETURN_IF_ERROR(buf.status()); + return new LocalShapedBuffer(std::move(buf).ValueOrDie()); } -std::unique_ptr LocalShapedBuffer::ToLiteral() const { +StatusOr> LocalShapedBuffer::ToLiteral() const { LocalClient* client = GetOrCreateLocalClient(); - return client->ShapedBufferToLiteral(*shaped_buffer()).ConsumeValueOrDie(); + return client->ShapedBufferToLiteral(*shaped_buffer()); } CompiledLocalComputation::CompiledLocalComputation( diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index e1048909ab..5ec097846a 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -59,12 +59,14 @@ StatusOr > TransferFromOutfeedLocalReplica( // client. class LocalShapedBuffer { public: - static LocalShapedBuffer* FromLiteral( + static StatusOr FromLiteral( const Literal& argument, const tensorflow::gtl::optional& shape_with_layout); + LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); const ScopedShapedBuffer* shaped_buffer() const; - std::unique_ptr ToLiteral() const; + + StatusOr > ToLiteral() const; private: ScopedShapedBuffer shaped_buffer_; diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index ac792e8189..b8cce5a5f7 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -205,6 +205,19 @@ tensorflow::ImportNumpy(); } } +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::LocalShapedBuffer*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + %typemap(out) StatusOr< std::unique_ptr > { if ($1.ok()) { std::unique_ptr value = $1.ConsumeValueOrDie(); -- GitLab From 0abbff6c0bdf0ee4690def786513298afc8b772a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 May 2018 19:45:59 -0700 Subject: [PATCH 115/839] [XLA] Redesign: cleanup client_library_test_base. PiperOrigin-RevId: 195357555 --- .../xla/tests/client_library_test_base.cc | 221 +--------------- .../xla/tests/client_library_test_base.h | 243 ++++++------------ 2 files changed, 92 insertions(+), 372 deletions(-) diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 22660c35dc..c09e7eaf2b 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -94,27 +94,13 @@ string ClientLibraryTestBase::TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } -template StatusOr> ClientLibraryTestBase::Execute( - BuilderT* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); return client_->Execute(computation, arguments, &execution_options_); } -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_output_layout) { - ExecutionOptions execution_options = execution_options_; - if (shape_with_output_layout != nullptr) { - *execution_options.mutable_shape_with_output_layout() = - *shape_with_output_layout; - } - return client_->ExecuteAndTransfer(computation, arguments, - &execution_options); -} - StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -128,17 +114,6 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } -template <> -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_output_layout) { - // Build the computation, as a convenience. - TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); - return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); -} - -template <> StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout) { @@ -162,18 +137,6 @@ ClientLibraryTestBase::ExecuteAndTransferReference( &execution_options); } -std::unique_ptr ClientLibraryTestBase::ExecuteOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - return Execute(builder, arguments).ConsumeValueOrDie(); -} - -std::unique_ptr ClientLibraryTestBase::ExecuteAndTransferOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - return ExecuteAndTransfer(builder, arguments).ConsumeValueOrDie(); -} - string ClientLibraryTestBase::ExecuteToString( XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { auto computation_status = builder->Build(); @@ -191,32 +154,6 @@ string ClientLibraryTestBase::ExecuteToString( } } -string ClientLibraryTestBase::ExecuteToString( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - auto computation_status = builder->Build(); - if (!computation_status.ok()) { - return computation_status.status().ToString(); - } - auto computation = computation_status.ConsumeValueOrDie(); - - auto result = - client_->ExecuteAndTransfer(computation, arguments, &execution_options_); - if (!result.ok()) { - return result.status().ToString(); - } else { - return result.ValueOrDie()->ToString(); - } -} - -void ClientLibraryTestBase::ComputeAndCompareR1( - ComputationBuilder* builder, const tensorflow::core::Bitmap& expected, - tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = Literal::CreateR1(expected); - ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, - arguments); -} - void ClientLibraryTestBase::ComputeAndCompareR1( XlaBuilder* builder, const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments) { @@ -225,18 +162,16 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments); } -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, shape_with_layout)); } -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, @@ -245,7 +180,7 @@ void ClientLibraryTestBase::ComputeAndCompareLiteral( tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( - const xla::Computation& computation, const Literal& expected, + const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output) { @@ -271,7 +206,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( - const xla::Computation& computation, const Literal& expected, + const xla::XlaComputation& computation, const Literal& /*expected*/, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output, @@ -334,28 +269,8 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( return choose(0); } -tensorflow::Status -ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( - const xla::XlaComputation& /*computation*/, const Literal& /*expected*/, - tensorflow::gtl::ArraySlice /*arguments*/, - const std::function& /*verify_output*/) { - return Unimplemented("not yet implemented for XlaComputation"); -} - -tensorflow::Status -ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( - const xla::XlaComputation& /*computation*/, const Literal& /*expected*/, - tensorflow::gtl::ArraySlice /*arguments*/, - const std::function& /*verify_output*/, - const Shape* /*output_with_layout*/) { - return Unimplemented("not yet implemented for XlaComputation"); -} - -template tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), @@ -412,9 +327,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( return tensorflow::Status::OK(); } -template tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, ErrorSpec error, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), @@ -484,9 +398,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( EXPECT_EQ(expected, actual->GetR1U8AsString()); } -template void ClientLibraryTestBase::ComputeAndCompareTuple( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); @@ -497,9 +410,8 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( LiteralTestUtil::ExpectEqual(expected, *actual); } -template void ClientLibraryTestBase::ComputeAndCompareTuple( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); @@ -510,60 +422,6 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( LiteralTestUtil::ExpectNear(expected, *actual, error); } -void ClientLibraryTestBase::ComputeAndCompare( - ComputationBuilder* builder, const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments) { - auto status_or_data = ComputeValueAndReference(builder, operand, arguments); - EXPECT_IS_OK(status_or_data); - if (!status_or_data.ok()) { - return; - } - std::unique_ptr reference, result; - std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(*reference, *result); -} - -void ClientLibraryTestBase::ComputeAndCompare( - ComputationBuilder* builder, const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { - auto status_or_data = ComputeValueAndReference(builder, operand, arguments); - EXPECT_IS_OK(status_or_data); - if (!status_or_data.ok()) { - return; - } - std::unique_ptr reference, result; - std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*reference, *result, error); -} - -StatusOr, std::unique_ptr>> -ClientLibraryTestBase::ComputeValueAndReference( - ComputationBuilder* builder, const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments) { - // Transfer the arguments to the executor service. We put the unique_ptr's - // into a vector to keep the data alive on the service until the end of this - // function. - std::vector> argument_data; - for (const auto& arg : arguments) { - TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg)); - argument_data.push_back(std::move(data)); - } - - // Create raw pointers to the GlobalData for the rest of the call stack. - std::vector argument_data_ptr; - std::transform( - argument_data.begin(), argument_data.end(), - std::back_inserter(argument_data_ptr), - [](const std::unique_ptr& data) { return data.get(); }); - - TF_ASSIGN_OR_RETURN( - auto reference, - builder->ComputeConstant(operand, /*output_layout=*/nullptr, arguments)); - TF_ASSIGN_OR_RETURN(auto result, - ExecuteAndTransfer(builder, argument_data_ptr)); - return std::make_pair(std::move(reference), std::move(result)); -} - void ClientLibraryTestBase::ComputeAndCompare( XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { auto status_or_data = ComputeValueAndReference(builder, arguments); @@ -651,8 +509,8 @@ XlaComputation ClientLibraryTestBase::CreateScalarMax() { return computation_status.ConsumeValueOrDie(); } -Computation ClientLibraryTestBase::CreateScalarReluSensitivity() { - ComputationBuilder builder(client_, "relu_sensitivity"); +XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { + XlaBuilder builder("relu_sensitivity"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); auto activation = builder.Parameter(0, shape, "activation"); auto backprop = builder.Parameter(1, shape, "backprop"); @@ -693,14 +551,6 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, return array; } -ComputationDataHandle ClientLibraryTestBase::AddParam( - const Literal& argument, ComputationBuilder* builder) { - ComputationDataHandle data_handle; - arguments_.push_back(CreateParameterAndTransferLiteral( - arguments_.size(), argument, "", builder, &data_handle)); - return data_handle; -} - XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaBuilder* builder) { XlaOp data_handle; @@ -709,59 +559,10 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, return data_handle; } -ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( - const Literal& literal, ComputationBuilder* builder) { - return builder->ConstantLiteral( - use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); -} - XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { return builder->ConstantLiteral( use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); } -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_layout); - -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_layout); - -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, - const Shape* shape_with_layout); - -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, - const Shape* shape_with_layout); - -template void ClientLibraryTestBase::ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments); - -template void ClientLibraryTestBase::ComputeAndCompareTuple( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments); - -template void ClientLibraryTestBase::ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - -template void ClientLibraryTestBase::ComputeAndCompareTuple( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - -template StatusOr> ClientLibraryTestBase::Execute( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); - -template StatusOr> ClientLibraryTestBase::Execute( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); - } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 32eea7c2f3..e58979a303 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -25,10 +25,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -91,21 +90,11 @@ class ClientLibraryTestBase : public ::testing::Test { // Convenience methods for building and running a computation with the member // execution options. Modify execution_options_ in your test if you want to // customize the options. - template StatusOr> Execute( - BuilderT* builder, tensorflow::gtl::ArraySlice arguments); + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); - // TODO(b/74197823): Remove the template type 'BuilderT' in all methods once - // the migration to XlaBuilder is complete. - - template StatusOr> ExecuteAndTransfer( - BuilderT* builder, tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_output_layout = nullptr); - - StatusOr> ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); StatusOr> ExecuteAndTransfer( @@ -121,101 +110,90 @@ class ClientLibraryTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); - // Convenience OrDie variants of above methods. - std::unique_ptr ExecuteOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); - std::unique_ptr ExecuteAndTransferOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); - // Run a computation and return its value as a string. If an error // occurs, then instead return the error as a string. string ExecuteToString(XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); - string ExecuteToString(ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); // Convenience methods for building and running a computation, transferring // the result, and comparing it to the expected value(s). Methods are // templated on the native host type which maps to specific XLA types (See - // ComputationBuilder/XlaBuilder for details). For each rank, two forms are + // XlaBuilder for details). For each rank, two forms are // provided: one for floating point types with an ErrorSpec parameter, and one // for integral types without the ErrorSpec parameter. - template - void ComputeAndCompareR0(BuilderT* builder, NativeT expected, + template + void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR0(BuilderT* builder, NativeT expected, + template + void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR1(BuilderT* builder, + template + void ComputeAndCompareR1(XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR1(BuilderT* builder, + template + void ComputeAndCompareR1(XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // As above, but uses a bitmap to hold the predicate vector to avoid // deficiencies of vector. - void ComputeAndCompareR1(ComputationBuilder* builder, - const tensorflow::core::Bitmap& expected, - tensorflow::gtl::ArraySlice arguments); void ComputeAndCompareR1(XlaBuilder* builder, const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR2(BuilderT* builder, const Array2D& expected, + template + void ComputeAndCompareR2(XlaBuilder* builder, + const Array2D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR2(BuilderT* builder, const Array2D& expected, + template + void ComputeAndCompareR2(XlaBuilder* builder, + const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR3(BuilderT* builder, const Array3D& expected, + template + void ComputeAndCompareR3(XlaBuilder* builder, + const Array3D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR3(BuilderT* builder, const Array3D& expected, + template + void ComputeAndCompareR3(XlaBuilder* builder, + const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR4(BuilderT* builder, const Array4D& expected, + template + void ComputeAndCompareR4(XlaBuilder* builder, + const Array4D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR4(BuilderT* builder, const Array4D& expected, + template + void ComputeAndCompareR4(XlaBuilder* builder, + const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // Build and run the computation and compare the result with the given // literal. shape_with_layout indicates the result layout to request when // calling Execute. - template void ComputeAndCompareLiteral( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); - template void ComputeAndCompareLiteral( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); // ComputeAndCompare variant which returns an error status. - template tensorflow::Status ComputeAndCompareLiteralWithStatus( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); - template tensorflow::Status ComputeAndCompareLiteralWithStatus( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); @@ -227,25 +205,13 @@ class ClientLibraryTestBase : public ::testing::Test { // Convenience method for running a built computation, transferring the // result, and comparing it to the expected tuple literal. - template void ComputeAndCompareTuple( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments); - template void ComputeAndCompareTuple( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - // Convenience method for running a built computation and comparing the result - // with the HloEvaluator. - void ComputeAndCompare(ComputationBuilder* builder, - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments); - void ComputeAndCompare(ComputationBuilder* builder, - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments, - ErrorSpec error); - // Convenience method for running a built computation and comparing the result // with the reference result. void ComputeAndCompare(XlaBuilder* builder, @@ -257,7 +223,7 @@ class ClientLibraryTestBase : public ::testing::Test { // Create scalar operations for use in reductions. XlaComputation CreateScalarRelu(); XlaComputation CreateScalarMax(); - Computation CreateScalarReluSensitivity(); + XlaComputation CreateScalarReluSensitivity(); // Special case convenience functions for creating filled arrays. @@ -297,34 +263,25 @@ class ClientLibraryTestBase : public ::testing::Test { // server, then stores into "data_handle" the global handle for that // parameter. When the use_bfloat16 flag is set but the literal has F32 // elements, the literal will be converted to BF16 before being transferred. - template std::unique_ptr CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - BuilderT* builder, HandleT* data_handle); + XlaBuilder* builder, XlaOp* data_handle); // As above, but the caller can specify the device that the literal is // transferred to. If device_handle is nullptr, the literal will be // transferred to the default device. - template std::unique_ptr CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, BuilderT* builder, - HandleT* data_handle); + const DeviceHandle* device_handle, XlaBuilder* builder, + XlaOp* data_handle); // Creates a parameter instruction and sets the value that will be passed to // the computation as specified. This function must be used for all parameters // or none and no parameters must be passed when invoking the computation if // using this mechanism. If using this mechanism, then each parameter must be // set exactly once. The first added parameter gets index 0, then 1 and so on. - ComputationDataHandle AddParam(const Literal& argument, - ComputationBuilder* builder); XlaOp AddParam(const Literal& argument, XlaBuilder* builder); - template - ComputationDataHandle AddParam(const Array& argument, - ComputationBuilder* builder) { - return AddParam(*Literal::CreateFromArray(argument), builder); - } template XlaOp AddParam(const Array& argument, XlaBuilder* builder) { return AddParam(*Literal::CreateFromArray(argument), builder); @@ -333,18 +290,11 @@ class ClientLibraryTestBase : public ::testing::Test { // Creates a constant instruction with the given literal. When the // use_bfloat16 flag is set but the literal has F32 elements, the elements // will be converted to BF16s. - ComputationDataHandle CreateConstantFromLiteral(const Literal& literal, - ComputationBuilder* builder); XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder); // Creates a constant instruction with the given array. When the use_bfloat16 // flag is set but the array has float elements, the elements will be // converted to bfloat16s. - template - ComputationDataHandle CreateConstantFromArray(const Array& array, - ComputationBuilder* builder) { - return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); - } template XlaOp CreateConstantFromArray(const Array& array, @@ -353,13 +303,6 @@ class ClientLibraryTestBase : public ::testing::Test { } // Same as CreateConstantFromArray, but for scalars. - template - ComputationDataHandle CreateConstantFromScalar(NativeT value, - ComputationBuilder* builder) { - return CreateConstantFromLiteral(*Literal::CreateR0(value), - builder); - } - template XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { return CreateConstantFromLiteral(*Literal::CreateR0(value), @@ -374,12 +317,12 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR0Parameter(NativeT value, int64 parameter_number, const string& name, - BuilderT* builder, - HandleT* data_handle); + XlaBuilder* builder, + XlaOp* data_handle); // Creates a parameter instruction that wraps the given values and then stores // into "data_handle" the global handle for that parameter. @@ -389,10 +332,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle); + const string& name, XlaBuilder* builder, XlaOp* data_handle); // Creates a parameter instruction that wraps the given constant array // "array_2d" and then stores to "data_handle" the global handle for that @@ -403,10 +346,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle); + const string& name, XlaBuilder* builder, XlaOp* data_handle); // Creates a parameter instruction that wraps the given constant array // "array_3d" and then stores to "data_handle" the global handle for that @@ -417,10 +360,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle); + const string& name, XlaBuilder* builder, XlaOp* data_handle); // Getter and setter for the use_bfloat16 flag, which indicates whether to run // tests with all float-type input/output converted to bfloat16. @@ -435,21 +378,6 @@ class ClientLibraryTestBase : public ::testing::Test { ExecutionOptions execution_options_; private: - // Build and run the computation with all permutations of output layouts. - tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( - const xla::Computation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const std::function& verify_output); - // Build and run the computation with all permutations of layouts of all input - // arguments. - tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts( - const xla::Computation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const std::function& verify_output, - const Shape* output_with_layout = nullptr); - tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, @@ -462,13 +390,6 @@ class ClientLibraryTestBase : public ::testing::Test { const string& error_message)>& verify_output, const Shape* output_with_layout = nullptr); - // Executes the computation and calculates the expected reference value using - // the HloEvaluator. Returns two literals in the order of (expected, actual). - StatusOr, std::unique_ptr>> - ComputeValueAndReference(ComputationBuilder* builder, - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments); - // Executes the computation and calculates the expected reference value using // the reference client. Returns two literals in the order of (expected, // actual). @@ -484,9 +405,9 @@ class ClientLibraryTestBase : public ::testing::Test { std::vector> arguments_; }; -template +template void ClientLibraryTestBase::ComputeAndCompareR0( - BuilderT* builder, NativeT expected, + XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR0(expected); @@ -494,9 +415,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR0( - BuilderT* builder, NativeT expected, + XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -510,9 +431,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR1( - BuilderT* builder, tensorflow::gtl::ArraySlice expected, + XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR1(expected); @@ -520,9 +441,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR1( - BuilderT* builder, tensorflow::gtl::ArraySlice expected, + XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -536,9 +457,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR2( - BuilderT* builder, const Array2D& expected, + XlaBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR2FromArray2D(expected); @@ -546,9 +467,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR2( - BuilderT* builder, const Array2D& expected, + XlaBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -562,9 +483,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR3( - BuilderT* builder, const Array3D& expected, + XlaBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR3FromArray3D(expected); @@ -572,9 +493,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR3( - BuilderT* builder, const Array3D& expected, + XlaBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -588,9 +509,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR4( - BuilderT* builder, const Array4D& expected, + XlaBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR4FromArray4D(expected); @@ -598,9 +519,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR4( - BuilderT* builder, const Array4D& expected, + XlaBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -614,10 +535,10 @@ void ClientLibraryTestBase::ComputeAndCompareR4( arguments, error); } -template +template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, - BuilderT* builder, HandleT* data_handle) { + XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR0(value); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -628,10 +549,10 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle) { + const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR1(values); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -642,10 +563,10 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle) { + const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -656,10 +577,10 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle) { + const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); if (use_bfloat16_ && literal->shape().element_type() == F32) { literal = LiteralTestUtil::ConvertF32ToBF16(*literal); @@ -695,23 +616,21 @@ std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( return result; } -template std::unique_ptr ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, const Literal& literal, const string& name, - BuilderT* builder, - HandleT* data_handle) { + XlaBuilder* builder, + XlaOp* data_handle) { return CreateParameterAndTransferLiteral(parameter_number, literal, name, nullptr, builder, data_handle); } -template std::unique_ptr ClientLibraryTestBase::CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, BuilderT* builder, - HandleT* data_handle) { + const DeviceHandle* device_handle, XlaBuilder* builder, + XlaOp* data_handle) { const Literal* param_literal = &literal; std::unique_ptr converted_literal; if (use_bfloat16_) { -- GitLab From 8ec11ae8eb7b97caced73ed3971209236e2aef5c Mon Sep 17 00:00:00 2001 From: Yuefeng Zhou Date: Thu, 3 May 2018 22:01:39 -0700 Subject: [PATCH 116/839] Add the MultiWorkerMirroredStrategy PiperOrigin-RevId: 195368876 --- tensorflow/contrib/distribute/python/BUILD | 13 ++ .../distribute/python/mirrored_strategy.py | 1 + .../python/multi_worker_strategy.py | 141 ++++++++++++++++++ .../python/multi_worker_strategy_test.py | 64 ++++++++ .../distribute/python/one_device_strategy.py | 1 + tensorflow/python/training/distribute.py | 20 ++- 6 files changed, 238 insertions(+), 2 deletions(-) create mode 100644 tensorflow/contrib/distribute/python/multi_worker_strategy.py create mode 100644 tensorflow/contrib/distribute/python/multi_worker_strategy_test.py diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index aaafc184bf..8dfcaf6032 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -86,6 +86,19 @@ py_library( ], ) +py_library( + name = "multi_worker_strategy", + srcs = ["multi_worker_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:training", + "//tensorflow/python:util", + ], +) + py_library( name = "one_device_strategy", srcs = ["one_device_strategy.py"], diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 2e57b02583..8237b23dbb 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -80,6 +80,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): dict((d, i) for i, d in enumerate(devices))) self._cross_tower_ops = cross_tower_ops self._prefetch_on_device = prefetch_on_device + # TODO(yuefengz): consider setting the default device. def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py new file mode 100644 index 0000000000..a552b370eb --- /dev/null +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy.py @@ -0,0 +1,141 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes implementing a mirrored DistributionStrategy for multiple workers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from functools import partial + +from tensorflow.contrib.distribute.python import values +from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.training import device_util +from tensorflow.python.training import server_lib +from tensorflow.python.util import nest + + +# TODO(yuefengz): support between-graph replication. +# TODO(yuefengz): merge this class into its base class. +# TODO(yuefengz): in some cases, we probably want to use configure method to +# configure this class. +# TODO(yuefengz): MirroredStrategy.worker_devices may be confusing after the +# class is introduced. +class MultiWorkerMirroredStrategy(MirroredStrategy): + """Mirrored strategy that works on multiple workers with in-graph replication. + + There are several important concepts for distributed TensorFlow, e.g. + `client`, `job`, 'task', `cluster`, `in-graph replication` and + 'synchronous training' and they have already been defined in the + [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed). + The distribution strategy inherits these concepts as well and in addition to + that we also clarify several more concepts: + * **In-graph replication**: the `client` creates a single `tf.Graph` that + specifies tasks for devices on all workers. The `client` then creates a + client session which will talk to the `master` service of a `worker`. Then + the `master` will parition the graph and distribute the work to all + participating workers. + * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one + physical machine. We will have multiple `worker`s with different `task` + index. They all do similar things except for one worker checkpointing model + variables, writing summaries, etc. in addition to its ordinary work. + + This class maps one tower to one device on a worker. It mirrors all model + variables on all towers. For example, if you have two `worker`s and each + `worker` has 4 GPUs, it will create 8 copies of the model variables on these 8 + GPUs. Then like in MirroredStrategy, each tower performs their computation + with their own copy of variables unless in cross-tower model where variable or + tensor reduction happens. + """ + + def __init__(self, + num_gpus_per_worker=1, + worker_job_name=None, + num_workers=None, + cluster=None, + cross_tower_ops=None, + prefetch_on_device=None): + """Initialize the strategy object. + + Args: + num_gpus_per_worker: number of GPUs per work. If it is zero, the local + CPU will be used. + worker_job_name: the job name for `worker`, typically just 'worker'. + num_workers: the number of workers. If it is 0, it regenerates to + single-worker MirroredStrategy. + cluster: a `tf.train.ClusterSpec` object or a dict that can be used to + construct a `tf.train.ClusterSpec` object or a `tf.train.ClusterDef` + proto buffer. It is an alternative way to initialize this object. + cross_tower_ops: the cross tower ops to use. If None, a default one will + be used. If configure method is called, a best one for the configuration + will be chosen. + prefetch_on_device: a boolean to specify whether to prefetech input to + each worker's devices. + + Raises: + ValueError: if got an unexpected `cluster`. + """ + if cluster is None: + self._workers = [ + '/job:%s/task:%d' % (worker_job_name, task_index) + for task_index in range(num_workers) + ] + else: + if isinstance(cluster, (dict, cluster_pb2.ClusterDef)): + cluster_spec = server_lib.ClusterSpec(cluster) + elif isinstance(cluster, server_lib.ClusterSpec): + cluster_spec = cluster + else: + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + '`tf.train.ClusterDef` object') + + self._workers = [] + for job in sorted(cluster_spec.jobs): + for task in range(cluster_spec.num_tasks(job)): + self._workers.append('/job:%s/task:%d' % (job, task)) + + self._num_gpus_per_worker = num_gpus_per_worker + if num_gpus_per_worker > 0: + self._worker_device_map = { + worker: [ + device_util.canonicalize(worker + '/device:GPU:%d' % gpu) + for gpu in range(num_gpus_per_worker) + ] for worker in self._workers + } + else: + self._worker_device_map = { + worker: [device_util.canonicalize(worker, '/device:CPU:0')] + for worker in self._workers + } + self._devices = nest.flatten(self._worker_device_map.values()) + + super(MultiWorkerMirroredStrategy, self).__init__( + devices=self._devices, prefetch_on_device=prefetch_on_device) + + # Setting `_default_device` will add a device scope in the + # distribution.scope. We set the default device to the first worker. When + # users specify device under distribution.scope by + # with tf.device("/cpu:0"): + # ... + # their ops will end up on the cpu device of its first worker, e.g. + # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode. + self._default_device = self._workers[0] + + def distribute_dataset(self, dataset_fn): + return values.MultiWorkerDataset( + partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, + self._prefetch_on_device) diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py new file mode 100644 index 0000000000..ee7588163e --- /dev/null +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py @@ -0,0 +1,64 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MultiWorkerMirroredStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import multi_worker_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.training import server_lib + + +@test_util.with_c_api +class MultiWorkerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, + strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + return multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster=server_lib.ClusterSpec({ + 'worker': ['/job:worker/task:0', '/job:worker/task:1'] + }), + num_gpus_per_worker=context.num_gpus()) + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + +class DeviceScopeTest(test.TestCase): + """Test the device scope of MultiWorkerMirroredStrategy.""" + + def testDeviceScope(self): + with context.graph_mode(): + strategy = multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={'worker': ['/job:worker/task:0', '/job:worker/task:1']}, + num_gpus_per_worker=context.num_gpus()) + with strategy.scope(): + a = constant_op.constant(1.) + with ops.device('/cpu:0'): + b = constant_op.constant(1.) + self.assertEqual(a.device, '/job:worker/task:0') + self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 64aa369201..09b6d4a515 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -40,6 +40,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): super(OneDeviceStrategy, self).__init__() self._device = device self._prefetch_on_device = prefetch_on_device + self._default_device = device def _create_variable(self, next_creator, *args, **kwargs): # No need to distinguish tower-local variables when not mirroring, diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index c16b05102e..21f81ee187 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -290,19 +290,31 @@ def _require_distribution_strategy_scope(distribution_strategy): class _CurrentDistributionContext(object): """Context manager for setting the `DistributionStrategy` and var creator.""" - def __init__(self, distribution_strategy, var_creator_scope, var_scope=None): + def __init__(self, + distribution_strategy, + var_creator_scope, + var_scope=None, + default_device=None): self._context = _CrossTowerThreadMode(distribution_strategy) self._var_creator_scope = var_creator_scope self._var_scope = var_scope + if default_device: + self._device_scope = ops.device(default_device) + else: + self._device_scope = None def __enter__(self): _push_per_thread_mode(self._context) if self._var_scope: self._var_scope.__enter__() self._var_creator_scope.__enter__() + if self._device_scope: + self._device_scope.__enter__() return self._context.distribution_strategy def __exit__(self, exception_type, exception_value, traceback): + if self._device_scope: + self._device_scope.__exit__(exception_type, exception_value, traceback) self._var_creator_scope.__exit__(exception_type, exception_value, traceback) if self._var_scope: self._var_scope.__exit__(exception_type, exception_value, traceback) @@ -557,6 +569,9 @@ class DistributionStrategy(object): # TODO(josh11b): List of towers with their worker and parameter devices # (where the parameter devices may overlap in the ps case). + def __init__(self): + self._default_device = None + def scope(self): """Returns a context manager selecting this DistributionStrategy as current. @@ -587,7 +602,8 @@ class DistributionStrategy(object): self, variable_scope.variable_creator_scope(creator_with_resource_vars), variable_scope.variable_scope( variable_scope.get_variable_scope(), - custom_getter=disable_partitioned_variables)) + custom_getter=disable_partitioned_variables), + self._default_device) def _create_variable(self, next_creator, *args, **kwargs): # Note: should support "colocate_with" argument. -- GitLab From da0dcb21501b765932e392ae710ebbecefeb309c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 May 2018 23:36:02 -0700 Subject: [PATCH 117/839] Internal change. PiperOrigin-RevId: 195374319 --- .../kernels/bidirectional_sequence_lstm.cc | 6 ++--- tensorflow/contrib/lite/kernels/kernel_util.h | 4 +++ tensorflow/contrib/lite/kernels/lstm.cc | 4 +-- tensorflow/contrib/lite/kernels/mean.cc | 16 ++++++------ tensorflow/contrib/lite/kernels/svdf.cc | 26 +++++++++---------- .../kernels/unidirectional_sequence_lstm.cc | 4 +-- 6 files changed, 30 insertions(+), 30 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index 3ac0210f36..a35ba23ced 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -365,8 +365,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayFree(node->temporaries); node->temporaries = TfLiteIntArrayCreate(2); node->temporaries->data[0] = *scratch_tensor_index; - TfLiteTensor* fw_scratch_buffer = - &context->tensors[node->temporaries->data[0]]; + TfLiteTensor* fw_scratch_buffer = GetTemporary(context, node, /*index=*/0); fw_scratch_buffer->type = input->type; fw_scratch_buffer->allocation_type = kTfLiteArenaRw; @@ -434,8 +433,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Create a scratch buffer tensor. node->temporaries->data[1] = *(scratch_tensor_index) + 1; - TfLiteTensor* bw_scratch_buffer = - &context->tensors[node->temporaries->data[1]]; + TfLiteTensor* bw_scratch_buffer = GetTemporary(context, node, /*index=*/1); bw_scratch_buffer->type = input->type; bw_scratch_buffer->allocation_type = kTfLiteArenaRw; diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h index 2f407b5da3..e225443a67 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.h +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -32,6 +32,10 @@ inline TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node, int index) { return &context->tensors[node->outputs->data[index]]; } +inline TfLiteTensor* GetTemporary(TfLiteContext* context, TfLiteNode* node, + int index) { + return &context->tensors[node->temporaries->data[index]]; +} inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; } inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; } diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index 668226e674..a1521efbb4 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -290,7 +290,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayFree(node->temporaries); node->temporaries = TfLiteIntArrayCreate(1); node->temporaries->data[0] = *scratch_tensor_index; - TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]]; + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); scratch_buffer->type = input->type; scratch_buffer->allocation_type = kTfLiteArenaRw; @@ -378,7 +378,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const bool use_peephole = (cell_to_output_weights != nullptr); // Index the scratch buffers pointers to the global scratch buffer. - TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]]; + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); float* input_gate_scratch = nullptr; float* cell_scratch = nullptr; diff --git a/tensorflow/contrib/lite/kernels/mean.cc b/tensorflow/contrib/lite/kernels/mean.cc index 047bdd1039..98f80e32d9 100644 --- a/tensorflow/contrib/lite/kernels/mean.cc +++ b/tensorflow/contrib/lite/kernels/mean.cc @@ -146,7 +146,7 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, TfLiteIntArrayFree(node->temporaries); node->temporaries = TfLiteIntArrayCreate(3); node->temporaries->data[0] = *scratch_tensor_index; - TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]]; + TfLiteTensor* scratch_tensor = GetTemporary(context, node, /*index=*/0); scratch_tensor->type = kTfLiteInt32; scratch_tensor->allocation_type = kTfLiteArenaRw; TfLiteIntArray* index_size = TfLiteIntArrayCreate(1); @@ -156,11 +156,11 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, // Creates a temp tensor to store resolved axis given input data. node->temporaries->data[1] = *scratch_tensor_index + 1; - TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]]; + TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); resolved_axis->type = kTfLiteInt32; // Creates a temp tensor to store temp sums when calculating mean. node->temporaries->data[2] = *scratch_tensor_index + 2; - TfLiteTensor* temp_sum = &context->tensors[node->temporaries->data[2]]; + TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); switch (op_context->input->type) { case kTfLiteFloat32: temp_sum->type = kTfLiteFloat32; @@ -187,8 +187,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { MeanContext op_context(context, node); TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context)); - TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]]; - TfLiteTensor* temp_sum = &context->tensors[node->temporaries->data[2]]; + TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); // Leaves work to Eval if axis is not constant; else resizes output. if (!IsConstantTensor(op_context.axis)) { SetTensorToDynamic(op_context.output); @@ -208,9 +208,9 @@ template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { MeanContext op_context(context, node); int num_axis = static_cast(NumElements(op_context.axis)); - TfLiteTensor* temp_index = &context->tensors[node->temporaries->data[0]]; - TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]]; - TfLiteTensor* temp_sum = &context->tensors[node->temporaries->data[2]]; + TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0); + TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2); // Resize the output tensor if the output tensor is dynamic. if (IsDynamicTensor(op_context.output)) { TF_LITE_ENSURE_OK(context, diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc index c69755447d..13da51c7a7 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -37,7 +37,7 @@ constexpr int kWeightsFeatureTensor = 1; constexpr int kWeightsTimeTensor = 2; constexpr int kBiasTensor = 3; constexpr int kStateTensor = 0; -constexpr int KOutputTensor = 1; +constexpr int kOutputTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; @@ -59,9 +59,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; TfLiteTensor* weights_feature = - &context->tensors[node->inputs->data[kWeightsFeatureTensor]]; - TfLiteTensor* weights_time = - &context->tensors[node->inputs->data[kWeightsTimeTensor]]; + GetInput(context, node, kWeightsFeatureTensor); + TfLiteTensor* weights_time = GetInput(context, node, kWeightsTimeTensor); // Check all the parameters of tensor match within themselves and match the // input configuration. @@ -79,8 +78,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units); } - TfLiteTensor* state = &context->tensors[node->outputs->data[kStateTensor]]; - TfLiteTensor* output = &context->tensors[node->outputs->data[KOutputTensor]]; + TfLiteTensor* state = GetOutput(context, node, kStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // Resize state. // For each batch, the state is a 2-D tensor: memory_size * num_filters @@ -112,7 +111,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { scratch_size_array->data[0] = batch_size; scratch_size_array->data[1] = num_filters; - TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]]; + TfLiteTensor* scratch_tensor = GetTemporary(context, node, /*index=*/0); scratch_tensor->type = input->type; scratch_tensor->allocation_type = kTfLiteArenaRw; TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_tensor, @@ -124,15 +123,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* weights_feature = - &context->tensors[node->inputs->data[kWeightsFeatureTensor]]; - TfLiteTensor* weights_time = - &context->tensors[node->inputs->data[kWeightsTimeTensor]]; + GetInput(context, node, kWeightsFeatureTensor); + TfLiteTensor* weights_time = GetInput(context, node, kWeightsTimeTensor); - TfLiteTensor* state = &context->tensors[node->outputs->data[kStateTensor]]; - TfLiteTensor* output = &context->tensors[node->outputs->data[KOutputTensor]]; - TfLiteTensor* scratch = &context->tensors[node->temporaries->data[0]]; + TfLiteTensor* state = GetOutput(context, node, kStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0); TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc index 3c1256d3a6..5987bf68b5 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -292,7 +292,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayFree(node->temporaries); node->temporaries = TfLiteIntArrayCreate(1); node->temporaries->data[0] = *scratch_tensor_index; - TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]]; + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); scratch_buffer->type = input->type; scratch_buffer->allocation_type = kTfLiteArenaRw; @@ -381,7 +381,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const bool use_peephole = (cell_to_output_weights != nullptr); // Index the scratch buffers pointers to the global scratch buffer. - TfLiteTensor* scratch_buffer = &context->tensors[node->temporaries->data[0]]; + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); float* input_gate_scratch = nullptr; float* cell_scratch = nullptr; float* forget_gate_scratch = nullptr; -- GitLab From 0bb55f02022e88affefc111cf9a8cf70a046d1da Mon Sep 17 00:00:00 2001 From: HyoukJoong Lee Date: Fri, 4 May 2018 00:51:58 -0700 Subject: [PATCH 118/839] Automated g4 rollback of changelist 194829761 PiperOrigin-RevId: 195379693 --- .../xla/service/hlo_module_group_metadata.cc | 7 ------- .../xla/service/hlo_module_group_metadata.h | 3 --- tensorflow/compiler/xla/service/service.cc | 13 +++---------- 3 files changed, 3 insertions(+), 20 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 3367d76ded..54c34ce116 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -194,13 +194,6 @@ int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const { LOG(FATAL) << "unknown module"; } -int64 HloModuleGroupMetadata::GetDeviceModulesCount() const { - return std::count_if(modules_.begin(), modules_.end(), - [](const HloModule* module) { - return !module->config().is_host_module(); - }); -} - Status HloModuleGroupMetadata::RecordInstructions() { const auto visitor = [this](HloInstruction* hlo) -> Status { if (hlo->opcode() == HloOpcode::kWhile) { diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index d619082616..c48a7ab0b5 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -147,9 +147,6 @@ class HloModuleGroupMetadata { // the module in the module vector. int64 GetModuleId(const HloModule* module) const; - // Returns the number of modules for devices (excluding the host module). - int64 GetDeviceModulesCount() const; - // Returns the companion instructions for the given instruction. // // Precondition: IsCompanionWhile(instruction) is true. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 6ce03ab39d..495f8801ba 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -626,16 +626,9 @@ Service::ExecuteParallelAndRegisterResult( // profiled. std::map index_to_profiled_streams; - // Build DeviceAssignment for all cores based on the provided device handles. - DeviceAssignment device_assignment(options_.number_of_replicas(), - executables.size()); - for (int64 i = 0; i < executables.size(); i++) { - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); - CHECK_EQ(replicas.size(), arguments[i].size()); - for (int64 replica = 0; replica < replicas.size(); ++replica) { - device_assignment(replica, i) = replicas[replica]->device_ordinal(); - } - } + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, + backend->computation_placer()->AssignDevices( + options_.number_of_replicas(), executables.size())); for (int64 i = 0; i < executables.size(); i++) { // Stream executors for the replicas of the current computation. -- GitLab From 1284047dca0dd58745a31cd2fd68da3173c7e120 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 01:47:12 -0700 Subject: [PATCH 119/839] * Don't copy on-host and on-device shapes locally. * Use ForEachMutableElement rather than the iterators, as it is much quicker. There is still room for improvement; ForEachMutableElement is linear in the number of nodes in the shape tree but we want to be linear in the number of nodes in the sub shape tree. But I feel this is a good enough improvement. PiperOrigin-RevId: 195384423 --- tensorflow/compiler/jit/BUILD | 25 ++++++++ tensorflow/compiler/jit/xla_launch_util.cc | 22 ++++--- tensorflow/compiler/jit/xla_launch_util.h | 11 ++++ .../compiler/jit/xla_launch_util_test.cc | 64 +++++++++++++++++++ 4 files changed, 113 insertions(+), 9 deletions(-) create mode 100644 tensorflow/compiler/jit/xla_launch_util_test.cc diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index af2965bba5..07136d6a74 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -360,6 +360,31 @@ tf_cc_test( ], ) +tf_cc_test( + name = "xla_launch_util_test", + size = "small", + srcs = ["xla_launch_util_test.cc"], + deps = [ + ":common", + ":xla_compilation_cache", + ":xla_launch_util", + ":xla_tensor", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:gpu_runtime", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core/kernels:variable_ops", + ], +) + # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. cc_header_only_library( name = "xla_jit_headers_lib", diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 2a7f04271d..33e53612b9 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -77,16 +77,16 @@ Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) { return Status::OK(); } -namespace { +namespace internal { // Return the 'index''th subtree of the given ShapedBuffer as a // ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the // subtree, and sets the input's buffer pointers to nullptr for the subtree. ScopedShapedBuffer ExtractSubShapedBuffer( ShapedBuffer* shaped_buffer, int index, xla::DeviceMemoryAllocator* allocator) { - xla::Shape on_host_shape = xla::ShapeUtil::GetTupleElementShape( + const xla::Shape& on_host_shape = xla::ShapeUtil::GetTupleElementShape( shaped_buffer->on_host_shape(), index); - xla::Shape on_device_shape = xla::ShapeUtil::GetTupleElementShape( + const xla::Shape& on_device_shape = xla::ShapeUtil::GetTupleElementShape( shaped_buffer->on_device_shape(), index); ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape, @@ -98,14 +98,18 @@ ScopedShapedBuffer ExtractSubShapedBuffer( sub_shape_tree.CopySubtreeFrom(shape_tree, /*source_base_index=*/{index}, /*target_base_index=*/{}); - for (auto& index_to_buffer : shape_tree) { - if (!index_to_buffer.first.empty() && index_to_buffer.first[0] == index) { - index_to_buffer.second = se::DeviceMemoryBase(nullptr, 0); - } - } + shape_tree.ForEachMutableElement( + [index](const xla::ShapeIndex& shape_index, + tensorflow::se::DeviceMemoryBase* data) { + // shape_index is empty for the root node. Ignore that. + if (!shape_index.empty() && shape_index[0] == index) { + *data = tensorflow::se::DeviceMemoryBase(nullptr, 0); + } + }); return ScopedShapedBuffer(std::move(sub_shaped_buffer), allocator); } -} // namespace +} // namespace internal +using internal::ExtractSubShapedBuffer; XlaComputationLaunchContext::XlaComputationLaunchContext( int64 num_resource_args, xla::LocalClient* client, diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 8a6ff3b0c7..38291b0bd4 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -140,6 +140,17 @@ class XlaTensorBuffer : public TensorBuffer { Allocator* allocator_; }; +// Exposed in this header file for microbenchmarking purposes, but this is an +// internal implementation detail. +namespace internal { +// Return the 'index''th subtree of the given ShapedBuffer as a +// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the +// subtree, and sets the input's buffer pointers to nullptr for the subtree. +xla::ScopedShapedBuffer ExtractSubShapedBuffer( + xla::ShapedBuffer* shaped_buffer, int index, + xla::DeviceMemoryAllocator* allocator); +} // namespace internal + } // namespace tensorflow #endif diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc new file mode 100644 index 0000000000..27813efc0b --- /dev/null +++ b/tensorflow/compiler/jit/xla_launch_util_test.cc @@ -0,0 +1,64 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains microbenchmarks for performance critical functions in +// xla_launch_util.cc. + +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +// Test ExtractSubBuffer with different depths (depth of ShapeTree) and fan-outs +// (cardinality of each non-leaf node's children). +void BM_ExtractSubBuffer(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = xla::ShapeUtil::MakeTupleShape(shapes); + } + xla::ShapedBuffer shaped_buffer(shape, shape, /*platform=*/nullptr, + /*device_ordinal=*/0); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + // Extract a buffer from approximately the middle of the first level of the + // tree. + tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer, + /*index=*/fan_out / 2, + /*allocator=*/nullptr) + .release(); + } +} + +BENCHMARK(BM_ExtractSubBuffer) + ->ArgPair(1, 4) + ->ArgPair(1, 8) + ->ArgPair(1, 32) + ->ArgPair(1, 64) + ->ArgPair(1, 128) + ->ArgPair(1, 256) + ->ArgPair(1, 512) + ->ArgPair(2, 4) + ->ArgPair(2, 8) + ->ArgPair(2, 32) + ->ArgPair(2, 64) + ->ArgPair(2, 128); + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + tensorflow::testing::RunBenchmarks(); + return RUN_ALL_TESTS(); +} -- GitLab From 73a1908b3c50d2f665a3a9af491e217d814edb40 Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Fri, 4 May 2018 01:57:02 -0700 Subject: [PATCH 120/839] Prefer non-nested GradientTape.gradient call when only one source is passed. PiperOrigin-RevId: 195385406 --- tensorflow/docs_src/programmers_guide/eager.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/docs_src/programmers_guide/eager.md b/tensorflow/docs_src/programmers_guide/eager.md index 595e6be4af..5926e9f7f4 100644 --- a/tensorflow/docs_src/programmers_guide/eager.md +++ b/tensorflow/docs_src/programmers_guide/eager.md @@ -227,8 +227,8 @@ w = tfe.Variable([[1.0]]) with tf.GradientTape() as tape: loss = w * w -grad = tape.gradient(loss, [w]) -print(grad) # => [tf.Tensor([[ 2.]], shape=(1, 1), dtype=float32)] +grad = tape.gradient(loss, w) +print(grad) # => tf.Tensor([[ 2.]], shape=(1, 1), dtype=float32) ``` Here's an example of `tf.GradientTape` that records forward-pass operations @@ -596,7 +596,7 @@ def line_search_step(fn, init_x, rate=1.0): # Variables are automatically recorded, but manually watch a tensor tape.watch(init_x) value = fn(init_x) - grad, = tape.gradient(value, [init_x]) + grad = tape.gradient(value, init_x) grad_norm = tf.reduce_sum(grad * grad) init_value = value while value > init_value - rate * grad_norm: -- GitLab From c183c5600b1393767c8c85aad34a436feb3bbe75 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 02:04:33 -0700 Subject: [PATCH 121/839] Fixing some linter errors in TF documentation (Github > GitHub, the the > the). PiperOrigin-RevId: 195386172 --- tensorflow/docs_src/deploy/index.md | 2 +- tensorflow/docs_src/get_started/get_started_for_beginners.md | 2 +- tensorflow/docs_src/mobile/android_build.md | 4 ++-- tensorflow/docs_src/mobile/linking_libs.md | 2 +- tensorflow/docs_src/mobile/mobile_intro.md | 4 ++-- tensorflow/docs_src/mobile/tflite/index.md | 2 +- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorflow/docs_src/deploy/index.md b/tensorflow/docs_src/deploy/index.md index 07b1bc9257..61edba04b4 100644 --- a/tensorflow/docs_src/deploy/index.md +++ b/tensorflow/docs_src/deploy/index.md @@ -14,4 +14,4 @@ the following documents: designed for production environments. TensorFlow Serving provides out-of-the-box integration with TensorFlow models. [Source code for TensorFlow Serving](https://github.com/tensorflow/serving) - is available on Github. + is available on GitHub. diff --git a/tensorflow/docs_src/get_started/get_started_for_beginners.md b/tensorflow/docs_src/get_started/get_started_for_beginners.md index fbe0ed74f8..d5a80e22c5 100644 --- a/tensorflow/docs_src/get_started/get_started_for_beginners.md +++ b/tensorflow/docs_src/get_started/get_started_for_beginners.md @@ -233,7 +233,7 @@ The Iris program requires the data from the following two .csv files: * `http://download.tensorflow.org/data/iris_training.csv`, which contains the training set. * `http://download.tensorflow.org/data/iris_test.csv`, which contains the - the test set. + test set. The **training set** contains the examples that we'll use to train the model; the **test set** contains the examples that we'll use to evaluate the trained diff --git a/tensorflow/docs_src/mobile/android_build.md b/tensorflow/docs_src/mobile/android_build.md index c35530061d..f4b07db459 100644 --- a/tensorflow/docs_src/mobile/android_build.md +++ b/tensorflow/docs_src/mobile/android_build.md @@ -26,7 +26,7 @@ If you haven't already, do the following two things: - Install [Android Studio](https://developer.android.com/studio/index.html), following the instructions on their website. -- Clone the TensorFlow repository from Github: +- Clone the TensorFlow repository from GitHub: git clone https://github.com/tensorflow/tensorflow @@ -37,7 +37,7 @@ If you haven't already, do the following two things: 2. From the **Open File or Project** window that appears, navigate to and select the `tensorflow/examples/android` directory from wherever you cloned the - TensorFlow Github repo. Click OK. + TensorFlow GitHub repo. Click OK. If it asks you to do a Gradle Sync, click OK. diff --git a/tensorflow/docs_src/mobile/linking_libs.md b/tensorflow/docs_src/mobile/linking_libs.md index 2a0a77c92d..cf0db59021 100644 --- a/tensorflow/docs_src/mobile/linking_libs.md +++ b/tensorflow/docs_src/mobile/linking_libs.md @@ -32,7 +32,7 @@ include this functionality in your program: 2. Download the nightly precompiled version from [ci.tensorflow.org](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/). -3. Build the JAR file yourself using the instructions [in our Android Github repo](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/android) +3. Build the JAR file yourself using the instructions [in our Android GitHub repo](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/android) ### iOS diff --git a/tensorflow/docs_src/mobile/mobile_intro.md b/tensorflow/docs_src/mobile/mobile_intro.md index 69b63ae7d2..1b0b9b44b4 100644 --- a/tensorflow/docs_src/mobile/mobile_intro.md +++ b/tensorflow/docs_src/mobile/mobile_intro.md @@ -80,7 +80,7 @@ tracking is especially important for applications where you’re trying to count how many objects are present over time, since it gives you a good idea when a new object enters or leaves the scene. We have some sample code for this available for Android [on -Github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android), +GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android), and also a [more general object detection model](https://github.com/tensorflow/models/tree/master/research/object_detection/README.md) available as well. @@ -231,7 +231,7 @@ process. The next step is to pick an effective model to use. You might be able to avoid training a model from scratch if someone else has already implemented a model similar to what you need; we have a repository of models implemented in -TensorFlow [on Github](https://github.com/tensorflow/models) that you can look +TensorFlow [on GitHub](https://github.com/tensorflow/models) that you can look through. Lean towards the simplest model you can find, and try to get started as soon as you have even a small amount of labelled data, since you’ll get the best results when you’re able to iterate quickly. The shorter the time it takes to diff --git a/tensorflow/docs_src/mobile/tflite/index.md b/tensorflow/docs_src/mobile/tflite/index.md index 11f11ea4dc..01881ccf3b 100644 --- a/tensorflow/docs_src/mobile/tflite/index.md +++ b/tensorflow/docs_src/mobile/tflite/index.md @@ -11,7 +11,7 @@ optimizing the kernels for mobile apps, pre-fused activations, and quantized kernels that allow smaller and faster (fixed-point math) models. Most of our TensorFlow Lite documentation is [on -Github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite) +GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite) for the time being. ## What does TensorFlow Lite contain? -- GitLab From 7a7bbc303c451fea5b3dd93109028531a89a18ab Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 02:22:14 -0700 Subject: [PATCH 122/839] Do not crash on ROOT outfeed operations. PiperOrigin-RevId: 195388075 --- .../compiler/xla/service/cpu/ir_emitter.cc | 8 ++- .../compiler/xla/service/cpu/tests/BUILD | 14 +++++ .../cpu/tests/cpu_literal_caching_test.cc | 16 +----- .../xla/service/cpu/tests/cpu_outfeed_test.cc | 57 +++++++++++++++++++ 4 files changed, 79 insertions(+), 16 deletions(-) create mode 100644 tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index e473389a29..6347ee2a2a 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -2563,8 +2563,12 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { // nothing to do since the result was already written directly into the output // buffer. VLOG(2) << "FinishVisit root: " << root->ToString(); - llvm::Value* root_value = GetEmittedValueFor(root); - VLOG(2) << " value: " << llvm_ir::DumpToString(*root_value); + if (root->opcode() == HloOpcode::kOutfeed) { + VLOG(2) << " outfeed with value: " + << llvm_ir::DumpToString(*GetEmittedValueFor(root->operand(0))); + } else { + VLOG(2) << " value: " << llvm_ir::DumpToString(*GetEmittedValueFor(root)); + } auto record_complete_computation = [&](llvm::Value* prof_counter) { if (prof_counter) { diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 4ddb7a85bc..18a915e533 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -161,3 +161,17 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +tf_cc_test( + name = "cpu_outfeed_test", + srcs = ["cpu_outfeed_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index b10eb74635..d6e0425c55 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -50,16 +50,10 @@ ENTRY main { const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body out0 = () outfeed(f32[2,3,2] const_a) - out1 = () outfeed(f32[2,3,2] const_b) - - ROOT root = f32[] constant(1) + ROOT out1 = () outfeed(f32[2,3,2] const_b) } )"; - // TODO(b/78879738): The fake "f32[] constant(1)" root is only needed to work - // around b/78879738. Once b/78879738 is fixed, we can set one of the - // outfeeds as the root. - string filecheck_pattern = R"( CHECK: private constant [2 x [3 x [2 x float]]] CHECK-NOT: private constant [2 x [3 x [2 x float]]] @@ -99,16 +93,10 @@ ENTRY main { const_b = (f32[2,1]{1,0}, f32[2]{0}) while((f32[2,1]{1,0}, f32[2]{0}) const_a), condition=while_cond, body=while_body out0 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_a) - out1 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_b) - - ROOT root = f32[] constant(1) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_b) } )"; - // TODO(b/78879738): The fake "f32[] constant(1)" root is only needed to work - // around b/78879738. Once b/78879738 is fixed, we can set one of the - // outfeeds as the root. - string filecheck_pattern = R"( CHECK: private constant [2 x float] CHECK: private constant [2 x [1 x float]] diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc new file mode 100644 index 0000000000..879372eb13 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace cpu { +namespace { +class CpuOutfeedTest : public CpuCodegenTest {}; + +TEST_F(CpuOutfeedTest, OutfeedRoot) { + const string hlo_text = R"( +HloModule Outfeed + +ENTRY main { + const_a = f32[2,3,2] constant( + f32[2,3,2] + {{{1, 2}, {1001, 1002}, {2001, 2002}}, + {{2, 1}, {2001, 3002}, {2001, 2002}}}) + + ROOT out = () outfeed(f32[2,3,2] const_a) +} +)"; + + string filecheck_pattern = R"( +CHECK: private constant [2 x [3 x [2 x float]]] +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/false); +} + +} // namespace +} // namespace cpu +} // namespace xla -- GitLab From 34bb6643654b9a207b93d046d5fde807eb7ee499 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 03:27:18 -0700 Subject: [PATCH 123/839] Fix HloSharding::GetSubSharding to return correct array shardings Previously it always returned a tuple sharding even if the specified index was referenceing a non-tuple element. PiperOrigin-RevId: 195393313 --- tensorflow/compiler/xla/service/hlo_sharding.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 994de44123..7f7e3f7dab 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -367,10 +367,14 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape, const ShapeIndex& index) const { CHECK(IsTuple()); - ShapeTree sub_shape_tree(ShapeUtil::GetSubshape(shape, index), - Replicate()); + Shape sub_shape = ShapeUtil::GetSubshape(shape, index); + ShapeTree sub_shape_tree(sub_shape, Replicate()); sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {}); - return Tuple(sub_shape_tree); + if (ShapeUtil::IsTuple(sub_shape)) { + return Tuple(sub_shape_tree); + } else { + return sub_shape_tree.element({}); + } } std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { -- GitLab From 2d6170fc0afee7269cab7f84647f2a65b86e7020 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Fri, 4 May 2018 03:43:00 -0700 Subject: [PATCH 124/839] [XLA] Remove template keyword on non-template methods. This is an error with clang trunk. PiperOrigin-RevId: 195394277 --- tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc | 9 +++------ third_party/libxsmm.BUILD | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 6cb470caf8..464cc01214 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -67,8 +67,7 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { Literal::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 {0.36788f, 1.64872f}}); // row 1 - this->template ComputeAndCompareLiteral(&builder, *expected, {}, - ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { @@ -96,8 +95,7 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { std::unique_ptr expected = Literal::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 {-0.5f, 1.0f}}); // row 1 - this->template ComputeAndCompareLiteral(&builder, *expected, {}, - ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { @@ -116,8 +114,7 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { std::unique_ptr expected = Literal::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 {3.0f, -4.0f}}); // row 1 - this->template ComputeAndCompareLiteral(&builder, *expected, {}, - ErrorSpec(1e-6)); + this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); } struct TestLinspaceMaxParam { diff --git a/third_party/libxsmm.BUILD b/third_party/libxsmm.BUILD index 78ed1f4e16..4124f2db63 100644 --- a/third_party/libxsmm.BUILD +++ b/third_party/libxsmm.BUILD @@ -38,8 +38,8 @@ genrule( ":libxsmm_interface", ], visibility = [ - "//third_party/eigen3:__pkg__", "//tensorflow/core/kernels:__pkg__", + "//third_party/eigen3:__pkg__", ], ) -- GitLab From 3db0e545d2460be0392dfcaa304231cd2105648e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 4 May 2018 10:18:46 -0700 Subject: [PATCH 125/839] Change RecvTensor RPC implementation to use DeviceContext::CopyDeviceTensorToCPU rather than calling GPUUtil::CopyGPUTensorToCPU. The direct call into the GPU code is problematic for non-GPU devices. PiperOrigin-RevId: 195433287 --- tensorflow/core/distributed_runtime/rpc/BUILD | 1 - .../rpc/grpc_worker_service.cc | 24 +++++++------------ 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index e973a22f45..c2719f5462 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -169,7 +169,6 @@ tf_cuda_library( ":grpc_worker_service_impl", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:gpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:worker_proto_cc", diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index bbf7391377..26fad1fc3c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -23,9 +23,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#if GOOGLE_CUDA -#include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#endif // GOOGLE_CUDA #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" @@ -439,10 +436,10 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); }); env_->rendezvous_mgr->RecvLocalAsync( step_id, parsed, - [opts, response, done, src_dev](const Status& status, - const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, - const Tensor& val, const bool is_dead) { + [opts, response, done, src_dev, request]( + const Status& status, const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, const Tensor& val, + const bool is_dead) { opts->ClearCancelCallback(); if (status.ok()) { // DMA can only be used for Tensors that do not fall into @@ -455,8 +452,7 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, { // Non-DMA cases. if (src_dev->tensorflow_gpu_device_info() && (!on_host)) { -#if GOOGLE_CUDA - const DeviceContext* send_dev_context = send_args.device_context; + DeviceContext* send_dev_context = send_args.device_context; AllocatorAttributes alloc_attrs; alloc_attrs.set_gpu_compatible(true); alloc_attrs.set_on_host(true); @@ -465,7 +461,8 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, CHECK(send_dev_context) << "send dev name: " << src_dev->name() << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); - // "val" is on a GPU. Uses GPUUtil to fill the copy on host. + // "val" is on an accelerator device. Uses the device_context to + // fill the copy on host. StatusCallback copy_ready = [response, done, copy, is_dead](const Status& s) { // The value is now ready to be returned on the wire. @@ -474,11 +471,8 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, delete copy; }; - GPUUtil::CopyGPUTensorToCPU(src_dev, send_dev_context, &val, copy, - copy_ready); -#else - done(errors::Internal("No GPU device in process")); -#endif // GOOGLE_CUDA + send_dev_context->CopyDeviceTensorToCPU( + &val, request->rendezvous_key(), src_dev, copy, copy_ready); } else { grpc::EncodeTensorToByteBuffer(is_dead, val, response); done(Status::OK()); -- GitLab From a5f44b3519627859fb476a9cad1acc354bfa649f Mon Sep 17 00:00:00 2001 From: Alan Chiao Date: Fri, 4 May 2018 10:31:01 -0700 Subject: [PATCH 126/839] Implement neg op PiperOrigin-RevId: 195435079 --- tensorflow/contrib/lite/builtin_ops.h | 1 + .../lite/g3doc/tf_ops_compatibility.md | 11 ++ tensorflow/contrib/lite/kernels/BUILD | 14 ++ tensorflow/contrib/lite/kernels/neg.cc | 79 +++++++++++ tensorflow/contrib/lite/kernels/neg_test.cc | 80 +++++++++++ tensorflow/contrib/lite/kernels/register.cc | 2 + tensorflow/contrib/lite/model.cc | 1 + tensorflow/contrib/lite/nnapi_delegate.cc | 1 + tensorflow/contrib/lite/schema/schema.fbs | 5 + .../contrib/lite/schema/schema_generated.h | 124 +++++++++++++++++- tensorflow/contrib/lite/testing/BUILD | 1 + .../contrib/lite/testing/generate_examples.py | 25 ++++ .../testing/generated_examples_zip_test.cc | 1 + .../contrib/lite/toco/tflite/operator.cc | 3 +- .../contrib/lite/toco/tflite/operator_test.cc | 1 + 15 files changed, 341 insertions(+), 8 deletions(-) create mode 100644 tensorflow/contrib/lite/kernels/neg.cc create mode 100644 tensorflow/contrib/lite/kernels/neg_test.cc diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 21e0e04ef6..962a7a8970 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -84,6 +84,7 @@ typedef enum { kTfLiteBuiltinArgMax = 56, kTfLiteBuiltinMinimum = 57, kTfLiteBuiltinLess = 58, + kTfLiteBuiltinNeg = 59, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index aa28f8d050..0051ee84ec 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -397,6 +397,17 @@ Options { } ``` +**NEG** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: elementwise negation of the input tensor +} +``` + **PAD** ``` diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 57b3136cce..feab18b5c2 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -158,6 +158,7 @@ cc_library( "mean.cc", "mfcc.cc", "mul.cc", + "neg.cc", "pad.cc", "pooling.cc", "register.cc", @@ -856,6 +857,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "neg_test", + size = "small", + srcs = ["neg_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/neg.cc b/tensorflow/contrib/lite/kernels/neg.cc new file mode 100644 index 0000000000..692da81727 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/neg.cc @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace neg { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + output->type = input->type; + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +template +void Negate(const T* in_data, int num_elements, T* out_data) { + // TODO(alanchiao): add vectorized version. + for (int i = 0; i < num_elements; ++i) { + out_data[i] = -in_data[i]; + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const int num_elements = NumElements(input); + switch (input->type) { + case kTfLiteInt64: + Negate(input->data.i64, num_elements, output->data.i64); + break; + case kTfLiteInt32: + Negate(input->data.i32, num_elements, output->data.i32); + break; + case kTfLiteFloat32: + Negate(input->data.f, num_elements, output->data.f); + break; + default: + context->ReportError( + context, "Neg only currently supports int64, int32, and float32.", + input->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace neg + +TfLiteRegistration* Register_NEG() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + neg::Prepare, neg::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/neg_test.cc b/tensorflow/contrib/lite/kernels/neg_test.cc new file mode 100644 index 0000000000..3c95ac8cc2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/neg_test.cc @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class NegOpModel : public SingleOpModel { + public: + NegOpModel(const TensorData& input, const TensorData& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_NEG, BuiltinOptions_NegOptions, + CreateNegOptions(builder_).Union()); + BuildInterpreter({GetShape(input_)}); + } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + protected: + int input_; + int output_; +}; + +TEST(NegOpModel, NegFloat) { + NegOpModel m({TensorType_FLOAT32, {2, 3}}, {TensorType_FLOAT32, {2, 3}}); + m.SetInput({-2.0f, -1.0f, 0.f, 1.0f, 2.0f, 3.0f}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({2.0f, 1.0f, 0.f, -1.0f, -2.0f, -3.0f})); +} + +TEST(NegOpModel, NegInt32) { + NegOpModel m({TensorType_INT32, {2, 3}}, {TensorType_INT32, {2, 3}}); + m.SetInput({-2, -1, 0, 1, 2, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1, 0, -1, -2, -3})); +} + +TEST(NegOpModel, NegInt64) { + NegOpModel m({TensorType_INT64, {2, 3}}, {TensorType_INT64, {2, 3}}); + m.SetInput({-2, -1, 0, 1, 2, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1, 0, -1, -2, -3})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index f91d188ffa..29ea718a96 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -81,6 +81,7 @@ TfLiteRegistration* Register_MINIMUM(); TfLiteRegistration* Register_ARG_MAX(); TfLiteRegistration* Register_LESS(); TfLiteRegistration* Register_FLOOR(); +TfLiteRegistration* Register_NEG(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -143,6 +144,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX()); AddBuiltin(BuiltinOperator_LESS, Register_LESS()); AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR()); + AddBuiltin(BuiltinOperator_NEG, Register_NEG()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index e15f1be7d3..590f042e21 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -351,6 +351,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_DEQUANTIZE: case BuiltinOperator_PRELU: case BuiltinOperator_FLOOR: + case BuiltinOperator_NEG: break; case BuiltinOperator_CAST: { TfLiteCastParams* params = MallocPOD(); diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index e1895dd38e..6eac18c4f5 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -372,6 +372,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_MINIMUM: case tflite::BuiltinOperator_ARG_MAX: case tflite::BuiltinOperator_LESS: + case tflite::BuiltinOperator_NEG: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index b16baf02dc..265b1dd3fe 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -136,6 +136,7 @@ enum BuiltinOperator : byte { ARG_MAX = 56, MINIMUM = 57, LESS = 58, + NEG = 59, } // Options for the builtin operators. @@ -181,6 +182,7 @@ union BuiltinOptions { MaximumMinimumOptions, ArgMaxOptions, LessOptions, + NegOptions, } enum Padding : byte { SAME, VALID } @@ -406,6 +408,9 @@ table ArgMaxOptions { table LessOptions { } +table NegOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 57af973460..c172f77aa9 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -154,6 +154,9 @@ struct ArgMaxOptionsT; struct LessOptions; struct LessOptionsT; +struct NegOptions; +struct NegOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -272,11 +275,12 @@ enum BuiltinOperator { BuiltinOperator_ARG_MAX = 56, BuiltinOperator_MINIMUM = 57, BuiltinOperator_LESS = 58, + BuiltinOperator_NEG = 59, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_LESS + BuiltinOperator_MAX = BuiltinOperator_NEG }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[58] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[59] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -335,7 +339,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[58] { BuiltinOperator_MAXIMUM, BuiltinOperator_ARG_MAX, BuiltinOperator_MINIMUM, - BuiltinOperator_LESS + BuiltinOperator_LESS, + BuiltinOperator_NEG }; return values; } @@ -401,6 +406,7 @@ inline const char **EnumNamesBuiltinOperator() { "ARG_MAX", "MINIMUM", "LESS", + "NEG", nullptr }; return names; @@ -454,11 +460,12 @@ enum BuiltinOptions { BuiltinOptions_MaximumMinimumOptions = 39, BuiltinOptions_ArgMaxOptions = 40, BuiltinOptions_LessOptions = 41, + BuiltinOptions_NegOptions = 42, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_LessOptions + BuiltinOptions_MAX = BuiltinOptions_NegOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[42] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[43] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -501,7 +508,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[42] { BuiltinOptions_DequantizeOptions, BuiltinOptions_MaximumMinimumOptions, BuiltinOptions_ArgMaxOptions, - BuiltinOptions_LessOptions + BuiltinOptions_LessOptions, + BuiltinOptions_NegOptions }; return values; } @@ -550,6 +558,7 @@ inline const char **EnumNamesBuiltinOptions() { "MaximumMinimumOptions", "ArgMaxOptions", "LessOptions", + "NegOptions", nullptr }; return names; @@ -728,6 +737,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LessOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NegOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1087,6 +1100,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_LessOptions ? reinterpret_cast(value) : nullptr; } + NegOptionsT *AsNegOptions() { + return type == BuiltinOptions_NegOptions ? + reinterpret_cast(value) : nullptr; + } + const NegOptionsT *AsNegOptions() const { + return type == BuiltinOptions_NegOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -4014,6 +4035,46 @@ inline flatbuffers::Offset CreateLessOptions( flatbuffers::Offset CreateLessOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct NegOptionsT : public flatbuffers::NativeTable { + typedef NegOptions TableType; + NegOptionsT() { + } +}; + +struct NegOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef NegOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + NegOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(NegOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const NegOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct NegOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit NegOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + NegOptionsBuilder &operator=(const NegOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateNegOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + NegOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateNegOptions(flatbuffers::FlatBufferBuilder &_fbb, const NegOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -4254,6 +4315,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const LessOptions *builtin_options_as_LessOptions() const { return builtin_options_type() == BuiltinOptions_LessOptions ? static_cast(builtin_options()) : nullptr; } + const NegOptions *builtin_options_as_NegOptions() const { + return builtin_options_type() == BuiltinOptions_NegOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -4444,6 +4508,10 @@ template<> inline const LessOptions *Operator::builtin_options_as() return builtin_options_as_LessOptions(); } +template<> inline const NegOptions *Operator::builtin_options_as() const { + return builtin_options_as_NegOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -6070,6 +6138,29 @@ inline flatbuffers::Offset CreateLessOptions(flatbuffers::FlatBuffe _fbb); } +inline NegOptionsT *NegOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new NegOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void NegOptions::UnPackTo(NegOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset NegOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const NegOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateNegOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateNegOptions(flatbuffers::FlatBufferBuilder &_fbb, const NegOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const NegOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateNegOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -6417,6 +6508,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_NegOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -6599,6 +6694,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_NegOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -6769,6 +6868,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateLessOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_NegOptions: { + auto ptr = reinterpret_cast(value); + return CreateNegOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -6939,6 +7042,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new LessOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_NegOptions: { + value = new NegOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -7151,6 +7258,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_NegOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index a1162cef38..211de63d58 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -43,6 +43,7 @@ gen_zipped_test_files( "mean.zip", "minimum.zip", "mul.zip", + "neg.zip", "pad.zip", "relu.zip", "relu1.zip", diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 2f8f7a1a79..7e892769bf 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -2061,6 +2061,31 @@ def make_floor_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_neg_tests(zip_path): + """Make a set of tests to do neg.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32], + "input_shape": [[1, 3, 4, 3], [5]], + }] + + def build_graph(parameters): + """Build the neg op testing graph.""" + input_tensor = tf.placeholder( + dtype=parameters["input_dtype"], + name="input", + shape=parameters["input_shape"]) + out = tf.negative(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + values = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [values], sess.run(outputs, feed_dict=dict(zip(inputs, [values]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + # Toco binary path provided by the generate rule. bin_path = None diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 34abb213c9..0673a3bb46 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -266,6 +266,7 @@ INSTANTIATE_TESTS(maximum) INSTANTIATE_TESTS(mean) INSTANTIATE_TESTS(minimum) INSTANTIATE_TESTS(mul) +INSTANTIATE_TESTS(neg) INSTANTIATE_TESTS(pad) // INSTANTIATE_TESTS(prelu) INSTANTIATE_TESTS(relu) diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index d2e14ac5e0..e18ae805c0 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -872,7 +872,6 @@ std::vector> BuildOperatorList() { // attributes. ops.emplace_back( new SimpleOperator("ADDN", OperatorType::kAddN)); - ops.emplace_back(new SimpleOperator("NEG", OperatorType::kNeg)); ops.emplace_back(new SimpleOperator( "RSQRT", OperatorType::kTensorFlowRsqrt)); // Simple Operators. @@ -901,7 +900,7 @@ std::vector> BuildOperatorList() { "MINIMUM", OperatorType::kTensorFlowMinimum)); ops.emplace_back(new SimpleOperator( "LESS", OperatorType::kTensorFlowLess)); - + ops.emplace_back(new SimpleOperator("NEG", OperatorType::kNeg)); return ops; } } // namespace diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 36ed741541..2b6c32b07c 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -115,6 +115,7 @@ TEST_F(OperatorTest, SimpleOperators) { "MINIMUM", OperatorType::kTensorFlowMinimum); CheckSimpleOperator("LESS", OperatorType::kTensorFlowLess); + CheckSimpleOperator("NEG", OperatorType::kNeg); } TEST_F(OperatorTest, BuiltinAdd) { -- GitLab From 47f1bd90658dd6858fb4bbefd4ef8acbef4ca931 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Fri, 4 May 2018 10:37:42 -0700 Subject: [PATCH 127/839] TFTS: Make it easier to swap in different autoregressive models. Adds a very simple LSTM encoder/decoder option as an example. ARModel's new constructor argument is a bit awkward, since Estimator's new graphs mean we need a Model factory rather than a Model (or to un-build the model?). It's still a much more pleasant way to write autoregressive models than fiddling with ARModel directly, since ARModel handles collecting all the features (and the prediction loop, etc.). Happy to hear other ideas for an API. PiperOrigin-RevId: 195436186 --- .../timeseries/python/timeseries/ar_model.py | 284 +++++++++++++++--- .../python/timeseries/ar_model_test.py | 86 ++++-- .../python/timeseries/estimators.py | 17 +- .../python/timeseries/estimators_test.py | 17 +- 4 files changed, 319 insertions(+), 85 deletions(-) diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py index 558d9480b4..ce96180c92 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib import distributions +from tensorflow.contrib.rnn.python.ops import lstm_ops from tensorflow.contrib.timeseries.python.timeseries import model from tensorflow.contrib.timeseries.python.timeseries import model_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures @@ -29,6 +30,9 @@ from tensorflow.python.estimator import estimator_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.keras._impl.keras.engine import sequential +from tensorflow.python.keras._impl.keras.engine import training +from tensorflow.python.keras._impl.keras.layers import core from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -40,12 +44,150 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope +class FlatPredictionModel(training.Model): + """Flattens input and output windows and puts them through dense layers. + + This model does not operate on its own, but rather is a plugin to + `ARModel`. See `ARModel`'s constructor documentation + (`prediction_model_factory`) for a usage example. + """ + + def __init__(self, + num_features, + input_window_size, + output_window_size, + hidden_layer_sizes=None): + """Construct the flat prediction model. + + Args: + num_features: number of input features per time step. + input_window_size: Number of past time steps of data to look at when doing + the regression. + output_window_size: Number of future time steps to predict. Note that + setting it to > 1 empirically seems to give a better fit. + hidden_layer_sizes: list of sizes of hidden layers. + """ + super(FlatPredictionModel, self).__init__() + self._input_flatten = core.Flatten() + self._output_flatten = core.Flatten() + if hidden_layer_sizes: + self._hidden_layers = sequential.Sequential([ + core.Dense(layer_size, activation=nn_ops.relu) + for layer_size in hidden_layer_sizes]) + else: + self._hidden_layers = None + self._mean_transform = core.Dense(num_features * output_window_size, + name="predicted_mean") + self._covariance_transform = core.Dense(num_features * output_window_size, + name="log_sigma_square") + self._prediction_shape = [-1, output_window_size, num_features] + + def call(self, input_window_features, output_window_features): + """Compute predictions from input and output windows. + + Args: + input_window_features: A floating point Tensor with shape [batch size, + input window size, input features]. The batch dimension may not have + static shape information, but the window size and number of input + features are known at graph construction time and recorded in the static + shape information for the `input_window_features` `Tensor`. Note that + `input_window_size` may be zero. + output_window_features: A floating point Tensor with shape [batch size, + output window size, output features]. As with `input_window_features`, + the last two dimensions have static shape information. If there are no + output features, the size of the last dimension will be zero. + Returns: + A dictionary of predictions with keys "mean" and "covariance" (only + diagonal covariances are currently supported). Each has shape + [batch size, output window size, num_features], where num_features is the + same as the constructor argument. + """ + if input_window_features.shape[1].value == 0: + # TODO(allenl): Make reshape()'s static shape information work on + # zero-size Tensors? Currently this special case is required because + # otherwise the Dense layers get unknown last dimensions. + activation = self._output_flatten(output_window_features) + elif output_window_features.shape[2].value == 0: + activation = self._input_flatten(input_window_features) + else: + activation = array_ops.concat( + [self._input_flatten(input_window_features), + self._output_flatten(output_window_features)], + axis=1) + if self._hidden_layers: + activation = self._hidden_layers(activation) + predicted_mean = array_ops.reshape( + self._mean_transform(activation), + self._prediction_shape) + predicted_covariance = array_ops.reshape( + gen_math_ops.exp(self._covariance_transform(activation)), + self._prediction_shape) + return {"mean": predicted_mean, + "covariance": predicted_covariance} + + +class LSTMPredictionModel(training.Model): + """A simple encoder/decoder model using an LSTM. + + This model does not operate on its own, but rather is a plugin to + `ARModel`. See `ARModel`'s constructor documentation + (`prediction_model_factory`) for a usage example. + """ + + def __init__(self, + num_features, + input_window_size, + output_window_size, + num_units=128): + """Construct the LSTM prediction model. + + Args: + num_features: number of input features per time step. + input_window_size: Number of past time steps of data to look at when doing + the regression. + output_window_size: Number of future time steps to predict. Note that + setting it to > 1 empirically seems to give a better fit. + num_units: The number of units in the encoder and decoder LSTM cells. + """ + super(LSTMPredictionModel, self).__init__() + self._encoder = lstm_ops.LSTMBlockFusedCell( + num_units=num_units, name="encoder") + self._decoder = lstm_ops.LSTMBlockFusedCell( + num_units=num_units, name="decoder") + self._mean_transform = core.Dense(num_features, + name="mean_transform") + self._covariance_transform = core.Dense(num_features, + name="covariance_transform") + + def call(self, input_window_features, output_window_features): + """Compute predictions from input and output windows.""" + # Convert to time major + input_window_features = array_ops.transpose(input_window_features, + [1, 0, 2]) + output_window_features = array_ops.transpose(output_window_features, + [1, 0, 2]) + _, encoder_state = self._encoder( + input_window_features, dtype=self.dtype) + decoder_output, _ = self._decoder( + output_window_features, dtype=self.dtype, + initial_state=encoder_state) + + # Switch back to batch major + decoder_output = array_ops.transpose(decoder_output, [1, 0, 2]) + predicted_mean = self._mean_transform(decoder_output) + predicted_covariance = gen_math_ops.exp( + self._covariance_transform(decoder_output)) + return {"mean": predicted_mean, + "covariance": predicted_covariance} + + class ARModel(model.TimeSeriesModel): """Auto-regressive model, both linear and non-linear. Features to the model include time and values of input_window_size timesteps, - and times for output_window_size timesteps. These are passed through zero or - more hidden layers, and then fed to a loss function (e.g. squared loss). + and times for output_window_size timesteps. These are passed through a + configurable prediction model, and then fed to a loss function (e.g. squared + loss). Note that this class can also be used to regress against time only by setting the input_window_size to zero. @@ -58,9 +200,9 @@ class ARModel(model.TimeSeriesModel): input_window_size, output_window_size, num_features, + prediction_model_factory=FlatPredictionModel, num_time_buckets=10, loss=NORMAL_LIKELIHOOD_LOSS, - hidden_layer_sizes=None, exogenous_feature_columns=None): """Constructs an auto-regressive model. @@ -73,6 +215,22 @@ class ARModel(model.TimeSeriesModel): output_window_size: Number of future time steps to predict. Note that setting it to > 1 empirically seems to give a better fit. num_features: number of input features per time step. + prediction_model_factory: A callable taking arguments `num_features`, + `input_window_size`, and `output_window_size` and returning a + `tf.keras.Model`. The `Model`'s `call()` takes two arguments: an input + window and an output window, and returns a dictionary of + predictions. See `FlatPredictionModel` for an example. Example usage: + + ```python + model = ar_model.ARModel( + periodicities=2, num_features=3, + prediction_model_factory=functools.partial( + FlatPredictionModel, + hidden_layer_sizes=[10, 10])) + ``` + + The default model computes predictions as a linear function of flattened + input and output windows. num_time_buckets: Number of buckets into which to divide (time % periodicity) for generating time based features. loss: Loss function to use for training. Currently supported values are @@ -81,18 +239,15 @@ class ARModel(model.TimeSeriesModel): SQUARED_LOSS, the evaluation loss is reported based on un-scaled observations and predictions, while the training loss is computed on normalized data (if input statistics are available). - hidden_layer_sizes: list of sizes of hidden layers. exogenous_feature_columns: A list of `tf.feature_column`s (for example `tf.feature_column.embedding_column`) corresponding to exogenous features which provide extra information to the model but are not part of the series to be predicted. Passed to `tf.feature_column.input_layer`. """ + self._model_factory = prediction_model_factory self.input_window_size = input_window_size self.output_window_size = output_window_size - if hidden_layer_sizes is None: - hidden_layer_sizes = [] - self.hidden_layer_sizes = hidden_layer_sizes self.window_size = self.input_window_size + self.output_window_size self.loss = loss super(ARModel, self).__init__( @@ -115,6 +270,19 @@ class ARModel(model.TimeSeriesModel): assert len(self._periods) or self.input_window_size assert output_window_size > 0 + def initialize_graph(self, input_statistics=None): + super(ARModel, self).initialize_graph(input_statistics=input_statistics) + self._model_scope = variable_scope.variable_scope( + # The trailing slash means we strip all enclosing variable_scopes, which + # unfortunately is necessary because the model gets called inside and + # outside a "while" scope (for prediction and training respectively), + # and the variables names need to match. + "model/", use_resource=True) + self._model_instance = self._model_factory( + num_features=self.num_features, + input_window_size=self.input_window_size, + output_window_size=self.output_window_size) + def get_start_state(self): # State which matches the format we'll return later. Typically this will not # be used by the model directly, but the shapes and dtypes should match so @@ -166,17 +334,6 @@ class ARModel(model.TimeSeriesModel): return array_ops.reshape(predicted_mean, [-1, self.output_window_size, self.num_features]) - def _create_hidden_stack(self, activation, activation_size): - activations = [] - for layer_number, layer_size in enumerate(self.hidden_layer_sizes): - # TODO(agarwal): Migrate to fully_connected in tf slim - activation = model_utils.fully_connected( - activation, activation_size, layer_size, - name="layer_{}".format(layer_number)) - activation_size = layer_size - activations.append((activation, activation_size)) - return activations - def prediction_ops(self, times, values, exogenous_regressors): """Compute model predictions given input data. @@ -195,7 +352,7 @@ class ARModel(model.TimeSeriesModel): self.num_features]. """ times.get_shape().assert_is_compatible_with([None, self.window_size]) - activations = [] + batch_size = array_ops.shape(times)[0] if self.input_window_size: values.get_shape().assert_is_compatible_with( [None, self.input_window_size, self.num_features]) @@ -203,39 +360,66 @@ class ARModel(model.TimeSeriesModel): exogenous_regressors.get_shape().assert_is_compatible_with( [None, self.window_size, self.exogenous_size]) # Create input features. - activation_components = [] + input_window_features = [] + input_feature_size = 0 + output_window_features = [] + output_feature_size = 0 if self._periods: _, time_features = self._compute_time_features(times) - activation_size = self.window_size * self._buckets * len(self._periods) - activation_components.append( - array_ops.reshape(time_features, [-1, activation_size])) - else: - activation_size = 0 + num_time_features = self._buckets * len(self._periods) + time_features = array_ops.reshape( + time_features, + [batch_size, + self.window_size, + num_time_features]) + input_time_features, output_time_features = array_ops.split( + time_features, (self.input_window_size, self.output_window_size), + axis=1) + input_feature_size += num_time_features + output_feature_size += num_time_features + input_window_features.append(input_time_features) + output_window_features.append(output_time_features) if self.input_window_size: inp = array_ops.slice(values, [0, 0, 0], [-1, self.input_window_size, -1]) - inp_size = self.input_window_size * self.num_features - inp = array_ops.reshape(inp, [-1, inp_size]) - activation_components.append(inp) - activation_size += inp_size + input_window_features.append( + array_ops.reshape( + inp, + [batch_size, self.input_window_size, self.num_features])) + input_feature_size += self.num_features if self.exogenous_size: - exogenous_size = self.window_size * self.exogenous_size - activation_size += exogenous_size - exogenous_flattened = array_ops.reshape( - exogenous_regressors, [-1, exogenous_size]) - activation_components.append(exogenous_flattened) - assert activation_size - assert activation_components - activation = array_ops.concat(activation_components, axis=1) - activations.append((activation, activation_size)) - # Create hidden layers. - activations += self._create_hidden_stack(activation, activation_size) - # Create mean and convariance ops. - predicted_mean = self._predicted_mean_op(activations) - predicted_covariance = self._predicted_covariance_op(activations, - self.num_features) - return {"activations": activations, - "mean": predicted_mean, - "covariance": predicted_covariance} + input_exogenous_features, output_exogenous_features = array_ops.split( + exogenous_regressors, + (self.input_window_size, self.output_window_size), + axis=1) + input_feature_size += self.exogenous_size + output_feature_size += self.exogenous_size + input_window_features.append(input_exogenous_features) + output_window_features.append(output_exogenous_features) + assert input_window_features + input_window_features = array_ops.concat(input_window_features, axis=2) + if output_window_features: + output_window_features = array_ops.concat(output_window_features, axis=2) + else: + output_window_features = array_ops.zeros( + [batch_size, self.output_window_size, 0], + dtype=self.dtype) + static_batch_size = times.get_shape()[0].value + input_window_features.set_shape( + [static_batch_size, self.input_window_size, input_feature_size]) + output_window_features.set_shape( + [static_batch_size, self.output_window_size, output_feature_size]) + return self._output_window_predictions(input_window_features, + output_window_features) + + def _output_window_predictions( + self, input_window_features, output_window_features): + with self._model_scope: + predictions = self._model_instance( + input_window_features, output_window_features) + result_shape = [None, self.output_window_size, self.num_features] + for v in predictions.values(): + v.set_shape(result_shape) + return predictions def loss_op(self, targets, prediction_ops): """Create loss_op.""" @@ -286,6 +470,8 @@ class ARModel(model.TimeSeriesModel): values are Tensors of shape [batch_size, predict window size, num_features] and correspond to the values passed in `TIMES`. """ + if not self._graph_initialized: + self.initialize_graph() predict_times = math_ops.cast( ops.convert_to_tensor(features[PredictionFeatures.TIMES]), dtypes.int32) exogenous_regressors = self._process_exogenous_features( @@ -701,9 +887,9 @@ class AnomalyMixtureARModel(ARModel): input_window_size, output_window_size, num_features, + prediction_model_factory=FlatPredictionModel, anomaly_distribution=GAUSSIAN_ANOMALY, num_time_buckets=10, - hidden_layer_sizes=None, exogenous_feature_columns=None): assert (anomaly_prior_probability < 1.0 and anomaly_prior_probability > 0.0) @@ -719,7 +905,7 @@ class AnomalyMixtureARModel(ARModel): input_window_size=input_window_size, output_window_size=output_window_size, loss=ARModel.NORMAL_LIKELIHOOD_LOSS, - hidden_layer_sizes=hidden_layer_sizes, + prediction_model_factory=prediction_model_factory, exogenous_feature_columns=exogenous_feature_columns) def _create_anomaly_ops(self, times, values, prediction_ops_dict): diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py index d078ac8d46..63f5d3568b 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py @@ -18,12 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + import numpy as np +from tensorflow.contrib.timeseries.python.timeseries import ar_model from tensorflow.contrib.timeseries.python.timeseries import input_pipeline from tensorflow.contrib.timeseries.python.timeseries import test_utils -from tensorflow.contrib.timeseries.python.timeseries.ar_model import AnomalyMixtureARModel -from tensorflow.contrib.timeseries.python.timeseries.ar_model import ARModel from tensorflow.contrib.timeseries.python.timeseries.estimators import ARRegressor from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures @@ -91,7 +92,7 @@ class ARModelTest(test.TestCase): np.random.seed(3) data_noise_stddev = 0.2 if max_loss is None: - if loss == ARModel.NORMAL_LIKELIHOOD_LOSS: + if loss == ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS: max_loss = 1.0 else: max_loss = 0.05 / (data_noise_stddev ** 2) @@ -137,7 +138,7 @@ class ARModelTest(test.TestCase): test_loss = test_evaluation["loss"] logging.info("Final test loss: %f", test_loss) self.assertLess(test_loss, max_loss) - if loss == ARModel.SQUARED_LOSS: + if loss == ar_model.ARModel.SQUARED_LOSS: # Test that the evaluation loss is reported without input scaling. self.assertAllClose( test_loss, @@ -169,7 +170,7 @@ class ARModelTest(test.TestCase): predicted_mean = predictions["mean"][:, 0] true_values = predict_true_values[0, :, 0] - if loss == ARModel.NORMAL_LIKELIHOOD_LOSS: + if loss == ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS: variances = predictions["covariance"][:, 0] standard_deviations = np.sqrt(variances) # Note that we may get tighter bounds with more training steps. @@ -180,26 +181,26 @@ class ARModelTest(test.TestCase): def test_time_regression_squared(self): self.train_helper(input_window_size=0, train_steps=350, - loss=ARModel.SQUARED_LOSS) + loss=ar_model.ARModel.SQUARED_LOSS) def test_autoregression_squared(self): self.train_helper(input_window_size=15, - loss=ARModel.SQUARED_LOSS) + loss=ar_model.ARModel.SQUARED_LOSS) def test_autoregression_short_input_window(self): self.train_helper(input_window_size=8, - loss=ARModel.SQUARED_LOSS) + loss=ar_model.ARModel.SQUARED_LOSS) def test_autoregression_normal(self): self.train_helper(input_window_size=10, - loss=ARModel.NORMAL_LIKELIHOOD_LOSS, + loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS, train_steps=300, max_loss=1.5, anomaly_distribution=None) def test_autoregression_normal_multiple_periods(self): self.train_helper(input_window_size=10, - loss=ARModel.NORMAL_LIKELIHOOD_LOSS, + loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS, max_loss=2.0, multiple_periods=True, anomaly_distribution=None) @@ -207,15 +208,15 @@ class ARModelTest(test.TestCase): def test_autoregression_normal_anomalies_normal(self): self.train_helper( input_window_size=10, - loss=ARModel.NORMAL_LIKELIHOOD_LOSS, - anomaly_distribution=AnomalyMixtureARModel.GAUSSIAN_ANOMALY) + loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS, + anomaly_distribution=ar_model.AnomalyMixtureARModel.GAUSSIAN_ANOMALY) def test_autoregression_normal_anomalies_cauchy(self): self.train_helper( input_window_size=10, max_loss=1.5, - loss=ARModel.NORMAL_LIKELIHOOD_LOSS, - anomaly_distribution=AnomalyMixtureARModel.CAUCHY_ANOMALY) + loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS, + anomaly_distribution=ar_model.AnomalyMixtureARModel.CAUCHY_ANOMALY) def test_wrong_window_size(self): estimator = ARRegressor( @@ -237,15 +238,38 @@ class ARModelTest(test.TestCase): with self.assertRaisesRegexp(ValueError, "requires a window of at least"): estimator.evaluate(input_fn=_bad_window_size_input_fn, steps=1) - def test_predictions_direct(self): + def test_predictions_direct_flat(self): + g = ops.Graph() + with g.as_default(): + model = ar_model.ARModel(periodicities=2, + num_features=1, + num_time_buckets=10, + input_window_size=2, + output_window_size=2, + prediction_model_factory=functools.partial( + ar_model.FlatPredictionModel, + hidden_layer_sizes=[40, 10])) + with session.Session(): + predicted_values = model.predict({ + PredictionFeatures.TIMES: [[4, 6, 10]], + PredictionFeatures.STATE_TUPLE: ( + [[1, 2]], [[[1.], [2.]]], [[[], []]]) + }) + variables.global_variables_initializer().run() + self.assertAllEqual(predicted_values["mean"].eval().shape, + [1, 3, 1]) + + def test_predictions_direct_lstm(self): g = ops.Graph() with g.as_default(): - model = ARModel(periodicities=2, - num_features=1, - num_time_buckets=10, - input_window_size=2, - output_window_size=2, - hidden_layer_sizes=[40, 10]) + model = ar_model.ARModel(periodicities=2, + num_features=1, + num_time_buckets=10, + input_window_size=2, + output_window_size=2, + prediction_model_factory=functools.partial( + ar_model.LSTMPredictionModel, + num_units=16)) with session.Session(): predicted_values = model.predict({ PredictionFeatures.TIMES: [[4, 6, 10]], @@ -259,11 +283,11 @@ class ARModelTest(test.TestCase): def test_long_eval(self): g = ops.Graph() with g.as_default(): - model = ARModel(periodicities=2, - num_features=1, - num_time_buckets=10, - input_window_size=2, - output_window_size=1) + model = ar_model.ARModel(periodicities=2, + num_features=1, + num_time_buckets=10, + input_window_size=2, + output_window_size=1) raw_features = { TrainEvalFeatures.TIMES: [[1, 3, 5, 7, 11]], TrainEvalFeatures.VALUES: [[[1.], [2.], [3.], [4.], [5.]]]} @@ -309,11 +333,11 @@ class ARModelTest(test.TestCase): def test_long_eval_discard_indivisible(self): g = ops.Graph() with g.as_default(): - model = ARModel(periodicities=2, - num_features=1, - num_time_buckets=10, - input_window_size=2, - output_window_size=2) + model = ar_model.ARModel(periodicities=2, + num_features=1, + num_time_buckets=10, + input_window_size=2, + output_window_size=2) raw_features = { TrainEvalFeatures.TIMES: [[1, 3, 5, 7, 11]], TrainEvalFeatures.VALUES: [[[1.], [2.], [3.], [4.], [5.]]]} diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index f4608ca2d1..4ec8d26116 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.contrib.timeseries.python.timeseries import ar_model from tensorflow.contrib.timeseries.python.timeseries import feature_keys from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib @@ -61,7 +63,10 @@ class TimeSeriesRegressor(estimator_lib.Estimator): input_statistics_generator = math_utils.InputStatisticsFromMiniBatch( dtype=model.dtype, num_features=model.num_features) if state_manager is None: - state_manager = state_management.PassthroughStateManager() + if isinstance(model, ar_model.ARModel): + state_manager = state_management.FilteringOnlyStateManager() + else: + state_manager = state_management.PassthroughStateManager() if optimizer is None: optimizer = train.AdamOptimizer(0.02) self._model = model @@ -246,11 +251,13 @@ class ARRegressor(TimeSeriesRegressor): anomaly_distribution = ar_model.AnomalyMixtureARModel.GAUSSIAN_ANOMALY model = ar_model.ARModel( periodicities=periodicities, num_features=num_features, + prediction_model_factory=functools.partial( + ar_model.FlatPredictionModel, + hidden_layer_sizes=hidden_layer_sizes), exogenous_feature_columns=exogenous_feature_columns, num_time_buckets=num_time_buckets, input_window_size=input_window_size, - output_window_size=output_window_size, loss=loss, - hidden_layer_sizes=hidden_layer_sizes) + output_window_size=output_window_size, loss=loss) else: if loss != ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS: raise ValueError( @@ -261,9 +268,11 @@ class ARRegressor(TimeSeriesRegressor): input_window_size=input_window_size, output_window_size=output_window_size, num_features=num_features, + prediction_model_factory=functools.partial( + ar_model.FlatPredictionModel, + hidden_layer_sizes=hidden_layer_sizes), exogenous_feature_columns=exogenous_feature_columns, num_time_buckets=num_time_buckets, - hidden_layer_sizes=hidden_layer_sizes, anomaly_prior_probability=anomaly_prior_probability, anomaly_distribution=anomaly_distribution) state_manager = state_management.FilteringOnlyStateManager() diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py index eebee053f8..706742ca28 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import tempfile import numpy @@ -178,7 +179,7 @@ class TimeSeriesRegressorTest(test.TestCase): session=sess) self.assertAllEqual([10, 15, 1], predictions["mean"].shape) - def test_fit_restore_fit_ar_regressor(self): + def test_fit_restore_fit_ar_flat(self): def _estimator_fn(model_dir, exogenous_feature_columns): return estimators.ARRegressor( periodicities=10, input_window_size=10, output_window_size=6, @@ -189,6 +190,20 @@ class TimeSeriesRegressorTest(test.TestCase): exogenous_feature_columns=exogenous_feature_columns) self._fit_restore_fit_test_template(_estimator_fn, dtype=dtypes.float32) + def test_fit_restore_fit_ar_lstm(self): + def _estimator_fn(model_dir, exogenous_feature_columns): + return estimators.TimeSeriesRegressor( + model=ar_model.ARModel( + periodicities=10, input_window_size=10, output_window_size=6, + num_features=1, + exogenous_feature_columns=exogenous_feature_columns, + prediction_model_factory=functools.partial( + ar_model.LSTMPredictionModel, + num_units=10)), + config=_SeedRunConfig(), + model_dir=model_dir) + self._fit_restore_fit_test_template(_estimator_fn, dtype=dtypes.float32) + def test_fit_restore_fit_structural_ensemble_regressor(self): dtype = dtypes.float32 def _estimator_fn(model_dir, exogenous_feature_columns): -- GitLab From e32c42a6deed1f8ed1dcdeaaba0acf74685c18e3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 10:47:38 -0700 Subject: [PATCH 128/839] Improve broadcast add implementation. PiperOrigin-RevId: 195437679 --- .../internal/optimized/optimized_ops.h | 87 ++++++++++++++++++- .../internal/reference/reference_ops.h | 81 +++++++++++++++++ 2 files changed, 165 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 3d6042c31f..4776726972 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -2593,7 +2593,7 @@ inline void Add(int left_shift, const uint8* input1_data, } #endif // NEON - for (; i < size; i++) { + for (; i < size; ++i) { const int32 input1_val = input1_offset + input1_data[i]; const int32 input2_val = input2_offset + input2_data[i]; const int32 shifted_input1_val = input1_val * (1 << left_shift); @@ -2750,7 +2750,7 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, int32 output_activation_min, int32 output_activation_max, uint8* output_data, const Dims<4>& output_dims) { - gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit"); + gemmlowp::ScopedProfilingLabel label("BroadcastAddGeneric/8bit"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; @@ -2799,6 +2799,60 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, } } +inline void BroadcastAddFivefold( + int y0, int y1, int y2, int y3, int y4, int left_shift, + const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAddFivefold/8bit"); + + // Fivefold nested loops. The second input resets its position for each + // iteration of the second loop. The first input resets its position at the + // beginning of the fourth loop. The innermost loop is an elementwise add of + // sections of the arrays. + uint8* output_data_ptr = output_data; + const uint8* input1_data_ptr = input1_data; + const uint8* input2_data_reset = input2_data; + for (int i4 = 0; i4 < y4; ++i4) { + const uint8* input2_data_ptr; + for (int i3 = 0; i3 < y3; ++i3) { + input2_data_ptr = input2_data_reset; + for (int i2 = 0; i2 < y2; ++i2) { + for (int i1 = 0; i1 < y1; ++i1) { + for (int i0 = 0; i0 < y0; ++i0) { + const int32 input1_val = input1_offset + input1_data_ptr[i0]; + const int32 input2_val = input2_offset + input2_data_ptr[i0]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, raw_output)); + output_data_ptr[i0] = static_cast(clamped_output); + } + input2_data_ptr += y0; + output_data_ptr += y0; + } + input1_data_ptr += y0; + } + } + input2_data_reset = input2_data_ptr; + } +} + template inline void BroadcastAdd(int left_shift, const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, @@ -2827,6 +2881,33 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, output_activation_max, output_data, output_dims); } +template +inline void BroadcastAddFivefold( + int y0, int y1, int y2, int y3, int y4, int left_shift, + const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + BroadcastAddFivefold(y0, y1, y2, y3, y4, left_shift, input1_data, input1_dims, + input1_offset, input1_multiplier, input1_shift, + input2_data, input2_dims, input2_offset, + input2_multiplier, input2_shift, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims); +} + inline void Mul(const float* input1_data, const Dims<4>& input1_dims, const float* input2_data, const Dims<4>& input2_dims, float output_activation_min, float output_activation_max, @@ -4375,7 +4456,7 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPointAccum = gemmlowp::FixedPoint; using FixedPoint0 = gemmlowp::FixedPoint; -gemmlowp::ScopedProfilingLabel label("Softmax/8bit"); + gemmlowp::ScopedProfilingLabel label("Softmax/8bit"); const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); const int height = MatchingArraySize(input_dims, 2, output_dims, 2); const int width = MatchingArraySize(input_dims, 1, output_dims, 1); diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index d41ade4c9d..c6ed614593 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -1189,6 +1189,60 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, } } +inline void BroadcastAddFivefold( + int y0, int y1, int y2, int y3, int y4, int left_shift, + const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAddFivefold/8bit"); + + int sb1 = y0; + int sa2 = y0; + int sb2 = y0 * y1; + int sa3 = y0 * y2; + int sa4 = y0 * y2 * y3; + int sb4 = y0 * y1 * y2; + + uint8* output_data_ptr = output_data; + for (int i4 = 0; i4 < y4; ++i4) { + for (int i3 = 0; i3 < y3; ++i3) { + for (int i2 = 0; i2 < y2; ++i2) { + for (int i1 = 0; i1 < y1; ++i1) { + for (int i0 = 0; i0 < y0; ++i0) { + const int32 input1_val = + input1_offset + + input1_data[i4 * sa4 + i3 * sa3 + i2 * sa2 + i0]; + const int32 input2_val = + input2_offset + + input2_data[i4 * sb4 + i2 * sb2 + i1 * sb1 + i0]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, raw_output)); + *output_data_ptr = static_cast(clamped_output); + ++output_data_ptr; + } + } + } + } + } +} + template inline void BroadcastAdd(int left_shift, const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, @@ -1217,6 +1271,33 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, output_activation_max, output_data, output_dims); } +template +inline void BroadcastAddFivefold( + int y0, int y1, int y2, int y3, int y4, int left_shift, + const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + BroadcastAddFivefold(y0, y1, y2, y3, y4, left_shift, input1_data, input1_dims, + input1_offset, input1_multiplier, input1_shift, + input2_data, input2_dims, input2_offset, + input2_multiplier, input2_shift, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims); +} + inline void Mul(const float* input1_data, const Dims<4>& input1_dims, const float* input2_data, const Dims<4>& input2_dims, float output_activation_min, float output_activation_max, -- GitLab From 09d0e300f9a136838102c94e4b5cf0d4d0876ace Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 10:50:29 -0700 Subject: [PATCH 129/839] Internal clean up: change scanf to use int64_t instead of int64 PiperOrigin-RevId: 195438212 --- tensorflow/core/lib/strings/numbers.cc | 4 ++-- tensorflow/core/util/command_line_flags.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc index e4b909296e..987e4fe733 100644 --- a/tensorflow/core/lib/strings/numbers.cc +++ b/tensorflow/core/lib/strings/numbers.cc @@ -392,8 +392,8 @@ string FpToString(Fprint fp) { bool StringToFp(const string& s, Fprint* fp) { char junk; - uint64 result; - if (sscanf(s.c_str(), "%llx%c", &result, &junk) == 1) { + uint64_t result; + if (sscanf(s.c_str(), "%lx%c", &result, &junk) == 1) { *fp = result; return true; } else { diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc index 8c27d01917..b281acb2b0 100644 --- a/tensorflow/core/util/command_line_flags.cc +++ b/tensorflow/core/util/command_line_flags.cc @@ -69,8 +69,8 @@ bool ParseInt64Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, str_util::ConsumePrefix(&arg, flag) && str_util::ConsumePrefix(&arg, "=")) { char extra; - int64 parsed_int64; - if (sscanf(arg.data(), "%lld%c", &parsed_int64, &extra) != 1) { + int64_t parsed_int64; + if (sscanf(arg.data(), "%ld%c", &parsed_int64, &extra) != 1) { LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag << "."; *value_parsing_ok = false; -- GitLab From 01a70dc43d32eb5add5f1cb5de2d6c98ed88dd83 Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Fri, 4 May 2018 11:17:17 -0700 Subject: [PATCH 130/839] Add operations before Identity operations should be quantized. Fixes #19014 PiperOrigin-RevId: 195443326 --- .../contrib/quantize/python/quantize.py | 6 ++++-- .../contrib/quantize/python/quantize_test.py | 20 +++++++++---------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index efc1a94b3c..60616ea749 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -33,7 +33,7 @@ from tensorflow.python.platform import tf_logging as logging _QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'} # Activations that are supported by the quantization rewrite. -_ACTIVATION_TYPES = {'Relu', 'Relu6', 'Identity'} +_ACTIVATION_TYPES = {'Relu', 'Relu6'} def Quantize(graph, @@ -267,8 +267,10 @@ def _FindLayersToQuantize(graph): # The input to the activation can come from bias add, fold bias add, the # bypasses. + # TODO(suharshs): We should ideally skip Identity operations instead of + # treating them as an activation. activation_pattern = graph_matcher.OpTypePattern( - '|'.join(_ACTIVATION_TYPES), + '|'.join(_ACTIVATION_TYPES) + '|Identity', inputs=[ graph_matcher.OneofPattern([ bias_add_pattern, folded_bias_add_pattern, bypass_pattern_a, diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 5e479f3946..e7360ae03c 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -74,7 +74,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): weights_initializer=self._WeightInit(0.09), activation_fn=None, scope='test/test') node = math_ops.add(conv, input2, name='test/add') - node = array_ops.identity(node, name='test/identity') + node = nn_ops.relu6(node, name='test/relu6') update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') @@ -97,7 +97,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): for output in quant_op.outputs: consumers.extend(output.consumers()) - self.assertNotIn('test/identity', [c.name for c in consumers]) + self.assertNotIn('test/relu6', [c.name for c in consumers]) def testInsertQuantOpForAddAfterSeparableConv2d(self): self._RunTestOverParameters( @@ -114,7 +114,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): weights_initializer=self._WeightInit(0.09), activation_fn=None, scope='test/test') node = math_ops.add(conv, input2, name='test/add') - node = array_ops.identity(node, name='test/identity') + node = nn_ops.relu6(node, name='test/relu6') update_barrier = control_flow_ops.no_op(name='update_barrier') with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') @@ -135,7 +135,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): for output in quant_op.outputs: consumers.extend(output.consumers()) - self.assertNotIn('test/identity', [c.name for c in consumers]) + self.assertNotIn('test/relu6', [c.name for c in consumers]) def testFinalLayerQuantized(self): self._RunTestOverParameters(self._TestFinalLayerQuantized) @@ -174,7 +174,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): stride=2, padding='SAME', weights_initializer=self._WeightInit(0.09), - activation_fn=array_ops.identity, + activation_fn=nn_ops.relu6, scope='test/test') bypass_tensor = math_ops.add(conv, input2, name='test/add') # The output of the post_activation bypass will be another layer. @@ -184,7 +184,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): stride=2, padding='SAME', weights_initializer=self._WeightInit(0.09), - activation_fn=array_ops.identity, + activation_fn=nn_ops.relu6, scope='test/unused') quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) @@ -212,7 +212,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): stride=2, padding='SAME', weights_initializer=self._WeightInit(0.09), - activation_fn=array_ops.identity, + activation_fn=nn_ops.relu6, scope='test/test1') # The bypass of this conv is the post activation bypass of the previous @@ -227,7 +227,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): scope='test/test2') bypass_tensor = math_ops.add(conv1, conv2, name='test/add') - _ = array_ops.identity(bypass_tensor, name='test/output') + _ = nn_ops.relu6(bypass_tensor, name='test/output') quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) @@ -248,11 +248,11 @@ class QuantizeTest(test_util.TensorFlowTestCase): 'test/test1/act_quant/FakeQuantWithMinMaxVars' in op_names) self.assertTrue('test/act_quant/FakeQuantWithMinMaxVars' in op_names) self.assertEqual( - 'Identity', + 'Relu6', graph.get_operation_by_name( 'test/test1/act_quant/FakeQuantWithMinMaxVars').inputs[0].op.type) self.assertEqual( - 'Identity', + 'Relu6', graph.get_operation_by_name( 'test/act_quant/FakeQuantWithMinMaxVars').inputs[0].op.type) -- GitLab From a2cba4a627f880cf8160de624fc1ad947c01e973 Mon Sep 17 00:00:00 2001 From: mbhuiyan Date: Fri, 4 May 2018 12:02:28 -0700 Subject: [PATCH 131/839] if MKL is used allocation id is set to 9 and 10 --- .../direct_session_with_tracking_alloc_test.cc | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index 0ff022a8bc..29c8c8daec 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -101,18 +101,21 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { EXPECT_EQ(2, shape.dim_size()); EXPECT_EQ(2, shape.dim(0).size()); EXPECT_EQ(1, shape.dim(1).size()); -#ifndef INTEL_MKL +#ifdef INTEL_MKL // if MKL is used, it goes through various additional // graph rewrite pass. In TF, everytime a graph pass // happens, "constant" nodes are allocated // and deallocated. Each allocation calls the - // (FindChunkPtr of BFCAllocator) - // , which increments the value of AllocationId. + // (FindChunkPtr of BFCAllocator), + // which increments the value of AllocationId. // Thus AllocationId becomes more than 3 and 4 if - // MKL is used, they can be 10 and 11 or - // other numbers. If MKL is used - // following check will not hold. - // Thus, skipping the check if MKL is used. + // MKL is used. Now they are 9 and 10 for MKL. + if (node->name() == y->name()) { + EXPECT_EQ(9, cm->AllocationId(node, 0)); + } else { + EXPECT_EQ(10, cm->AllocationId(node, 0)); + } +#else if (node->name() == y->name()) { EXPECT_EQ(3, cm->AllocationId(node, 0)); } else { -- GitLab From 2314acf98fb874317dd17ef3daf438d7af87f900 Mon Sep 17 00:00:00 2001 From: Anya Petrova Date: Fri, 4 May 2018 13:22:03 -0700 Subject: [PATCH 132/839] Fix a small typo. --- SECURITY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SECURITY.md b/SECURITY.md index a5ce3a62ee..01886b613e 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -173,7 +173,7 @@ the progress being made towards a fix and announcement. In addition, please include the following information along with your report: * Your name and affiliation (if any). -* A description the technical details of the vulnerabilities. It is very +* A description of the technical details of the vulnerabilities. It is very important to let us know how we can reproduce your findings. * An explanation who can exploit this vulnerability, and what they gain when doing so -- write an attack scenario. This will help us evaluate your report -- GitLab From f368558429f5ebdbc0a187c3801dccf1ca6963c7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 11:40:01 -0700 Subject: [PATCH 133/839] [XLA] Cleanup client_library_test_base: move definition of CreateParameterAndTransferLiteral to .cc file PiperOrigin-RevId: 195446864 --- .../xla/tests/client_library_test_base.cc | 29 +++++++++++++++++++ .../xla/tests/client_library_test_base.h | 29 ------------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index c09e7eaf2b..41f9a5f666 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -565,4 +565,33 @@ XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); } +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, + const Literal& literal, + const string& name, + XlaBuilder* builder, + XlaOp* data_handle) { + return CreateParameterAndTransferLiteral(parameter_number, literal, name, + nullptr, builder, data_handle); +} + +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, XlaBuilder* builder, + XlaOp* data_handle) { + const Literal* param_literal = &literal; + std::unique_ptr converted_literal; + if (use_bfloat16_) { + converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); + param_literal = converted_literal.get(); + } + std::unique_ptr data = + client_->TransferToServer(*param_literal, device_handle) + .ConsumeValueOrDie(); + *data_handle = + builder->Parameter(parameter_number, param_literal->shape(), name); + return data; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index e58979a303..16e838e60f 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -616,35 +616,6 @@ std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( return result; } -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, - const Literal& literal, - const string& name, - XlaBuilder* builder, - XlaOp* data_handle) { - return CreateParameterAndTransferLiteral(parameter_number, literal, name, - nullptr, builder, data_handle); -} - -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, XlaBuilder* builder, - XlaOp* data_handle) { - const Literal* param_literal = &literal; - std::unique_ptr converted_literal; - if (use_bfloat16_) { - converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); - param_literal = converted_literal.get(); - } - std::unique_ptr data = - client_->TransferToServer(*param_literal, device_handle) - .ConsumeValueOrDie(); - *data_handle = - builder->Parameter(parameter_number, param_literal->shape(), name); - return data; -} - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ -- GitLab From be9b87375adecad9bd8bb12c81b2566c77a68ad7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 11:40:20 -0700 Subject: [PATCH 134/839] [XLA] Redesign: migrate the SWIG wrapped xla client. Added LocalOp that wraps XlaOp, so that it's fully visible to swig. PiperOrigin-RevId: 195446939 --- tensorflow/compiler/xla/python/BUILD | 3 +- .../xla/python/local_computation_builder.cc | 315 ++++++++------- .../xla/python/local_computation_builder.h | 206 +++++----- .../xla/python/local_computation_builder.i | 53 +-- tensorflow/compiler/xla/python/xla_client.py | 362 +++++++----------- 5 files changed, 415 insertions(+), 524 deletions(-) diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index ecb87bd889..932cce943f 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -49,9 +49,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 044458164f..df262c97bf 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/python/local_computation_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/default/thread_annotations.h" @@ -248,7 +249,7 @@ LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( return new LocalShapedBuffer(std::move(result_buffer)); } -LocalComputation::LocalComputation(Computation computation) +LocalComputation::LocalComputation(XlaComputation computation) : computation_(std::move(computation)) {} StatusOr LocalComputation::Compile( @@ -271,7 +272,7 @@ StatusOr LocalComputation::Compile( return new CompiledLocalComputation(std::move(local_executable)); } -const Computation& LocalComputation::computation() const { +const XlaComputation& LocalComputation::computation() const { return computation_; } @@ -281,8 +282,12 @@ StatusOr LocalComputation::GetReturnValueShape() const { return std::move(*program_shape.mutable_result()); } +LocalOp::LocalOp(const XlaOp& op) : op_(op) {} + +const XlaOp& LocalOp::op() const { return op_; } + LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) - : builder_(GetOrCreateLocalClient(), computation_name) {} + : builder_(computation_name) {} void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { builder_.SetOpMetadata(metadata); @@ -291,19 +296,21 @@ void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } StatusOr LocalComputationBuilder::Build() { - TF_ASSIGN_OR_RETURN(Computation computation, builder_.Build()); + TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build()); return new LocalComputation(std::move(computation)); } -ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { +LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, + const string& name) { return builder_.Parameter(parameter_number, shape, name); } std::unique_ptr LocalComputationBuilder::GetShape( - const ComputationDataHandle& operand) { - return builder_.GetShape(operand).ConsumeValueOrDie(); + const LocalOp& operand) { + auto result = MakeUnique(); + *result = builder_.GetShape(operand.op()).ValueOrDie(); + return result; } StatusOr LocalComputationBuilder::GetReturnValueShape() { @@ -311,222 +318,236 @@ StatusOr LocalComputationBuilder::GetReturnValueShape() { return program_shape.result(); } -ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) { +LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { return builder_.Infeed(shape); } -void LocalComputationBuilder::Outfeed(const ComputationDataHandle& operand, +void LocalComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config) { - builder_.Outfeed(operand, shape, outfeed_config); + builder_.Outfeed(operand.op(), shape, outfeed_config); } -ComputationDataHandle LocalComputationBuilder::ConstantLiteral( - const Literal& literal) { +LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { return builder_.ConstantLiteral(literal); } -ComputationDataHandle LocalComputationBuilder::Broadcast( - const ComputationDataHandle& operand, +LocalOp LocalComputationBuilder::Broadcast( + const LocalOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - return builder_.Broadcast(operand, broadcast_sizes); + return builder_.Broadcast(operand.op(), broadcast_sizes); } -ComputationDataHandle LocalComputationBuilder::Pad( - const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config) { - return builder_.Pad(operand, padding_value, padding_config); +LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, + const LocalOp& padding_value, + const PaddingConfig& padding_config) { + return builder_.Pad(operand.op(), padding_value.op(), padding_config); } -ComputationDataHandle LocalComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, +LocalOp LocalComputationBuilder::Reshape( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - return builder_.Reshape(operand, dimensions, new_sizes); + return builder_.Reshape(operand.op(), dimensions, new_sizes); } -ComputationDataHandle LocalComputationBuilder::Collapse( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - return builder_.Collapse(operand, dimensions); +LocalOp LocalComputationBuilder::Collapse( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return builder_.Collapse(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::CrossReplicaSum( - const ComputationDataHandle& operand) { - return builder_.CrossReplicaSum(operand); +LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { + return builder_.CrossReplicaSum(operand.op()); } -ComputationDataHandle LocalComputationBuilder::Slice( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, +LocalOp LocalComputationBuilder::Slice( + const LocalOp& operand, tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - return builder_.Slice(operand, start_indices, limit_indices, strides); + return builder_.Slice(operand.op(), start_indices, limit_indices, strides); } -ComputationDataHandle LocalComputationBuilder::SliceInDim( - const ComputationDataHandle& operand, int64 start_index, int64 limit_index, - int64 stride, int64 dimno) { - return builder_.SliceInDim(operand, start_index, limit_index, stride, dimno); +LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, + int64 start_index, + int64 limit_index, int64 stride, + int64 dimno) { + return builder_.SliceInDim(operand.op(), start_index, limit_index, stride, + dimno); } -ComputationDataHandle LocalComputationBuilder::DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, +LocalOp LocalComputationBuilder::DynamicSlice( + const LocalOp& operand, const LocalOp& start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - return builder_.DynamicSlice(operand, start_indices, slice_sizes); + return builder_.DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } -ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices) { - return builder_.DynamicUpdateSlice(operand, update, start_indices); +LocalOp LocalComputationBuilder::DynamicUpdateSlice( + const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices) { + return builder_.DynamicUpdateSlice(operand.op(), update.op(), + start_indices.op()); } -ComputationDataHandle LocalComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension) { - return builder_.ConcatInDim(operands, dimension); +LocalOp LocalComputationBuilder::ConcatInDim( + tensorflow::gtl::ArraySlice operands, int64 dimension) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.ConcatInDim(xla_ops, dimension); } -ComputationDataHandle -LocalComputationBuilder::SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, +LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter) { + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter) { return builder_.SelectAndScatterWithGeneralPadding( - operand, select.computation(), window_dimensions, window_strides, padding, - source, init_value, scatter.computation()); + operand.op(), select.computation(), window_dimensions, window_strides, + padding, source.op(), init_value.op(), scatter.computation()); } -ComputationDataHandle LocalComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { - return builder_.Tuple(elements); +LocalOp LocalComputationBuilder::Tuple( + tensorflow::gtl::ArraySlice elements) { + std::vector xla_ops; + xla_ops.reserve(elements.size()); + for (const auto& op : elements) { + xla_ops.push_back(op.op()); + } + + return builder_.Tuple(xla_ops); } -ComputationDataHandle LocalComputationBuilder::GetTupleElement( - const ComputationDataHandle& tuple_data, int64 index) { - return builder_.GetTupleElement(tuple_data, index); +LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, + int64 index) { + return builder_.GetTupleElement(tuple_data.op(), index); } -ComputationDataHandle LocalComputationBuilder::Dot( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - return builder_.Dot(lhs, rhs); +LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { + return builder_.Dot(lhs.op(), rhs.op()); } -ComputationDataHandle LocalComputationBuilder::DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::DotGeneral( + const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return builder_.DotGeneral(lhs, rhs, dimension_numbers); + return builder_.DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, + return builder_.ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, + padding, lhs_dilation, rhs_dilation, dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvertElementType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - return builder_.ConvertElementType(operand, new_element_type); +LocalOp LocalComputationBuilder::ConvertElementType( + const LocalOp& operand, PrimitiveType new_element_type) { + return builder_.ConvertElementType(operand.op(), new_element_type); } -ComputationDataHandle LocalComputationBuilder::Call( +LocalOp LocalComputationBuilder::Call( const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands) { - return builder_.Call(local_computation.computation(), operands); + tensorflow::gtl::ArraySlice operands) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.Call(local_computation.computation(), xla_ops); } -ComputationDataHandle LocalComputationBuilder::Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation) { - return builder_.Transpose(operand, permutation); +LocalOp LocalComputationBuilder::Transpose( + const LocalOp& operand, tensorflow::gtl::ArraySlice permutation) { + return builder_.Transpose(operand.op(), permutation); } -ComputationDataHandle LocalComputationBuilder::Rev( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - return builder_.Rev(operand, dimensions); +LocalOp LocalComputationBuilder::Rev( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return builder_.Rev(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, +LocalOp LocalComputationBuilder::Map( + tensorflow::gtl::ArraySlice operands, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { - return builder_.Map(operands, local_computation.computation(), dimensions, - static_operands); + tensorflow::gtl::ArraySlice static_operands) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + + std::vector static_xla_ops; + static_xla_ops.reserve(static_operands.size()); + for (const auto& op : static_operands) { + static_xla_ops.push_back(op.op()); + } + + return builder_.Map(xla_ops, local_computation.computation(), dimensions, + static_xla_ops); } -ComputationDataHandle LocalComputationBuilder::Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::Reduce( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions_to_reduce) { - return builder_.Reduce(operand, init_value, local_computation.computation(), - dimensions_to_reduce); + return builder_.Reduce(operand.op(), init_value.op(), + local_computation.computation(), dimensions_to_reduce); } -ComputationDataHandle LocalComputationBuilder::ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { return builder_.ReduceWindowWithGeneralPadding( - operand, init_value, local_computation.computation(), window_dimensions, - window_strides, padding); + operand.op(), init_value.op(), local_computation.computation(), + window_dimensions, window_strides, padding); } -ComputationDataHandle LocalComputationBuilder::RngNormal( - const ComputationDataHandle& mu, const ComputationDataHandle& sigma, - const Shape& shape) { - return builder_.RngNormal(mu, sigma, shape); +LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, + const LocalOp& sigma, + const Shape& shape) { + return builder_.RngNormal(mu.op(), sigma.op(), shape); } -ComputationDataHandle LocalComputationBuilder::RngUniform( - const ComputationDataHandle& a, const ComputationDataHandle& b, - const Shape& shape) { - return builder_.RngUniform(a, b, shape); +LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, + const Shape& shape) { + return builder_.RngUniform(a.op(), b.op(), shape); } -ComputationDataHandle LocalComputationBuilder::While( - const LocalComputation& condition, const LocalComputation& body, - const ComputationDataHandle& init) { - return builder_.While(condition.computation(), body.computation(), init); +LocalOp LocalComputationBuilder::While(const LocalComputation& condition, + const LocalComputation& body, + const LocalOp& init) { + return builder_.While(condition.computation(), body.computation(), init.op()); } -ComputationDataHandle LocalComputationBuilder::Conditional( - const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, +LocalOp LocalComputationBuilder::Conditional( + const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, const LocalOp& false_operand, const LocalComputation& false_computation) { - return builder_.Conditional(predicate, true_operand, - true_computation.computation(), false_operand, - false_computation.computation()); + return builder_.Conditional( + predicate.op(), true_operand.op(), true_computation.computation(), + false_operand.op(), false_computation.computation()); } -StatusOr LocalComputationBuilder::IsConstant( - const ComputationDataHandle& operand, int64 num_parameters) { - return builder_.IsConstant(operand, num_parameters); +StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { + return builder_.IsConstant(operand.op()); } -StatusOr> LocalComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters) { - return builder_.ComputeConstant(operand, output_layout, parameters); +StatusOr LocalComputationBuilder::BuildConstantSubGraph( + const LocalOp& operand) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, + builder_.BuildConstantSubGraph(operand.op())); + return new LocalComputation(std::move(computation)); } #define _FORWARD(method_name, return_sig, args_sig, args) \ @@ -534,23 +555,19 @@ StatusOr> LocalComputationBuilder::ComputeConstant( return builder_.method_name args; \ } -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand), (operand)) - -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions), \ - (lhs, rhs, broadcast_dimensions)) - -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs), \ - (lhs, rhs, ehs)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op())) + +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions), \ + (lhs.op(), rhs.op(), broadcast_dimensions)) + +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs), \ + (lhs.op(), rhs.op(), ehs.op())) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 5ec097846a..a06b85b4ea 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -17,9 +17,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -97,25 +98,37 @@ class CompiledLocalComputation { std::unique_ptr executable_; }; -// Wraps a Computation produced by a LocalComputationBuilder. The +// Wraps a XlaComputation produced by a LocalComputationBuilder. The // Compile method compiles the computation to a (local) executable via // the client library's local client. This class is intended to be // made available to Python via SWIG. class LocalComputation { public: - LocalComputation(Computation computation); + LocalComputation(XlaComputation computation); StatusOr Compile( const std::vector& argument_shapes, const ExecutableBuildOptions* build_options); - const Computation& computation() const; + const XlaComputation& computation() const; // Returns the return-value shape for this computation. StatusOr GetReturnValueShape() const; private: - Computation computation_; + XlaComputation computation_; +}; + +// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended +// to be made available to Python via SWIG. +class LocalOp { + public: + LocalOp(const XlaOp& op); + + const XlaOp& op() const; + + private: + XlaOp op_; }; // Wraps the ComputationBuilder API in order to: @@ -135,166 +148,137 @@ class LocalComputationBuilder { // Returns an owned LocalComputation to the caller on success. StatusOr Build(); - ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, - const string& name); + LocalOp Parameter(int64 parameter_number, const Shape& shape, + const string& name); - std::unique_ptr GetShape(const ComputationDataHandle& operand); + std::unique_ptr GetShape(const LocalOp& operand); // Returns the shape of the current return value for the computation. StatusOr GetReturnValueShape(); - ComputationDataHandle Infeed(const Shape& shape); + LocalOp Infeed(const Shape& shape); - void Outfeed(const ComputationDataHandle& operand, const Shape& shape, + void Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config); - ComputationDataHandle ConstantLiteral(const Literal& literal); + LocalOp ConstantLiteral(const Literal& literal); - ComputationDataHandle Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + LocalOp Broadcast(const LocalOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); - ComputationDataHandle Pad(const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config); + LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, + const PaddingConfig& padding_config); - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + LocalOp Reshape(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); - ComputationDataHandle Collapse(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Collapse(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions); - ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); + LocalOp CrossReplicaSum(const LocalOp& operand); - ComputationDataHandle Slice(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + LocalOp Slice(const LocalOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); - ComputationDataHandle SliceInDim(const ComputationDataHandle& operand, - int64 start_index, int64 limit_index, - int64 stride, int64 dimno); + LocalOp SliceInDim(const LocalOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno); - ComputationDataHandle DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); - ComputationDataHandle DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices); + LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices); - ComputationDataHandle ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension); + LocalOp ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension); - ComputationDataHandle SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, + LocalOp SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter); + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter); - ComputationDataHandle Tuple( - tensorflow::gtl::ArraySlice elements); + LocalOp Tuple(tensorflow::gtl::ArraySlice elements); - ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, - int64 index); + LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index); - ComputationDataHandle Dot(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); + LocalOp Dot(const LocalOp& lhs, const LocalOp& rhs); - ComputationDataHandle DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers); + LocalOp DotGeneral(const LocalOp& lhs, const LocalOp& rhs, + const DotDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + LocalOp ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); + LocalOp ConvertElementType(const LocalOp& operand, + PrimitiveType new_element_type); - ComputationDataHandle Call( - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands); + LocalOp Call(const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice operands); - ComputationDataHandle Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation); + LocalOp Transpose(const LocalOp& operand, + tensorflow::gtl::ArraySlice permutation); - ComputationDataHandle Rev(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Rev(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions); - ComputationDataHandle Map( - tensorflow::gtl::ArraySlice operands, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands); + LocalOp Map(tensorflow::gtl::ArraySlice operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands); - ComputationDataHandle Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); - ComputationDataHandle ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, + LocalOp ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding); - ComputationDataHandle RngNormal(const ComputationDataHandle& mu, - const ComputationDataHandle& sigma, - const Shape& shape); + LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, + const Shape& shape); - ComputationDataHandle RngUniform(const ComputationDataHandle& a, - const ComputationDataHandle& b, - const Shape& shape); + LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); - ComputationDataHandle While(const LocalComputation& condition, - const LocalComputation& body, - const ComputationDataHandle& init); + LocalOp While(const LocalComputation& condition, const LocalComputation& body, + const LocalOp& init); - ComputationDataHandle Conditional(const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, - const LocalComputation& false_computation); + LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, + const LocalOp& false_operand, + const LocalComputation& false_computation); - StatusOr IsConstant(const ComputationDataHandle& operand, - int64 num_parameters); + StatusOr IsConstant(const LocalOp& operand); - StatusOr > ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand)) -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions)) +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions)) -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs)) +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs)) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) @@ -338,7 +322,7 @@ class LocalComputationBuilder { #undef _FORWARD_TRIOP private: - ComputationBuilder builder_; + XlaBuilder builder_; }; // Functions for freeing resources from the Python side. diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index b8cce5a5f7..04c56bbba9 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -22,9 +22,8 @@ limitations under the License. // // C++ Python // -------------------------------------+--------------------------------------- -// ComputationDataHandle <-> int // ArraySlice <- sequence of int -// ArraySlice <- sequence of int +// ArraySlice <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray // Shape -> pair holding (dtype, dimensions) @@ -91,12 +90,9 @@ limitations under the License. // One central reason for the Python-side indirection is that the // Python-side objects produced by the typemaps in this file are // further packaged up by xla_client before being passed on. For -// instance, xla_client wraps the long produced for a C++ -// ComputationDataHandle in a Python ComputationDataHandle proto, -// rather than exposing a raw long outside of the client. Similarly, -// the Python pair produced for a C++ Shape is further wrapped in a -// Python class (xla_client.Shape) so as not to expose the raw pair -// externally. +// instance, the Python pair produced for a C++ Shape is further +// wrapped in a Python class (xla_client.Shape) so as not to expose +// the raw pair externally. // // Other SWIG object wrappers (e.g. of LocalComputation) are further // wrapped by xla_client in order to set up a custom destructor that @@ -124,6 +120,7 @@ using namespace xla; using namespace xla::swig; namespace xla { + namespace swig { bool GetIntAttr(PyObject* o, const char* field, int64* result) { @@ -177,21 +174,6 @@ bool HandleStringAttribute(PyObject* o, tensorflow::ImportNumpy(); %} -// ComputationDataHandle - -%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { - const int64 handle = numpy::PyIntOrPyLongToLong($input); - if (handle == -1 && PyErr_Occurred()) { - SWIG_fail; - } - temp.set_handle(handle); - $1 = &temp; -} - -%typemap(out) ComputationDataHandle { - $result = numpy::LongToPyIntOrPyLong($1.handle()); -} - %typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); @@ -301,33 +283,23 @@ tensorflow::ImportNumpy(); $1 = temps; } -// ComputationDataHandle +// ArraySlice -%typemap(in) tensorflow::gtl::ArraySlice - (std::vector temps) { +%typemap(in) tensorflow::gtl::ArraySlice( + std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); SWIG_fail; } const int size = PySequence_Size($input); - temps.resize(size); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - PyObject* py_int = numpy::PyNumberToPyInt(o); - if (!py_int) { - PyErr_SetString( - PyExc_TypeError, - "Argument sequence element cannot be converted to int"); - SWIG_fail; - } - const int64 handle = numpy::PyIntOrPyLongToLong(py_int); - if (handle == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - Py_DECREF(o); + LocalOp* op; + if ((SWIG_ConvertPtr(o, (void**)&op, $descriptor(xla::swig::LocalOp*), + SWIG_POINTER_EXCEPTION)) == -1) { SWIG_fail; } - temps[i].set_handle(handle); - Py_DECREF(py_int); + temps.push_back(*op); Py_DECREF(o); } $1 = temps; @@ -934,6 +906,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputation; %unignore xla::swig::LocalComputation::Compile; %unignore xla::swig::LocalComputation::GetReturnValueShape; +%unignore xla::swig::LocalOp; %unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::Build; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index f6809b6b87..1d5b75d1be 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -335,20 +335,6 @@ def _wrap_shape(shape_info): return Shape.array_shape(dtype, dims) -def _wrap_data_handle(handle): - cdh = xla_data_pb2.ComputationDataHandle() - cdh.handle = handle - return cdh - - -def _unwrap_data_handle(handle_proto): - return handle_proto.handle - - -def _unwrap_data_handles(handle_protos): - return [_unwrap_data_handle(cdh) for cdh in handle_protos] - - def require_numpy_array_layout(value): if isinstance(value, tuple): return tuple(require_numpy_array_layout(x) for x in value) @@ -535,9 +521,9 @@ class ComputationBuilder(object): queue for subsequent use in the computation. Returns: - A ComputationDataHandle message. + A LocalOp. """ - return _wrap_data_handle(self._client.Infeed(shape)) + return self._client.Infeed(shape) def Outfeed(self, operand): """Enqueues an outfeed op onto the computation. @@ -545,9 +531,7 @@ class ComputationBuilder(object): Outfeed operations enqueue data, using the given operand, onto the XLA outfeed queue for subsequent dequeue via the client API. """ - self._client.Outfeed( - _unwrap_data_handle(operand), self.GetShape(operand), - ''.encode('utf-8')) + self._client.Outfeed(operand, self.GetShape(operand), ''.encode('utf-8')) def Constant(self, value): """Enqueues a constant op onto the computation. @@ -557,10 +541,10 @@ class ComputationBuilder(object): to one of the supported types. Returns: - A ComputationDataHandle message. + A LocalOp. """ value = require_numpy_array_layout(value) - return _wrap_data_handle(self._client.ConstantLiteral(value)) + return self._client.ConstantLiteral(value) def ConstantF32Scalar(self, value): """Convenience method to enqueue a scalar F32 constant op. @@ -569,7 +553,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float32)) @@ -580,7 +564,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float64)) @@ -591,7 +575,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int32)) @@ -602,7 +586,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int64)) @@ -613,7 +597,7 @@ class ComputationBuilder(object): value: a boolean value. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.bool)) @@ -629,15 +613,14 @@ class ComputationBuilder(object): parameters, use it for *all* parameters to avoid clashes. Returns: - A ComputationDataHandle message. + A LocalOp. """ if name is None: name = '' if parameter_num is None: parameter_num = next(self._parameter_numbering) - return _wrap_data_handle( - self._client.Parameter(parameter_num, shape, name.encode('utf8'))) + return self._client.Parameter(parameter_num, shape, name.encode('utf8')) def ParameterFromNumpy(self, value, name=None, parameter_num=None): """Enqueues a Parameter op onto the computation. @@ -649,7 +632,7 @@ class ComputationBuilder(object): parameter_num: as in ParameterWithShape. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.ParameterWithShape( Shape.from_pyval(value), name=name, parameter_num=parameter_num) @@ -658,14 +641,13 @@ class ComputationBuilder(object): """Enqueues a broadcast operation onto the computation. Args: - operand: the operand ComputationDataHandle to broadcast. + operand: the operand LocalOp to broadcast. sizes: an iterable of broadcast sizes. Returns: - A ComputationDataHandle representing the added broadcast op. + A LocalOp representing the added broadcast op. """ - return _wrap_data_handle( - self._client.Broadcast(_unwrap_data_handle(operand), sizes)) + return self._client.Broadcast(operand, sizes) def Concatenate(self, operands, dimension): """Enqueues a concatenate operation onto the computation. @@ -675,10 +657,9 @@ class ComputationBuilder(object): dimension: the dimension in which to perform the concatenation. Returns: - A ComputationDataHandle representing the added concatenate op. + A LocalOp representing the added concatenate op. """ - return _wrap_data_handle( - self._client.ConcatInDim(_unwrap_data_handles(operands), dimension)) + return self._client.ConcatInDim(operands, dimension) def ConvertElementType(self, operand, new_element_type): """Enqueues an element type conversion operation onto the computation. @@ -688,14 +669,12 @@ class ComputationBuilder(object): new_element_type: the target primitive type. Returns: - A ComputationDataHandle representing the added conversion op. + A LocalOp representing the added conversion op. """ - return _wrap_data_handle( - self._client.ConvertElementType( - _unwrap_data_handle(operand), new_element_type)) + return self._client.ConvertElementType(operand, new_element_type) def GetShape(self, operand): - return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) + return _wrap_shape(self._client.GetShape(operand)) def GetReturnValueShape(self): return _wrap_shape(self._client.GetReturnValueShape()) @@ -707,40 +686,35 @@ class ComputationBuilder(object): """Enqueues a Pad operation onto the computation. Args: - operand: ComputationDataHandle representing the array to pad. - padding_value: ComputationDataHandle representing the scalar pad value. + operand: LocalOp representing the array to pad. + padding_value: LocalOp representing the scalar pad value. padding_config: either an xla_data_pb2.PaddingConfig or a list of integer triples (edge_padding_low, edge_padding_high, interior_padding) representing the configuration of the padding operation. Returns: - A ComputationDataHandle representing the added Pad op. + A LocalOp representing the added Pad op. """ if not isinstance(padding_config, xla_data_pb2.PaddingConfig): padding_config = GetPaddingConfigFromTriples(padding_config) - return _wrap_data_handle( - self._client.Pad(_unwrap_data_handle(operand), - _unwrap_data_handle(padding_value), - padding_config)) + return self._client.Pad(operand, padding_value, padding_config) def Reshape(self, operand, dimensions, new_sizes): """Enqueues a reshape op onto the computation. Args: - operand: ComputationDataHandle representing the array to be reshaped. + operand: LocalOp representing the array to be reshaped. dimensions: sequence of integers encoding the order in which dimensions are collapsed or None, in which case dimensions are flattened in order. new_sizes: sequence of integers encoding the new dimension sizes (shape). Returns: - A ComputationDataHandle representing the added Reshape op. + A LocalOp representing the added Reshape op. """ if dimensions is None: ndim = len(self.GetShape(operand).dimensions()) dimensions = tuple(range(ndim)) - return _wrap_data_handle( - self._client.Reshape( - _unwrap_data_handle(operand), dimensions, new_sizes)) + return self._client.Reshape(operand, dimensions, new_sizes) def CrossReplicaSum(self, operand): """CrossReplicaSum op. @@ -749,67 +723,56 @@ class ComputationBuilder(object): operand: the operand to sum across replica instances. Returns: - A ComputationDataHandle that has the sum of the value among all replicas. + A LocalOp that has the sum of the value among all replicas. """ - return _wrap_data_handle( - self._client.CrossReplicaSum(_unwrap_data_handle(operand))) + return self._client.CrossReplicaSum(operand) def Collapse(self, operand, dimensions): """Collapse op.""" - return _wrap_data_handle( - self._client.Collapse(_unwrap_data_handle(operand), dimensions)) + return self._client.Collapse(operand, dimensions) def Trans(self, operand): """Specialized matrix transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), [1, 0])) + return self._client.Transpose(operand, [1, 0]) def Transpose(self, operand, permutation): """Transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), permutation)) + return self._client.Transpose(operand, permutation) def Rev(self, operand, dimensions): """Rev op.""" - return _wrap_data_handle( - self._client.Rev(_unwrap_data_handle(operand), dimensions)) + return self._client.Rev(operand, dimensions) def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin """Clamp op.""" - return _wrap_data_handle( - self._client.Clamp(_unwrap_data_handle(min), - _unwrap_data_handle(operand), - _unwrap_data_handle(max))) + return self._client.Clamp(min, operand, max) def SelectAndScatter(self, operand, select, window_dimensions, window_strides, padding, source, init_value, scatter): """Select and scatter op, used by the gradient of ReduceWindow. Args: - operand: ComputationDataHandle for array of dimension N and type T over + operand: LocalOp for array of dimension N and type T over which the windows slide. select: Computation of type (T, T) -> Pred to apply to the elements of each window to indicate which element is selected. window_dimensions: sequence of N integers for dimensions of the window. window_strides: sequence of N integers for the strides of the window. padding: PaddingType representing either 'SAME' or 'VALID ' padding. - source: ComputationDataHandle for array of type T with values to scatter. - init_value: ComputationDataHandle of scalar type T for initial out value. + source: LocalOp for array of type T with values to scatter. + init_value: LocalOp of scalar type T for initial out value. scatter: Computation of type (T, T) -> T to apply to each scatter source element with its destination element. Returns: - A ComputationDataHandle representing the added SelectAndScatter op. + A LocalOp representing the added SelectAndScatter op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.SelectAndScatterWithGeneralPadding( - _unwrap_data_handle(operand), select.c_local_computation, - window_dimensions, window_strides, pads, - _unwrap_data_handle(source), _unwrap_data_handle(init_value), - scatter.c_local_computation)) + return self._client.SelectAndScatterWithGeneralPadding( + operand, select.c_local_computation, window_dimensions, window_strides, + pads, source, init_value, scatter.c_local_computation) def Select(self, pred, on_true, on_false): """Element-wise selection op. @@ -817,17 +780,13 @@ class ComputationBuilder(object): Constructs an output array from elements of two input arrays, based on the values of a predicate array. """ - return _wrap_data_handle( - self._client.Select( - _unwrap_data_handle(pred), - _unwrap_data_handle(on_true), - _unwrap_data_handle(on_false))) + return self._client.Select(pred, on_true, on_false) def Slice(self, operand, start_indices, limit_indices, strides=None): """Enqueues a slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. + operand: LocalOp for the N dimensional array to be sliced. start_indices: iterable of N integers containing the starting indices of the slice for each dimension. limit_indices: iterable of N integers containing the ending indices @@ -836,207 +795,177 @@ class ComputationBuilder(object): each dimension. Returns: - A ComputationDataHandle representing the added Slice op. + A LocalOp representing the added Slice op. """ if strides is None: start_indices = list(start_indices) strides = [1] * len(start_indices) - return _wrap_data_handle( - self._client.Slice( - _unwrap_data_handle(operand), start_indices, limit_indices, - strides)) + return self._client.Slice(operand, start_indices, limit_indices, strides) def SliceInDim(self, operand, start_index, limit_index, stride, dimno): """Enqueues a slice-in-dimension operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. + operand: LocalOp for the N dimensional array to be sliced. start_index: an integer containing the start index of the slice. limit_index: an integer containing the end index of the slice. stride: an integer containing the stride size for the slice. dimno: an integer indicating the dimension along which to slice. Returns: - A ComputationDataHandle representing the added Slice op. + A LocalOp representing the added Slice op. """ - return _wrap_data_handle( - self._client.SliceInDim( - _unwrap_data_handle(operand), start_index, limit_index, stride, - dimno)) + return self._client.SliceInDim(operand, start_index, limit_index, stride, + dimno) def DynamicSlice(self, operand, start_indices, slice_sizes): """Enqueues a slice op with dynamic start indices onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. - start_indices: ComputationDataHandle for the 1D array of N integers + operand: LocalOp for the N dimensional array to be sliced. + start_indices: LocalOp for the 1D array of N integers containing the starting indices of the slice. slice_sizes: iterable of N integers containing the slice sizes in each dimension. Returns: - A ComputationDataHandle representing the added DynamicSlice op. + A LocalOp representing the added DynamicSlice op. """ - return _wrap_data_handle( - self._client.DynamicSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(start_indices), - slice_sizes)) + return self._client.DynamicSlice(operand, start_indices, slice_sizes) def DynamicUpdateSlice(self, operand, update, start_indices): """Enqueues a dynamic update slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be updated. + operand: LocalOp for the N dimensional array to be updated. update: N dimensional array comprising the slice update. start_indices: Rank-1 array of N integers comprising the starting indices of the slice along each dimension. Returns: - A ComputationDataHandle representing the added DynamicUpdateSlice op. + A LocalOp representing the added DynamicUpdateSlice op. """ - return _wrap_data_handle( - self._client.DynamicUpdateSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(update), - _unwrap_data_handle(start_indices))) + return self._client.DynamicUpdateSlice(operand, update, start_indices) def Tuple(self, *ops): """Enqueues a tuple operation onto the computation. Args: - ops: a sequence of tuple operands (each a ComputationDataHandle). + ops: a sequence of tuple operands (each a LocalOp). Returns: - A ComputationDataHandle representing the added Tuple op. + A LocalOp representing the added Tuple op. """ - return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops))) + return self._client.Tuple(ops) def GetTupleElement(self, tup, index): """Enqueues a 'get tuple element' operation onto the computation. Args: - tup: the tuple operand (a ComputationDataHandle). + tup: the tuple operand (a LocalOp). index: numeric index to select from the tuple. Returns: - A ComputationDataHandle representing the added GetTupleElement op. + A LocalOp representing the added GetTupleElement op. """ - return _wrap_data_handle( - self._client.GetTupleElement(_unwrap_data_handle(tup), index)) + return self._client.GetTupleElement(tup, index) def Call(self, computation_to_apply, operands): """Enqueues a call operation onto the computation. Args: computation_to_apply: a Computation object. - operands: an iterable of ComputationDataHandle. The number and types of + operands: an iterable of LocalOp. The number and types of operands must match the arity of computation_to_apply. Returns: - A ComputationDataHandle representing the added call op. + A LocalOp representing the added call op. """ - return _wrap_data_handle( - self._client.Call(computation_to_apply.c_local_computation, - _unwrap_data_handles(operands))) + return self._client.Call(computation_to_apply.c_local_computation, operands) def Map(self, operands, computation_to_apply, dimensions, static_operands=()): """Enqueues a map operation onto the computation. Args: - operands: an iterable of ComputationDataHandle. + operands: an iterable of LocalOp. computation_to_apply: a Computation object. dimensions: dimensions over which to apply map the function. static_operands: auxiliary arguments passed to the applied computation. Returns: - A ComputationDataHandle representing the added Map op. + A LocalOp representing the added Map op. """ - return _wrap_data_handle( - self._client.Map( - _unwrap_data_handles(operands), - computation_to_apply.c_local_computation, - dimensions, - _unwrap_data_handles(static_operands))) + return self._client.Map(operands, computation_to_apply.c_local_computation, + dimensions, static_operands) def Reduce(self, operand, init_value, computation_to_apply, dimensions): """Enqueues a reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a Computation object - binary reduction function. dimensions: sequence of dimensions (integers) to reduce on. Returns: - A ComputationDataHandle representing the added Reduce op. + A LocalOp representing the added Reduce op. """ - return _wrap_data_handle( - self._client.Reduce( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - dimensions)) + return self._client.Reduce(operand, init_value, + computation_to_apply.c_local_computation, + dimensions) def ReduceWindow(self, operand, init_value, computation_to_apply, window_dimensions, window_strides, padding): """Enqueues a windowed reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a binary reduction function (Computation). window_dimensions: dimensions of window (sequence of integers). window_strides: strides for window (sequence of integers). padding: PaddingType representing either 'SAME' or 'VALID' padding. Returns: - A ComputationDataHandle representing the added ReduceWindow op. + A LocalOp representing the added ReduceWindow op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.ReduceWindowWithGeneralPadding( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - window_dimensions, window_strides, pads)) + return self._client.ReduceWindowWithGeneralPadding( + operand, init_value, computation_to_apply.c_local_computation, + window_dimensions, window_strides, pads) def RngNormal(self, mu, sigma, dims): """Enqueues an RngNormal operation onto the computation. Args: - mu: A ComputationDataHandle to an F32 scalar specifying the mean. - sigma: A ComputationDataHandle to an F32 scalar specifying the standard + mu: A LocalOp to an F32 scalar specifying the mean. + sigma: A LocalOp to an F32 scalar specifying the standard deviation. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of F32 values. + Returns: a LocalOp to the generated array of F32 values. """ shape = Shape.array_shape(self.GetShape(mu).element_type(), dims) - return _wrap_data_handle( - self._client.RngNormal( - _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape)) + return self._client.RngNormal(mu, sigma, shape) def RngUniform(self, a, b, dims): """Enqueues an RngUniform operation onto the computation. Args: - a: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + a: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of b) specifying the low end of the interval [a, b) over which values are generated. - b: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + b: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of a) specifying the high end of the interval [a, b) over which values are generated. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of values with the + Returns: a LocalOp to the generated array of values with the same numeric type (F32, S32, or U32) as the arguments a and b. """ shape = Shape.array_shape(self.GetShape(a).element_type(), dims) - return _wrap_data_handle( - self._client.RngUniform( - _unwrap_data_handle(a), _unwrap_data_handle(b), shape)) + return self._client.RngUniform(a, b, shape) def While(self, cond, body, init): """Enqueues a While operation onto the computation. @@ -1044,112 +973,105 @@ class ComputationBuilder(object): Args: cond: a Computation for the loop condition, which has type T -> PRED body: a Computation for the loop body, which has type T -> T - init: a ComputationDataHandle for the initial parameter, which has type T + init: a LocalOp for the initial parameter, which has type T - Returns: a ComputationDataHandle representing the While operation. + Returns: a LocalOp representing the While operation. """ - return _wrap_data_handle( - self._client.While(cond.c_local_computation, - body.c_local_computation, - _unwrap_data_handle(init))) + return self._client.While(cond.c_local_computation, + body.c_local_computation, init) def Conditional(self, pred, true_operand, true_computation, false_operand, false_computation): """Enqueues a Conditional operation onto the computation. Args: - predicate: a ComputationDataHandle to test, which has scalar type PRED - true_operand: a ComputationDataHandle of type T_0 + predicate: a LocalOp to test, which has scalar type PRED + true_operand: a LocalOp of type T_0 true_computation: a Computation to apply to true_operand, type T_0 -> S false_operand: a ComputationDatahandle of type T_1 false_computation: a Computation to apply to false_operand, type T_1 -> S - Returns: a ComputationDataHandle representing the Conditional operation. + Returns: a LocalOp representing the Conditional operation. """ - return _wrap_data_handle( - self._client.Conditional( - _unwrap_data_handle(pred), _unwrap_data_handle(true_operand), - true_computation.c_local_computation, - _unwrap_data_handle(false_operand), - false_computation.c_local_computation)) + return self._client.Conditional( + pred, true_operand, true_computation.c_local_computation, false_operand, + false_computation.c_local_computation) - def IsConstant(self, operand, num_parameters=0): - """Enqueues an IsConstant operation onto the computation. + def IsConstant(self, operand): + """Checks whether the given operand is a compile-time constant. Args: operand: a ComputationDataHandle to test. - num_parameters: optional int, number of computation parameters to treat as - constant (default 0). Returns: bool indicating whether `operand` is a compile-time constant, - meaning its value does not depend on parameters with index greater than or - equal to `num_parameters`. + meaning its value does not depend on any parametersor, or on stateful + operators such as `RngNormal` or `Infeed`. + """ + return self._client.IsConstant(operand) + + def BuildConstantSubGraph(self, operand): + """Builds a constant sub graph. + + Args: + operand: a LocalOp to test. + Returns: a LocalComputation that is rooted on the given `operand` which is a + compile-time constant. """ - return self._client.IsConstant(_unwrap_data_handle(operand), num_parameters) + return self._client.BuildConstantSubGraph(operand) def Dot(self, lhs, rhs): """Enqueues a dot operation onto the computation. Args: - lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array. - rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array. + lhs: LocalOp for the rank 1 or rank 2 left-hand-side array. + rhs: LocalOp for the rank 1 or rank 2 right-hand-side array. - Returns: a ComputationDataHandle representing the Dot operation. + Returns: a LocalOp representing the Dot operation. """ - return _wrap_data_handle( - self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) + return self._client.Dot(lhs, rhs) def DotGeneral(self, lhs, rhs, dimension_numbers): """Enqueues a general dot operation onto the computation. Args: - lhs: ComputationDataHandle for the left-hand-side array. - rhs: ComputationDataHandle for the right-hand-side array. + lhs: LocalOp for the left-hand-side array. + rhs: LocalOp for the right-hand-side array. dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of integers representing the dimensions to treat as contracting dimensions and batch dimensions on each input operand. - Returns: a ComputationDataHandle representing the DotGeneral operation. + Returns: a LocalOp representing the DotGeneral operation. """ if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers): dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) - return _wrap_data_handle( - self._client.DotGeneral( - _unwrap_data_handle(lhs), _unwrap_data_handle(rhs), - dimension_numbers)) + return self._client.DotGeneral(lhs, rhs, dimension_numbers) def Conv(self, lhs, rhs, window_strides, padding): """Enqueues a Conv operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of integer kernel strides. padding: PaddingType representing either 'SAME' or 'VALID' padding. - Returns: a ComputationDataHandle representing the Conv operation. + Returns: a LocalOp representing the Conv operation. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(lhs).dimensions()[2:], self.GetShape(rhs).dimensions()[2:], window_strides) dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - pads, - (), - (), - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (), + (), dimension_numbers) def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation): """Enqueues a ConvWithGeneralPadding operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of kernel strides. padding: length-N array-like of pairs of integers of (low, high) padding. lhs_dilation: length-N array-like of dilation factors. @@ -1159,14 +1081,9 @@ class ComputationBuilder(object): A ComputationdataHandle representing the added ConvWithGeneralPadding op. """ dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - padding, - lhs_dilation, - rhs_dilation, - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers) def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" @@ -1196,15 +1113,14 @@ def _forward_methods_to_local_builder(): """Generate a forwarding method that wraps/unwraps data handles.""" def forward(self, *args, **kwargs): - unwrapped_args = [_unwrap_data_handle(arg) for arg in args] + arg_list = list(args) - if is_binop and len(unwrapped_args) < 3: - unwrapped_args.append(kwargs.get('broadcast_dimensions', ())) + if is_binop and len(arg_list) < 3: + arg_list.append(kwargs.get('broadcast_dimensions', ())) - return _wrap_data_handle( - target_method( - self._client, # pylint: disable=protected-access - *unwrapped_args)) + return target_method( + self._client, # pylint: disable=protected-access + *arg_list) return forward -- GitLab From 5ca373b4b64167f8b0fcab96d7d2e7886ea31b6a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 12:28:42 -0700 Subject: [PATCH 135/839] Some fixes to support another TF graph: 1. Fix ResolveBatchNormalization to avoid deleting arrays that may still be used. 2. Correctly count the number of ops using a given array, even when some ops use the same array as more than one of their inputs. 3. In PropagateFixedSizes for Concatenation ops, when resolving a -1 wildcard to a fixed value, we were doing so in a local 'axis' variable without actually updating op->axis! The resulting -1 value still in op->axis tripped runtime code, causing the concatenation to misbehave during inference. PiperOrigin-RevId: 195454037 --- .../graph_transformations/propagate_fixed_sizes.cc | 11 +++++------ .../resolve_batch_normalization.cc | 6 +++--- tensorflow/contrib/lite/toco/tooling_util.cc | 4 ++++ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 4923f83d91..b02b02c5be 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -670,8 +670,7 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { const auto& first_input_array = model->GetArray(op->inputs[0]); output_array.copy_shape(first_input_array.shape()); // Negative axis means the count starts at the back of the dims(). - int axis = op->axis; - if (axis < 0) axis += first_input_array.shape().dims().size(); + if (op->axis < 0) op->axis += first_input_array.shape().dims().size(); // Determine the concat size, and enfore that all inputs have // the same dimensions count. int concat_size = 0; @@ -684,14 +683,14 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { CHECK_EQ(input_array.shape().dimensions_count(), output_array.shape().dimensions_count()); const std::vector& input_dims = input_array.shape().dims(); - CHECK_LT(axis, input_dims.size()); - concat_size += input_dims[axis]; + CHECK_LT(op->axis, input_dims.size()); + concat_size += input_dims[op->axis]; } // Write out the concat_size on the output array shape. auto& output_shape = *output_array.mutable_shape(); auto& output_dims = *output_shape.mutable_dims(); - CHECK_LT(axis, output_shape.dimensions_count()); - output_dims[axis] = concat_size; + CHECK_LT(op->axis, output_shape.dimensions_count()); + output_dims[op->axis] = concat_size; } void ProcessRangeOperator(Model* model, RangeOperator* op) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc index 2b3ee36ad1..8f2c1f8162 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc @@ -134,9 +134,9 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { } // Remove the old param arrays - model->EraseArray(bn_op->inputs[1]); - model->EraseArray(bn_op->inputs[2]); - model->EraseArray(bn_op->inputs[3]); + DeleteArrayIfUsedOnce(bn_op->inputs[1], model); + DeleteArrayIfUsedOnce(bn_op->inputs[2], model); + DeleteArrayIfUsedOnce(bn_op->inputs[3], model); // Remove the old operator DCHECK_EQ(bn_it->get(), bn_op); diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 86ee1f3761..341d45e753 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -143,6 +143,10 @@ int CountOpsWithInput(const Model& model, const string& array_name) { for (auto& input : op->inputs) { if (input == array_name) { count++; + // Breaking here is important: some graphs have ops that use the + // same array as more than one of their inputs, and in that case + // we want it counted only once. + break; } } } -- GitLab From 67b5e724121c5874425936fe01318642508d9975 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 4 May 2018 14:40:02 -0700 Subject: [PATCH 136/839] [XLA:GPU] Mark floating-point division as an inexpensive op. "Expensive" really means "so expensive you'd choose not to fuse in order to avoid doing it twice". FP division definitely isn't that expensive. PiperOrigin-RevId: 195473524 --- .../xla/service/gpu/instruction_fusion.cc | 13 +++++ .../xla/service/gpu/instruction_fusion.h | 2 + .../service/gpu/instruction_fusion_test.cc | 56 +++++++++++++++++++ 3 files changed, 71 insertions(+) diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 85ecbe8fdb..c5eb721185 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -48,6 +48,19 @@ bool IsFusile(const HloInstruction& hlo) { } // namespace +/*static*/ bool GpuInstructionFusion::IsExpensive( + const HloInstruction& instruction) { + switch (instruction.opcode()) { + // We say that floating-point division is cheap on the GPU. + case HloOpcode::kDivide: + return !ShapeUtil::ElementIsFloating(instruction.shape()) && + InstructionFusion::IsExpensive(instruction); + + default: + return InstructionFusion::IsExpensive(instruction); + } +} + bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index bb2990e6df..9fb06b0a24 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -27,6 +27,8 @@ class GpuInstructionFusion : public InstructionFusion { explicit GpuInstructionFusion(bool may_duplicate) : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {} + static bool IsExpensive(const HloInstruction& instruction); + bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; HloInstruction::FusionKind ChooseKind( diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 4b231c449f..6c9a805ad6 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -253,5 +253,61 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { op::Dot(op::Parameter(), op::Transpose(op::Parameter())))); } +// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is +// duplicated and fused into both reduces. +TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { + auto module = tools::Parse(R"( + HloModule test_module + Add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = f32[] constant(0) + one = f32[] constant(1) + p0 = f32[100] parameter(0) + recip = f32[100] divide(one, p0) + sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT root = (f32[], f32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion())); +} + +// Compute sum(100/p0), where p0 has type s32, twice. Check that the division +// is *not* duplicated and fused into both reduces, because we say that integer +// division is not cheap. +TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { + auto module = tools::Parse(R"( + HloModule test_module + Add { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = s32[] constant(0) + one_hundred = s32[] constant(100) + p0 = s32[100] parameter(0) + recip = s32[100] divide(one_hundred, p0) + sum1 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT mul = (s32[], s32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + } // namespace gpu } // namespace xla -- GitLab From 4d0388d22060a61f40965127c153c681b2412c50 Mon Sep 17 00:00:00 2001 From: James Qin Date: Fri, 4 May 2018 14:53:58 -0700 Subject: [PATCH 137/839] Fix build failure for macos py3 PiperOrigin-RevId: 195475780 --- tensorflow/python/debug/examples/debug_tflearn_iris.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/debug/examples/debug_tflearn_iris.py b/tensorflow/python/debug/examples/debug_tflearn_iris.py index 00090b21fe..7cbaae46b4 100644 --- a/tensorflow/python/debug/examples/debug_tflearn_iris.py +++ b/tensorflow/python/debug/examples/debug_tflearn_iris.py @@ -140,7 +140,7 @@ def main(_): # Make predictions, using tfdbg hook. predict_results = classifier.predict(test_input_fn, hooks=hooks) - print("A prediction result: %s" % predict_results.next()) + print("A prediction result: %s" % next(predict_results)) if __name__ == "__main__": -- GitLab From cb1775e9525ae621d23708a3d64a6cad897be95e Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 4 May 2018 15:14:00 -0700 Subject: [PATCH 138/839] Identify and prune nodes that can never be executed PiperOrigin-RevId: 195478951 --- tensorflow/core/grappler/optimizers/BUILD | 1 + .../grappler/optimizers/loop_optimizer.cc | 140 ++++++++++++++++++ .../core/grappler/optimizers/loop_optimizer.h | 1 + .../optimizers/loop_optimizer_test.cc | 107 +++++++++++++ tensorflow/core/grappler/utils.h | 4 +- 5 files changed, 251 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 5b5e1e024e..900dfa95c5 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -604,6 +604,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index 5adc5b9227..7d3520febc 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" @@ -504,6 +505,140 @@ Status RemoveStackOps(const std::unordered_set& nodes_to_preserve, return Status::OK(); } +Status RemoveDeadBranches(const std::unordered_set& nodes_to_preserve, + GraphDef* optimized_graph) { + std::unordered_set dead_nodes; + std::unordered_map> dead_merge_inputs; + // TODO(bsteiner): also rewrite switches as identity. For now we just record + // them + std::unordered_set + identity_switches; + + GraphView view(optimized_graph); + for (const NodeDef& node : optimized_graph->node()) { + if (!IsSwitch(node)) { + continue; + } + if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) { + continue; + } + GraphView::InputPort ctrl_port(&node, 1); + GraphView::OutputPort ctrl_node = view.GetRegularFanin(ctrl_port); + if (!IsConstant(*ctrl_node.node)) { + continue; + } + Tensor selector; + CHECK(selector.FromProto(ctrl_node.node->attr().at("value").tensor())); + const int dead_fanout = selector.scalar()() ? 0 : 1; + GraphView::OutputPort dead(const_cast(&node), dead_fanout); + identity_switches.insert(dead); + + SetVector zombie_inputs; + for (const GraphView::InputPort& port : view.GetFanout(dead)) { + if (dead_nodes.find(port.node) == dead_nodes.end()) { + zombie_inputs.PushBack(port); + } + } + // If we encounter a single node that must be preserved in the fanout of the + // switch node we need to preserve the entire switch fanout: we therefore + // work on a local copy that only gets committed to the master copy once the + // whole fanout has been explored. + std::unordered_set local_dead_nodes = dead_nodes; + std::unordered_map> local_dead_merge_inputs = + dead_merge_inputs; + bool found_node_to_preserve = false; + while (!found_node_to_preserve && !zombie_inputs.Empty()) { + GraphView::InputPort dead = zombie_inputs.PopBack(); + if (nodes_to_preserve.find(dead.node->name()) != + nodes_to_preserve.end()) { + found_node_to_preserve = true; + break; + } + + if (local_dead_nodes.find(dead.node) != local_dead_nodes.end()) { + continue; + } + + if (IsMerge(*dead.node)) { + const int fanout = dead.node->attr().at("N").i(); + if (fanout > 2) { + // This never happens in practice, so we'll just skip these to + // simplify the code for now. + found_node_to_preserve = true; + break; + } + GraphView::OutputPort value_index(dead.node, 1); + const std::unordered_set& + index_fanout = view.GetFanout(value_index); + if (!index_fanout.empty()) { + // The 2nd output (that indicates which input is propagated) is + // connected. This never happens in practice, so we'll just skip this + // case to simplify the code for now. + found_node_to_preserve = true; + break; + } + + bool fully_dead = false; + if (dead.port_id < 0) { + // If the control dependency never gets triggered the merge will also + // never get triggered. + local_dead_nodes.insert(dead.node); + fully_dead = true; + } else { + local_dead_merge_inputs[dead.node].insert(dead.port_id); + if (local_dead_merge_inputs[dead.node].size() == + dead.node->attr().at("N").i()) { + fully_dead = true; + } + if (fully_dead) { + local_dead_nodes.insert(dead.node); + for (const GraphView::InputPort& port : + view.GetFanouts(*dead.node, true)) { + zombie_inputs.PushBack(port); + } + } + } + } else { + if (local_dead_nodes.insert(dead.node).second) { + for (const GraphView::InputPort& dead_fanout : + view.GetFanouts(*dead.node, true)) { + zombie_inputs.PushBack(dead_fanout); + } + } + } + } + if (!found_node_to_preserve) { + std::swap(dead_nodes, local_dead_nodes); + std::swap(dead_merge_inputs, local_dead_merge_inputs); + } + } + + int last = optimized_graph->node_size() - 1; + for (int i = optimized_graph->node_size() - 1; i >= 0; --i) { + NodeDef* node = optimized_graph->mutable_node(i); + if (dead_nodes.find(node) != dead_nodes.end()) { + optimized_graph->mutable_node()->SwapElements(i, last); + last--; + } + } + optimized_graph->mutable_node()->DeleteSubrange(last + 1, dead_nodes.size()); + + for (const auto& itr : dead_merge_inputs) { + NodeDef* dead_node = itr.first; + if (dead_nodes.find(dead_node) != dead_nodes.end()) { + // The node has been pruned since all its inputs are dead. + continue; + } + const std::set& dead_inputs = itr.second; + for (int index : dead_inputs) { + dead_node->mutable_input()->DeleteSubrange(index, 1); + } + dead_node->set_op("Identity"); + dead_node->mutable_attr()->erase("N"); + } + return Status::OK(); +} + } // namespace Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, @@ -517,6 +652,11 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, if (options_.enable_stack_push_removal) { TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph)); } + if (opt_level_ == RewriterConfig::AGGRESSIVE && + options_.enable_dead_branch_removal) { + TF_RETURN_IF_ERROR( + RemoveDeadBranches(item.NodesToPreserve(), optimized_graph)); + } return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.h b/tensorflow/core/grappler/optimizers/loop_optimizer.h index 764506f7c1..85b8e65543 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.h +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.h @@ -54,6 +54,7 @@ class LoopOptimizer : public GraphOptimizer { struct LoopOptimizerOptions { bool enable_loop_invariant_node_motion = false; bool enable_stack_push_removal = true; + bool enable_dead_branch_removal = true; static LoopOptimizerOptions Default(RewriterConfig::Toggle opt_level) { LoopOptimizerOptions options; diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc index 10ec544424..6fd177b710 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc @@ -589,5 +589,112 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) { } } +TEST_F(LoopOptimizerTest, RemoveDeadBranches) { + Scope scope = Scope::NewRootScope(); + Output v_in = ops::Variable(scope.WithOpName("v_in"), {3}, DT_FLOAT); + + Output ctrl1 = ops::Const(scope.WithOpName("ctrl1"), false, TensorShape({})); + ops::Switch s1(scope.WithOpName("switch1"), v_in, ctrl1); + Output square1 = ops::Square(scope.WithOpName("square1"), s1.output_false); + Output sqrt1 = ops::Sqrt(scope.WithOpName("sqrt1"), s1.output_true); + + Output ctrl2 = ops::Const(scope.WithOpName("ctrl2"), true, TensorShape({})); + ops::Switch s2(scope.WithOpName("switch2"), v_in, ctrl2); + Output square2 = ops::Square(scope.WithOpName("square2"), s2.output_false); + Output sqrt2 = ops::Sqrt(scope.WithOpName("sqrt2"), s2.output_true); + + Output ctrl3 = ops::Const(scope.WithOpName("ctrl3"), false, TensorShape({})); + ops::Switch s3(scope.WithOpName("switch3"), v_in, ctrl3); + Output square3 = ops::Square(scope.WithOpName("square3"), s3.output_false); + Output sqrt3 = ops::Sqrt(scope.WithOpName("sqrt3"), s3.output_true); + + Output ctrl4 = ops::Const(scope.WithOpName("ctrl4"), false, TensorShape({})); + ops::Switch s4(scope.WithOpName("switch4"), v_in, ctrl4); + Output square4 = ops::Square(scope.WithOpName("square4"), s4.output_false); + Output sqrt4 = ops::Sqrt(scope.WithOpName("sqrt4"), s4.output_true); + + ops::Merge m1(scope.WithOpName("m1"), {square1, sqrt1}); + ops::Merge m2(scope.WithOpName("m2"), {v_in, square1}); + ops::Merge m3(scope.WithOpName("m3"), {v_in, sqrt1}); + ops::Merge m4(scope.WithOpName("m4"), {square1, sqrt2}); + ops::Merge m5(scope.WithOpName("m5"), {square2, sqrt1}); + ops::Merge m6(scope.WithOpName("m6").WithControlDependencies(sqrt2), + {v_in, square1}); + ops::Merge m7(scope.WithOpName("m7").WithControlDependencies(sqrt1), + {v_in, square1}); + + ops::Switch s5(scope.WithOpName("switch5"), v_in, ctrl1); + Output id1 = ops::Identity(scope.WithOpName("id1"), s5.output_false); + Output id2 = ops::Identity(scope.WithOpName("id2"), s5.output_true); + ops::Merge m8(scope.WithOpName("m8"), {id1, id2}); + + ops::Switch s6(scope.WithOpName("switch6"), v_in, ctrl1); + Output id3 = ops::Identity(scope.WithOpName("id3"), s6.output_false); + Output id4 = ops::Identity(scope.WithOpName("id4"), s6.output_true); + ops::Merge m9(scope.WithOpName("m9"), {id3, id4}); + + GrapplerItem item; + item.fetch.push_back("m8"); + item.fetch.push_back("id4"); + + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_CHECK_OK(status); + + for (const NodeDef& node : output.node()) { + // These nodes should have been pruned + EXPECT_NE("Square1", node.name()); + EXPECT_NE("Sqrt2", node.name()); + EXPECT_NE("m5", node.name()); + EXPECT_NE("m7", node.name()); + + if (node.name() == "m1") { + // sqrt1 is dead + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("square1", node.input(0)); + } else if (node.name() == "m2") { + // both inputs are alive + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("v_in", node.input(0)); + EXPECT_EQ("square1", node.input(1)); + } else if (node.name() == "m3") { + // sqrt1 is dead + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("v_in", node.input(0)); + } else if (node.name() == "m4") { + // both inputs are alive + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("square1", node.input(0)); + EXPECT_EQ("sqrt2", node.input(1)); + } else if (node.name() == "m6") { + // both inputs are alive and the control dependency can get triggered + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("v_in", node.input(0)); + EXPECT_EQ("square1", node.input(1)); + EXPECT_EQ("^sqrt2", node.input(2)); + } else if (node.name() == "m8") { + // The node is to be preserved because of a fetch + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("id1", node.input(0)); + EXPECT_EQ("id2", node.input(1)); + } else if (node.name() == "m9") { + // The node is to be preserved because of a fetch + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("id3", node.input(0)); + EXPECT_EQ("id4", node.input(1)); + } + } +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index b87ae05546..1c6fef59ea 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -65,7 +65,7 @@ class NodeMap { // A vector with a set. The set stores the same elements as the vector, and // quickly answers whether a value is in the vector. Duplicated elements are not // allowed for now. -template +template > class SetVector { public: // Returns false if value already existed in the set, true otherwise. @@ -91,7 +91,7 @@ class SetVector { void Reserve(int64 size) { vector_.reserve(size); } private: - std::unordered_set set_; + std::unordered_set set_; std::vector vector_; }; -- GitLab From c92de2f3fc81c701ab29408a8a84cd6e41e96fe5 Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Sat, 5 May 2018 10:44:20 -0400 Subject: [PATCH 139/839] Skip all ops with function attribute by default --- tensorflow/core/api_def/BUILD | 6 ------ .../api_def/java_api/api_def_FilterDataset.pbtxt | 4 ---- .../api_def/java_api/api_def_FlatMapDataset.pbtxt | 4 ---- tensorflow/core/api_def/java_api/api_def_For.pbtxt | 4 ---- .../java_api/api_def_GeneratorDataset.pbtxt | 4 ---- .../java_api/api_def_GroupByWindowDataset.pbtxt | 4 ---- tensorflow/core/api_def/java_api/api_def_If.pbtxt | 4 ---- .../java_api/api_def_InterleaveDataset.pbtxt | 4 ---- .../java_api/api_def_MapAndBatchDataset.pbtxt | 4 ---- .../core/api_def/java_api/api_def_MapDataset.pbtxt | 4 ---- .../api_def/java_api/api_def_OneShotIterator.pbtxt | 4 ---- .../api_def_ParallelInterleaveDataset.pbtxt | 4 ---- .../java_api/api_def_ParallelMapDataset.pbtxt | 4 ---- .../core/api_def/java_api/api_def_RemoteCall.pbtxt | 4 ---- .../api_def/java_api/api_def_ScanDataset.pbtxt | 4 ---- .../java_api/api_def_SymbolicGradient.pbtxt | 4 ---- .../core/api_def/java_api/api_def_While.pbtxt | 4 ---- tensorflow/java/BUILD | 1 - tensorflow/java/src/gen/cc/op_generator.cc | 14 +++++++++++++- 19 files changed, 13 insertions(+), 72 deletions(-) delete mode 100644 tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_For.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_If.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_While.pbtxt diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 06b797e32e..1454a1d9b2 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -30,12 +30,6 @@ filegroup( visibility = ["//tensorflow:internal"], ) -filegroup( - name = "java_api_def", - srcs = glob(["java_api/*"]), - visibility = ["//tensorflow:internal"], -) - cc_library( name = "excluded_ops_lib", srcs = ["excluded_ops.cc"], diff --git a/tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt deleted file mode 100644 index debd7e5709..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "FilterDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt deleted file mode 100644 index 329ab15ef5..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "FlatMapDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_For.pbtxt b/tensorflow/core/api_def/java_api/api_def_For.pbtxt deleted file mode 100644 index caabc947bb..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_For.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "For" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt deleted file mode 100644 index a6e5167c30..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "GeneratorDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt deleted file mode 100644 index 4c0b2084a8..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "GroupByWindowDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_If.pbtxt b/tensorflow/core/api_def/java_api/api_def_If.pbtxt deleted file mode 100644 index 13b8635ca7..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_If.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "If" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt deleted file mode 100644 index ed748d4d2a..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "InterleaveDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt deleted file mode 100644 index cb96bf63d8..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "MapAndBatchDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt deleted file mode 100644 index e0ab8dd9db..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "MapDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt b/tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt deleted file mode 100644 index 13130e6882..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "OneShotIterator" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt deleted file mode 100644 index 6a985d24fa..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "ParallelInterleaveDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt deleted file mode 100644 index 64f25b9e5e..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "ParallelMapDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt b/tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt deleted file mode 100644 index 2ccb5c8cf3..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "RemoteCall" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt deleted file mode 100644 index 3463e60049..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "ScanDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt b/tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt deleted file mode 100644 index 88c3acea74..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "SymbolicGradient" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_While.pbtxt b/tensorflow/core/api_def/java_api/api_def_While.pbtxt deleted file mode 100644 index 33756682c3..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_While.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "While" - visibility: SKIP -} \ No newline at end of file diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 7cd0208dbf..0cc8e7c3e2 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -72,7 +72,6 @@ tf_java_op_gen_srcjar( name = "java_op_gen_sources", api_def_srcs = [ "//tensorflow/core/api_def:base_api_def", - "//tensorflow/core/api_def:java_api_def", ], base_package = "org.tensorflow.op", gen_tool = ":java_op_gen_tool", diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 7355b3a395..f4cefbe933 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -420,6 +420,18 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, writer.EndType(); } +bool CanGenerateOp(const OpDef& op_def, const ApiDef& api_def) { + if (api_def.visibility() == ApiDef::SKIP) { + return false; + } + for (const auto& attr : op_def.attr()) { + if (attr.type() == "func") { + return false; // TODO(karllessard) add support for function attributes + } + } + return true; +} + } // namespace Status OpGenerator::Run(const OpList& op_list, const string& base_package, @@ -441,7 +453,7 @@ Status OpGenerator::Run(const OpList& op_list, const string& base_package, api_map.UpdateDocs(); for (const auto& op_def : op_list.op()) { const ApiDef* api_def = api_map.GetApiDef(op_def.name()); - if (api_def->visibility() != ApiDef::SKIP) { + if (CanGenerateOp(op_def, *api_def)) { OpSpec op(OpSpec::Create(op_def, *api_def)); for (const EndpointSpec& endpoint : op.endpoints()) { GenerateOp(op, endpoint, base_package, output_dir, env_); -- GitLab From 90bbbdcc42a67c93ba8dcbc66f9c1d06909c48cb Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Sat, 5 May 2018 10:48:22 -0400 Subject: [PATCH 140/839] Remove comment left-over --- tensorflow/core/api_def/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 1454a1d9b2..19d6438809 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -4,7 +4,6 @@ # The following targets can be used to access ApiDefs: # :base_api_def # :python_api_def -# :java_api_def package( default_visibility = ["//visibility:private"], -- GitLab From ab48fb528221152299fb08da8116d2eca54b8423 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 4 May 2018 15:40:07 -0700 Subject: [PATCH 141/839] [XLA] Print allowed attributes when the user specifies an invalid attr. PiperOrigin-RevId: 195482974 --- .../compiler/xla/tools/parser/hlo_parser.cc | 30 +++++++++++++------ .../xla/tools/parser/hlo_parser_test.cc | 2 +- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 3a945fb3b1..40dc0730ce 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -30,6 +30,7 @@ namespace { using tensorflow::StringPiece; using tensorflow::gtl::optional; +using tensorflow::str_util::Join; using tensorflow::str_util::Split; using tensorflow::str_util::SplitAndParseAsInts; using tensorflow::strings::Printf; @@ -53,7 +54,7 @@ class HloParser { std::unique_ptr ConsumeHloModule() { return std::move(module_); } // Returns the error information. - string GetError() const { return tensorflow::str_util::Join(error_, "\n"); } + string GetError() const { return Join(error_, "\n"); } private: // ParseXXX returns false if an error occurred. @@ -245,7 +246,7 @@ bool HloParser::Error(LocTy loc, StringPiece msg) { error_lines.push_back(std::string(lexer_.GetLine(loc))); error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); - error_.push_back(tensorflow::str_util::Join(error_lines, "\n")); + error_.push_back(Join(error_lines, "\n")); VLOG(1) << "Error: " << error_.back(); return false; } @@ -1488,11 +1489,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", - tensorflow::str_util::Join( - elems_seen_until_dim, ",", - [](string* out, const int64& num_elems) { - tensorflow::strings::StrAppend(out, num_elems - 1); - }), + Join(elems_seen_until_dim, ",", + [](string* out, const int64& num_elems) { + tensorflow::strings::StrAppend(out, num_elems - 1); + }), "]"); }; do { @@ -1680,7 +1680,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, return Error( index_loc, StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", tensorflow::str_util::Join(index, ", "), "]")); + ": [", Join(index, ", "), "]")); } } if (!ParseToken(TokKind::kColon, @@ -1848,7 +1848,19 @@ bool HloParser::ParseAttributeHelper( } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { - return Error(loc, Printf("unexpected attribute %s", name.c_str())); + string allowed_attrs; + if (attrs.empty()) { + allowed_attrs = "No attributes are allowed here."; + } else { + allowed_attrs = StrCat( + "Allowed attributes: ", + Join(attrs, ", ", + [&](string* out, const std::pair& kv) { + StrAppend(out, kv.first); + })); + } + return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(), + allowed_attrs.c_str())); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 4e085bc89c..d38d8907a6 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -1138,7 +1138,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { )"; ExpectHasSubstr(Parse(original).status().error_message(), - "unexpected attribute calls"); + "unexpected attribute \"calls\""); } TEST_F(HloParserTest, MissingAttribute) { -- GitLab From 008a3b69a601dc68fd940eb8a03b0c445714a339 Mon Sep 17 00:00:00 2001 From: Karmel Allison Date: Fri, 4 May 2018 16:01:02 -0700 Subject: [PATCH 142/839] Add the ability to export separate SavedModels for train and eval mode to Estimator with two new methods, available in tf.contrib: export_all_saved_models and export_saved_model_for_mode. PiperOrigin-RevId: 195485922 --- tensorflow/contrib/estimator/BUILD | 38 ++ tensorflow/contrib/estimator/__init__.py | 3 + .../estimator/python/estimator/export.py | 216 ++++++++++ .../estimator/python/estimator/export_test.py | 391 ++++++++++++++++++ tensorflow/python/estimator/BUILD | 1 + tensorflow/python/estimator/estimator.py | 346 +++++++++++++--- tensorflow/python/estimator/estimator_test.py | 336 ++++++++++++++- tensorflow/python/estimator/export/export.py | 325 +++++++++++---- .../python/estimator/export/export_output.py | 223 +++++++++- .../estimator/export/export_output_test.py | 110 +++++ .../python/estimator/export/export_test.py | 253 +++++++++++- tensorflow/python/estimator/model_fn.py | 8 + tensorflow/python/saved_model/builder_impl.py | 54 ++- tensorflow/python/saved_model/constants.py | 6 + .../python/saved_model/saved_model_test.py | 90 ++++ .../python/saved_model/signature_constants.py | 6 + .../python/saved_model/signature_def_utils.py | 2 + .../saved_model/signature_def_utils_impl.py | 56 +++ .../saved_model/signature_def_utils_test.py | 95 +++++ .../python/saved_model/tag_constants.py | 5 + 20 files changed, 2373 insertions(+), 191 deletions(-) create mode 100644 tensorflow/contrib/estimator/python/estimator/export.py create mode 100644 tensorflow/contrib/estimator/python/estimator/export_test.py diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 571e2e3a5d..e9a68801ef 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -17,6 +17,7 @@ py_library( ":boosted_trees", ":dnn", ":dnn_linear_combined", + ":export", ":extenders", ":head", ":linear", @@ -180,6 +181,43 @@ py_test( ], ) +py_library( + name = "export", + srcs = [ + "python/estimator/export.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( + name = "export_test", + size = "medium", + srcs = ["python/estimator/export_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], # b/62863147 + deps = [ + ":export", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:metrics", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:session", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:tag_constants", + ], +) + py_library( name = "head", srcs = [ diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index d43b3ea6bf..ec502f86dd 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * +from tensorflow.contrib.estimator.python.estimator.export import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * from tensorflow.contrib.estimator.python.estimator.linear import * @@ -56,6 +57,8 @@ _allowed_symbols = [ 'TowerOptimizer', 'RNNClassifier', 'RNNEstimator', + 'export_saved_model_for_mode', + 'export_all_saved_models', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/export.py b/tensorflow/contrib/estimator/python/estimator/export.py new file mode 100644 index 0000000000..e7e366a3f2 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export.py @@ -0,0 +1,216 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrapper for methods to export train/eval graphs from Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import model_fn as model_fn_lib + + +def export_saved_model_for_mode( + estimator, export_dir_base, input_receiver_fn, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Exports a single train/eval/predict graph as a SavedModel. + + For a detailed guide, see + @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn, steps=1000) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + export_dir = tf.contrib.estimator.export_saved_model_for_mode( + classifier, + export_dir_base='my_model/', + input_receiver_fn=train_rcvr_fn, + mode=model_fn_lib.ModeKeys.TRAIN) + + # export_dir is a timestamped directory with the SavedModel, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + ... + ``` + + This method takes an input_receiver_fn and mode. For the mode passed in, + this method builds a new graph by calling the input_receiver_fn to obtain + feature and label `Tensor`s. Next, this method calls the `Estimator`'s + model_fn in the passed mode to generate the model graph based on + those features and labels, and restores the given checkpoint + (or, lacking that, the most recent checkpoint) into the graph. + Finally, it creates a timestamped export directory below the + export_dir_base, and writes a `SavedModel` into it containing + the `MetaGraphDef` for the given mode and its associated signatures. + + For prediction, the exported `MetaGraphDef` will provide one `SignatureDef` + for each element of the export_outputs dict returned from the model_fn, + named using the same keys. One of these keys is always + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which + signature will be served when a serving request does not specify one. + For each signature, the outputs are provided by the corresponding + `ExportOutput`s, and the inputs are always the input receivers provided by + the serving_input_receiver_fn. + + For training and evaluation, the train_op is stored in an extra collection, + and loss, metrics, and predictions are included in a SignatureDef for the + mode in question. + + Extra assets may be written into the SavedModel via the assets_extra + argument. This should be a dict, where each key gives a destination path + (including the filename) relative to the assets.extra directory. The + corresponding value gives the full path of the source file to be copied. + For example, the simple case of copying a single file without renaming it + is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating with mode will be exported. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_saved_model_for_mode( + export_dir_base, input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=mode) + # pylint: enable=protected-access + + +def export_all_saved_models( + estimator, export_dir_base, input_receiver_fn_map, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False): + # pylint: disable=line-too-long + """Exports requested train/eval/predict graphs as separate SavedModels. + + This is a wrapper around export_saved_model_for_mode that accepts + multiple modes simultaneously and creates directories for each under + export_dir_base. See `Estimator.export_saved_model_for_mode` for + further details as to how the export works for each mode. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + serve_rcvr_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( + feature_spec) + + rcvr_fn_map = { + model_fn_lib.ModeKeys.TRAIN: train_rcvr_fn, + model_fn_lib.ModeKeys.PREDICT: serve_rcvr_fn, + } + + export_dirs = tf.contrib.estimator.export_all_saved_models( + classifier, + export_dir_base='my_model/', + input_receiver_fn_map=rcvr_fn_map) + + # export_dirs is a dict of directories with SavedModels, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], + export_dirs[tf.estimator.ModeKeys.TRAIN]) + ... + ``` + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn + mappings, where the input_receiver_fn is a function that takes no + argument and returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + + Returns: + A dict of tf.estimator.ModeKeys value to string path for each exported + directory. + + Raises: + ValueError: if any input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_all_saved_models( + export_dir_base, input_receiver_fn_map, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs) + # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/export_test.py b/tensorflow/contrib/estimator/python/estimator/export_test.py new file mode 100644 index 0000000000..89d02582e1 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export_test.py @@ -0,0 +1,391 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib wrapping of export_saved_model_for_mode functionality. + +These are direct copies of the tests included in core, with import locations +changed. These should be removed when the functionality in core is part of the +public API. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from tensorflow.contrib.estimator.python.estimator import export as contrib_export +from tensorflow.python.client import session +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.training import training +from tensorflow.python.util import compat + + +def _model_fn_for_export_tests(features, labels, mode): + _, _ = features, labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + update_global_step = state_ops.assign_add(training.get_global_step(), 1) + with ops.control_dependencies([update_global_step]): + train_op = constant_op.constant(2.) + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + loss=constant_op.constant(1.), + train_op=train_op, + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + + +def _x_y_input_fn(): + return ({'x': constant_op.constant([[1], [1]]), + 'y': constant_op.constant([[2], [2]])}, + constant_op.constant([[1], [1]])) + + +def _model_fn_with_x_y(features, labels, mode): + _ = labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + if mode == model_fn_lib.ModeKeys.PREDICT: + variables.Variable(36., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + else: + prefix = 'eval_' if mode == model_fn_lib.ModeKeys.EVAL else '' + + multiplied = math_ops.multiply( + features['x'], features['y'], name='{}multiplied'.format(prefix)) + metrics = {'mean': metrics_lib.mean(features['x'] - features['y'], + name='{}mean'.format(prefix))} + variables.Variable(1., name='later_var') + variables.Variable(3., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=multiplied, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + eval_metric_ops=metrics) + + +def _get_serving_input_receiver_fn(): + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + return export.build_parsing_serving_input_receiver_fn(feature_spec) + + +def _get_supervised_input_receiver_fn(): + feature_spec = { + 'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'), + 'y': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_y') + } + label_spec = array_ops.placeholder( + dtype=dtypes.float32, shape=[1], name='truth') + + return export.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + +class EstimatorExportTest(test.TestCase): + + def test_export_saved_model_train(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.TRAIN) + + def test_export_saved_model_eval(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL) + + def test_export_saved_model_predict(self): + self._test_export_saved_model_for_mode( + _get_serving_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT) + + def _test_export_saved_model_for_mode(self, input_receiver_fn, mode): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = contrib_export.export_saved_model_for_mode( + est, export_dir_base, input_receiver_fn, mode=mode) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + tag_set = model_fn_lib.EXPORT_TAG_MAP[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('name_collision_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_receiver_map(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertFalse('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_train_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertTrue('mean/update_op' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_eval_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertTrue('eval_mean/value' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_no_serving(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 2) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + # TODO(karmel): is this the desired behavior when names are shared? + self.assertTrue('feature_x_1' in graph_ops) + self.assertTrue('feature_y_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_three_defs(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + # Restore, to validate that the export was well-formed. + for mode, tag_set in model_fn_lib.EXPORT_TAG_MAP.items(): + export_dir = export_dirs[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('global_step/Assign' in graph_ops) + self.assertTrue('global_step/Initializer/zeros' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_all_vars(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_name_collision(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertEqual(3, collection_vars[-1].eval()) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # This is a non-obvious detail: when we load the estimator spec + # for predict, name_collision gets set to 36. However, we then restore + # from checkpoint, which should overwrite that var and make it the 3 + # from training. In practice, this would not be a good way to write + # a model_fn, but leaving this check in for now to ensure consistency + # with what would happen given our current order of spec, then + # checkpoint. + self.assertEqual(3, collection_vars[-1].eval()) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def _test_export_all_saved_models(self, input_receiver_fn_map): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_with_x_y) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dirs = contrib_export.export_all_saved_models( + est, export_dir_base, input_receiver_fn_map) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + + for _, export_dir in export_dirs.items(): + self._validate_exported_files(export_dir) + + return export_dirs, tmpdir + + def _validate_exported_files(self, export_dir): + self.assertTrue(gfile.Exists(export_dir)) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.index')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.data-00000-of-00001')))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 56dec1eaa1..b25cc7aa26 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -91,6 +91,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/saved_model:signature_constants", + "//tensorflow/python/saved_model:tag_constants", "@six_archive//:six", ], ) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 530a4a24ef..9ae64d230e 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -37,9 +37,8 @@ from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config from tensorflow.python.estimator import util -from tensorflow.python.estimator.export.export import build_all_signature_defs -from tensorflow.python.estimator.export.export import get_temp_export_dir -from tensorflow.python.estimator.export.export import get_timestamped_export_dir +from tensorflow.python.estimator.export import export as export_helpers +from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops @@ -51,7 +50,6 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import constants -from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import device_setter @@ -609,73 +607,283 @@ class Estimator(object): are provided, or no checkpoint can be found. """ # pylint: enable=line-too-long + return self._export_saved_model_for_mode( + export_dir_base, + serving_input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=model_fn_lib.ModeKeys.PREDICT) + + def _export_all_saved_models( + self, export_dir_base, input_receiver_fn_map, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False): + # pylint: disable=line-too-long + """Exports requested train/eval/predict graphs as separate SavedModels. + + This is a wrapper around export_saved_model_for_mode that accepts + multiple modes simultaneously and creates directories for each under + export_dir_base. See `Estimator.export_saved_model_for_mode` for + further details as to how the export works for each mode. + + See tf.contrib.estimator.export_all_saved_models for the currently + exposed version of this function. + + Args: + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn + mappings, where the input_receiver_fn is a function that takes no + argument and returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + + Returns: + A dict of tf.estimator.ModeKeys value to string path for each exported + directory. + + Raises: + ValueError: if any input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + # TODO(b/65561022): Consider allowing multiple input_receiver_fns per mode. + exported = {} + for mode, input_receiver_fn in input_receiver_fn_map.items(): + export_mode_dir = os.path.join( + compat.as_bytes(export_dir_base), + compat.as_bytes(mode)) + gfile.MakeDirs(export_mode_dir) + + exported_path = self._export_saved_model_for_mode( + export_mode_dir, + input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=mode) + + exported[mode] = exported_path + + return exported + + def _export_saved_model_for_mode( + self, export_dir_base, input_receiver_fn, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Exports a single train/eval/predict graph as a SavedModel. + + For a detailed guide, see + @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + + See tf.contrib.estimator.export_saved_model_for_mode for the currently + exposed version of this function. + + This method takes an input_receiver_fn and mode. For the mode passed in, + this method builds a new graph by calling the input_receiver_fn to obtain + feature and label `Tensor`s. Next, this method calls the `Estimator`'s + model_fn in the passed mode to generate the model graph based on + those features and labels, and restores the given checkpoint + (or, lacking that, the most recent checkpoint) into the graph. + Finally, it creates a timestamped export directory below the + export_dir_base, and writes a `SavedModel` into it containing + the `MetaGraphDef` for the given mode and its associated signatures. + + For prediction, the exported `MetaGraphDef` will provide one `SignatureDef` + for each element of the export_outputs dict returned from the model_fn, + named using the same keys. One of these keys is always + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which + signature will be served when a serving request does not specify one. + For each signature, the outputs are provided by the corresponding + `ExportOutput`s, and the inputs are always the input receivers provided by + the serving_input_receiver_fn. + + For training and evaluation, the train_op is stored in an extra collection, + and loss, metrics, and predictions are included in a SignatureDef for the + mode in question. + + Extra assets may be written into the SavedModel via the assets_extra + argument. This should be a dict, where each key gives a destination path + (including the filename) relative to the assets.extra directory. The + corresponding value gives the full path of the source file to be copied. + For example, the simple case of copying a single file without renaming it + is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + + Args: + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating with mode will be exported. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long with context.graph_mode(): - if serving_input_receiver_fn is None: - raise ValueError('serving_input_receiver_fn must be defined.') + if not input_receiver_fn: + raise ValueError('An input_receiver_fn must be defined.') - with ops.Graph().as_default() as g: - self._create_and_assert_global_step(g) - random_seed.set_random_seed(self._config.tf_random_seed) - serving_input_receiver = serving_input_receiver_fn() + if not checkpoint_path: + # Locate the latest checkpoint + checkpoint_path = saver.latest_checkpoint(self._model_dir) + if not checkpoint_path: + raise ValueError("Couldn't find trained model at %s." % self._model_dir) - # Call the model_fn and collect the export_outputs. - estimator_spec = self._call_model_fn( - features=serving_input_receiver.features, - labels=None, - mode=model_fn_lib.ModeKeys.PREDICT, - config=self.config) - - # Build the SignatureDefs from receivers and all outputs - signature_def_map = build_all_signature_defs( - serving_input_receiver.receiver_tensors, - estimator_spec.export_outputs, - serving_input_receiver.receiver_tensors_alternatives) - - if not checkpoint_path: - # Locate the latest checkpoint - checkpoint_path = saver.latest_checkpoint(self._model_dir) - if not checkpoint_path: - raise ValueError( - "Couldn't find trained model at %s." % self._model_dir) - - export_dir = get_timestamped_export_dir(export_dir_base) - temp_export_dir = get_temp_export_dir(export_dir) - - # TODO(soergel): Consider whether MonitoredSession makes sense here - with tf_session.Session(config=self._session_config) as session: - - saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( - sharded=True) - saver_for_restore.restore(session, checkpoint_path) - - local_init_op = ( - estimator_spec.scaffold.local_init_op or - monitored_session.Scaffold.default_local_init_op()) - - # Perform the export - builder = saved_model_builder.SavedModelBuilder(temp_export_dir) - builder.add_meta_graph_and_variables( - session, [tag_constants.SERVING], - signature_def_map=signature_def_map, - assets_collection=ops.get_collection( - ops.GraphKeys.ASSET_FILEPATHS), - legacy_init_op=local_init_op, - strip_default_attrs=strip_default_attrs) - builder.save(as_text) - - # Add the extra assets - if assets_extra: - assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir), - compat.as_bytes('assets.extra')) - for dest_relative, source in assets_extra.items(): - dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), - compat.as_bytes(dest_relative)) - dest_path = os.path.dirname(dest_absolute) - gfile.MakeDirs(dest_path) - gfile.Copy(source, dest_absolute) - - gfile.Rename(temp_export_dir, export_dir) - return export_dir + export_dir = export_helpers.get_timestamped_export_dir(export_dir_base) + temp_export_dir = export_helpers.get_temp_export_dir(export_dir) + + builder = saved_model_builder.SavedModelBuilder(temp_export_dir) + + self._add_meta_graph_and_variables_for_mode( + builder, input_receiver_fn, checkpoint_path, + strip_default_attrs, mode) + + builder.save(as_text) + + # Add the extra assets + if assets_extra: + assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir), + compat.as_bytes('assets.extra')) + for dest_relative, source in assets_extra.items(): + dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), + compat.as_bytes(dest_relative)) + dest_path = os.path.dirname(dest_absolute) + gfile.MakeDirs(dest_path) + gfile.Copy(source, dest_absolute) + + gfile.Rename(temp_export_dir, export_dir) + return export_dir + + def _add_meta_graph_and_variables_for_mode( + self, builder, input_receiver_fn, checkpoint_path, strip_default_attrs, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Loads variables and adds them along with a MetaGraphDef for saving. + + Args: + builder: instance of SavedModelBuilder that will be used for saving. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating which mode will be exported. + """ + # pylint: enable=line-too-long + with ops.Graph().as_default() as g: + self._create_and_assert_global_step(g) + random_seed.set_random_seed(self._config.tf_random_seed) + + input_receiver = input_receiver_fn() + + # Call the model_fn and collect the export_outputs. + estimator_spec = self._call_model_fn( + features=input_receiver.features, + labels=getattr(input_receiver, 'labels', None), + mode=mode, + config=self.config) + + export_outputs = self._get_export_outputs_for_spec(estimator_spec) + + # Build the SignatureDefs from receivers and all outputs + signature_def_map = export_helpers.build_all_signature_defs( + input_receiver.receiver_tensors, + export_outputs, + getattr(input_receiver, 'receiver_tensors_alternatives', None), + serving_only=(mode == model_fn_lib.ModeKeys.PREDICT)) + + with tf_session.Session(config=self._session_config) as session: + + export_tags = model_fn_lib.EXPORT_TAG_MAP[mode] + + local_init_op = ( + estimator_spec.scaffold.local_init_op or + monitored_session.Scaffold.default_local_init_op()) + + saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( + sharded=True) + saver_for_restore.restore(session, checkpoint_path) + + # We add the train op explicitly for now, so that we don't have to + # change the Builder public interface. Note that this is a no-op + # for prediction, where train_op is None. + builder._add_train_op(estimator_spec.train_op) # pylint: disable=protected-access + + builder.add_meta_graph_and_variables( + session, + tags=export_tags, + signature_def_map=signature_def_map, + assets_collection=ops.get_collection( + ops.GraphKeys.ASSET_FILEPATHS), + strip_default_attrs=strip_default_attrs, + legacy_init_op=local_init_op) + + def _get_export_outputs_for_spec(self, estimator_spec): + """Given an EstimatorSpec, determine what our export outputs should be. + + EstimatorSpecs contain export_outputs that are used for serving, but for + training and eval graphs, we must wrap the tensors of interest in + appropriate ExportOutput objects. + + Args: + estimator_spec: EstimatorSpec object that will be exported. + + Returns: + a dict mapping export_output_name to ExportOutput object. + + Raises: + ValueError: if an appropriate ExportOutput cannot be found for the + passed EstimatorSpec.mode + """ + mode = estimator_spec.mode + if mode == model_fn_lib.ModeKeys.PREDICT: + outputs = estimator_spec.export_outputs + else: + if mode == model_fn_lib.ModeKeys.TRAIN: + output_class = export_output.TrainOutput + elif mode == model_fn_lib.ModeKeys.EVAL: + output_class = export_output.EvalOutput + else: + raise ValueError( + 'Export output type not found for mode: {}'.format(mode)) + + export_out = output_class( + loss=estimator_spec.loss, + predictions=estimator_spec.predictions, + metrics=estimator_spec.eval_metric_ops) + outputs = {mode: export_out} + + return outputs def _get_features_from_input_fn(self, input_fn, mode): """Extracts the `features` from return values of `input_fn`.""" @@ -1544,3 +1752,5 @@ def _get_default_warm_start_settings(warm_start_from): else: raise ValueError('warm_start_from must be a string or a WarmStartSettings, ' 'instead got {}'.format(type(warm_start_from))) + + diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 76b45b7f57..02088e5134 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -1865,6 +1865,41 @@ def _model_fn_for_export_tests(features, labels, mode): 'test': export_output.ClassificationOutput(scores, classes)}) +def _x_y_input_fn(): + return ({'x': constant_op.constant([[1], [1]]), + 'y': constant_op.constant([[2], [2]])}, + constant_op.constant([[1], [1]])) + + +def _model_fn_with_x_y(features, labels, mode): + _ = labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + if mode == model_fn_lib.ModeKeys.PREDICT: + variables.Variable(36., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + else: + prefix = 'eval_' if mode == model_fn_lib.ModeKeys.EVAL else '' + + multiplied = math_ops.multiply( + features['x'], features['y'], name='{}multiplied'.format(prefix)) + metrics = {'mean': metrics_lib.mean(features['x'] - features['y'], + name='{}mean'.format(prefix))} + variables.Variable(1., name='later_var') + variables.Variable(3., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=multiplied, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + eval_metric_ops=metrics) + + def _model_fn_with_saveables_for_export_tests(features, labels, mode): _, _ = features, labels table = saver_test_utils.CheckpointedOp(name='v2') @@ -1881,21 +1916,41 @@ def _model_fn_with_saveables_for_export_tests(features, labels, mode): 'test': export_output.PredictOutput({'prediction': prediction})}) +def _get_serving_input_receiver_fn(): + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + return export.build_parsing_serving_input_receiver_fn(feature_spec) + + +def _get_supervised_input_receiver_fn(): + feature_spec = { + 'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'), + 'y': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_y') + } + label_spec = array_ops.placeholder( + dtype=dtypes.float32, shape=[1], name='truth') + + return export.build_raw_supervised_input_receiver_fn(feature_spec, label_spec) + + _VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n' _EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n' class EstimatorExportTest(test.TestCase): - def test_export_savedmodel_proto_roundtrip(self): - tmpdir = tempfile.mkdtemp() - est = estimator.Estimator(model_fn=_model_fn_for_export_tests) - est.train(input_fn=dummy_input_fn, steps=1) + def test_export_savedmodel_proto_roundtrip_raw_receiver(self): feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( feature_spec) + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=dummy_input_fn, steps=1) + # Perform the export. export_dir_base = os.path.join( compat.as_bytes(tmpdir), compat.as_bytes('export')) @@ -1904,6 +1959,266 @@ class EstimatorExportTest(test.TestCase): # Check that all the files are in the right places. self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertTrue('weight' in graph_ops) + + def test_export_saved_model_train(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.TRAIN) + + def test_export_saved_model_eval(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL) + + def test_export_saved_model_predict(self): + self._test_export_saved_model_for_mode( + _get_serving_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT) + + def _test_export_saved_model_for_mode(self, input_receiver_fn, mode): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = est._export_saved_model_for_mode( + export_dir_base, input_receiver_fn, mode=mode) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + tag_set = model_fn_lib.EXPORT_TAG_MAP[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('name_collision_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_receiver_map(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertFalse('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_train_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertTrue('mean/update_op' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_eval_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertTrue('eval_mean/value' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_no_serving(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 2) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + # TODO(karmel): is this the desired behavior when names are shared? + self.assertTrue('feature_x_1' in graph_ops) + self.assertTrue('feature_y_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_three_defs(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + # Restore, to validate that the export was well-formed. + for mode, tag_set in model_fn_lib.EXPORT_TAG_MAP.items(): + export_dir = export_dirs[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('global_step/Assign' in graph_ops) + self.assertTrue('global_step/Initializer/zeros' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_all_vars(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_name_collision(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertEqual(3, collection_vars[-1].eval()) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # This is a non-obvious detail: when we load the estimator spec + # for predict, name_collision gets set to 36. However, we then restore + # from checkpoint, which should overwrite that var and make it the 3 + # from training. In practice, this would not be a good way to write + # a model_fn, but leaving this check in for now to ensure consistency + # with what would happen given our current order of spec, then + # checkpoint. + self.assertEqual(3, collection_vars[-1].eval()) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def _test_export_all_saved_models(self, input_receiver_fn_map): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_with_x_y) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dirs = est._export_all_saved_models( + export_dir_base, input_receiver_fn_map) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + + for _, export_dir in export_dirs.items(): + self._validate_exported_files(export_dir) + + return export_dirs, tmpdir + + def _validate_exported_files(self, export_dir): self.assertTrue(gfile.Exists(export_dir)) self.assertTrue(gfile.Exists(os.path.join( compat.as_bytes(export_dir), @@ -1918,18 +2233,6 @@ class EstimatorExportTest(test.TestCase): compat.as_bytes(export_dir), compat.as_bytes('variables/variables.data-00000-of-00001')))) - # Restore, to validate that the export was well-formed. - with ops.Graph().as_default() as graph: - with session.Session(graph=graph) as sess: - loader.load(sess, [tag_constants.SERVING], export_dir) - graph_ops = [x.name for x in graph.get_operations()] - self.assertTrue('input_example_tensor' in graph_ops) - self.assertTrue('ParseExample/ParseExample' in graph_ops) - self.assertTrue('weight' in graph_ops) - - # Clean up. - gfile.DeleteRecursively(tmpdir) - def test_export_savedmodel_with_saveables_proto_roundtrip(self): tmpdir = tempfile.mkdtemp() est = estimator.Estimator( @@ -2485,5 +2788,6 @@ class EstimatorIntegrationTest(test.TestCase): serving_input_receiver_fn) self.assertTrue(gfile.Exists(export_dir)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index 41c1f5a2e2..9aafb56679 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -40,6 +40,60 @@ from tensorflow.python.util.tf_export import tf_export _SINGLE_FEATURE_DEFAULT_NAME = 'feature' _SINGLE_RECEIVER_DEFAULT_NAME = 'input' +_SINGLE_LABEL_DEFAULT_NAME = 'label' + + +def _wrap_and_check_receiver_tensors(receiver_tensors): + """Ensure that receiver_tensors is a dict of str to Tensor mappings. + + Args: + receiver_tensors: dict of str to Tensors, or a single Tensor. + + Returns: + dict of str to Tensors; this is the original dict if one was passed, or + the original tensor wrapped in a dictionary. + + Raises: + ValueError: if receiver_tensors is None, or has non-string keys, + or non-Tensor values + """ + if receiver_tensors is None: + raise ValueError('receiver_tensors must be defined.') + if not isinstance(receiver_tensors, dict): + receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} + for name, tensor in receiver_tensors.items(): + _check_tensor_key(name, error_label='receiver_tensors') + _check_tensor(tensor, name, error_label='receiver_tensor') + return receiver_tensors + + +def _check_tensor(tensor, name, error_label='feature'): + """Check that passed `tensor` is a Tensor or SparseTensor.""" + if not (isinstance(tensor, ops.Tensor) + or isinstance(tensor, sparse_tensor.SparseTensor)): + fmt_name = ' {}'.format(name) if name else '' + value_error = ValueError( + '{}{} must be a Tensor or SparseTensor.'.format(error_label, fmt_name)) + # NOTE(ericmc): This if-else block is a specific carve-out for + # LabeledTensor, which has a `.tensor` attribute and which is + # convertible to tf.Tensor via ops.convert_to_tensor. + # Allowing all types convertible to tf.Tensor is considered by soergel@ + # to be too permissive. + # TODO(soergel): accept any type convertible to Tensor, + # as in cl/193238295 snapshot #6. + if hasattr(tensor, 'tensor'): + try: + ops.convert_to_tensor(tensor) + except TypeError: + raise value_error + else: + raise value_error + + +def _check_tensor_key(name, error_label='feature'): + if not isinstance(name, six.string_types): + raise ValueError( + '{} keys must be strings: {}.'.format(error_label, name)) @tf_export('estimator.export.ServingInputReceiver') @@ -51,16 +105,18 @@ class ServingInputReceiver(collections.namedtuple( The expected return values are: features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or `SparseTensor`, specifying the features to be passed to the model. - receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying - input nodes where this receiver expects to be fed by default. Typically, - this is a single placeholder expecting serialized `tf.Example` protos. + receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` + or `SparseTensor`, specifying input nodes where this receiver expects to + be fed by default. Typically, this is a single placeholder expecting + serialized `tf.Example` protos. receiver_tensors_alternatives: a dict of string to additional - groups of receiver tensors, each of which may be a `Tensor` or a dict of - string to `Tensor`. These named receiver tensor alternatives generate - additional serving signatures, which may be used to feed inputs at - different points within the input receiver subgraph. A typical usage is - to allow feeding raw feature `Tensor`s *downstream* of the - tf.parse_example() op. Defaults to None. + groups of receiver tensors, each of which may be a `Tensor`, + `SparseTensor`, or dict of string to `Tensor` or`SparseTensor`. + These named receiver tensor alternatives generate additional serving + signatures, which may be used to feed inputs at different points within + the input receiver subgraph. A typical usage is to allow feeding raw + feature `Tensor`s *downstream* of the tf.parse_example() op. + Defaults to None. """ def __new__(cls, features, receiver_tensors, @@ -70,36 +126,10 @@ class ServingInputReceiver(collections.namedtuple( if not isinstance(features, dict): features = {_SINGLE_FEATURE_DEFAULT_NAME: features} for name, tensor in features.items(): - if not isinstance(name, six.string_types): - raise ValueError('feature keys must be strings: {}.'.format(name)) - if not (isinstance(tensor, ops.Tensor) - or isinstance(tensor, sparse_tensor.SparseTensor)): - value_error = ValueError( - 'feature {} must be a Tensor or SparseTensor.'.format(name)) - # NOTE(ericmc): This if-else block is a specific carve-out for - # LabeledTensor, which has a `.tensor` attribute and which is - # convertible to tf.Tensor via ops.convert_to_tensor. - # Allowing all types convertible to tf.Tensor is considered by soergel@ - # to be too permissive. - if hasattr(tensor, 'tensor'): - try: - ops.convert_to_tensor(tensor) - except TypeError: - raise value_error - else: - raise value_error - - if receiver_tensors is None: - raise ValueError('receiver_tensors must be defined.') - if not isinstance(receiver_tensors, dict): - receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} - for name, tensor in receiver_tensors.items(): - if not isinstance(name, six.string_types): - raise ValueError( - 'receiver_tensors keys must be strings: {}.'.format(name)) - if not isinstance(tensor, ops.Tensor): - raise ValueError( - 'receiver_tensor {} must be a Tensor.'.format(name)) + _check_tensor_key(name) + _check_tensor(tensor, name) + + receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors) if receiver_tensors_alternatives is not None: if not isinstance(receiver_tensors_alternatives, dict): @@ -115,14 +145,9 @@ class ServingInputReceiver(collections.namedtuple( receiver_tensors_alternatives[alternative_name] = ( receiver_tensors_alt) for name, tensor in receiver_tensors_alt.items(): - if not isinstance(name, six.string_types): - raise ValueError( - 'receiver_tensors keys must be strings: {}.'.format(name)) - if not (isinstance(tensor, ops.Tensor) - or isinstance(tensor, sparse_tensor.SparseTensor)): - raise ValueError( - 'receiver_tensor {} must be a Tensor or SparseTensor.'.format( - name)) + _check_tensor_key(name, error_label='receiver_tensors_alternative') + _check_tensor( + tensor, name, error_label='receiver_tensors_alternative') return super(ServingInputReceiver, cls).__new__( cls, @@ -155,25 +180,25 @@ class TensorServingInputReceiver(collections.namedtuple( The expected return values are: features: A single `Tensor` or `SparseTensor`, representing the feature to be passed to the model. - receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying - input nodes where this receiver expects to be fed by default. Typically, - this is a single placeholder expecting serialized `tf.Example` protos. + receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` + or `SparseTensor`, specifying input nodes where this receiver expects to + be fed by default. Typically, this is a single placeholder expecting + serialized `tf.Example` protos. receiver_tensors_alternatives: a dict of string to additional - groups of receiver tensors, each of which may be a `Tensor` or a dict of - string to `Tensor`. These named receiver tensor alternatives generate - additional serving signatures, which may be used to feed inputs at - different points within the input receiver subgraph. A typical usage is - to allow feeding raw feature `Tensor`s *downstream* of the - tf.parse_example() op. Defaults to None. + groups of receiver tensors, each of which may be a `Tensor`, + `SparseTensor`, or dict of string to `Tensor` or`SparseTensor`. + These named receiver tensor alternatives generate additional serving + signatures, which may be used to feed inputs at different points within + the input receiver subgraph. A typical usage is to allow feeding raw + feature `Tensor`s *downstream* of the tf.parse_example() op. + Defaults to None. """ def __new__(cls, features, receiver_tensors, receiver_tensors_alternatives=None): if features is None: raise ValueError('features must be defined.') - if not (isinstance(features, ops.Tensor) - or isinstance(features, sparse_tensor.SparseTensor)): - raise ValueError('feature must be a Tensor or SparseTensor.') + _check_tensor(features, None) receiver = ServingInputReceiver( features=features, @@ -187,6 +212,49 @@ class TensorServingInputReceiver(collections.namedtuple( receiver_tensors_alternatives=receiver.receiver_tensors_alternatives) +class SupervisedInputReceiver(collections.namedtuple( + 'SupervisedInputReceiver', + ['features', 'labels', 'receiver_tensors'])): + """A return type for a training_input_receiver_fn or eval_input_receiver_fn. + + This differs from a ServingInputReceiver in that (1) this receiver expects + a set of labels to be passed in with features, and (2) this receiver does + not support receiver_tensors_alternatives, which are primarily used for + serving. + + The expected return values are: + features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the features to be passed to the model. + labels: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the labels to be passed to the model. + receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` + or `SparseTensor`, specifying input nodes where this receiver expects to + be fed by default. Typically, this is a single placeholder expecting + serialized `tf.Example` protos. + + """ + + def __new__(cls, features, labels, receiver_tensors): + # Both features and labels can be dicts or raw tensors. + for input_vals, error_label in ((features, 'feature'), (labels, 'label')): + if input_vals is None: + raise ValueError('{}s must be defined.'.format(error_label)) + if isinstance(input_vals, dict): + for name, tensor in input_vals.items(): + _check_tensor_key(name, error_label=error_label) + _check_tensor(tensor, name, error_label=error_label) + else: + _check_tensor(input_vals, None, error_label=error_label) + + receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors) + + return super(SupervisedInputReceiver, cls).__new__( + cls, + features=features, + labels=labels, + receiver_tensors=receiver_tensors) + + @tf_export('estimator.export.build_parsing_serving_input_receiver_fn') def build_parsing_serving_input_receiver_fn(feature_spec, default_batch_size=None): @@ -216,6 +284,23 @@ def build_parsing_serving_input_receiver_fn(feature_spec, return serving_input_receiver_fn +def _placeholder_from_tensor(t, default_batch_size=None): + shape_list = t.get_shape().as_list() + shape_list[0] = default_batch_size + shape = tensor_shape.TensorShape(shape_list) + + # Reuse the feature tensor's op name (t.op.name) for the placeholder, + # excluding the index from the tensor's name (t.name): + # t.name = "%s:%d" % (t.op.name, t._value_index) + return array_ops.placeholder(dtype=t.dtype, shape=shape, name=t.op.name) + + +def _placeholders_from_receiver_tensors_dict( + input_vals, default_batch_size=None): + return {name: _placeholder_from_tensor(t, default_batch_size) + for name, t in input_vals.items()} + + @tf_export('estimator.export.build_raw_serving_input_receiver_fn') def build_raw_serving_input_receiver_fn(features, default_batch_size=None): """Build a serving_input_receiver_fn expecting feature Tensors. @@ -233,17 +318,9 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None): """ def serving_input_receiver_fn(): """A serving_input_receiver_fn that expects features to be fed directly.""" - receiver_tensors = {} - for name, t in features.items(): - shape_list = t.get_shape().as_list() - shape_list[0] = default_batch_size - shape = tensor_shape.TensorShape(shape_list) - - # Reuse the feature tensor's op name (t.op.name) for the placeholder, - # excluding the index from the tensor's name (t.name): - # t.name = "%s:%d" % (t.op.name, t._value_index) - receiver_tensors[name] = array_ops.placeholder( - dtype=t.dtype, shape=shape, name=t.op.name) + receiver_tensors = _placeholders_from_receiver_tensors_dict( + features, default_batch_size) + # TODO(b/34885899): remove the unnecessary copy # The features provided are simply the placeholders, but we defensively copy # the dict because it may be mutated. @@ -252,13 +329,100 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None): return serving_input_receiver_fn +def build_raw_supervised_input_receiver_fn( + features, labels, default_batch_size=None): + """Build a supervised_input_receiver_fn for raw features and labels. + + This function wraps tensor placeholders in a supervised_receiver_fn + with the expectation that the features and labels appear precisely as + the model_fn expects them. Features and labels can therefore be dicts of + tensors, or raw tensors. + + Args: + features: a dict of string to `Tensor` or `Tensor`. + labels: a dict of string to `Tensor` or `Tensor`. + default_batch_size: the number of query examples expected per batch. + Leave unset for variable batch size (recommended). + + Returns: + A supervised_input_receiver_fn. + + Raises: + ValueError: if features and labels have overlapping keys. + """ + # Check for overlapping keys before beginning. + try: + feat_keys = features.keys() + except AttributeError: + feat_keys = [_SINGLE_RECEIVER_DEFAULT_NAME] + try: + label_keys = labels.keys() + except AttributeError: + label_keys = [_SINGLE_LABEL_DEFAULT_NAME] + + overlap_keys = set(feat_keys) & set(label_keys) + if overlap_keys: + raise ValueError('Features and labels must have distinct keys. ' + 'Found overlapping keys: {}'.format(overlap_keys)) + + def supervised_input_receiver_fn(): + """A receiver_fn that expects pass-through features and labels.""" + if not isinstance(features, dict): + features_cp = _placeholder_from_tensor(features, default_batch_size) + receiver_features = {_SINGLE_RECEIVER_DEFAULT_NAME: features_cp} + else: + receiver_features = _placeholders_from_receiver_tensors_dict( + features, default_batch_size) + features_cp = receiver_features + + if not isinstance(labels, dict): + labels_cp = _placeholder_from_tensor(labels, default_batch_size) + receiver_labels = {_SINGLE_LABEL_DEFAULT_NAME: labels_cp} + else: + receiver_labels = _placeholders_from_receiver_tensors_dict( + labels, default_batch_size) + labels_cp = receiver_labels + + receiver_tensors = dict(receiver_features) + receiver_tensors.update(receiver_labels) + return SupervisedInputReceiver(features_cp, labels_cp, receiver_tensors) + + return supervised_input_receiver_fn + + ### Below utilities are specific to SavedModel exports. def build_all_signature_defs(receiver_tensors, export_outputs, - receiver_tensors_alternatives=None): - """Build `SignatureDef`s for all export outputs.""" + receiver_tensors_alternatives=None, + serving_only=True): + """Build `SignatureDef`s for all export outputs. + + Args: + receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying + input nodes where this receiver expects to be fed by default. Typically, + this is a single placeholder expecting serialized `tf.Example` protos. + export_outputs: a dict of ExportOutput instances, each of which has + an as_signature_def instance method that will be called to retrieve + the signature_def for all export output tensors. + receiver_tensors_alternatives: a dict of string to additional + groups of receiver tensors, each of which may be a `Tensor` or a dict of + string to `Tensor`. These named receiver tensor alternatives generate + additional serving signatures, which may be used to feed inputs at + different points within the input receiver subgraph. A typical usage is + to allow feeding raw feature `Tensor`s *downstream* of the + tf.parse_example() op. Defaults to None. + serving_only: boolean; if true, resulting signature defs will only include + valid serving signatures. If false, all requested signatures will be + returned. + + Returns: + signature_def representing all passed args. + + Raises: + ValueError: if export_outputs is not a dict + """ if not isinstance(receiver_tensors, dict): receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} if export_outputs is None or not isinstance(export_outputs, dict): @@ -293,17 +457,24 @@ def build_all_signature_defs(receiver_tensors, _log_signature_report(signature_def_map, excluded_signatures) # The above calls to export_output.as_signature_def should return only - # valid signatures; if there is a validity problem, they raise ValueError, - # which we ignore above. Consequently the call to is_valid_signature here - # should not remove anything else; it's just an extra sanity check. - return {k: v for k, v in signature_def_map.items() - if signature_def_utils.is_valid_signature(v)} + # valid signatures; if there is a validity problem, they raise a ValueError, + # in which case we exclude that signature from signature_def_map above. + # The is_valid_signature check ensures that the signatures produced are + # valid for serving, and acts as an additional sanity check for export + # signatures produced for serving. We skip this check for training and eval + # signatures, which are not intended for serving. + if serving_only: + signature_def_map = {k: v for k, v in signature_def_map.items() + if signature_def_utils.is_valid_signature(v)} + return signature_def_map _FRIENDLY_METHOD_NAMES = { signature_constants.CLASSIFY_METHOD_NAME: 'Classify', signature_constants.REGRESS_METHOD_NAME: 'Regress', signature_constants.PREDICT_METHOD_NAME: 'Predict', + signature_constants.SUPERVISED_TRAIN_METHOD_NAME: 'Train', + signature_constants.SUPERVISED_EVAL_METHOD_NAME: 'Eval', } diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index 87b964be37..d387ea2940 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -38,6 +38,8 @@ class ExportOutput(object): __metaclass__ = abc.ABCMeta + _SEPARATOR_CHAR = '/' + @abc.abstractmethod def as_signature_def(self, receiver_tensors): """Generate a SignatureDef proto for inclusion in a MetaGraphDef. @@ -51,6 +53,52 @@ class ExportOutput(object): """ pass + def _check_output_key(self, key, error_label): + # For multi-head models, the key can be a tuple. + if isinstance(key, tuple): + key = self._SEPARATOR_CHAR.join(key) + + if not isinstance(key, six.string_types): + raise ValueError( + '{} output key must be a string; got {}.'.format(error_label, key)) + return key + + def _wrap_and_check_outputs( + self, outputs, single_output_default_name, error_label=None): + """Wraps raw tensors as dicts and checks type. + + Note that we create a new dict here so that we can overwrite the keys + if necessary. + + Args: + outputs: A `Tensor` or a dict of string to `Tensor`. + single_output_default_name: A string key for use in the output dict + if the provided `outputs` is a raw tensor. + error_label: descriptive string for use in error messages. If none, + single_output_default_name will be used. + + Returns: + A dict of tensors + + Raises: + ValueError: if the outputs dict keys are not strings or tuples of strings + or the values are not Tensors. + """ + if not isinstance(outputs, dict): + outputs = {single_output_default_name: outputs} + + output_dict = {} + for key, value in outputs.items(): + error_name = error_label or single_output_default_name + key = self._check_output_key(key, error_name) + if not isinstance(value, ops.Tensor): + raise ValueError( + '{} output value must be a Tensor; got {}.'.format( + error_name, value)) + + output_dict[key] = value + return output_dict + @tf_export('estimator.export.ClassificationOutput') class ClassificationOutput(ExportOutput): @@ -154,9 +202,6 @@ class RegressionOutput(ExportOutput): return signature_def_utils.regression_signature_def(examples, self.value) -_SINGLE_OUTPUT_DEFAULT_NAME = 'output' - - @tf_export('estimator.export.PredictOutput') class PredictOutput(ExportOutput): """Represents the output of a generic prediction head. @@ -165,6 +210,7 @@ class PredictOutput(ExportOutput): Named outputs must be provided as a dict from string to `Tensor`, """ + _SINGLE_OUTPUT_DEFAULT_NAME = 'output' def __init__(self, outputs): """Constructor for PredictOutput. @@ -177,16 +223,9 @@ class PredictOutput(ExportOutput): ValueError: if the outputs is not dict, or any of its keys are not strings, or any of its values are not `Tensor`s. """ - if not isinstance(outputs, dict): - outputs = {_SINGLE_OUTPUT_DEFAULT_NAME: outputs} - for key, value in outputs.items(): - if not isinstance(key, six.string_types): - raise ValueError( - 'Prediction output key must be a string; got {}.'.format(key)) - if not isinstance(value, ops.Tensor): - raise ValueError( - 'Prediction output value must be a Tensor; got {}.'.format(value)) - self._outputs = outputs + + self._outputs = self._wrap_and_check_outputs( + outputs, self._SINGLE_OUTPUT_DEFAULT_NAME, error_label='Prediction') @property def outputs(self): @@ -195,3 +234,161 @@ class PredictOutput(ExportOutput): def as_signature_def(self, receiver_tensors): return signature_def_utils.predict_signature_def(receiver_tensors, self.outputs) + + +class _SupervisedOutput(ExportOutput): + """Represents the output of a supervised training or eval process.""" + __metaclass__ = abc.ABCMeta + + LOSS_NAME = 'loss' + PREDICTIONS_NAME = 'predictions' + METRICS_NAME = 'metrics' + + METRIC_VALUE_SUFFIX = 'value' + METRIC_UPDATE_SUFFIX = 'update_op' + + _loss = None + _predictions = None + _metrics = None + + def __init__(self, loss=None, predictions=None, metrics=None): + """Constructor for SupervisedOutput (ie, Train or Eval output). + + Args: + loss: dict of Tensors or single Tensor representing calculated loss. + predictions: dict of Tensors or single Tensor representing model + predictions. + metrics: dict of (metric_value, update_op) tuples, or a single tuple. + metric_value must be a Tensor, and update_op must be a Tensor or Op. + + Raises: + ValueError: if any of the outputs' dict keys are not strings or tuples of + strings or the values are not Tensors (or Operations in the case of + update_op). + """ + + if loss is not None: + loss_dict = self._wrap_and_check_outputs(loss, self.LOSS_NAME) + self._loss = self._prefix_output_keys(loss_dict, self.LOSS_NAME) + if predictions is not None: + pred_dict = self._wrap_and_check_outputs( + predictions, self.PREDICTIONS_NAME) + self._predictions = self._prefix_output_keys( + pred_dict, self.PREDICTIONS_NAME) + if metrics is not None: + self._metrics = self._wrap_and_check_metrics(metrics) + + def _prefix_output_keys(self, output_dict, output_name): + """Prepend output_name to the output_dict keys if it doesn't exist. + + This produces predictable prefixes for the pre-determined outputs + of SupervisedOutput. + + Args: + output_dict: dict of string to Tensor, assumed valid. + output_name: prefix string to prepend to existing keys. + + Returns: + dict with updated keys and existing values. + """ + + new_outputs = {} + for key, val in output_dict.items(): + key = self._prefix_key(key, output_name) + new_outputs[key] = val + return new_outputs + + def _prefix_key(self, key, output_name): + if key.find(output_name) != 0: + key = output_name + self._SEPARATOR_CHAR + key + return key + + def _wrap_and_check_metrics(self, metrics): + """Handle the saving of metrics. + + Metrics is either a tuple of (value, update_op), or a dict of such tuples. + Here, we separate out the tuples and create a dict with names to tensors. + + Args: + metrics: dict of (metric_value, update_op) tuples, or a single tuple. + + Returns: + dict of output_names to tensors + + Raises: + ValueError: if the dict key is not a string, or the metric values or ops + are not tensors. + """ + if not isinstance(metrics, dict): + metrics = {self.METRICS_NAME: metrics} + + outputs = {} + for key, (metric_val, metric_op) in metrics.items(): + key = self._check_output_key(key, self.METRICS_NAME) + key = self._prefix_key(key, self.METRICS_NAME) + + val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX + op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX + if not isinstance(metric_val, ops.Tensor): + raise ValueError( + '{} output value must be a Tensor; got {}.'.format( + key, metric_val)) + if (not isinstance(metric_op, ops.Tensor) and + not isinstance(metric_op, ops.Operation)): + raise ValueError( + '{} update_op must be a Tensor or Operation; got {}.'.format( + key, metric_op)) + outputs[val_name] = metric_val + outputs[op_name] = metric_op + + return outputs + + @property + def loss(self): + return self._loss + + @property + def predictions(self): + return self._predictions + + @property + def metrics(self): + return self._metrics + + @abc.abstractmethod + def _get_signature_def_fn(self): + """Returns a function that produces a SignatureDef given desired outputs.""" + pass + + def as_signature_def(self, receiver_tensors): + signature_def_fn = self._get_signature_def_fn() + return signature_def_fn( + receiver_tensors, self.loss, self.predictions, self.metrics) + + +class TrainOutput(_SupervisedOutput): + """Represents the output of a supervised training process. + + This class generates the appropriate signature def for exporting + training output by type-checking and wrapping loss, predictions, and metrics + values. + """ + + def _get_signature_def_fn(self): + return signature_def_utils.supervised_train_signature_def + + +class EvalOutput(_SupervisedOutput): + """Represents the output of a supervised eval process. + + This class generates the appropriate signature def for exporting + eval output by type-checking and wrapping loss, predictions, and metrics + values. + """ + + def _get_signature_def_fn(self): + return signature_def_utils.supervised_eval_signature_def + + + + diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py index 7090e53d80..b21ba91b0f 100644 --- a/tensorflow/python/estimator/export/export_output_test.py +++ b/tensorflow/python/estimator/export/export_output_test.py @@ -225,5 +225,115 @@ class ExportOutputTest(test.TestCase): }) +class MockSupervisedOutput(export_output_lib._SupervisedOutput): + """So that we can test the abstract class methods directly.""" + + def _get_signature_def_fn(self): + pass + + +class SupervisedOutputTest(test.TestCase): + + def test_supervised_outputs_valid(self): + """Tests that no errors are raised when provided outputs are valid.""" + loss = {"my_loss": constant_op.constant([0])} + predictions = {u"output1": constant_op.constant(["foo"])} + metrics = {"metrics": (constant_op.constant([0]), + constant_op.constant([10])), + "metrics2": (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = MockSupervisedOutput(loss, predictions, metrics) + self.assertEqual(outputter.loss["loss/my_loss"], loss["my_loss"]) + self.assertEqual( + outputter.predictions["predictions/output1"], predictions["output1"]) + self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0]) + self.assertEqual( + outputter.metrics["metrics2/update_op"], metrics["metrics2"][1]) + + # Single Tensor is OK too + outputter = MockSupervisedOutput( + loss["my_loss"], predictions["output1"], metrics["metrics"]) + self.assertEqual(outputter.loss, {"loss": loss["my_loss"]}) + self.assertEqual( + outputter.predictions, {"predictions": predictions["output1"]}) + self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0]) + + def test_supervised_outputs_none(self): + outputter = MockSupervisedOutput( + constant_op.constant([0]), None, None) + self.assertEqual(len(outputter.loss), 1) + self.assertEqual(outputter.predictions, None) + self.assertEqual(outputter.metrics, None) + + def test_supervised_outputs_invalid(self): + with self.assertRaisesRegexp(ValueError, "predictions output value must"): + MockSupervisedOutput(constant_op.constant([0]), [3], None) + with self.assertRaisesRegexp(ValueError, "loss output value must"): + MockSupervisedOutput("str", None, None) + with self.assertRaisesRegexp(ValueError, "metrics output value must"): + MockSupervisedOutput(None, None, (15.3, 4)) + with self.assertRaisesRegexp(ValueError, "loss output key must"): + MockSupervisedOutput({25: "Tensor"}, None, None) + + def test_supervised_outputs_tuples(self): + """Tests that no errors are raised when provided outputs are valid.""" + loss = {("my", "loss"): constant_op.constant([0])} + predictions = {(u"output1", "2"): constant_op.constant(["foo"])} + metrics = {("metrics", "twice"): (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = MockSupervisedOutput(loss, predictions, metrics) + self.assertEqual(set(outputter.loss.keys()), set(["loss/my/loss"])) + self.assertEqual(set(outputter.predictions.keys()), + set(["predictions/output1/2"])) + self.assertEqual(set(outputter.metrics.keys()), + set(["metrics/twice/value", "metrics/twice/update_op"])) + + def test_supervised_outputs_no_prepend(self): + """Tests that no errors are raised when provided outputs are valid.""" + loss = {"loss": constant_op.constant([0])} + predictions = {u"predictions": constant_op.constant(["foo"])} + metrics = {u"metrics": (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = MockSupervisedOutput(loss, predictions, metrics) + self.assertEqual(set(outputter.loss.keys()), set(["loss"])) + self.assertEqual(set(outputter.predictions.keys()), set(["predictions"])) + self.assertEqual(set(outputter.metrics.keys()), + set(["metrics/value", "metrics/update_op"])) + + def test_train_signature_def(self): + loss = {"my_loss": constant_op.constant([0])} + predictions = {u"output1": constant_op.constant(["foo"])} + metrics = {"metrics": (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = export_output_lib.TrainOutput(loss, predictions, metrics) + + receiver = {u"features": constant_op.constant(100, shape=(100, 2)), + "labels": constant_op.constant(100, shape=(100, 1))} + sig_def = outputter.as_signature_def(receiver) + + self.assertTrue("loss/my_loss" in sig_def.outputs) + self.assertTrue("metrics/value" in sig_def.outputs) + self.assertTrue("predictions/output1" in sig_def.outputs) + self.assertTrue("features" in sig_def.inputs) + + def test_eval_signature_def(self): + loss = {"my_loss": constant_op.constant([0])} + predictions = {u"output1": constant_op.constant(["foo"])} + + outputter = export_output_lib.EvalOutput(loss, predictions, None) + + receiver = {u"features": constant_op.constant(100, shape=(100, 2)), + "labels": constant_op.constant(100, shape=(100, 1))} + sig_def = outputter.as_signature_def(receiver) + + self.assertTrue("loss/my_loss" in sig_def.outputs) + self.assertFalse("metrics/value" in sig_def.outputs) + self.assertTrue("predictions/output1" in sig_def.outputs) + self.assertTrue("features" in sig_def.inputs) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py index c203be7dac..0af587f2a8 100644 --- a/tensorflow/python/estimator/export/export_test.py +++ b/tensorflow/python/estimator/export/export_test.py @@ -54,7 +54,7 @@ ops.register_tensor_conversion_function(LabeledTensorMock, _convert_labeled_tensor_mock_to_tensor) -class ExportTest(test_util.TensorFlowTestCase): +class ServingInputReceiverTest(test_util.TensorFlowTestCase): def test_serving_input_receiver_constructor(self): """Tests that no errors are raised when input is expected.""" @@ -161,6 +161,165 @@ class ExportTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): _ = export.ServingInputReceiver(feature, receiver_tensor) + +class SupervisedInputReceiverTest(test_util.TensorFlowTestCase): + + def test_input_receiver_constructor(self): + """Tests that no errors are raised when input is expected.""" + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + labels = { + "classes": constant_op.constant([0] * 100), + } + + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + export.SupervisedInputReceiver(features, labels, receiver_tensors) + + def test_input_receiver_raw_values(self): + """Tests that no errors are raised when input is expected.""" + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + + labels = { + "classes": constant_op.constant([0] * 100), + } + + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + rec = export.SupervisedInputReceiver( + features["feature2"], labels, receiver_tensors) + self.assertIsInstance(rec.features, sparse_tensor.SparseTensor) + + rec = export.SupervisedInputReceiver( + features, labels["classes"], receiver_tensors) + self.assertIsInstance(rec.labels, ops.Tensor) + + def test_input_receiver_features_invalid(self): + features = constant_op.constant([0] * 100) + labels = constant_op.constant([0]) + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + + with self.assertRaisesRegexp(ValueError, "features must be defined"): + export.SupervisedInputReceiver( + features=None, + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp(ValueError, "feature keys must be strings"): + export.SupervisedInputReceiver( + features={1: constant_op.constant([1])}, + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp(ValueError, "label keys must be strings"): + export.SupervisedInputReceiver( + features=features, + labels={1: constant_op.constant([1])}, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp( + ValueError, "feature feature1 must be a Tensor or SparseTensor"): + export.SupervisedInputReceiver( + features={"feature1": [1]}, + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp( + ValueError, "feature must be a Tensor or SparseTensor"): + export.SupervisedInputReceiver( + features=[1], + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp( + ValueError, "label must be a Tensor or SparseTensor"): + export.SupervisedInputReceiver( + features=features, + labels=100, + receiver_tensors=receiver_tensors) + + def test_input_receiver_receiver_tensors_invalid(self): + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + labels = constant_op.constant([0]) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors must be defined"): + export.SupervisedInputReceiver( + features=features, + labels=labels, + receiver_tensors=None) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors keys must be strings"): + export.SupervisedInputReceiver( + features=features, + labels=labels, + receiver_tensors={ + 1: array_ops.placeholder(dtypes.string, name="example0")}) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensor example1 must be a Tensor"): + export.SupervisedInputReceiver( + features=features, + labels=labels, + receiver_tensors={"example1": [1]}) + + def test_single_feature_single_receiver(self): + feature = constant_op.constant(5) + label = constant_op.constant(5) + receiver_tensor = array_ops.placeholder(dtypes.string) + input_receiver = export.SupervisedInputReceiver( + feature, label, receiver_tensor) + + # single receiver is automatically named + receiver_key, = input_receiver.receiver_tensors.keys() + self.assertEqual("input", receiver_key) + + def test_multi_feature_single_receiver(self): + features = {"foo": constant_op.constant(5), + "bar": constant_op.constant(6)} + labels = {"value": constant_op.constant(5)} + receiver_tensor = array_ops.placeholder(dtypes.string) + _ = export.SupervisedInputReceiver(features, labels, receiver_tensor) + + def test_multi_feature_multi_receiver(self): + features = {"foo": constant_op.constant(5), + "bar": constant_op.constant(6)} + labels = {"value": constant_op.constant(5)} + receiver_tensors = {"baz": array_ops.placeholder(dtypes.int64), + "qux": array_ops.placeholder(dtypes.float32)} + _ = export.SupervisedInputReceiver(features, labels, receiver_tensors) + + def test_feature_labeled_tensor(self): + feature = LabeledTensorMock() + label = constant_op.constant(5) + receiver_tensor = array_ops.placeholder(dtypes.string) + _ = export.SupervisedInputReceiver(feature, label, receiver_tensor) + + +class ExportTest(test_util.TensorFlowTestCase): + def test_build_parsing_serving_input_receiver_fn(self): feature_spec = {"int_feature": parsing_ops.VarLenFeature(dtypes.int64), "float_feature": parsing_ops.VarLenFeature(dtypes.float32)} @@ -237,6 +396,69 @@ class ExportTest(test_util.TensorFlowTestCase): dtypes.int32, serving_input_receiver.receiver_tensors["feature_2"].dtype) + def test_build_raw_supervised_input_receiver_fn(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"foo": constant_op.constant([5]), + "bar": constant_op.constant([6])} + input_receiver_fn = export.build_raw_supervised_input_receiver_fn( + features, labels) + with ops.Graph().as_default(): + input_receiver = input_receiver_fn() + self.assertEqual(set(["feature_1", "feature_2"]), + set(input_receiver.features.keys())) + self.assertEqual(set(["foo", "bar"]), + set(input_receiver.labels.keys())) + self.assertEqual(set(["feature_1", "feature_2", "foo", "bar"]), + set(input_receiver.receiver_tensors.keys())) + self.assertEqual( + dtypes.string, input_receiver.receiver_tensors["feature_1"].dtype) + self.assertEqual( + dtypes.int32, input_receiver.receiver_tensors["feature_2"].dtype) + + def test_build_raw_supervised_input_receiver_fn_raw_tensors(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"foo": constant_op.constant([5]), + "bar": constant_op.constant([6])} + input_receiver_fn1 = export.build_raw_supervised_input_receiver_fn( + features["feature_1"], labels) + input_receiver_fn2 = export.build_raw_supervised_input_receiver_fn( + features["feature_1"], labels["foo"]) + with ops.Graph().as_default(): + input_receiver = input_receiver_fn1() + self.assertIsInstance(input_receiver.features, ops.Tensor) + self.assertEqual(set(["foo", "bar"]), + set(input_receiver.labels.keys())) + self.assertEqual(set(["input", "foo", "bar"]), + set(input_receiver.receiver_tensors.keys())) + + input_receiver = input_receiver_fn2() + self.assertIsInstance(input_receiver.features, ops.Tensor) + self.assertIsInstance(input_receiver.labels, ops.Tensor) + self.assertEqual(set(["input", "label"]), + set(input_receiver.receiver_tensors.keys())) + + def test_build_raw_supervised_input_receiver_fn_batch_size(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"foo": constant_op.constant([5]), + "bar": constant_op.constant([6])} + input_receiver_fn = export.build_raw_supervised_input_receiver_fn( + features, labels, default_batch_size=10) + with ops.Graph().as_default(): + input_receiver = input_receiver_fn() + self.assertEqual([10], input_receiver.receiver_tensors["feature_1"].shape) + self.assertEqual([10], input_receiver.features["feature_1"].shape) + + def test_build_raw_supervised_input_receiver_fn_overlapping_keys(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"feature_1": constant_op.constant([5]), + "bar": constant_op.constant([6])} + with self.assertRaises(ValueError): + export.build_raw_supervised_input_receiver_fn(features, labels) + def test_build_all_signature_defs_without_receiver_alternatives(self): receiver_tensor = array_ops.placeholder(dtypes.string) output_1 = constant_op.constant([1.]) @@ -404,6 +626,35 @@ class ExportTest(test_util.TensorFlowTestCase): self.assertTrue(int(time_1) < int(time_2)) self.assertTrue(int(time_2) < int(time_3)) + def test_build_all_signature_defs_serving_only(self): + receiver_tensor = {"input": array_ops.placeholder(dtypes.string)} + output_1 = constant_op.constant([1.]) + export_outputs = { + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_output.PredictOutput(outputs=output_1), + "train": export_output.TrainOutput(loss=output_1), + } + + signature_defs = export.build_all_signature_defs( + receiver_tensor, export_outputs) + + expected_signature_defs = { + "serving_default": signature_def_utils.predict_signature_def( + receiver_tensor, {"output": output_1}) + } + + self.assertDictEqual(expected_signature_defs, signature_defs) + + signature_defs = export.build_all_signature_defs( + receiver_tensor, export_outputs, serving_only=False) + + expected_signature_defs.update({ + "train": signature_def_utils.supervised_train_signature_def( + receiver_tensor, loss={"loss": output_1}) + }) + + self.assertDictEqual(expected_signature_defs, signature_defs) + class TensorServingReceiverTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index 8111ab564c..4ab2578769 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook from tensorflow.python.util import nest @@ -53,6 +54,13 @@ class ModeKeys(object): LOSS_METRIC_KEY = 'loss' AVERAGE_LOSS_METRIC_KEY = 'average_loss' +# Mapping of the modes to appropriate tag_constants that are used for saving. +EXPORT_TAG_MAP = { + ModeKeys.PREDICT: [tag_constants.SERVING], + ModeKeys.TRAIN: [tag_constants.TRAINING], + ModeKeys.EVAL: [tag_constants.EVAL], +} + @tf_export('estimator.EstimatorSpec') class EstimatorSpec( diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 3447d917e9..071033b066 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -168,6 +168,25 @@ class SavedModelBuilder(object): raise TypeError("main_op needs to be an Operation: %r" % main_op) ops.add_to_collection(constants.MAIN_OP_KEY, main_op) + def _add_train_op(self, train_op): + """Add train op to the SavedModel. + + Note that this functionality is in development, and liable to be + moved elsewhere. + + Args: + train_op: Op or group of ops that are used for training. These are + stored as a collection with key TRAIN_OP_KEY, but not executed. + + Raises: + TypeError if Train op is not of type `Operation`. + """ + if train_op is not None: + if (not isinstance(train_op, ops.Tensor) and + not isinstance(train_op, ops.Operation)): + raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op) + ops.add_to_collection(constants.TRAIN_OP_KEY, train_op) + def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map): """Tags the meta graph def and adds it to the SavedModel. @@ -238,6 +257,20 @@ class SavedModelBuilder(object): for outputs_key in outputs: self._validate_tensor_info(outputs[outputs_key]) + def _add_collections( + self, assets_collection, legacy_init_op, main_op, train_op): + """Add asset and op collections to be saved.""" + # Save asset files and write them to disk, if any. + self._save_and_write_assets(assets_collection) + + if main_op is None: + # Add legacy init op to the SavedModel. + self._maybe_add_legacy_init_op(legacy_init_op) + else: + self._add_main_op(main_op) + + self._add_train_op(train_op) + def add_meta_graph(self, tags, signature_def_map=None, @@ -285,14 +318,8 @@ class SavedModelBuilder(object): # properly populated. self._validate_signature_def_map(signature_def_map) - # Save asset files and write them to disk, if any. - self._save_and_write_assets(assets_collection) - - if main_op is None: - # Add legacy init op to the SavedModel. - self._maybe_add_legacy_init_op(legacy_init_op) - else: - self._add_main_op(main_op) + # Add assets and ops + self._add_collections(assets_collection, legacy_init_op, main_op, None) # Initialize a saver to generate a sharded output for all saveables in the # current scope. @@ -351,6 +378,7 @@ class SavedModelBuilder(object): strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + """ # pylint: enable=line-too-long if self._has_saved_variables: @@ -362,8 +390,8 @@ class SavedModelBuilder(object): # properly populated. self._validate_signature_def_map(signature_def_map) - # Save asset files and write them to disk, if any. - self._save_and_write_assets(assets_collection) + # Add assets and ops + self._add_collections(assets_collection, legacy_init_op, main_op, None) # Create the variables sub-directory, if it does not exist. variables_dir = os.path.join( @@ -376,12 +404,6 @@ class SavedModelBuilder(object): compat.as_text(variables_dir), compat.as_text(constants.VARIABLES_FILENAME)) - if main_op is None: - # Add legacy init op to the SavedModel. - self._maybe_add_legacy_init_op(legacy_init_op) - else: - self._add_main_op(main_op) - # Initialize a saver to generate a sharded output for all saveables in the # current scope. saver = tf_saver.Saver( diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py index 34206c6f6d..61c6ffbd0d 100644 --- a/tensorflow/python/saved_model/constants.py +++ b/tensorflow/python/saved_model/constants.py @@ -41,6 +41,10 @@ MAIN_OP_KEY = "saved_model_main_op" tf_export("saved_model.constants.MAIN_OP_KEY").export_constant( __name__, "MAIN_OP_KEY") +# CollectionDef key for the SavedModel train op. +# Not exported while export_all_saved_models is in contrib. +TRAIN_OP_KEY = "saved_model_train_op" + # Schema version for SavedModel. SAVED_MODEL_SCHEMA_VERSION = 1 tf_export("saved_model.constants.SAVED_MODEL_SCHEMA_VERSION").export_constant( @@ -65,3 +69,5 @@ tf_export("saved_model.constants.VARIABLES_DIRECTORY").export_constant( VARIABLES_FILENAME = "variables" tf_export("saved_model.constants.VARIABLES_FILENAME").export_constant( __name__, "VARIABLES_FILENAME") + + diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 804255375e..a4d994fd43 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -734,6 +734,96 @@ class SavedModelTest(test.TestCase): builder.add_meta_graph_and_variables( sess, ["foo"], legacy_init_op=legacy_init_op) + def testTrainOp(self): + export_dir = self._get_export_dir("test_train_op") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + # Add `v1` and `v2` variables to the graph. + v1 = variables.Variable(1, name="v1") + ops.add_to_collection("v", v1) + v2 = variables.Variable(2, name="v2") + ops.add_to_collection("v", v2) + + sess.run(variables.global_variables_initializer()) + train_op = state_ops.assign_add(v1, v2) + + sess.run(train_op) + # TODO(karmel): remove explicit call when in the public method. + builder._add_train_op(train_op) + builder.add_meta_graph_and_variables(sess, ["foo"]) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo"], export_dir) + self.assertEqual(3, ops.get_collection("v")[0].eval()) + self.assertEqual(2, ops.get_collection("v")[1].eval()) + self.assertIsInstance( + ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor) + + def testTrainOpGroup(self): + export_dir = self._get_export_dir("test_train_op_group") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + # Add `v1` and `v2` variables to the graph. + v1 = variables.Variable(1, name="v1") + ops.add_to_collection("v", v1) + v2 = variables.Variable(2, name="v2") + ops.add_to_collection("v", v2) + + sess.run(variables.global_variables_initializer()) + train_op = control_flow_ops.group() + + sess.run(train_op) + # TODO(karmel): remove explicit call when in the public method. + builder._add_train_op(train_op) + builder.add_meta_graph_and_variables(sess, ["foo"]) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo"], export_dir) + self.assertEqual(1, ops.get_collection("v")[0].eval()) + self.assertEqual(2, ops.get_collection("v")[1].eval()) + self.assertIsInstance( + ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation) + + def testTrainOpAfterVariables(self): + export_dir = self._get_export_dir("test_train_op_after_variables") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + # Add `v1` and `v2` variables to the graph. + v1 = variables.Variable(1, name="v1") + ops.add_to_collection("v", v1) + v2 = variables.Variable(2, name="v2") + ops.add_to_collection("v", v2) + + sess.run(variables.global_variables_initializer()) + builder.add_meta_graph_and_variables(sess, ["pre_foo"]) + + train_op = state_ops.assign_add(v1, v2) + sess.run(train_op) + # TODO(karmel): remove explicit call when in the public method. + builder._add_train_op(train_op) + builder.add_meta_graph(["foo"]) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo"], export_dir) + self.assertIsInstance( + ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor) + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["pre_foo"], export_dir) + self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY)) + def testMultipleAssets(self): export_dir = self._get_export_dir("test_multiple_assets") builder = saved_model_builder.SavedModelBuilder(export_dir) diff --git a/tensorflow/python/saved_model/signature_constants.py b/tensorflow/python/saved_model/signature_constants.py index 819f351291..99007a9634 100644 --- a/tensorflow/python/saved_model/signature_constants.py +++ b/tensorflow/python/saved_model/signature_constants.py @@ -94,3 +94,9 @@ tf_export("saved_model.signature_constants.REGRESS_OUTPUTS").export_constant( __name__, "REGRESS_OUTPUTS") ################################################################################ +# Train/Eval API constants. +# Not exported while export_all_saved_models is in contrib. + +SUPERVISED_TRAIN_METHOD_NAME = "tensorflow/supervised/training" + +SUPERVISED_EVAL_METHOD_NAME = "tensorflow/supervised/eval" diff --git a/tensorflow/python/saved_model/signature_def_utils.py b/tensorflow/python/saved_model/signature_def_utils.py index ea0f52f17e..27d6b70e9d 100644 --- a/tensorflow/python/saved_model/signature_def_utils.py +++ b/tensorflow/python/saved_model/signature_def_utils.py @@ -26,6 +26,8 @@ from tensorflow.python.saved_model.signature_def_utils_impl import classificatio from tensorflow.python.saved_model.signature_def_utils_impl import is_valid_signature from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def +from tensorflow.python.saved_model.signature_def_utils_impl import supervised_eval_signature_def +from tensorflow.python.saved_model.signature_def_utils_impl import supervised_train_signature_def # pylint: enable=unused-import del absolute_import diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py index d033159188..f8ad788f77 100644 --- a/tensorflow/python/saved_model/signature_def_utils_impl.py +++ b/tensorflow/python/saved_model/signature_def_utils_impl.py @@ -185,6 +185,62 @@ def predict_signature_def(inputs, outputs): return signature_def +def supervised_train_signature_def( + inputs, loss, predictions=None, metrics=None): + return _supervised_signature_def( + signature_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss, + predictions=predictions, metrics=metrics) + + +def supervised_eval_signature_def( + inputs, loss, predictions=None, metrics=None): + return _supervised_signature_def( + signature_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss, + predictions=predictions, metrics=metrics) + + +def _supervised_signature_def( + method_name, inputs, loss=None, predictions=None, + metrics=None): + """Creates a signature for training and eval data. + + This function produces signatures that describe the inputs and outputs + of a supervised process, such as training or evaluation, that + results in loss, metrics, and the like. Note that this function only requires + inputs to be not None. + + Args: + method_name: Method name of the SignatureDef as a string. + inputs: dict of string to `Tensor`. + loss: dict of string to `Tensor` representing computed loss. + predictions: dict of string to `Tensor` representing the output predictions. + metrics: dict of string to `Tensor` representing metric ops. + + Returns: + A train- or eval-flavored signature_def. + + Raises: + ValueError: If inputs or outputs is `None`. + """ + if inputs is None or not inputs: + raise ValueError('{} inputs cannot be None or empty.'.format(method_name)) + + signature_inputs = {key: utils.build_tensor_info(tensor) + for key, tensor in inputs.items()} + + signature_outputs = {} + for output_set in (loss, predictions, metrics): + if output_set is not None: + sig_out = {key: utils.build_tensor_info(tensor) + for key, tensor in output_set.items()} + signature_outputs.update(sig_out) + + signature_def = build_signature_def( + signature_inputs, signature_outputs, method_name) + + return signature_def + + @tf_export('saved_model.signature_def_utils.is_valid_signature') def is_valid_signature(signature_def): """Determine whether a SignatureDef can be served by TensorFlow Serving.""" diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py index b2bd14db8c..ebc5450633 100644 --- a/tensorflow/python/saved_model/signature_def_utils_test.py +++ b/tensorflow/python/saved_model/signature_def_utils_test.py @@ -180,6 +180,101 @@ class SignatureDefUtilsTest(test.TestCase): self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype) self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim)) + def testTrainSignatureDef(self): + self._testSupervisedSignatureDef( + signature_def_utils_impl.supervised_train_signature_def, + signature_constants.SUPERVISED_TRAIN_METHOD_NAME) + + def testEvalSignatureDef(self): + self._testSupervisedSignatureDef( + signature_def_utils_impl.supervised_eval_signature_def, + signature_constants.SUPERVISED_EVAL_METHOD_NAME) + + def _testSupervisedSignatureDef(self, fn_to_test, method_name): + inputs = { + "input-1": constant_op.constant("a", name="input-1"), + "input-2": constant_op.constant("b", name="input-2"), + } + loss = {"loss-1": constant_op.constant(0.45, name="loss-1")} + predictions = { + "classes": constant_op.constant([100], name="classes"), + } + metrics_val = constant_op.constant(100.0, name="metrics_val") + metrics = { + "metrics/value": metrics_val, + "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"), + } + + signature_def = fn_to_test(inputs, loss, predictions, metrics) + + self.assertEqual(method_name, signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(2, len(signature_def.inputs)) + input1_tensor_info_actual = (signature_def.inputs["input-1"]) + self.assertEqual("input-1:0", input1_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype) + self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim)) + input2_tensor_info_actual = (signature_def.inputs["input-2"]) + self.assertEqual("input-2:0", input2_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype) + self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim)) + + # Check outputs in signature def. + self.assertEqual(4, len(signature_def.outputs)) + self.assertEqual("loss-1:0", signature_def.outputs["loss-1"].name) + self.assertEqual(types_pb2.DT_FLOAT, signature_def.outputs["loss-1"].dtype) + + self.assertEqual("classes:0", signature_def.outputs["classes"].name) + self.assertEqual(1, len(signature_def.outputs["classes"].tensor_shape.dim)) + + self.assertEqual( + "metrics_val:0", signature_def.outputs["metrics/value"].name) + self.assertEqual( + types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype) + + self.assertEqual( + "metrics_op:0", signature_def.outputs["metrics/update_op"].name) + self.assertEqual( + types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype) + + def testTrainSignatureDefMissingInputs(self): + self._testSupervisedSignatureDefMissingInputs( + signature_def_utils_impl.supervised_train_signature_def, + signature_constants.SUPERVISED_TRAIN_METHOD_NAME) + + def testEvalSignatureDefMissingInputs(self): + self._testSupervisedSignatureDefMissingInputs( + signature_def_utils_impl.supervised_eval_signature_def, + signature_constants.SUPERVISED_EVAL_METHOD_NAME) + + def _testSupervisedSignatureDefMissingInputs(self, fn_to_test, method_name): + inputs = { + "input-1": constant_op.constant("a", name="input-1"), + "input-2": constant_op.constant("b", name="input-2"), + } + loss = {"loss-1": constant_op.constant(0.45, name="loss-1")} + predictions = { + "classes": constant_op.constant([100], name="classes"), + } + metrics_val = constant_op.constant(100, name="metrics_val") + metrics = { + "metrics/value": metrics_val, + "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"), + } + + with self.assertRaises(ValueError): + signature_def = fn_to_test( + {}, loss=loss, predictions=predictions, metrics=metrics) + + signature_def = fn_to_test(inputs, loss=loss) + self.assertEqual(method_name, signature_def.method_name) + self.assertEqual(1, len(signature_def.outputs)) + + signature_def = fn_to_test(inputs, metrics=metrics, loss=loss) + self.assertEqual(method_name, signature_def.method_name) + self.assertEqual(3, len(signature_def.outputs)) + def testGetShapeAndTypes(self): inputs = { "input-1": constant_op.constant(["a", "b"]), diff --git a/tensorflow/python/saved_model/tag_constants.py b/tensorflow/python/saved_model/tag_constants.py index 5a797da791..c82154e7b9 100644 --- a/tensorflow/python/saved_model/tag_constants.py +++ b/tensorflow/python/saved_model/tag_constants.py @@ -32,6 +32,9 @@ TRAINING = "train" tf_export("saved_model.tag_constants.TRAINING").export_constant( __name__, "TRAINING") +# Tag for the `eval` graph. Not exported while the export logic is in contrib. +EVAL = "eval" + # Tag for the `gpu` graph. GPU = "gpu" tf_export("saved_model.tag_constants.GPU").export_constant(__name__, "GPU") @@ -39,3 +42,5 @@ tf_export("saved_model.tag_constants.GPU").export_constant(__name__, "GPU") # Tag for the `tpu` graph. TPU = "tpu" tf_export("saved_model.tag_constants.TPU").export_constant(__name__, "TPU") + + -- GitLab From 037e52e20157985d3f385f8e0426cdde3f5aae2b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 16:37:27 -0700 Subject: [PATCH 143/839] Expose read-only versions of tensors in tflite. PiperOrigin-RevId: 195491701 --- tensorflow/contrib/lite/interpreter.h | 37 ++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 1074f64263..0450e86ae7 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -201,7 +201,7 @@ class Interpreter { // Overrides execution plan. This bounds checks indices sent in. TfLiteStatus SetExecutionPlan(const std::vector& new_plan); - // Get a tensor data structure. + // Get a mutable tensor data structure. // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this // read/write access to structure TfLiteTensor* tensor(int tensor_index) { @@ -210,9 +210,14 @@ class Interpreter { return &context_.tensors[tensor_index]; } + // Get an immutable tensor data structure. + const TfLiteTensor* tensor(int tensor_index) const { + if (tensor_index >= context_.tensors_size || tensor_index < 0) + return nullptr; + return &context_.tensors[tensor_index]; + } + // Get a pointer to an operation and registration data structure if in bounds. - // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this - // read/write access to structure const std::pair* node_and_registration( int node_index) const { if (node_index >= nodes_and_registration_.size() || node_index < 0) @@ -220,7 +225,8 @@ class Interpreter { return &nodes_and_registration_[node_index]; } - // Perform a checked cast to the appropriate tensor type. + // Perform a checked cast to the appropriate tensor type (mutable pointer + // version). template T* typed_tensor(int tensor_index) { if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { @@ -231,6 +237,18 @@ class Interpreter { return nullptr; } + // Perform a checked cast to the appropriate tensor type (immutable pointer + // version). + template + const T* typed_tensor(int tensor_index) const { + if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) { + if (tensor_ptr->type == typeToTfLiteType()) { + return reinterpret_cast(tensor_ptr->data.raw); + } + } + return nullptr; + } + // Return a pointer into the data of a given input tensor. The given index // must be between 0 and inputs().size(). template @@ -238,13 +256,20 @@ class Interpreter { return typed_tensor(inputs_[index]); } - // Return a pointer into the data of a given output tensor. The given index - // must be between 0 and outputs().size(). + // Return a mutable pointer into the data of a given output tensor. The given + // index must be between 0 and outputs().size(). template T* typed_output_tensor(int index) { return typed_tensor(outputs_[index]); } + // Return an immutable pointer into the data of a given output tensor. The + // given index must be between 0 and outputs().size(). + template + const T* typed_output_tensor(int index) const { + return typed_tensor(outputs_[index]); + } + // Change the dimensionality of a given tensor. Note, this is only acceptable // for tensor indices that are inputs. // Returns status of failure or success. -- GitLab From fa1d92f70adf52d9258384e8528f9a7203a141dd Mon Sep 17 00:00:00 2001 From: Bjarke Hammersholt Roune Date: Fri, 4 May 2018 16:51:06 -0700 Subject: [PATCH 144/839] Add infrastructure for a backend-specific configuration for each op. This is intentionally not exposed in ComputationBuilder and is not intended for use or to be set at all prior to the last backend-specific part of compilation. PiperOrigin-RevId: 195493500 --- tensorflow/compiler/xla/service/hlo.proto | 3 + .../compiler/xla/service/hlo_computation.cc | 52 ++++----- .../compiler/xla/service/hlo_computation.h | 20 ++-- .../compiler/xla/service/hlo_graph_dumper.cc | 43 +++++--- .../compiler/xla/service/hlo_graph_dumper.h | 5 +- .../compiler/xla/service/hlo_instruction.cc | 100 +++++------------- .../compiler/xla/service/hlo_instruction.h | 59 +++++++++-- tensorflow/compiler/xla/service/hlo_module.cc | 12 +++ tensorflow/compiler/xla/service/hlo_module.h | 19 ++++ .../compiler/xla/service/hlo_verifier.cc | 71 ++++++------- tensorflow/compiler/xla/statusor.h | 11 +- tensorflow/compiler/xla/statusor_test.cc | 8 ++ .../compiler/xla/tools/parser/hlo_parser.cc | 10 +- .../xla/tools/parser/hlo_parser_test.cc | 22 +++- 14 files changed, 259 insertions(+), 176 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index aa6860880b..1f7c1cffd3 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -147,6 +147,9 @@ message HloInstructionProto { repeated int64 called_computation_ids = 38; xla.OpSharding sharding = 40; + + // Backend configuration for the instruction. Has backend-specific meaning. + string backend_config = 43; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 594413e88f..17e43c3cb8 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -347,6 +347,11 @@ std::list HloComputation::MakeEmbeddedComputationsList() // To avoid special handling of this computation, cast away const of // 'this'. 'this' is immediately removed from the post order after // construction. + // + // TODO(b/78350259): This violates const-correctness, since while the original + // computation is not returned, we still retrieve non-const computations from + // a const one. Consider also avoiding const for HloComputation, or review XLA + // for const-correctness of non-HloInstruction* types like this. ComputeComputationPostOrder(const_cast(this), &visited, &post_order); @@ -723,18 +728,25 @@ Status HloComputation::Accept( return this->Accept(&visitor); } -std::unique_ptr HloComputation::Clone(const string& suffix, - HloModule* module) { +std::unique_ptr HloComputation::Clone( + const string& suffix, HloModule* module, + HloInstruction::CloneMap* clone_map) { return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - module, suffix); + module, clone_map, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - HloModule* module, const string& suffix) { + HloModule* module, HloInstruction::CloneMap* clone_map, + const string& suffix) { + HloInstruction::CloneMap local_clone_map; + if (clone_map == nullptr) { + clone_map = &local_clone_map; + } + // Look up instr in the replacements map, and return either the replacement, // or instr, if the replacement isn't present. // @@ -756,24 +768,19 @@ std::unique_ptr HloComputation::CloneWithReplacements( } } - std::unordered_map clone_map; std::vector> instructions; std::unique_ptr new_instr = nullptr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { auto replaced_operand = replace(operand); - // If replaced_operand is null, that means 'replacements' asked us not to - // include operand in the new computation. But we can't do that, because - // operand is used by instr. CHECK_NE(replaced_operand, nullptr) - << "replacements map tried to eliminate a used instruction " - << operand->ToString() << ", used by " << instr->ToString(); - new_operands.push_back(FindOrDie(clone_map, replaced_operand)); + << "Replacements map specifies to leave out " << operand->ToString() + << ", but it is used by " << instr->ToString() << "."; + new_operands.push_back(FindOrDie(*clone_map, replaced_operand)); } - new_instr = - instr->CloneWithNewOperands(instr->shape(), new_operands, module); - InsertOrDie(&clone_map, instr, new_instr.get()); + new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands, + module, clone_map); instructions.push_back(std::move(new_instr)); } Builder builder(name() + "." + suffix); @@ -781,27 +788,24 @@ std::unique_ptr HloComputation::CloneWithReplacements( builder.AddInstruction(std::move(instr)); } auto result = builder.Build( - /*root_instruction=*/FindOrDie(clone_map, replace(root_instruction()))); + /*root_instruction=*/FindOrDie(*clone_map, replace(root_instruction()))); // Clone control dependencies. for (auto instr : postorder) { - HloInstruction* new_instr = FindOrDie(clone_map, instr); + HloInstruction* new_instr = FindOrDie(*clone_map, instr); for (auto successor : instr->control_successors()) { auto replaced_successor = replace(successor); - - // successor may not be in clone_map, because it might have been - // removed by the replacements map. - if (replaced_successor == nullptr) { - continue; - } + CHECK_NE(replaced_successor, nullptr) + << "Replacements map specifies to leave out " << successor->ToString() + << ", but it is control-depended-on by " << instr->ToString() << "."; TF_CHECK_OK(new_instr->AddControlDependencyTo( - FindOrDie(clone_map, replaced_successor))); + FindOrDie(*clone_map, replaced_successor))); } } // We cloned the elements of 'replacements', so they're all going to be - // destroyed. HloInstructions need to be detached from their operands before + // destroyed. HloInstructions need to be detached from their operands before // they're destroyed, otherwise they stick around in the operands' users lists // and cause use-after-frees. for (auto& kv : replacements) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 9d3f6e9a2c..9898355625 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -291,11 +291,17 @@ class HloComputation { const std::function& visitor_func) const; // Returns a deep copy of this computation including all instructions. - // If the module pointer is not nullptr, it will be the module where - // the cloned computations will be added to (in order to support deep - // cloning). - std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr); + // + // If the module pointer is not nullptr, then the cloned computations will be + // added to this module in order to support deep cloning. Otherwise the module + // of the computation is used. + // + // If clone_map is not nullptr, then each original instruction that is cloned + // will be inserted and map to its clone. clone_map should not already contain + // any of the instructions to clone. + std::unique_ptr Clone( + const string& suffix = "clone", HloModule* module = nullptr, + HloInstruction::CloneMap* clone_map = nullptr); // Like Clone(), but if an instruction is present in replacement_map, we use // the map's value to replace that instruction in the cloned computation. @@ -305,7 +311,9 @@ class HloComputation { std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - HloModule* module = nullptr, const string& suffix = "clone"); + HloModule* module = nullptr, + HloInstruction::CloneMap* clone_map = nullptr, + const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index bb4db89f0a..794f1b4682 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -322,11 +322,13 @@ class HloDotDumper { public: HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, const DebugOptions& debug_options, bool show_metadata, - const HloExecutionProfile* profile, NodeFilter filter) + bool show_backend_config, const HloExecutionProfile* profile, + NodeFilter filter) : computation_(computation), label_(label.ToString()), debug_options_(debug_options), show_metadata_(show_metadata), + show_backend_config_(show_backend_config), profile_(profile), filter_(std::move(filter)) {} @@ -365,6 +367,7 @@ class HloDotDumper { string GetInstructionNodeShape(const HloInstruction* instr); string GetInstructionNodeLabel(const HloInstruction* instr); string GetInstructionNodeMetadata(const HloInstruction* instr); + string GetInstructionNodeBackendConfig(const HloInstruction* instr); string GetInstructionNodeExtraInfo(const HloInstruction* instr); string GetInstructionNodeInlinedOperands(const HloInstruction* instr); void AddInstructionIncomingEdges(const HloInstruction* instr); @@ -393,6 +396,7 @@ class HloDotDumper { const string label_; // overall name for the graph const DebugOptions& debug_options_; const bool show_metadata_; + const bool show_backend_config_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -611,6 +615,10 @@ tooltip = " "; if (!extra_info.empty()) { StrAppend(&subcomp_label, "
", extra_info); } + string node_backend_config = GetInstructionNodeBackendConfig(parent_instr); + if (!node_backend_config.empty()) { + StrAppend(&subcomp_label, "
", node_backend_config); + } bool highlight = filter_.Highlight(parent_instr); const char* fillcolor; @@ -765,6 +773,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string node_shape = GetInstructionNodeShape(instr); string node_label = GetInstructionNodeLabel(instr); string node_metadata = GetInstructionNodeMetadata(instr); + string node_backend_config = GetInstructionNodeBackendConfig(instr); string extra_info = GetInstructionNodeExtraInfo(instr); string inlined_constants = GetInstructionNodeInlinedOperands(instr); string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); @@ -782,8 +791,8 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } // Build the text that will be displayed inside the node. string node_body = node_label; - for (const string& s : - {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) { + for (const string& s : {trivial_subcomputation, node_metadata, + node_backend_config, extra_info, inlined_constants}) { if (!s.empty()) { StrAppend(&node_body, "
", s); } @@ -1078,6 +1087,15 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { return Join(lines, "
"); } +string HloDotDumper::GetInstructionNodeBackendConfig( + const HloInstruction* instr) { + if (!show_backend_config_ || instr->backend_config().empty()) { + return ""; + } + + return StrCat("backend_config=\"", instr->backend_config(), "\""); +} + string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { std::vector lines; @@ -1404,7 +1422,7 @@ string ExportGraph(const string& graph, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, - bool show_metadata) { + bool show_metadata, bool show_backend_config) { GraphRendererInterface::GraphKind graph_kind; string graph; if (debug_options.xla_hlo_dump_as_graphdef()) { @@ -1414,9 +1432,10 @@ string DumpGraph(const HloComputation& computation, const string& label, &graph)); graph_kind = GraphRendererInterface::TF_GRAPHDEF; } else { - graph = HloDotDumper(&computation, label, debug_options, show_metadata, - hlo_execution_profile, NodeFilter()) - .Dump(); + graph = + HloDotDumper(&computation, label, debug_options, show_metadata, + show_backend_config, hlo_execution_profile, NodeFilter()) + .Dump(); graph_kind = GraphRendererInterface::DOT_GRAPH; } @@ -1427,15 +1446,15 @@ string DumpGraph(const HloComputation& computation, const string& label, } string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata) { + bool show_metadata, bool show_backend_config) { auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); NodeFilter filter = MakeNodeFilter(&node, radius); - string graph = - HloDotDumper(node.parent(), label, debug_options, show_metadata, - /*profile=*/nullptr, filter) - .Dump(); + string graph = HloDotDumper(node.parent(), label, debug_options, + show_metadata, show_backend_config, + /*profile=*/nullptr, filter) + .Dump(); return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 2704aae1e3..fc8e1468ac 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -56,7 +56,7 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile = nullptr, - bool show_metadata = false); + bool show_metadata = false, bool show_backend_config = false); // Like DumpGraph, but renders only nodes "near" the given node in the graph. // @@ -64,7 +64,8 @@ string DumpGraph(const HloComputation& computation, const string& label, // (roughly) corresponds to the max distance a node may be from the primary node // before it's omitted from the graph. string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata = false); + bool show_metadata = false, + bool show_backend_config = false); // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a714d0e114..2c733726a6 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -109,6 +109,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->name_ = proto.name(); instruction->metadata_ = proto.metadata(); + instruction->set_backend_config(proto.backend_config()); if (proto.has_literal()) { TF_ASSIGN_OR_RETURN(instruction->literal_, Literal::CreateFromProto(proto.literal())); @@ -1231,12 +1232,15 @@ bool HloInstruction::HasSideEffect() const { std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, - HloModule* module) const { + HloModule* module, CloneMap* clone_map) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { VLOG(3) << " %" << new_operand->name(); } + if (module == nullptr) { + module = GetModule(); + } std::unique_ptr clone; @@ -1342,7 +1346,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( break; case HloOpcode::kFft: CHECK_EQ(new_operands.size(), 1); - return CreateFft(shape, new_operands[0], fft_type_, fft_length_); + clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_); + break; case HloOpcode::kCrossReplicaSum: clone = CreateCrossReplicaSum(shape, new_operands); break; @@ -1415,9 +1420,15 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kConstant: clone = CreateConstant(literal_->CloneToUnique()); break; - case HloOpcode::kFusion: - clone = CloneFusionWithNewOperands(shape, new_operands, module); + case HloOpcode::kFusion: { + CHECK_NE(module, nullptr); + auto new_fused_computation = module->AddEmbeddedComputation( + fused_instructions_computation()->Clone("clone", module, clone_map)); + clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(), + /*operands=*/new_operands, + /*fusion_computation=*/new_fused_computation); break; + } case HloOpcode::kParameter: clone = CreateParameter(parameter_number_, shape, name_); break; @@ -1481,15 +1492,19 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); + clone->set_backend_config(backend_config()); + if (clone_map != nullptr) { + InsertOrDie(clone_map, this, clone.get()); + } return clone; } HloInstruction::~HloInstruction() {} -std::unique_ptr HloInstruction::Clone(const string& suffix, - HloModule* module) const { +std::unique_ptr HloInstruction::Clone( + const string& suffix, HloModule* module, CloneMap* clone_map) const { std::unique_ptr clone = - CloneWithNewOperands(shape_, operands_, module); + CloneWithNewOperands(shape_, operands_, module, clone_map); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1526,71 +1541,6 @@ std::unique_ptr HloInstruction::Clone(const string& suffix, return clone; } -std::unique_ptr HloInstruction::CloneFusionWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module) const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(parent() != nullptr); - - auto new_instruction = - WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - // Add the operands to our new fusion instruction. - for (HloInstruction* new_operand : operands) { - new_instruction->AppendOperand(new_operand); - } - // Clone all the fused instructions for the new fusion instruction. - HloInstructionMap old_to_new; - std::list> new_fused_instructions; - // Create the list of fused parameters by mapping through the cloned, - // fused instructions. - for (HloInstruction* old_fused_parameter : - fused_instructions_computation()->parameter_instructions()) { - new_fused_instructions.push_back( - old_fused_parameter->Clone("clone", module)); - HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); - InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter); - } - for (auto old_fused_instruction : - fused_instructions_computation()->MakeInstructionPostOrder()) { - if (old_fused_instruction->opcode() == HloOpcode::kParameter) { - FindOrDie(old_to_new, old_fused_instruction); - continue; - } - std::vector new_operands; - for (int64 operand_idx = 0; - operand_idx < old_fused_instruction->operand_count(); ++operand_idx) { - HloInstruction* old_operand = - old_fused_instruction->mutable_operand(operand_idx); - new_operands.push_back(FindOrDie(old_to_new, old_operand)); - } - new_fused_instructions.push_back( - old_fused_instruction->CloneWithNewOperands( - old_fused_instruction->shape(), new_operands, module)); - HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); - new_fused_instruction->set_parent(parent_); - InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); - } - new_instruction->fusion_kind_ = fusion_kind_; - auto computation_builder = HloComputation::Builder( - fused_instructions_computation()->name() + ".clone", - new_instruction.get()); - // We iterated the fusion instructions in reverse post order which means - // that we must reverse our new list of fusion instructions. - for (auto new_fused_instruction_iter = new_fused_instructions.rbegin(); - new_fused_instruction_iter != new_fused_instructions.rend(); - ++new_fused_instruction_iter) { - computation_builder.AddInstruction(std::move(*new_fused_instruction_iter)); - } - if (module == nullptr) { - module = GetModule(); - } - auto fused_root_ = fused_expression_root(); - new_instruction->called_computations_.push_back( - CHECK_NOTNULL(module)->AddEmbeddedComputation( - computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); - return new_instruction; -} - std::pair HloInstruction::LatestNonGteAncestorAndIndex() const { const HloInstruction* hlo = this; @@ -2172,6 +2122,9 @@ string HloInstruction::ToString(const HloPrintOptions& options) const { !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } + if (options.print_backend_config() && !backend_config().empty()) { + StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\""); + } return result; } @@ -2357,6 +2310,7 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); } + return extra; } @@ -2386,6 +2340,7 @@ HloInstructionProto HloInstruction::ToProto() const { } *proto.mutable_metadata() = metadata_; + proto.set_backend_config(backend_config()); if (literal_ != nullptr) { *proto.mutable_literal() = literal_->ToProto(); } @@ -2971,6 +2926,7 @@ Status HloInstruction::AcceptOrdered( continue; } + // TODO(b/78350259): Eliminate const laundering. HloInstruction* instruction = const_cast(const_instruction); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a5e9aecb9e..19c8c11453 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -66,6 +66,7 @@ class HloPrintOptions { : print_large_constants_(false), print_subcomputation_references_(true), print_metadata_(true), + print_backend_config_(true), compact_operands_(false), print_operand_shape_(true), print_program_shape_(true), @@ -77,6 +78,7 @@ class HloPrintOptions { .set_print_large_constants(true) .set_print_subcomputation_references(true) .set_print_metadata(false) + .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_program_shape(false) .set_print_percent(false); @@ -99,12 +101,18 @@ class HloPrintOptions { return *this; } - // If true, metatdata will be printed. + // If true, metadata will be printed. HloPrintOptions& set_print_metadata(bool value) { print_metadata_ = value; return *this; } + // If true, backend_config will be printed. + HloPrintOptions& set_print_backend_config(bool value) { + print_backend_config_ = value; + return *this; + } + // If true, operands' shapes will be printed. HloPrintOptions& set_print_operand_shape(bool value) { print_operand_shape_ = value; @@ -141,6 +149,7 @@ class HloPrintOptions { return print_subcomputation_references_; } bool print_metadata() const { return print_metadata_; } + bool print_backend_config() const { return print_metadata_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } @@ -151,6 +160,7 @@ class HloPrintOptions { bool print_large_constants_; bool print_subcomputation_references_; bool print_metadata_; + bool print_backend_config_; bool compact_operands_; bool print_operand_shape_; bool print_program_shape_; @@ -643,6 +653,8 @@ class HloInstruction { // Detaches an instruction from its operands. That is, remove the instruction // from each operand's user set. This should only be called prior to // deallocating the instruction. + // + // TODO(b/78305363): Make this automatic when deleting an instruction. void DetachFromOperands(); // Performs a postorder DFS visit using this node as the root. If @@ -1157,23 +1169,30 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kRng RandomDistribution random_distribution() const; + // See documentation for Clone(). + using CloneMap = std::unordered_map; + // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of - // the instruction to form the name of the cloned instruction. If the module - // pointer is not nullptr, it will be the module where the cloned computations - // will be added to (in order to support deep cloning). Ignores the control - // predecessors and successors of this HLO instruction. + // the instruction to form the name of the cloned instruction. Ignores the + // control predecessors and successors of this HLO instruction. + // + // If the module pointer is not nullptr, then any cloned computations will be + // added to this module in order to support deep cloning. Otherwise the module + // of the instruction is used. + // + // If clone_map is not nullptr, then each original instruction that is cloned + // will be inserted and map to its clone. clone_map should not already contain + // any of the instructions to clone. std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr) const; + HloModule* module = nullptr, + CloneMap* clone_map = nullptr) const; - // Clones the HLO instruction as above but with new shape and operands. If - // the module pointer is not nullptr, it will be the module where the cloned - // computations will be added to (in order to support deep cloning). Ignores - // the control predecessors and successors of this HLO instruction. + // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr) const; + HloModule* module = nullptr, CloneMap* clone_map = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector& called_computations() const { @@ -1262,6 +1281,19 @@ class HloInstruction { // if no id has been assigned yet). int unique_id() const { return unique_id_; } + // Returns the backend-specific configuration for how a backend should compile + // this HLO. The meaning of the field is backend specific. Not for use before + // or during general HLO optimization, since HLO optimizations do not preserve + // this field and they cannot interpret it due to its meaning being backend + // specific. + // + // TODO(b/78194644): Introduce structured configuration format as per + // go/xla-heuristics. + const string& backend_config() const { return backend_config_; } + void set_backend_config(string backend_config) { + backend_config_ = std::move(backend_config); + } + // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } @@ -1283,6 +1315,7 @@ class HloInstruction { // Get/Set the number of partitions per outer dimension (in order, starting // with outer-most dimension first). Currently used by the parallel cpu // backend to partition HLOs into parallel tasks. + // // TODO(b/62783254) Replace these methods with a more general way to // annotate HLOs with backend-specific information. const std::vector& outer_dimension_partitions() const { @@ -1510,6 +1543,10 @@ class HloInstruction { // The string representation of the infeed configuration. string infeed_config_; + // The backend-specific configuration for how a backend should compile this + // HLO. See the documentation on backend_config(). + string backend_config_; + // String identifier for instruction. string name_; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index c7a7192867..5308fb5848 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -46,6 +46,18 @@ HloModule::HloModule(const string& name, const HloModuleConfig& config) config_(config), unique_id_(next_unique_module_id_++) {} +StatusOr HloModule::LaunderConstInstructionFromModule( + const HloInstruction* hlo) { + if (hlo == nullptr) { + return nullptr; + } + + TF_RET_CHECK(hlo->GetModule() == this); + + // TODO(b/78350259): Eliminate const laundering. + return const_cast(hlo); +} + HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, bool uniquify_names) { diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index f9674df812..1604a72612 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -217,6 +217,25 @@ class HloModule { // the lifetime of this process. int unique_id() const { return unique_id_; } + // Returns a non-const version of the passed-in const HloInstruction*. This is + // safe on the argument that if you have a non-const module, then you can + // access all instructions in the module as non-const. + // + // Returns an error if the passed-in instruction is not from this module, + // except that it is allowed to pass in a null pointer. + // + // TODO(b/78350259): Eliminate const laundering. The argument above is not + // reliable since at any time someone could add or discover a way for a + // non-const module to transitively contain a const HloInstruction. The + // reliable way to do this would be to create a const laundering map from a + // module, mapping each encountered HloInstruction to its non-const version + // and then look up each instruction in need of laundering in that map, but + // this is much more expensive and complicated. This returns a Status instead + // of doing a CHECK-failure in part to make it strongly apparent that this is + // something that can fail. + StatusOr LaunderConstInstructionFromModule( + const HloInstruction* hlo); + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 8a30cbf9cd..096ebb7946 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -116,7 +116,7 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // produces no HLO value in the graph. if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { - return InvalidArgument( + return InternalError( "Expected outfeed to have shape compatible with operand's shape %s, " "actual shape is %s:\n%s", ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), @@ -200,7 +200,7 @@ Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { transpose->operand(0)->shape(), transpose->dimensions())); } -Status ShapeVerifier::HandleParameter(HloInstruction*) { +Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { return tensorflow::Status::OK(); } @@ -410,7 +410,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { if (fp_type == PRIMITIVE_TYPE_INVALID) { fp_type = subshape.element_type(); } else if (fp_type != subshape.element_type()) { - return FailedPrecondition( + return InternalError( "Seen floating point types of different precisions in " "%s, but mixed precision is disallowed.", instruction->ToString().c_str()); @@ -490,7 +490,7 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } } if (!compatible) { - return InvalidArgument( + return InternalError( "Expected instruction to have shape compatible with %s, actual " "shape is %s:\n%s", ShapeUtil::HumanString(inferred_shape).c_str(), @@ -541,7 +541,7 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2) { if (instr1->channel_id() != instr2->channel_id()) { - return FailedPrecondition( + return InternalError( "Expected to have the same channel id, actual channel ids are: %s " "(%lld), %s (%lld)", instr1->ToString().c_str(), instr1->channel_id(), @@ -571,22 +571,22 @@ string ComputationsToString( Status VerifyHloStructure(HloModule* module) { for (const HloComputation* computation : module->computations()) { if (computation->parent() == nullptr) { - return FailedPrecondition("Computation %s has a null parent pointer", - computation->name().c_str()); + return InternalError("Computation %s has a null parent pointer", + computation->name().c_str()); } if (computation->parent() != module) { - return FailedPrecondition( + return InternalError( "Computation %s parent() does not point to parent module", computation->name().c_str()); } for (const HloInstruction* instruction : computation->instructions()) { if (instruction->parent() == nullptr) { - return FailedPrecondition("Instruction %s has a null parent pointer", - instruction->name().c_str()); + return InternalError("Instruction %s has a null parent pointer", + instruction->name().c_str()); } if (instruction->parent() != computation) { - return FailedPrecondition( + return InternalError( "Instruction %s parent() does not point to parent computation", instruction->name().c_str()); } @@ -602,7 +602,7 @@ Status VerifyHloStructure(HloModule* module) { for (int i = 0; i < instruction->operand_count(); ++i) { const HloInstruction* operand = instruction->operand(i); if (operand->parent() != instruction->parent()) { - return FailedPrecondition( + return InternalError( "Operand %d (%s) of instruction %s is in a different " "computation: %s vs %s", i, operand->name().c_str(), instruction->name().c_str(), @@ -619,7 +619,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { // The parent fusion instruction of the fusion computation must be 'fusion'. HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { - return FailedPrecondition( + return InternalError( "Instruction of fused computation does not match expected instruction " "%s.", fusion->ToString().c_str()); @@ -635,37 +635,37 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto* instruction : fused_computation->instructions()) { if (fused_root == instruction) { if (root_owned) { - return FailedPrecondition("Root appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Root appears more than once in %s.", + fusion->ToString().c_str()); } root_owned = true; } for (int i = 0; i < fused_parameters.size(); ++i) { if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { - return FailedPrecondition("Parameter appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Parameter appears more than once in %s.", + fusion->ToString().c_str()); } parameter_owned[i] = true; } } } if (!root_owned) { - return FailedPrecondition("Root not found in computation of %s.", - fusion->ToString().c_str()); + return InternalError("Root not found in computation of %s.", + fusion->ToString().c_str()); } // Make sure all the parameter_owned entries are set for (int i = 0; i < parameter_owned.size(); i++) { if (!parameter_owned[i]) { - return FailedPrecondition("Parameter %d not found in computation of %s.", - i, fusion->ToString().c_str()); + return InternalError("Parameter %d not found in computation of %s.", i, + fusion->ToString().c_str()); } } // Fused root must have no users. if (fused_root->user_count() != 0) { - return FailedPrecondition("Root of %s may not have users.", - fusion->ToString().c_str()); + return InternalError("Root of %s may not have users.", + fusion->ToString().c_str()); } // All uses of fused instructions must be in the fusion computation, and every @@ -674,13 +674,13 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { fusion->fused_instructions_computation()->instructions()) { if (instruction != fused_root) { if (instruction->user_count() == 0) { - return FailedPrecondition( - "Non-root instruction %s in %s must have users.", - instruction->ToString().c_str(), fusion->ToString().c_str()); + return InternalError("Non-root instruction %s in %s must have users.", + instruction->ToString().c_str(), + fusion->ToString().c_str()); } for (auto& user : instruction->users()) { if (fused_computation != user->parent()) { - return FailedPrecondition( + return InternalError( "Non-root instruction %s in %s may not have external users.", instruction->ToString().c_str(), fusion->ToString().c_str()); } @@ -695,34 +695,33 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); if (param_no < 0) { - return FailedPrecondition( - "Unexpected negative parameter number %lld in %s.", param_no, - fusion->ToString().c_str()); + return InternalError("Unexpected negative parameter number %lld in %s.", + param_no, fusion->ToString().c_str()); } if (param_no >= fused_parameters.size()) { - return FailedPrecondition( + return InternalError( "Unexpected parameter number %lld in %s: higher then number of " "parameters %lu.", param_no, fusion->ToString().c_str(), fused_parameters.size()); } if (parameter_numbers[param_no]) { - return FailedPrecondition( + return InternalError( "Did not expect parameter number %lld more than once in %s.", param_no, fusion->ToString().c_str()); } parameter_numbers[param_no] = true; if (!ShapeUtil::Compatible(fused_param->shape(), fusion->operand(param_no)->shape())) { - return FailedPrecondition( + return InternalError( "Shape mismatch between parameter number %lld and its operand in %s.", param_no, fusion->ToString().c_str()); } } - // Make sure all the parameter_numbers entries were seen + // Make sure all the parameter_numbers entries were seen. for (int i = 0; i < parameter_numbers.size(); i++) { if (!parameter_numbers[i]) { - return FailedPrecondition("Did not see parameter number %d in %s.", i, - fusion->ToString().c_str()); + return InternalError("Did not see parameter number %d in %s.", i, + fusion->ToString().c_str()); } } diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h index cccbce5fc8..0e1387c939 100644 --- a/tensorflow/compiler/xla/statusor.h +++ b/tensorflow/compiler/xla/statusor.h @@ -13,13 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// StatusOr is the union of a Status object and a T -// object. StatusOr models the concept of an object that is either a -// usable value, or an error Status explaining why such a value is -// not present. To this end, StatusOr does not allow its Status -// value to be Status::OK. Furthermore, the value of a StatusOr -// must not be null. This is enforced by a debug check in most cases, -// but even when it is not, clients must not set the value to null. +// StatusOr is the union of a Status object and a T object. StatusOr models +// the concept of an object that is either a value, or an error Status +// explaining why such a value is not present. To this end, StatusOr does not +// allow its Status value to be Status::OK. // // The primary use-case for StatusOr is as the return value of a // function which may fail. diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index f9d25945bc..7d76370e85 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -75,6 +75,14 @@ TEST(StatusOr, ElementType) { static_assert(std::is_same::element_type, char>(), ""); } +TEST(StatusOr, NullPointerStatusOr) { + // As a very special case, null-plain-pointer StatusOr used to be an + // error. Test that it no longer is. + StatusOr null_status(nullptr); + EXPECT_TRUE(null_status.ok()); + EXPECT_EQ(null_status.ValueOrDie(), nullptr); +} + TEST(StatusOr, TestNoDefaultConstructorInitialization) { // Explicitly initialize it with an error code. StatusOr statusor(tensorflow::errors::Cancelled("")); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 40dc0730ce..156a06c596 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -440,6 +440,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional metadata; attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; + optional backend_config; + attrs["backend_config"] = {/*required=*/false, AttrTy::kString, + &backend_config}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -1094,8 +1098,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, instruction->set_name(name); - // Add common attrs (sharding, control predecessors) to the instruction, if - // they were seen. + // Add shared attributes like metadata to the instruction, if they were seen. if (sharding) { instruction->set_sharding( HloSharding::FromProto(sharding.value()).ValueOrDie()); @@ -1112,6 +1115,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (metadata) { instruction->set_metadata(*metadata); } + if (backend_config) { + instruction->set_backend_config(std::move(*backend_config)); + } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index d38d8907a6..e100d8cda1 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -65,7 +65,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { R"(HloModule constant_pred_module ENTRY %constant_pred () -> pred[] { - ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} + ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar" } )" @@ -81,13 +81,14 @@ ENTRY %constant_s32 () -> s32[] { )" }, -// f32 constant, but the value is not a decimal +// f32 constant, but the value is not a decimal and there is a backend +// configuration { "ConstantF32", R"(HloModule ConstantF32_module ENTRY %ConstantF32.v4 () -> f32[] { - ROOT %constant = f32[] constant(42) + ROOT %constant = f32[] constant(42), backend_config="this is a configuration" } )" @@ -1013,6 +1014,19 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { // but the constant names will not be exactly the same. } +TEST_F(HloParserTest, ConfigurationField) { + const string original = R"(HloModule AModule +ENTRY %configuration_test() -> s32[] { + %constant = s32[] constant(42), backend_config="foo bar" +})"; + auto result = Parse(original); + TF_ASSERT_OK(result.status()); + EXPECT_EQ("foo bar", result.ValueOrDie() + ->entry_computation() + ->root_instruction() + ->backend_config()); +} + TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { const string original = R"(HloModule some_2_module @@ -1092,7 +1106,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} } )"; -- GitLab From bf228e1435da0032d2529de93661b742ee8a7048 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Fri, 4 May 2018 17:03:52 -0700 Subject: [PATCH 145/839] [tf.data] Adding `num_parallel_calls` to `map_and_batch`. PiperOrigin-RevId: 195495206 --- .../kernel_tests/batch_dataset_op_test.py | 44 +- .../contrib/data/python/ops/batching.py | 47 +- .../base_api/api_def_MapAndBatchDataset.pbtxt | 35 +- .../api_def_MapAndBatchDatasetV2.pbtxt | 54 ++ .../kernels/data/map_and_batch_dataset_op.cc | 773 +++++++++--------- tensorflow/core/ops/dataset_ops.cc | 13 + 6 files changed, 538 insertions(+), 428 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_MapAndBatchDatasetV2.pbtxt diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 6588fd04ac..2568b899d7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -427,7 +427,9 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def _testMapAndBatchDatasetHelper(self, num_parallel_batches=1): + def _testMapAndBatchDatasetHelper(self, + num_parallel_calls=None, + num_parallel_batches=None): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). @@ -446,6 +448,7 @@ class BatchDatasetTest(test.TestCase): batching.map_and_batch( map_func=_map_fn, batch_size=batch_size, + num_parallel_calls=num_parallel_calls, num_parallel_batches=num_parallel_batches)) .make_initializable_iterator()) init_op = iterator.initializer @@ -497,12 +500,18 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testMapAndBatchDataset(self): + def testMapAndBatch(self): return self._testMapAndBatchDatasetHelper() - def testMapAndBatchDatasetWithParallelBatching(self): + def testMapAndBatchWithParallelBatches(self): return self._testMapAndBatchDatasetHelper(num_parallel_batches=10) + def testMapAndBatchWithSequentialCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=1) + + def testMapAndBatchWithParallelCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=2) + def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False): iterator = ( dataset_ops.Dataset.range(10).apply( @@ -682,7 +691,7 @@ class UnbatchDatasetSerializationTest( class MapAndBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): - def testSerializationCore(self): + def testNumParallelBatches(self): range_size = 11 num_repeats = 2 batch_size = 5 @@ -709,6 +718,33 @@ class MapAndBatchDatasetSerializationTest( self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), num_outputs_drop_remainder) + def testNumParallelCalls(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_calls = 7 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_calls=num_parallel_calls, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + class PaddedBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 42ec2b0b01..b9393de4e9 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -466,14 +466,14 @@ def assert_element_shape(expected_shapes): class _MapAndBatchDataset(dataset_ops.MapDataset): """A `Dataset` that maps a function over a batch of elements.""" - def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches, + def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, drop_remainder): """See `Dataset.map()` for details.""" super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) self._batch_size_t = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - self._num_parallel_batches_t = ops.convert_to_tensor( - num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches") + self._num_parallel_calls_t = ops.convert_to_tensor( + num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") self._drop_remainder_t = ops.convert_to_tensor( drop_remainder, dtype=dtypes.bool, name="drop_remainder") @@ -483,12 +483,12 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def _as_variant_tensor(self): # pylint: disable=protected-access input_resource = self._input_dataset._as_variant_tensor() - return gen_dataset_ops.map_and_batch_dataset( + return gen_dataset_ops.map_and_batch_dataset_v2( input_resource, self._map_func.captured_inputs, f=self._map_func, batch_size=self._batch_size_t, - num_parallel_batches=self._num_parallel_batches_t, + num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), @@ -511,8 +511,9 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def map_and_batch(map_func, batch_size, - num_parallel_batches=1, - drop_remainder=False): + num_parallel_batches=None, + drop_remainder=False, + num_parallel_calls=None): """Fused implementation of `map` and `batch`. Maps `map_func` across `batch_size` consecutive elements of this dataset @@ -528,21 +529,37 @@ def map_and_batch(map_func, nested structure of tensors. batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of consecutive elements of this dataset to combine in a single batch. - num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the - number of batches to create in parallel. On one hand, higher values can - help mitigate the effect of stragglers. On the other hand, higher values - can increase contention if CPU is scarce. - drop_remainder: A `tf.bool` scalar `tf.Tensor`, representing whether the - last batch should be dropped in case its size is smaller than desired; - the default behavior is not to drop the smaller batch. + num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, + representing the number of batches to create in parallel. On one hand, + higher values can help mitigate the effect of stragglers. On the other + hand, higher values can increase contention if CPU is scarce. + drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing + whether the last batch should be dropped in case its size is smaller than + desired; the default behavior is not to drop the smaller batch. + num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, + representing the number of elements to process in parallel. If not + specified, `batch_size * num_parallel_batches` elements will be + processed in parallel. Returns: A `Dataset` transformation function, which can be passed to @{tf.data.Dataset.apply}. + + Raises: + ValueError: If both `num_parallel_batches` and `num_parallel_calls` are + specified. """ + if num_parallel_batches is None and num_parallel_calls is None: + num_parallel_calls = batch_size + elif num_parallel_batches is not None and num_parallel_calls is None: + num_parallel_calls = batch_size * num_parallel_batches + elif num_parallel_batches is not None and num_parallel_calls is not None: + raise ValueError("The `num_parallel_batches` and `num_parallel_calls` " + "arguments are mutually exclusive.") + def _apply_fn(dataset): return _MapAndBatchDataset(dataset, map_func, batch_size, - num_parallel_batches, drop_remainder) + num_parallel_calls, drop_remainder) return _apply_fn diff --git a/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt index bf544703de..e230c51edf 100644 --- a/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt @@ -1,5 +1,19 @@ op { graph_op_name: "MapAndBatchDataset" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: <