From 75259500f80266998f232a94853b0bc08d2925cc Mon Sep 17 00:00:00 2001 From: KB Sriram Date: Wed, 28 Feb 2018 07:16:20 -0800 Subject: [PATCH 001/738] C++ gradients: Fractional*Pool, Soft{Plus,Sign} 1. Adds gradients for four nn ops: FractionalAvgPool FractionalMaxPool SoftPlus SoftSign 2. Update randomization to allow numeric gradient checks on max pooling algorithms with more than one pool. Resolves https://github.com/tensorflow/tensorflow/issues/17330 --- tensorflow/cc/gradients/nn_grad.cc | 47 ++++++++++++++ tensorflow/cc/gradients/nn_grad_test.cc | 84 ++++++++++++++++++++----- 2 files changed, 115 insertions(+), 16 deletions(-) diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 63a67f09f6..4b89dac4c0 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -272,6 +272,53 @@ Status LRNGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("LRN", LRNGradHelper); +Status SoftplusGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftplusGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softplus", SoftplusGradHelper); + +Status SoftsignGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftsignGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softsign", SoftsignGradHelper); + +Status FractionalAvgPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalAvgPoolGrad( + scope, Shape(scope, op.input(0), Shape::OutType(DT_INT64)), + grad_inputs[0], op.output(1), op.output(2), + internal::FractionalAvgPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalAvgPool", FractionalAvgPoolGradHelper); + +Status FractionalMaxPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalMaxPoolGrad( + scope, op.input(0), op.output(0), grad_inputs[0], op.output(1), + op.output(2), internal::FractionalMaxPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalMaxPool", FractionalMaxPoolGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index c4eba7ecb0..b4d457a9d1 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -28,6 +28,8 @@ namespace { using ops::BiasAdd; using ops::Conv2D; using ops::Elu; +using ops::FractionalAvgPool; +using ops::FractionalMaxPool; using ops::L2Loss; using ops::LogSoftmax; using ops::LRN; @@ -41,6 +43,8 @@ using ops::Relu; using ops::Relu6; using ops::Selu; using ops::Softmax; +using ops::Softplus; +using ops::Softsign; class NNGradTest : public ::testing::Test { protected: @@ -71,22 +75,30 @@ class NNGradTest : public ::testing::Test { EXPECT_LT(max_error, 1e-3); } - // Sets tensor with random values, ensuring that the max value is largest by - // a reasonable amount. - // This is an issue for MaxPool, MaxPoolV2 and MaxPool3D, in which - // perturbations by the numeric gradient computation in the gradient checker - // can change the max value if values are too close together. + // Sets tensor with random values, ensuring that every pair of elements are at + // least a reasonable amount apart. + // This is an issue for max pooling operations, in which perturbations by the + // numeric gradient computation in the gradient checker can change the max + // value if a pool has values that are too close together. template - void SetRandomValuesWithBumpedMax(Tensor* tensor) { + void SetRandomValuesForMaxPooling(Tensor* tensor) { auto tensor_flat = tensor->flat(); - tensor_flat.setRandom(); - int32 max_index = 0; - for (size_t i = 1; i < tensor->NumElements(); i++) { - if (tensor_flat(i) > tensor_flat(max_index)) { - max_index = i; - } + // First set the array to an increasing sequence of values spaced + // a reasonable amount apart + T cur = 0; + for (size_t i = 0; i < tensor->NumElements(); i++) { + tensor_flat(i) = cur; + cur += 5e-2; + } + // Fischer-Yates shuffle the array + for (size_t i = tensor->NumElements() - 1; i >= 1; i--) { + // j <- random integer 0 <= j <= i + size_t j = random::New64() % (i + 1); + // swap values at i, j + T tmp = tensor_flat(i); + tensor_flat(i) = tensor_flat(j); + tensor_flat(j) = tmp; } - tensor_flat(max_index) += 1e-2; } Scope scope_; @@ -189,7 +201,7 @@ TEST_F(NNGradTest, MaxPoolGradHelper) { const std::vector strides{1, 2, 2, 1}; auto y = MaxPool(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -202,7 +214,7 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { Tensor strides = test::AsTensor({1, 2, 2, 1}, {4}); auto y = MaxPoolV2(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -215,7 +227,7 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) { const std::vector strides{1, 3, 3, 3, 1}; auto y = MaxPool3D(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -248,5 +260,45 @@ TEST_F(NNGradTest, LRN){ RunTest(x, x_shape, y, x_shape); } +TEST_F(NNGradTest, SoftplusGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softplus(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, SoftsignGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softsign(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, FractionalAvgPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalAvgPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalAvgPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_shape, y.output, y_shape); +} + +TEST_F(NNGradTest, FractionalMaxPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalMaxPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalMaxPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + Tensor x_init_value = Tensor(DT_FLOAT, x_shape); + SetRandomValuesForMaxPooling(&x_init_value); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_init_value, y.output, y_shape); +} + } // namespace } // namespace tensorflow -- GitLab From 1da3a47287aa911287d6667dd837dc2a7ddaa8f1 Mon Sep 17 00:00:00 2001 From: Smit Shilu Date: Thu, 22 Mar 2018 10:58:51 -0400 Subject: [PATCH 002/738] Update BUILD exports_files(["LICENSE"]) gives error while building on Mac and Ubuntu --- tensorflow/contrib/lite/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index dafe6f136e..1c5bc29763 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -6,8 +6,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") -exports_files(["LICENSE"]) - exports_files(glob([ "testdata/*.bin", "testdata/*.pb", -- GitLab From e7f3ed2477c7910e68573880efd2310e149ca785 Mon Sep 17 00:00:00 2001 From: mbhuiyan Date: Wed, 4 Apr 2018 10:52:49 -0700 Subject: [PATCH 003/738] Fixing a unit test failure for INTEL MKL where memeory allocation check failed because of use of INTEL MKL --- .../direct_session_with_tracking_alloc_test.cc | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index 31fb128f93..0ff022a8bc 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -101,11 +101,24 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { EXPECT_EQ(2, shape.dim_size()); EXPECT_EQ(2, shape.dim(0).size()); EXPECT_EQ(1, shape.dim(1).size()); +#ifndef INTEL_MKL + // if MKL is used, it goes through various additional + // graph rewrite pass. In TF, everytime a graph pass + // happens, "constant" nodes are allocated + // and deallocated. Each allocation calls the + // (FindChunkPtr of BFCAllocator) + // , which increments the value of AllocationId. + // Thus AllocationId becomes more than 3 and 4 if + // MKL is used, they can be 10 and 11 or + // other numbers. If MKL is used + // following check will not hold. + // Thus, skipping the check if MKL is used. if (node->name() == y->name()) { EXPECT_EQ(3, cm->AllocationId(node, 0)); } else { EXPECT_EQ(4, cm->AllocationId(node, 0)); } +#endif } EXPECT_LE(0, cm->MaxExecutionTime(node)); EXPECT_GE(run_duration_micros, cm->MaxExecutionTime(node)); -- GitLab From 9d1aa895adda8644ddbb55b5e1dbb0797ea6cbb0 Mon Sep 17 00:00:00 2001 From: Jie Date: Wed, 11 Apr 2018 14:42:15 -0700 Subject: [PATCH 004/738] [tftrt update] Added support for TRT plugin during conversion - converter & shape inference are now aware of plugin factory. - each plugin does serialization of plugin type & input dimensions - wrapper for nvinfer1::IPlugin & nvinfer1::PluginFactory * compatible with TRT 3.0.4 plugin API. * future plugin API changes willl be updated. --- tensorflow/contrib/tensorrt/BUILD | 26 ++++++ .../contrib/tensorrt/convert/convert_graph.cc | 4 +- .../contrib/tensorrt/convert/convert_nodes.cc | 84 ++++++++++++++--- .../contrib/tensorrt/kernels/trt_engine_op.cc | 4 +- .../contrib/tensorrt/plugin/trt_plugin.cc | 89 +++++++++++++++++++ .../contrib/tensorrt/plugin/trt_plugin.h | 81 +++++++++++++++++ .../tensorrt/plugin/trt_plugin_factory.cc | 81 +++++++++++++++++ .../tensorrt/plugin/trt_plugin_factory.h | 83 +++++++++++++++++ .../tensorrt/plugin/trt_plugin_utils.cc | 36 ++++++++ .../tensorrt/plugin/trt_plugin_utils.h | 51 +++++++++++ .../contrib/tensorrt/shape_fn/trt_shfn.cc | 4 +- 11 files changed, 528 insertions(+), 15 deletions(-) create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin.cc create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin.h create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 2f316767b3..98f18835b0 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -67,6 +67,7 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = [ ":trt_logging", + ":trt_plugins", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]) + tf_custom_op_library_additional_deps(), @@ -86,6 +87,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":trt_logging", + ":trt_plugins", ":trt_resources", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", @@ -222,6 +224,7 @@ tf_cuda_library( ], deps = [ ":segment", + ":trt_plugins", ":trt_logging", ":trt_resources", "//tensorflow/core/grappler:grappler_item", @@ -272,3 +275,26 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +# Library for the plugin factory +#cc_library( +tf_cuda_library( + name = "trt_plugins", + srcs = [ + "plugin/trt_plugin.cc", + "plugin/trt_plugin_factory.cc", + "plugin/trt_plugin_utils.cc", + ], + hdrs = [ + "plugin/trt_plugin.h", + "plugin/trt_plugin_factory.h", + "plugin/trt_plugin_utils.h", + ], + linkstatic = 1, + deps = [ + #"@protobuf_archive//:protobuf_headers", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index b412b296e0..899e1721e6 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include @@ -75,7 +76,8 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(ben,jie): ... }; // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h) - return candidate_ops.count(node->type_string()); + return (candidate_ops.count(node->type_string()) || + PluginFactoryTensorRT::GetInstance().IsPlugin(&node->type_string())); } void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 567b4af88d..a03c1e224a 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include @@ -246,6 +247,15 @@ class TFAttrs { return attrs_.count(key) ? this->get(key) : default_value; } + std::vector GetAllAttrKey() { + std::vector attr_list; + for (AttrMap::iterator iter = attrs_.begin(); iter != attrs_.end(); + iter++) { + attr_list.emplace_back(iter->first); + } + return attr_list; + } + private: typedef std::map AttrMap; AttrMap attrs_; @@ -262,6 +272,12 @@ std::vector TFAttrs::get>(string key) const { return std::vector(attr.begin(), attr.end()); } +template <> +std::vector TFAttrs::get>(string key) const { + auto attr = this->at(key)->list().f(); + return std::vector(attr.begin(), attr.end()); +} + template <> std::vector TFAttrs::get>(string key) const { auto attr = this->at(key)->list().s(); @@ -424,6 +440,7 @@ using OpConverter = class Converter { std::unordered_map trt_tensors_; std::unordered_map op_registry_; + OpConverter plugin_converter_; nvinfer1::INetworkDefinition* trt_network_; std::list> temp_bufs_; tensorflow::tensorrt::TRTWeightStore* weight_store_; @@ -444,8 +461,8 @@ class Converter { * remove this and annotate the edge as a control dependency. ************************************************************************/ // skip control nodes - if (input_name[0] == '^' ) continue; - string name = input_name; + if (input_name[0] == '^') continue; + string name = input_name; auto first = name.find_first_of(':'); if (first != string::npos && first + 2 == name.size() && name[first + 1] == '0') @@ -490,13 +507,17 @@ class Converter { std::vector inputs; TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs)); string op = node_def.op(); - if (!op_registry_.count(op)) { - return tensorflow::errors::Unimplemented( - "No converter registered for op: " + op); - } - OpConverter op_converter = op_registry_.at(op); std::vector outputs; - TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs)); + if (PluginFactoryTensorRT::GetInstance().IsPlugin(&op)) { + TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs)); + } else { + if (!op_registry_.count(op)) { + return tensorflow::errors::Unimplemented( + "No converter registered for op: " + op); + } + OpConverter op_converter = op_registry_.at(op); + TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs)); + } for (size_t i = 0; i < outputs.size(); ++i) { TRT_TensorOrWeights output = outputs.at(i); // TODO(jie): tf protobuf seems to be omitting the :0 suffix @@ -1158,9 +1179,9 @@ tensorflow::Status BinaryTensorOpTensor( CHECK_EQ_TYPE(tensor_r->getType(), dtype); auto op_pair = ops.find(node_def.op()); if (op_pair == ops.end()) - return tensorflow::errors::Unimplemented( - "binary op: " + node_def.op() + - " not supported at: " + node_def.name()); + return tensorflow::errors::Unimplemented("binary op: " + node_def.op() + + " not supported at: " + + node_def.name()); nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( *const_cast(tensor_l), @@ -1173,6 +1194,43 @@ tensorflow::Status BinaryTensorOpTensor( return tensorflow::Status::OK(); } +tensorflow::Status ConvertPlugin(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + // prepare input + std::vector all_inputs; + for (auto input : inputs) { + all_inputs.emplace_back(const_cast(input.tensor())); + } + + // plugin is owned by PluginFactory + // TODO(jie): destroy plugins later (resource management) + PluginTensorRT* plugin = + PluginFactoryTensorRT::GetInstance().CreatePlugin(&node_def.op()); + + // passing attributes + // TODO(jie): support more general attribute + TFAttrs attrs(node_def); + auto attr_key_vector = attrs.GetAllAttrKey(); + for (auto attr_key : attr_key_vector) { + std::cout << attr_key << std::endl; + // TODO(jie): support only list of float for toy example here. + auto data = attrs.get>(attr_key); + size_t size_data = data.size() * sizeof(float); + plugin->SetAttribute(attr_key, static_cast(data.data()), size_data); + } + + nvinfer1::IPluginLayer* layer = + ctx.network()->addPlugin(&all_inputs[0], int(inputs.size()), *plugin); + + for (int i = 0; i < layer->getNbOutputs(); i++) { + nvinfer1::ITensor* output_tensor = layer->getOutput(i); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + } + return tensorflow::Status::OK(); +} + tensorflow::Status ConvertPlaceholder( Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector& inputs, @@ -2073,6 +2131,8 @@ void Converter::register_op_converters() { op_registry_["Reshape"] = ConvertReshape; op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; + + plugin_converter_ = ConvertPlugin; } } // namespace @@ -2511,7 +2571,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef( std::vector input_names; std::vector input_dtypes; for (const std::pair& input : s.input_inds) { - VLOG(2) << "parsing input. Node id= " << input.first ; + VLOG(2) << "parsing input. Node id= " << input.first; int node_id = input.first; int output_idx = input.second; tensorflow::Node* node = s.graph.FindNodeId(node_id); diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index b32371b642..8881c48fe6 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/core/platform/logging.h" @@ -58,7 +59,8 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { IRuntime* infer = nvinfer1::createInferRuntime(logger); trt_engine_ptr_.reset(infer->deserializeCudaEngine( - serialized_engine.c_str(), serialized_engine.size(), nullptr)); + serialized_engine.c_str(), serialized_engine.size(), + &PluginFactoryTensorRT::GetInstance())); trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); // Runtime is safe to delete after engine creation infer->destroy(); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc new file mode 100644 index 0000000000..0e4a157d79 --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc @@ -0,0 +1,89 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include +#include +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) { + // sanity check. + assert(EncodeOpName(GetPluginName()) != + *static_cast(serialized_data)); + const char* buffer = static_cast(serialized_data) + + sizeof(input_dim_list_.size()); + + size_t count = *reinterpret_cast(buffer); + buffer += sizeof(size_t); + + for (int i = 0; i < count; i++) { + nvinfer1::Dims dim; + std::memcpy(&(dim.nbDims), buffer, sizeof(dim.nbDims)); + buffer += sizeof(dim.nbDims); + std::memcpy(dim.d, buffer, sizeof(dim.d)); + buffer += sizeof(dim.d); + std::memcpy(dim.type, buffer, sizeof(dim.type)); + buffer += sizeof(dim.type); + input_dim_list_.emplace_back(dim); + } +} + +size_t PluginTensorRT::getSerializationSize() { + nvinfer1::Dims dim; + return sizeof(size_t) + sizeof(input_dim_list_.size()) + sizeof(dim.nbDims) + + sizeof(dim.d) + sizeof(dim.type); +} + +void PluginTensorRT::serialize(void* serialized_data) { + size_t encode_op_name = EncodeOpName(GetPluginName()); + char* buffer = static_cast(serialized_data); + std::memcpy(buffer, &encode_op_name, sizeof(size_t)); + buffer += sizeof(size_t); + + auto list_size = input_dim_list_.size(); + std::memcpy(buffer, &list_size, sizeof(input_dim_list_.size())); + buffer += sizeof(input_dim_list_.size()); + + for (int i = 0; i < input_dim_list_.size(); i++) { + auto dim = input_dim_list_[i]; + std::memcpy(buffer, &(dim.nbDims), sizeof(dim.nbDims)); + buffer += sizeof(dim.nbDims); + std::memcpy(buffer, dim.d, sizeof(dim.d)); + buffer += sizeof(dim.d); + std::memcpy(buffer, dim.type, sizeof(dim.type)); + buffer += sizeof(dim.type); + } +} + +bool PluginTensorRT::StoreAttribute(const string& key, const void* ptr, + const size_t size) { + if (attr_map_.count(key) != 0) return false; + + attr_map_.emplace(key, std::vector(size)); + std::memcpy(attr_map_[key].data(), ptr, size); + return true; +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h new file mode 100644 index 0000000000..1bbfe62a4e --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN +#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN + +#include +#include +#include +#include + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +using std::string; +using std::unordered_map; + +class PluginTensorRT : public nvinfer1::IPlugin { + public: + PluginTensorRT(){}; + PluginTensorRT(const void* serialized_data, size_t length); + // PluginTensorRT(const void* serialized_data, size_t length, size_t + // &incremental); + virtual string GetPluginName() = 0; + virtual bool Finalize() = 0; + + virtual bool SetAttribute(const string& key, const void* ptr, + const size_t size) = 0; + virtual bool GetAttribute(const string& key, const void* ptr, + size_t& size) = 0; + + void configure(const nvinfer1::Dims* inputs, int nbInputs, + const nvinfer1::Dims* outputs, int nbOutputs, + int maxBatchSize) override { + for (int index = 0; index < nbInputs; index++) { + nvinfer1::Dims dim; + dim.nbDims = inputs[index].nbDims; + for (int i = 0; i < dim.nbDims; i++) { + dim.d[i] = inputs[index].d[i]; + dim.type[i] = inputs[index].type[i]; + } + input_dim_list_.emplace_back(dim); + } + return; + } + + virtual bool StoreAttribute(const string& key, const void* ptr, + const size_t size); + + virtual size_t getSerializationSize() override; + virtual void serialize(void* buffer) override; + + protected: + std::unordered_map > attr_map_; + + std::vector input_dim_list_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc new file mode 100644 index 0000000000..799c609a3e --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layerName, + const void* serial_data, + size_t serial_length) { + size_t parsed_byte = 0; + // extract op_name from serial_data + size_t encoded_op_name = + ExtractOpName(serial_data, serial_length, parsed_byte); + + if (!IsPlugin(encoded_op_name)) { + return nullptr; + } + + // should I lock plugins here? + instance_m_.lock(); + auto plugin_ptr = + plugin_registry_[encoded_op_name].first(serial_data, serial_length); + // string op_name = "IncPluginTRT"; + // auto plugin_ptr = plugin_registry_[EncodeLayerName(&op_name)].second(); + // auto plugin_ptr = plugin_registry_.begin()->second.second(); + owned_plugins_.emplace_back(plugin_ptr); + instance_m_.unlock(); + + return plugin_ptr; +} + +PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string* op_name) { + if (!IsPlugin(op_name)) return nullptr; + + instance_m_.lock(); + auto plugin_ptr = plugin_registry_[EncodeLayerName(op_name)].second(); + owned_plugins_.emplace_back(plugin_ptr); + instance_m_.unlock(); + + return plugin_ptr; +} + +bool PluginFactoryTensorRT::RegisterPlugin( + const string* op_name, PluginDeserializeFunc deserialize_func, + PluginConstructFunc construct_func) { + if (IsPlugin(op_name)) return false; + + // get instance_m_ first before write to registry; + instance_m_.lock(); + auto ret = plugin_registry_.emplace( + EncodeLayerName(op_name), + std::make_pair(deserialize_func, construct_func)); + instance_m_.unlock(); + + return ret.second; +} + +void PluginFactoryTensorRT::DestroyPlugins() { return; } + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h new file mode 100644 index 0000000000..e68f4629d0 --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY +#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY + +#include +#include +#include +#include "trt_plugin.h" +#include "trt_plugin_utils.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { + public: + // deserialization method + // virtual nvinfer1::IPlugin* createPlugin(const char* layerName, const void* + // serialData, size_t serialLength) override; + PluginTensorRT* createPlugin(const char* layerName, const void* serialData, + size_t serialLength) override; + + // construction + PluginTensorRT* CreatePlugin(const string* op_name); + + static PluginFactoryTensorRT& GetInstance() { + static PluginFactoryTensorRT factory_instance; + return factory_instance; + } + + bool RegisterPlugin(const string* op_name, + PluginDeserializeFunc deserialize_func, + PluginConstructFunc construct_func); + + bool IsPlugin(const size_t encode_name) { + return plugin_registry_.find(encode_name) != plugin_registry_.end(); + } + + bool IsPlugin(const string* op_name) { + return IsPlugin(EncodeLayerName(op_name)); + } + + size_t EncodeLayerName(const string* op_name) { + return EncodeOpName(*op_name); + } + + void DestroyPlugins(); + + protected: + std::unordered_map > + plugin_registry_; + + // TODO(jie): Owned plugin should be associated with different sessions; + // should really hand ownership of plugins to resource management; + std::vector > owned_plugins_; + std::mutex instance_m_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc new file mode 100644 index 0000000000..b14480cfa6 --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +size_t ExtractOpName(const void* serial_data, size_t serial_length, + size_t& incremental) { + incremental = sizeof(size_t); + if (serial_length < incremental) return 0; + size_t encoded_op_name = *static_cast(serial_data); + return encoded_op_name; +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h new file mode 100644 index 0000000000..e9675d84cd --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS +#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS + +#include +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +typedef std::function + PluginDeserializeFunc; + +typedef std::function PluginConstructFunc; + +inline size_t EncodeOpName(std::string str) { + return std::hash{}(str); +} + +// TODO(jie): work on error handling here +size_t ExtractOpName(const void* serial_data, size_t serial_length, + size_t& incremental); + +// size_t Deserialize(const char* serial_data, size_t serial_length, size_t +// &incremental); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index 8b475177bc..30b5616475 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include @@ -33,7 +34,8 @@ tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine)); nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine( - serialized_engine.c_str(), serialized_engine.size(), nullptr); + serialized_engine.c_str(), serialized_engine.size(), + &tensorrt::PluginFactoryTensorRT::GetInstance()); int num_batch = -1; std::vector<::tensorflow::DataType> input_type; -- GitLab From 0eb443db1a5654168f396702cae39f5dc3fc7e2e Mon Sep 17 00:00:00 2001 From: imsheridan Date: Tue, 17 Apr 2018 20:25:51 +0800 Subject: [PATCH 005/738] Add deprecated_args decoration to array_ops/ sparse_ops --- tensorflow/python/ops/array_ops.py | 2 ++ tensorflow/python/ops/sparse_ops.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index ceeabe090d..06da2485c3 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2690,6 +2690,8 @@ reverse.__doc__ = gen_array_ops.reverse_v2.__doc__ # pylint: disable=redefined-builtin @tf_export("reverse_sequence") +@deprecation.deprecated_args(None, "Use the `seq_axis` argument instead", "seq_dim") +@deprecation.deprecated_args(None, "Use the `batch_axis` argument instead", "batch_dim") def reverse_sequence(input, seq_lengths, seq_axis=None, diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index c580052c32..73ab216f35 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -110,6 +110,7 @@ def _convert_to_sparse_tensors(sp_inputs): # pylint: disable=protected-access @tf_export("sparse_concat") +@deprecation.deprecated_args(None, "concat_dim is deprecated, use axis instead", "concat_dim") def sparse_concat(axis, sp_inputs, name=None, @@ -616,6 +617,7 @@ class KeywordRequired(object): @tf_export("sparse_split") +@deprecation.deprecated_args(None, "split_dim is deprecated, use axis instead", "split_dim") def sparse_split(keyword_required=KeywordRequired(), sp_input=None, num_split=None, -- GitLab From d6838f52c7daea81c57cdeab8e98c4cd617e5f8b Mon Sep 17 00:00:00 2001 From: imsheridan Date: Tue, 17 Apr 2018 20:30:13 +0800 Subject: [PATCH 006/738] fix typo to keep consistence --- tensorflow/python/ops/array_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 06da2485c3..b6a1f5a272 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2690,8 +2690,8 @@ reverse.__doc__ = gen_array_ops.reverse_v2.__doc__ # pylint: disable=redefined-builtin @tf_export("reverse_sequence") -@deprecation.deprecated_args(None, "Use the `seq_axis` argument instead", "seq_dim") -@deprecation.deprecated_args(None, "Use the `batch_axis` argument instead", "batch_dim") +@deprecation.deprecated_args(None, "seq_dim is deprecated, use seq_axis instead", "seq_dim") +@deprecation.deprecated_args(None, "batch_dim is deprecated, use batch_axis instead", "batch_dim") def reverse_sequence(input, seq_lengths, seq_axis=None, -- GitLab From df0ce53aee6f4e14b3f1c9e0e772a1f7bd1bb95a Mon Sep 17 00:00:00 2001 From: Jie Date: Mon, 16 Apr 2018 17:47:00 -0700 Subject: [PATCH 007/738] [PR comment addressed] adding plugin test for registration updating plugin API wrapper addressing comments in the PR addressing coding style issues removing commented code --- tensorflow/contrib/tensorrt/BUILD | 15 ++- .../contrib/tensorrt/convert/convert_graph.cc | 2 +- .../contrib/tensorrt/convert/convert_nodes.cc | 10 +- .../custom_plugin_examples/inc_op_plugin.cc | 89 ++++++++++++++ .../custom_plugin_examples/inc_op_plugin.h | 114 ++++++++++++++++++ .../contrib/tensorrt/kernels/trt_engine_op.cc | 2 +- .../contrib/tensorrt/plugin/trt_plugin.cc | 37 ++++-- .../contrib/tensorrt/plugin/trt_plugin.h | 41 +++---- .../tensorrt/plugin/trt_plugin_factory.cc | 28 +++-- .../tensorrt/plugin/trt_plugin_factory.h | 33 +++-- .../tensorrt/plugin/trt_plugin_utils.cc | 18 ++- .../tensorrt/plugin/trt_plugin_utils.h | 11 +- .../tensorrt/plugin/trt_plugins_test.cc | 112 +++++++++++++++++ .../contrib/tensorrt/shape_fn/trt_shfn.cc | 2 +- 14 files changed, 423 insertions(+), 91 deletions(-) create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 98f18835b0..751f1d3482 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -277,7 +277,6 @@ tf_cc_test( ) # Library for the plugin factory -#cc_library( tf_cuda_library( name = "trt_plugins", srcs = [ @@ -292,9 +291,21 @@ tf_cuda_library( ], linkstatic = 1, deps = [ - #"@protobuf_archive//:protobuf_headers", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]), ) +tf_cuda_cc_test( + name = "trt_plugins_test", + size = "small", + srcs = ["plugin/trt_plugins_test.cc"], + deps = [ + ":trt_plugins", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:nv_infer", + ]), +) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 899e1721e6..91faba7e21 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -77,7 +77,7 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { }; // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h) return (candidate_ops.count(node->type_string()) || - PluginFactoryTensorRT::GetInstance().IsPlugin(&node->type_string())); + PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); } void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index a03c1e224a..d02c1ebf50 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -249,9 +249,8 @@ class TFAttrs { std::vector GetAllAttrKey() { std::vector attr_list; - for (AttrMap::iterator iter = attrs_.begin(); iter != attrs_.end(); - iter++) { - attr_list.emplace_back(iter->first); + for (auto & attr_item : attrs_) { + attr_list.emplace_back(attr_item.first); } return attr_list; } @@ -508,7 +507,7 @@ class Converter { TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs)); string op = node_def.op(); std::vector outputs; - if (PluginFactoryTensorRT::GetInstance().IsPlugin(&op)) { + if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) { TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs)); } else { if (!op_registry_.count(op)) { @@ -1207,14 +1206,13 @@ tensorflow::Status ConvertPlugin(Converter& ctx, // plugin is owned by PluginFactory // TODO(jie): destroy plugins later (resource management) PluginTensorRT* plugin = - PluginFactoryTensorRT::GetInstance().CreatePlugin(&node_def.op()); + PluginFactoryTensorRT::GetInstance()->CreatePlugin(node_def.op()); // passing attributes // TODO(jie): support more general attribute TFAttrs attrs(node_def); auto attr_key_vector = attrs.GetAllAttrKey(); for (auto attr_key : attr_key_vector) { - std::cout << attr_key << std::endl; // TODO(jie): support only list of float for toy example here. auto data = attrs.get>(attr_key); size_t size_data = data.size() * sizeof(float); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc new file mode 100644 index 0000000000..2155079e8b --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -0,0 +1,89 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "inc_op_plugin.h" + +namespace tensorflow { +namespace tensorrt { + +const string IncOpPlugin::plugin_name_ = "IncPluginTRT"; + +IncOpPlugin* CreateIncPlugin() { + return new IncOpPlugin(); +} + + +IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { + return new IncOpPlugin(buffer, length); +} + +bool RegisterIncOpPlugin() { + if (PluginFactoryTensorRT::GetInstance()->IsPlugin(IncOpPlugin::plugin_name_)) + return false; + return PluginFactoryTensorRT::GetInstance()->RegisterPlugin(IncOpPlugin::plugin_name_, CreateIncPluginDeserialize, CreateIncPlugin); +} + + +IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) : + PluginTensorRT(serialized_data, length) +{ + // account for the consumed pointer. + size_t consumed_data = PluginTensorRT::getSerializationSize(); + assert(length-consumed_data >= sizeof(float)); + SetAttribute("inc", serialized_data+consumed_data, sizeof(float)); +} + +bool IncOpPlugin::SetAttribute(const string &key, const void *ptr, const size_t size) { + if (strcmp(key.c_str(), "inc")==0 && size == sizeof(float)) { + StoreAttribute(key, ptr, size); // save the attribute to own the data; + inc_ = *static_cast(ptr); + return true; + } + return false; +} + +bool IncOpPlugin::GetAttribute(const string &key, const void *ptr, size_t &size) { + if (attr_map_.find(key) != attr_map_.end()) { + ptr = attr_map_[key].data(); + size = attr_map_[key].size(); + return true; + } + return false; +} + +int IncOpPlugin::enqueue(int batchSize, const void*const *inputs, void** outputs, void*, cudaStream_t stream) { + int count = 1; + for (int i=0; i(inputs[0]); + float *output = reinterpret_cast(outputs[0]); + IncrementKernel(input, inc_, output, count, stream); + return 0; +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h new file mode 100644 index 0000000000..52b68487e6 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -0,0 +1,114 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN +#define TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include +#include +#include +#include +#include +#include + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +using std::string; +using std::unordered_map; + +class IncOpPlugin : public PluginTensorRT +{ +public: + static const string plugin_name_; + IncOpPlugin() {}; + IncOpPlugin(const void* serialized_data, size_t length); + const string GetPluginName() override {return plugin_name_;}; + bool Finalize() override {return true;}; + bool SetAttribute(const string &key, const void *ptr, const size_t size) override; + bool GetAttribute(const string &key, const void *ptr, size_t &size) override; + + // TRT IPlugin methods + int getNbOutputs() const override {return 1;} + + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override { + assert(index==0); + assert(nbInputDims==1); + return inputs[0]; + } + + // no configure needed + // use configure to setup input dimensions + void configure(const nvinfer1::Dims *inputs, int nbInputs, const nvinfer1::Dims *outputs, int nbOutputs, int maxBatchSize) override { + assert(nbInputs==1); + PluginTensorRT::configure(inputs, nbInputs, outputs, nbOutputs, maxBatchSize); + return; + } + + int initialize() override { + return 0; + } + + void terminate() override { + return; + } + + size_t getWorkspaceSize(int maxBatchSize) const override { + return 0; + } + + int enqueue(int batchSize, const void*const *inputs, void** outputs, void* workspace, cudaStream_t stream) override; + + size_t getSerializationSize() override { + return PluginTensorRT::getSerializationSize() + sizeof(float); + } + + void serialize(void* buffer) override { + // serializa parent stuff + // OpName + PluginTensorRT::serialize(buffer); + + // incremented buffer after parent serialization; + buffer = static_cast(buffer) + PluginTensorRT::getSerializationSize(); + + std::memcpy(buffer, &inc_, sizeof(float)); + buffer = static_cast(buffer) + sizeof(float); + return; + } + +protected: + float inc_; + nvinfer1::Dims dim_; + // std::unordered_map > attr_map_; +}; + +IncOpPlugin* CreateIncPlugin(); +IncOpPlugin* CreateIncPluginDeserialize(const void*, size_t); +bool RegisterIncOpPlugin(); +void IncrementKernel(const float* d_input, float inc, float* d_output, int count, cudaStream_t stream); + + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 8881c48fe6..162301fb52 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -60,7 +60,7 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { IRuntime* infer = nvinfer1::createInferRuntime(logger); trt_engine_ptr_.reset(infer->deserializeCudaEngine( serialized_engine.c_str(), serialized_engine.size(), - &PluginFactoryTensorRT::GetInstance())); + PluginFactoryTensorRT::GetInstance())); trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); // Runtime is safe to delete after engine creation infer->destroy(); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc index 0e4a157d79..7600703775 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc @@ -26,10 +26,10 @@ namespace tensorrt { PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) { // sanity check. - assert(EncodeOpName(GetPluginName()) != - *static_cast(serialized_data)); - const char* buffer = static_cast(serialized_data) + - sizeof(input_dim_list_.size()); + const char* buffer = static_cast(serialized_data); + size_t op_name_char_count = *reinterpret_cast(buffer); + buffer += sizeof(size_t); + buffer += op_name_char_count; size_t count = *reinterpret_cast(buffer); buffer += sizeof(size_t); @@ -46,18 +46,37 @@ PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) { } } +void PluginTensorRT::configure(const nvinfer1::Dims* inputs, int num_inputs, + const nvinfer1::Dims* outputs, int num_outputs, + int max_batch_size) { + for (int index = 0; index < num_inputs; index++) { + nvinfer1::Dims dim; + dim.nbDims = inputs[index].nbDims; + for (int i = 0; i < dim.nbDims; i++) { + dim.d[i] = inputs[index].d[i]; + dim.type[i] = inputs[index].type[i]; + } + input_dim_list_.emplace_back(dim); + } + return; +} + size_t PluginTensorRT::getSerializationSize() { nvinfer1::Dims dim; - return sizeof(size_t) + sizeof(input_dim_list_.size()) + sizeof(dim.nbDims) + - sizeof(dim.d) + sizeof(dim.type); + return sizeof(size_t) + GetPluginName().size() + + sizeof(input_dim_list_.size()) + sizeof(dim.nbDims) + sizeof(dim.d) + + sizeof(dim.type); } void PluginTensorRT::serialize(void* serialized_data) { - size_t encode_op_name = EncodeOpName(GetPluginName()); + size_t op_name_size = GetPluginName().size(); char* buffer = static_cast(serialized_data); - std::memcpy(buffer, &encode_op_name, sizeof(size_t)); + std::memcpy(buffer, &op_name_size, sizeof(size_t)); buffer += sizeof(size_t); + std::memcpy(buffer, GetPluginName().data(), op_name_size); + buffer += op_name_size; + auto list_size = input_dim_list_.size(); std::memcpy(buffer, &list_size, sizeof(input_dim_list_.size())); buffer += sizeof(input_dim_list_.size()); @@ -73,7 +92,7 @@ void PluginTensorRT::serialize(void* serialized_data) { } } -bool PluginTensorRT::StoreAttribute(const string& key, const void* ptr, +bool PluginTensorRT::StoreAttribute(const std::string& key, const void* ptr, const size_t size) { if (attr_map_.count(key) != 0) return false; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h index 1bbfe62a4e..59b92657f6 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -28,46 +28,37 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -using std::string; -using std::unordered_map; - +// A wrapper class for TensorRT plugin +// User application should inherit from this class to write custom kernels. +// Allows user to insert custom op in TensorRT engine +// To register plugin in converter, user should also register custom +// tensorflow::tensorrt::PluginDeserializeFunc & +// tensorflow::tensorrt::PluginConstructFunc through +// tensorflow::tensorrt::PluginFactoryTensorRT class PluginTensorRT : public nvinfer1::IPlugin { public: PluginTensorRT(){}; PluginTensorRT(const void* serialized_data, size_t length); - // PluginTensorRT(const void* serialized_data, size_t length, size_t - // &incremental); - virtual string GetPluginName() = 0; + virtual const std::string& GetPluginName() = 0; virtual bool Finalize() = 0; - virtual bool SetAttribute(const string& key, const void* ptr, + virtual bool SetAttribute(const std::string& key, const void* ptr, const size_t size) = 0; - virtual bool GetAttribute(const string& key, const void* ptr, + virtual bool GetAttribute(const std::string& key, const void* ptr, size_t& size) = 0; - void configure(const nvinfer1::Dims* inputs, int nbInputs, - const nvinfer1::Dims* outputs, int nbOutputs, - int maxBatchSize) override { - for (int index = 0; index < nbInputs; index++) { - nvinfer1::Dims dim; - dim.nbDims = inputs[index].nbDims; - for (int i = 0; i < dim.nbDims; i++) { - dim.d[i] = inputs[index].d[i]; - dim.type[i] = inputs[index].type[i]; - } - input_dim_list_.emplace_back(dim); - } - return; - } - - virtual bool StoreAttribute(const string& key, const void* ptr, + void configure(const nvinfer1::Dims* inputs, int num_inputs, + const nvinfer1::Dims* outputs, int num_outputs, + int max_batch_size) override; + + virtual bool StoreAttribute(const std::string& key, const void* ptr, const size_t size); virtual size_t getSerializationSize() override; virtual void serialize(void* buffer) override; protected: - std::unordered_map > attr_map_; + std::unordered_map > attr_map_; std::vector input_dim_list_; }; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index 799c609a3e..44b10394c8 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -21,13 +21,13 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layerName, +PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, const void* serial_data, size_t serial_length) { size_t parsed_byte = 0; // extract op_name from serial_data - size_t encoded_op_name = - ExtractOpName(serial_data, serial_length, parsed_byte); + std::string encoded_op_name = + ExtractOpName(serial_data, serial_length, &parsed_byte); if (!IsPlugin(encoded_op_name)) { return nullptr; @@ -37,20 +37,18 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layerName, instance_m_.lock(); auto plugin_ptr = plugin_registry_[encoded_op_name].first(serial_data, serial_length); - // string op_name = "IncPluginTRT"; - // auto plugin_ptr = plugin_registry_[EncodeLayerName(&op_name)].second(); - // auto plugin_ptr = plugin_registry_.begin()->second.second(); owned_plugins_.emplace_back(plugin_ptr); instance_m_.unlock(); return plugin_ptr; } -PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string* op_name) { +PluginTensorRT* PluginFactoryTensorRT::CreatePlugin( + const std::string& op_name) { if (!IsPlugin(op_name)) return nullptr; instance_m_.lock(); - auto plugin_ptr = plugin_registry_[EncodeLayerName(op_name)].second(); + auto plugin_ptr = plugin_registry_[op_name].second(); owned_plugins_.emplace_back(plugin_ptr); instance_m_.unlock(); @@ -58,21 +56,27 @@ PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string* op_name) { } bool PluginFactoryTensorRT::RegisterPlugin( - const string* op_name, PluginDeserializeFunc deserialize_func, + const std::string& op_name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func) { if (IsPlugin(op_name)) return false; // get instance_m_ first before write to registry; instance_m_.lock(); auto ret = plugin_registry_.emplace( - EncodeLayerName(op_name), - std::make_pair(deserialize_func, construct_func)); + op_name, std::make_pair(deserialize_func, construct_func)); instance_m_.unlock(); return ret.second; } -void PluginFactoryTensorRT::DestroyPlugins() { return; } +void PluginFactoryTensorRT::DestroyPlugins() { + instance_m_.lock(); + for (auto& owned_plugin_ptr : owned_plugins_) { + owned_plugin_ptr.release(); + } + owned_plugins_.clear(); + instance_m_.unlock(); +} } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index e68f4629d0..824efcff35 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -32,39 +32,34 @@ namespace tensorrt { class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { public: // deserialization method - // virtual nvinfer1::IPlugin* createPlugin(const char* layerName, const void* - // serialData, size_t serialLength) override; - PluginTensorRT* createPlugin(const char* layerName, const void* serialData, - size_t serialLength) override; + PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data, + size_t serial_length) override; - // construction - PluginTensorRT* CreatePlugin(const string* op_name); + // plugin construction, PluginFactoryTensorRT owns the plugin; + PluginTensorRT* CreatePlugin(const std::string& op_name); - static PluginFactoryTensorRT& GetInstance() { - static PluginFactoryTensorRT factory_instance; + static PluginFactoryTensorRT* GetInstance() { + static PluginFactoryTensorRT* factory_instance = nullptr; + if (factory_instance == nullptr) { + factory_instance = new PluginFactoryTensorRT(); + } return factory_instance; } - bool RegisterPlugin(const string* op_name, + bool RegisterPlugin(const std::string& op_name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func); - bool IsPlugin(const size_t encode_name) { - return plugin_registry_.find(encode_name) != plugin_registry_.end(); + bool IsPlugin(const std::string& op_name) { + return plugin_registry_.find(op_name) != plugin_registry_.end(); } - bool IsPlugin(const string* op_name) { - return IsPlugin(EncodeLayerName(op_name)); - } - - size_t EncodeLayerName(const string* op_name) { - return EncodeOpName(*op_name); - } + size_t CountOwnedPlugins() { return owned_plugins_.size(); } void DestroyPlugins(); protected: - std::unordered_map > plugin_registry_; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc index b14480cfa6..8b65e8b41c 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -21,12 +22,17 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -size_t ExtractOpName(const void* serial_data, size_t serial_length, - size_t& incremental) { - incremental = sizeof(size_t); - if (serial_length < incremental) return 0; - size_t encoded_op_name = *static_cast(serial_data); - return encoded_op_name; +std::string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental) { + size_t op_name_char_count = *static_cast(serial_data); + *incremental = sizeof(size_t) + op_name_char_count; + + assert(serial_length >= *incremental); + + const char* buffer = static_cast(serial_data) + sizeof(size_t); + std::string op_name(buffer, op_name_char_count); + + return op_name; } } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h index e9675d84cd..d4da8b261e 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -31,16 +31,9 @@ typedef std::function typedef std::function PluginConstructFunc; -inline size_t EncodeOpName(std::string str) { - return std::hash{}(str); -} - // TODO(jie): work on error handling here -size_t ExtractOpName(const void* serial_data, size_t serial_length, - size_t& incremental); - -// size_t Deserialize(const char* serial_data, size_t serial_length, size_t -// &incremental); +std::string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc new file mode 100644 index 0000000000..2856b0f87d --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +namespace test { + +class StubPlugin : public PluginTensorRT { + public: + static const std::string plugin_name_; + StubPlugin(){}; + StubPlugin(const void* serialized_data, size_t length) + : PluginTensorRT(serialized_data, length){}; + const std::string& GetPluginName() override { return plugin_name_; }; + virtual bool Finalize() { return true; }; + virtual bool SetAttribute(const std::string& key, const void* ptr, + const size_t size) { + return true; + }; + virtual bool GetAttribute(const std::string& key, const void* ptr, + size_t& size) { + return true; + }; + int getNbOutputs() const override { return 1; } + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, + int nbInputDims) override { + return inputs[0]; + } + int initialize() override { return 0; } + void terminate() override { return; } + size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } + int enqueue(int batchSize, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) override { + return 0; + } +}; + +const std::string StubPlugin::plugin_name_ = "StubPlugin"; + +StubPlugin* CreateStubPlugin() { return new StubPlugin(); } + +StubPlugin* CreateStubPluginDeserialize(const void* serialized_data, + size_t length) { + return new StubPlugin(serialized_data, length); +} + +class PluginTest : public ::testing::Test { + public: + bool RegisterStubPlugin() { + if (PluginFactoryTensorRT::GetInstance()->IsPlugin( + StubPlugin::plugin_name_)) + return true; + return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( + StubPlugin::plugin_name_, CreateStubPluginDeserialize, + CreateStubPlugin); + } + + protected: +}; + +TEST_F(PluginTest, Registration) { + EXPECT_FALSE( + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); + EXPECT_TRUE(RegisterStubPlugin()); + + ASSERT_TRUE( + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); +} + +TEST_F(PluginTest, CreationDeletion) { + EXPECT_TRUE(RegisterStubPlugin()); + ASSERT_TRUE( + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); + + PluginFactoryTensorRT::GetInstance()->DestroyPlugins(); + ASSERT_TRUE(PluginFactoryTensorRT::GetInstance()->CreatePlugin( + StubPlugin::plugin_name_)); + ASSERT_EQ(1, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins()); + PluginFactoryTensorRT::GetInstance()->DestroyPlugins(); + ASSERT_EQ(0, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins()); +} + +} // namespace test +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index 30b5616475..f36495f6b6 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -35,7 +35,7 @@ tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine( serialized_engine.c_str(), serialized_engine.size(), - &tensorrt::PluginFactoryTensorRT::GetInstance()); + tensorrt::PluginFactoryTensorRT::GetInstance()); int num_batch = -1; std::vector<::tensorflow::DataType> input_type; -- GitLab From 419dbc8f44efe06612845ec291b98bb49e873639 Mon Sep 17 00:00:00 2001 From: Jie Date: Wed, 18 Apr 2018 14:42:42 -0700 Subject: [PATCH 008/738] [PR comment addressed] Added custom plugin example registered tensorflow custom op & plugin kernel python wrapper to import custom op & register plugin clang-format --- tensorflow/contrib/tensorrt/BUILD | 1 + .../contrib/tensorrt/convert/convert_nodes.cc | 2 +- .../tensorrt/custom_plugin_examples/BUILD | 110 ++++++++++++++++++ .../custom_plugin_examples/__init__.py | 24 ++++ .../tensorrt/custom_plugin_examples/inc_op.py | 30 +++++ .../inc_op_kernel.cu.cc | 44 +++++++ .../custom_plugin_examples/inc_op_kernel.h | 34 ++++++ .../custom_plugin_examples/inc_op_plugin.cc | 55 +++++---- .../custom_plugin_examples/inc_op_plugin.h | 85 +++++++------- .../custom_plugin_examples/ops/inc_op.cc | 34 ++++++ .../custom_plugin_examples/plugin_wrap.i | 31 +++++ .../test/plugin_test.py | 93 +++++++++++++++ .../contrib/tensorrt/plugin/trt_plugin.cc | 1 - .../contrib/tensorrt/plugin/trt_plugin.h | 10 +- .../tensorrt/plugin/trt_plugin_factory.cc | 14 +-- .../tensorrt/plugin/trt_plugin_factory.h | 6 +- .../tensorrt/plugin/trt_plugin_utils.cc | 4 +- .../tensorrt/plugin/trt_plugin_utils.h | 5 +- .../tensorrt/plugin/trt_plugins_test.cc | 6 +- 19 files changed, 485 insertions(+), 104 deletions(-) create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 751f1d3482..9c81c12705 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -291,6 +291,7 @@ tf_cuda_library( ], linkstatic = 1, deps = [ + "//tensorflow/core:platform_base", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]), diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index d02c1ebf50..874be96c78 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -249,7 +249,7 @@ class TFAttrs { std::vector GetAllAttrKey() { std::vector attr_list; - for (auto & attr_item : attrs_) { + for (const auto& attr_item : attrs_) { attr_list.emplace_back(attr_item.first); } return attr_list; diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD new file mode 100644 index 0000000000..5603ed0ccf --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -0,0 +1,110 @@ +package(default_visibility = ["//tensorflow:__subpackages__"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_cuda_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_py_wrap_cc", + "tf_copts", +) +load( + "@local_config_tensorrt//:build_defs.bzl", + "if_tensorrt", +) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") + +tf_kernel_library( + name = "_inc_op_plugin_kernel", + srcs = [ + "inc_op_plugin.cc", + ], + hdrs = [ + ], + gpu_srcs = [ + "inc_op_kernel.cu.cc", + "inc_op_kernel.h", + "inc_op_plugin.h", + ], + deps = if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + "//tensorflow/contrib/tensorrt:trt_plugins", + ]), +) + +tf_gen_op_libs( + op_lib_names = [ + "inc_op", + ], + deps = if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + "//tensorflow/contrib/tensorrt:trt_plugins", + ]), +) + +tf_gen_op_wrapper_py( + name = "inc_op", + gen_locally = True, + deps = [ + ":inc_op_op_lib", + ], +) + +tf_py_wrap_cc( + name = "plugin_wrap", + srcs = [ + "plugin_wrap.i", + ], + copts = tf_copts(), + deps = [ + ":_inc_op_plugin_kernel", + "//tensorflow/core:framework_lite", + "//util/python:python_headers", + ], +) + +tf_custom_op_library( + name = "_inc_op.so", + srcs = ["ops/inc_op.cc"], + deps = [ + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "//tensorflow/contrib/tensorrt:trt_plugins", + ]), +) + +tf_custom_op_py_library( + name = "inc_op_loader", + srcs = ["inc_op.py"], + dso = [ + ":_inc_op.so", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:resources", + ], +) + +py_library( + name = "inc_op_py", + srcs_version = "PY2AND3", + deps = [ + ":inc_op", + ":inc_op_loader", + ], +) + +py_library( + name = "init_py", + srcs = [ + "__init__.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":inc_op_py", + ":plugin_wrap", + ], +) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py new file mode 100644 index 0000000000..a61d008941 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Import custom op for plugin and register it in plugin factory registry.""" + +from ops import gen_inc_op +from plugin_wrap import inc_op_register +from inc_op import * + +# pylint: disable=unused-import,wildcard-import,g-import-not-at-top +inc_op = gen_inc_op.inc_plugin_trt +inc_op_register() +# pylint: enable=unused-import,wildcard-import,g-import-not-at-top diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py new file mode 100644 index 0000000000..ef8e26fbde --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py @@ -0,0 +1,30 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import platform +import os + +if platform.system() != "Windows": + from tensorflow.contrib.util import loader + from tensorflow.python.platform import resource_loader + + _inc_op = loader.load_op_library( + os.path.join(os.path.dirname(os.path.realpath(__file__)),"_inc_op.so")) +else: + raise RuntimeError("Windows not supported") diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc new file mode 100644 index 0000000000..5dd6b9bf94 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +__global__ void VecInc(const float* vec, float inc, float* dest, int n) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < n) dest[i] = vec[i] + inc; +} + +void IncrementKernel(const float* d_input, float inc, float* d_output, + int count, cudaStream_t stream) { + int threads_per_block = 256; + int blocks_per_grid = (count + threads_per_block - 1) / threads_per_block; + + VecInc<<>>(d_input, inc, + d_output, count); +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h new file mode 100644 index 0000000000..ec269143e8 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_INC_OP +#define TENSORFLOW_CONTRIB_TENSORRT_INC_OP + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +__global__ void VecInc(float* vec, float inc, float* dest, int n); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_INC_OP diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc index 2155079e8b..21617fa8b5 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -13,24 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "inc_op_plugin.h" namespace tensorflow { namespace tensorrt { -const string IncOpPlugin::plugin_name_ = "IncPluginTRT"; - -IncOpPlugin* CreateIncPlugin() { - return new IncOpPlugin(); -} +const std::string IncOpPlugin::plugin_name_ = "IncPluginTRT"; +IncOpPlugin* CreateIncPlugin() { return new IncOpPlugin(); } IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { return new IncOpPlugin(buffer, length); @@ -39,45 +34,49 @@ IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { bool RegisterIncOpPlugin() { if (PluginFactoryTensorRT::GetInstance()->IsPlugin(IncOpPlugin::plugin_name_)) return false; - return PluginFactoryTensorRT::GetInstance()->RegisterPlugin(IncOpPlugin::plugin_name_, CreateIncPluginDeserialize, CreateIncPlugin); + return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( + IncOpPlugin::plugin_name_, CreateIncPluginDeserialize, CreateIncPlugin); } - -IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) : - PluginTensorRT(serialized_data, length) -{ +IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) + : PluginTensorRT(serialized_data, length) { // account for the consumed pointer. size_t consumed_data = PluginTensorRT::getSerializationSize(); - assert(length-consumed_data >= sizeof(float)); - SetAttribute("inc", serialized_data+consumed_data, sizeof(float)); + assert(length - consumed_data >= sizeof(float)); + const char* buffer = reinterpret_cast(serialized_data); + SetAttribute("inc", buffer + consumed_data, sizeof(float)); } -bool IncOpPlugin::SetAttribute(const string &key, const void *ptr, const size_t size) { - if (strcmp(key.c_str(), "inc")==0 && size == sizeof(float)) { - StoreAttribute(key, ptr, size); // save the attribute to own the data; +bool IncOpPlugin::SetAttribute(const std::string& key, const void* ptr, + const size_t size) { + if (strcmp(key.c_str(), "inc") == 0 && size == sizeof(float)) { + StoreAttribute(key, ptr, size); // save the attribute to own the data; inc_ = *static_cast(ptr); return true; } return false; } -bool IncOpPlugin::GetAttribute(const string &key, const void *ptr, size_t &size) { - if (attr_map_.find(key) != attr_map_.end()) { - ptr = attr_map_[key].data(); - size = attr_map_[key].size(); +bool IncOpPlugin::GetAttribute(const std::string& key, const void** ptr, + size_t* size) const { + const auto& iter = attr_map_.find(key); + if (iter != attr_map_.end()) { + *ptr = iter->second.data(); + *size = iter->second.size(); return true; } return false; } -int IncOpPlugin::enqueue(int batchSize, const void*const *inputs, void** outputs, void*, cudaStream_t stream) { +int IncOpPlugin::enqueue(int batch_size, const void* const* inputs, + void** outputs, void*, cudaStream_t stream) { int count = 1; - for (int i=0; i(inputs[0]); - float *output = reinterpret_cast(outputs[0]); + count *= batch_size; + const float* input = reinterpret_cast(inputs[0]); + float* output = reinterpret_cast(outputs[0]); IncrementKernel(input, inc_, output, count, stream); return 0; } diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index 52b68487e6..a4774d354c 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -16,13 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN #define TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" -#include -#include -#include -#include #include +#include #include +#include +#include +#include +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -31,50 +31,44 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -using std::string; -using std::unordered_map; - -class IncOpPlugin : public PluginTensorRT -{ -public: - static const string plugin_name_; - IncOpPlugin() {}; +class IncOpPlugin : public PluginTensorRT { + public: + static const std::string plugin_name_; + IncOpPlugin(){}; IncOpPlugin(const void* serialized_data, size_t length); - const string GetPluginName() override {return plugin_name_;}; - bool Finalize() override {return true;}; - bool SetAttribute(const string &key, const void *ptr, const size_t size) override; - bool GetAttribute(const string &key, const void *ptr, size_t &size) override; - - // TRT IPlugin methods - int getNbOutputs() const override {return 1;} - - nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override { - assert(index==0); - assert(nbInputDims==1); + const std::string& GetPluginName() const override { return plugin_name_; }; + bool Finalize() override { return true; }; + bool SetAttribute(const std::string& key, const void* ptr, + const size_t size) override; + bool GetAttribute(const std::string& key, const void** ptr, + size_t* size) const override; + + int getNbOutputs() const override { return 1; } + + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, + int num_input_dims) override { + assert(index == 0); + assert(num_input_dims == 1); return inputs[0]; } - // no configure needed // use configure to setup input dimensions - void configure(const nvinfer1::Dims *inputs, int nbInputs, const nvinfer1::Dims *outputs, int nbOutputs, int maxBatchSize) override { - assert(nbInputs==1); - PluginTensorRT::configure(inputs, nbInputs, outputs, nbOutputs, maxBatchSize); - return; + void configure(const nvinfer1::Dims* inputs, int num_inputs, + const nvinfer1::Dims* outputs, int num_outputs, + int max_batch_size) override { + assert(nb_inputs == 1); + PluginTensorRT::configure(inputs, num_inputs, outputs, num_outputs, + max_batch_size); } - int initialize() override { - return 0; - } + int initialize() override { return 0; } - void terminate() override { - return; - } + void terminate() override {} - size_t getWorkspaceSize(int maxBatchSize) const override { - return 0; - } + size_t getWorkspaceSize(int max_batch_size) const override { return 0; } - int enqueue(int batchSize, const void*const *inputs, void** outputs, void* workspace, cudaStream_t stream) override; + int enqueue(int batch_size, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) override; size_t getSerializationSize() override { return PluginTensorRT::getSerializationSize() + sizeof(float); @@ -86,24 +80,23 @@ public: PluginTensorRT::serialize(buffer); // incremented buffer after parent serialization; - buffer = static_cast(buffer) + PluginTensorRT::getSerializationSize(); + buffer = + static_cast(buffer) + PluginTensorRT::getSerializationSize(); std::memcpy(buffer, &inc_, sizeof(float)); buffer = static_cast(buffer) + sizeof(float); - return; } -protected: + protected: float inc_; nvinfer1::Dims dim_; - // std::unordered_map > attr_map_; }; -IncOpPlugin* CreateIncPlugin(); +IncOpPlugin* CreateIncPlugin(); IncOpPlugin* CreateIncPluginDeserialize(const void*, size_t); bool RegisterIncOpPlugin(); -void IncrementKernel(const float* d_input, float inc, float* d_output, int count, cudaStream_t stream); - +void IncrementKernel(const float* d_input, float inc, float* d_output, + int count, cudaStream_t stream); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc new file mode 100644 index 0000000000..0dfead8f57 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +using namespace tensorflow; + +REGISTER_OP("IncPluginTRT") + .Attr("inc: list(float)") + .Input("input: float32") + .Output("output: float32") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i new file mode 100644 index 0000000000..9882daa842 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* Wrap inc_op_plugin */ +%module inc_op_plugin +%{ +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" +extern bool tensorflow::tensorrt::RegisterIncOpPlugin(); +%} + +%{ +bool inc_op_register() { + return tensorflow::tensorrt::RegisterIncOpPlugin(); +} +%} + +extern bool tensorflow::tensorrt::RegisterIncOpPlugin(); + +bool inc_op_register(); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py new file mode 100644 index 0000000000..52f49ae00e --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py @@ -0,0 +1,93 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to show usage of TensorRT custom op & plugin.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# normally we should do import tensorflow as tf and then +# tf.placeholder, tf.constant, tf.nn.conv2d etc but +# it looks like internal builds don't like it so +# importing every module individually + +from tensorflow.contrib import tensorrt as trt +from tensorflow.core.protobuf import config_pb2 as cpb2 +from tensorflow.python.client import session as csess +from tensorflow.python.framework import dtypes as dtypes +from tensorflow.python.framework import importer as importer +from tensorflow.python.framework import ops as ops +from tensorflow.python.ops import array_ops as aops +from tensorflow.python.ops import nn as nn +from tensorflow.python.ops import nn_ops as nn_ops +import numpy as np + +# import custom_op as plugin op +# the python api handles registration to the plugin factory +from tensorflow.contrib.tensorrt import custom_plugin_examples as cpe + +def get_plugin_graph_def(): + """Create a simple graph and return its graph_def.""" + g = ops.Graph() + with g.as_default(): + a = aops.placeholder( + dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") + relu = nn.relu(a, "relu") + v = nn_ops.max_pool( + relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") + + # insert custom_op in the graph + v = cpe.inc_op(v, inc=[16.5], name="plugin_test") + + v = v*2.0 + v = nn.relu(v) + v = nn.relu(v) + aops.squeeze(v, name="output") + return g.as_graph_def() + +def run_graph(gdef, dumm_inp): + """Run given graphdef once.""" + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + ops.reset_default_graph() + g = ops.Graph() + with g.as_default(): + inp, out = importer.import_graph_def( + graph_def=gdef, return_elements=["input", "output"]) + inp = inp.outputs[0] + out = out.outputs[0] + + with csess.Session( + config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: + val = sess.run(out, {inp: dumm_inp}) + return val + +if "__main__" in __name__: + inp_dims = (5, 24, 24, 2) + dummy_input = np.ones(inp_dims).astype(np.float32) + orig_graph = get_plugin_graph_def() # graph with plugin node + + # trigger conversion. + # plugin nodes have been registered during import, converter will be able to + # create corresponding plugin layer during conversion. + trt_graph = trt.create_inference_graph( + input_graph_def=orig_graph, + outputs=["output"], + max_batch_size=inp_dims[0], + max_workspace_size_bytes=1 << 25, + precision_mode="FP32", + minimum_segment_size=2 + ) + o2 = run_graph(trt_graph, dummy_input) + print (o2) diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc index 7600703775..82c549dbf5 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc @@ -58,7 +58,6 @@ void PluginTensorRT::configure(const nvinfer1::Dims* inputs, int num_inputs, } input_dim_list_.emplace_back(dim); } - return; } size_t PluginTensorRT::getSerializationSize() { diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h index 59b92657f6..772974a769 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -32,20 +32,18 @@ namespace tensorrt { // User application should inherit from this class to write custom kernels. // Allows user to insert custom op in TensorRT engine // To register plugin in converter, user should also register custom -// tensorflow::tensorrt::PluginDeserializeFunc & -// tensorflow::tensorrt::PluginConstructFunc through -// tensorflow::tensorrt::PluginFactoryTensorRT +// PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT class PluginTensorRT : public nvinfer1::IPlugin { public: PluginTensorRT(){}; PluginTensorRT(const void* serialized_data, size_t length); - virtual const std::string& GetPluginName() = 0; + virtual const std::string& GetPluginName() const = 0; virtual bool Finalize() = 0; virtual bool SetAttribute(const std::string& key, const void* ptr, const size_t size) = 0; - virtual bool GetAttribute(const std::string& key, const void* ptr, - size_t& size) = 0; + virtual bool GetAttribute(const std::string& key, const void** ptr, + size_t* size) const = 0; void configure(const nvinfer1::Dims* inputs, int num_inputs, const nvinfer1::Dims* outputs, int num_outputs, diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index 44b10394c8..776bce119d 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -33,12 +33,10 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, return nullptr; } - // should I lock plugins here? - instance_m_.lock(); + std::lock_guard lock(instance_m_); auto plugin_ptr = plugin_registry_[encoded_op_name].first(serial_data, serial_length); owned_plugins_.emplace_back(plugin_ptr); - instance_m_.unlock(); return plugin_ptr; } @@ -47,10 +45,9 @@ PluginTensorRT* PluginFactoryTensorRT::CreatePlugin( const std::string& op_name) { if (!IsPlugin(op_name)) return nullptr; - instance_m_.lock(); + std::lock_guard lock(instance_m_); auto plugin_ptr = plugin_registry_[op_name].second(); owned_plugins_.emplace_back(plugin_ptr); - instance_m_.unlock(); return plugin_ptr; } @@ -60,22 +57,19 @@ bool PluginFactoryTensorRT::RegisterPlugin( PluginConstructFunc construct_func) { if (IsPlugin(op_name)) return false; - // get instance_m_ first before write to registry; - instance_m_.lock(); + std::lock_guard lock(instance_m_); auto ret = plugin_registry_.emplace( op_name, std::make_pair(deserialize_func, construct_func)); - instance_m_.unlock(); return ret.second; } void PluginFactoryTensorRT::DestroyPlugins() { - instance_m_.lock(); + std::lock_guard lock(instance_m_); for (auto& owned_plugin_ptr : owned_plugins_) { owned_plugin_ptr.release(); } owned_plugins_.clear(); - instance_m_.unlock(); } } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 824efcff35..08fd376844 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -39,10 +39,8 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { PluginTensorRT* CreatePlugin(const std::string& op_name); static PluginFactoryTensorRT* GetInstance() { - static PluginFactoryTensorRT* factory_instance = nullptr; - if (factory_instance == nullptr) { - factory_instance = new PluginFactoryTensorRT(); - } + static PluginFactoryTensorRT* factory_instance = + new PluginFactoryTensorRT(); return factory_instance; } diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc index 8b65e8b41c..c5d3f38280 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc @@ -22,8 +22,8 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -std::string ExtractOpName(const void* serial_data, size_t serial_length, - size_t* incremental) { +string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental) { size_t op_name_char_count = *static_cast(serial_data); *incremental = sizeof(size_t) + op_name_char_count; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h index d4da8b261e..a94c67bba0 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -32,8 +33,8 @@ typedef std::function typedef std::function PluginConstructFunc; // TODO(jie): work on error handling here -std::string ExtractOpName(const void* serial_data, size_t serial_length, - size_t* incremental); +string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc index 2856b0f87d..9ef0fce972 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc @@ -51,9 +51,9 @@ class StubPlugin : public PluginTensorRT { return inputs[0]; } int initialize() override { return 0; } - void terminate() override { return; } + void terminate() override {} size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } - int enqueue(int batchSize, const void* const* inputs, void** outputs, + int enqueue(int batch_size, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override { return 0; } @@ -78,8 +78,6 @@ class PluginTest : public ::testing::Test { StubPlugin::plugin_name_, CreateStubPluginDeserialize, CreateStubPlugin); } - - protected: }; TEST_F(PluginTest, Registration) { -- GitLab From abfbbb86295c67eb1ac7c92235dbd5fb4b707169 Mon Sep 17 00:00:00 2001 From: Haggai Date: Wed, 18 Apr 2018 22:23:35 -0700 Subject: [PATCH 009/738] Remove reliance on TF core in XLA CPU Fft --- tensorflow/compiler/xla/service/cpu/BUILD | 1 - .../xla/service/cpu/runtime_fft_impl.h | 18 +++--------------- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 246b802861..6428ca528c 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -513,7 +513,6 @@ cc_library( deps = [ "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h index 984cb0616e..4f6b363364 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -21,8 +21,6 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/numeric_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/types.h" // 'tensorflow' namespace is used so that int64 and other types don't require @@ -71,11 +69,9 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = fft_shape[i]; out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -88,8 +84,8 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); // Compute the full FFT using a temporary tensor. - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(in_dims); + const Eigen::DSizes zero_start_indices; full_fft.device(device) = input.template fft(axes); @@ -112,11 +108,9 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; out_dims[i + 1] = fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -129,8 +123,7 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, // region we will slice from input given fft_shape. We slice input to // fft_shape on its inner-most dimensions, except the last (which we // slice to fft_shape[-1] / 2 + 1). - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(out_dims); // Calculate the starting point and range of the source of // negative frequency part. @@ -179,7 +172,6 @@ template void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, int32 fft_type, int64 input_batch, int64 fft_length0, int64 fft_length1, int64 fft_length2) { - CHECK(::xla::FftType_IsValid(fft_type)) << fft_type; switch (fft_type) { case ::xla::FftType::FFT: EigenFftC2C( @@ -203,8 +195,6 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, device, static_cast(out), static_cast(operand), input_batch, fft_length0, fft_length1, fft_length2); break; - default: - LOG(FATAL) << "Unsupported FFT type: " << fft_type; } } @@ -229,8 +219,6 @@ void EigenFftImpl(const EigenDevice& device, void* out, void* operand, input_batch, fft_length0, fft_length1, fft_length2); break; - default: - LOG(FATAL) << "Unsupported FFT rank " << fft_rank; } } -- GitLab From 6343b8dd77ba94c74acc3c04c985a5535b2b8169 Mon Sep 17 00:00:00 2001 From: Haggai Date: Wed, 18 Apr 2018 22:26:42 -0700 Subject: [PATCH 010/738] Add single-threaded support for XLA CPU Fft --- tensorflow/compiler/xla/service/cpu/BUILD | 17 ++++++++++ .../compiler/xla/service/cpu/cpu_runtime.cc | 2 ++ .../compiler/xla/service/cpu/cpu_runtime.h | 1 + .../compiler/xla/service/cpu/ir_emitter.cc | 8 ++++- .../cpu/runtime_single_threaded_fft.cc | 32 +++++++++++++++++++ .../service/cpu/runtime_single_threaded_fft.h | 31 ++++++++++++++++++ .../xla/service/cpu/simple_orc_jit.cc | 2 ++ 7 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 6428ca528c..4862f9e2f9 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -176,6 +176,7 @@ cc_library( ":runtime_matmul", ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", + ":runtime_single_threaded_fft", ":runtime_single_threaded_matmul", "@llvm//:execution_engine", "@llvm//:core", @@ -574,6 +575,22 @@ cc_library( ], ) +cc_library( + name = "runtime_single_threaded_fft", + srcs = [ + "runtime_fft_impl.h", + "runtime_single_threaded_fft.cc", + ], + hdrs = ["runtime_single_threaded_fft.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + cc_library( name = "runtime_single_threaded_matmul", srcs = ["runtime_single_threaded_matmul.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 872b0be1f8..4fcab483d6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -50,6 +50,8 @@ extern const char* const kEigenConvF16SymbolName = extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; +extern const char* const kEigenSingleThreadedFftSymbolName = + "__xla_cpu_runtime_EigenSingleThreadedFft"; extern const char* const kEigenSingleThreadedMatMulF16SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF16"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index e392e231b4..0cc45dac61 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -51,6 +51,7 @@ extern const char* const kMKLSingleThreadedMatMulF64SymbolName; extern const char* const kEigenConvF16SymbolName; extern const char* const kEigenConvF32SymbolName; extern const char* const kEigenFftSymbolName; +extern const char* const kEigenSingleThreadedFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF16SymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 3405277d44..8c2ca7104c 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1171,7 +1171,13 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); - const char* fn_name = runtime::kEigenFftSymbolName; + + bool multi_threaded_eigen = + hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + const char* fn_name = multi_threaded_eigen + ? runtime::kEigenFftSymbolName + : runtime::kEigenSingleThreadedFftSymbolName; + llvm::Function* fft_func = llvm::cast( module_->getOrInsertFunction(fn_name, fft_type)); fft_func->setCallingConv(llvm::CallingConv::C); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc new file mode 100644 index 0000000000..2613ddb127 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" + +#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* run_options_ptr, void* out, void* operand, int32 fft_type, + int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand, fft_type, + fft_rank, input_batch, fft_length0, fft_length1, + fft_length2); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h new file mode 100644 index 0000000000..dcd133d012 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, + void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + tensorflow::int64 input_batch, tensorflow::int64 fft_length0, + tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index b7ce5bbe47..7bd17002e3 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" #include "tensorflow/compiler/xla/types.h" @@ -190,6 +191,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); -- GitLab From 8149077125f6c2701713fef12fe0f0caac729e27 Mon Sep 17 00:00:00 2001 From: Krish Ravindranath Date: Thu, 19 Apr 2018 14:53:10 -0400 Subject: [PATCH 011/738] changes error to ValueError, notes that shuffle must be provided and should be set True for training --- tensorflow/python/estimator/inputs/numpy_io.py | 5 +++-- tensorflow/python/estimator/inputs/pandas_io.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py index a6f4712910..5b5eb41466 100644 --- a/tensorflow/python/estimator/inputs/numpy_io.py +++ b/tensorflow/python/estimator/inputs/numpy_io.py @@ -139,8 +139,9 @@ def numpy_input_fn(x, TypeError: `x` is not a dict or array, or if `shuffle` is not bool. """ if not isinstance(shuffle, bool): - raise TypeError('shuffle must be explicitly set as boolean; ' - 'got {}'.format(shuffle)) + raise ValueError('shuffle must be provided and explicitly set as boolean ' + '(it is recommended to set it as True for training); ' + 'got {}'.format(shuffle)) def input_fn(): """Numpy input function.""" diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py index bd06843021..16825e09de 100644 --- a/tensorflow/python/estimator/inputs/pandas_io.py +++ b/tensorflow/python/estimator/inputs/pandas_io.py @@ -75,8 +75,9 @@ def pandas_input_fn(x, 'pandas_input_fn should not be called without pandas installed') if not isinstance(shuffle, bool): - raise TypeError('shuffle must be explicitly set as boolean; ' - 'got {}'.format(shuffle)) + raise ValueError('shuffle must be provided and explicitly set as boolean ' + '(it is recommended to set it as True for training); ' + 'got {}'.format(shuffle)) x = x.copy() if y is not None: -- GitLab From 459d61cbe8ab9cbb86b2bb7eac602ff565d54fde Mon Sep 17 00:00:00 2001 From: Jie Date: Thu, 19 Apr 2018 13:48:14 -0700 Subject: [PATCH 012/738] [PR comment addressed] switched from std::string to TF string custom_plugin_examples python test added (bazel) style guide violation addressed --- .../contrib/tensorrt/convert/convert_nodes.cc | 22 ++--- .../tensorrt/custom_plugin_examples/BUILD | 42 ++++++--- .../custom_plugin_examples/__init__.py | 12 +-- .../inc_op_kernel.cu.cc | 2 - .../custom_plugin_examples/inc_op_kernel.h | 3 +- .../{inc_op_plugin.cc => inc_op_plugin.cu.cc} | 9 +- .../custom_plugin_examples/inc_op_plugin.h | 18 ++-- .../custom_plugin_examples/ops/inc_op.cc | 4 +- .../{test => }/plugin_test.py | 46 +++++----- tensorflow/contrib/tensorrt/log/trt_logger.h | 2 +- .../contrib/tensorrt/plugin/trt_plugin.cc | 3 +- .../contrib/tensorrt/plugin/trt_plugin.h | 14 +-- .../tensorrt/plugin/trt_plugin_factory.cc | 7 +- .../tensorrt/plugin/trt_plugin_factory.h | 8 +- .../tensorrt/plugin/trt_plugin_utils.cc | 2 +- .../tensorrt/plugin/trt_plugins_test.cc | 19 ++-- tensorflow/contrib/tensorrt/plugin_test.py | 88 +++++++++++++++++++ .../tensorrt/resources/trt_resources.h | 2 +- 18 files changed, 205 insertions(+), 98 deletions(-) rename tensorflow/contrib/tensorrt/custom_plugin_examples/{inc_op_plugin.cc => inc_op_plugin.cu.cc} (91%) rename tensorflow/contrib/tensorrt/custom_plugin_examples/{test => }/plugin_test.py (67%) create mode 100644 tensorflow/contrib/tensorrt/plugin_test.py diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 874be96c78..c8a96e5dba 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -241,9 +241,9 @@ class TFAttrs { return attrs_.at(key); } template - T get(string key) const; + T get(const string& key) const; template - T get(string key, const T& default_value) const { + T get(const string& key, const T& default_value) const { return attrs_.count(key) ? this->get(key) : default_value; } @@ -261,29 +261,29 @@ class TFAttrs { }; template <> -string TFAttrs::get(string key) const { +string TFAttrs::get(const string& key) const { return this->at(key)->s(); } template <> -std::vector TFAttrs::get>(string key) const { +std::vector TFAttrs::get>(const string& key) const { auto attr = this->at(key)->list().i(); return std::vector(attr.begin(), attr.end()); } template <> -std::vector TFAttrs::get>(string key) const { +std::vector TFAttrs::get>(const string& key) const { auto attr = this->at(key)->list().f(); return std::vector(attr.begin(), attr.end()); } template <> -std::vector TFAttrs::get>(string key) const { +std::vector TFAttrs::get>(const string& key) const { auto attr = this->at(key)->list().s(); return std::vector(attr.begin(), attr.end()); } template <> -nvinfer1::Dims TFAttrs::get(string key) const { +nvinfer1::Dims TFAttrs::get(const string& key) const { auto values = this->get>(key); nvinfer1::Dims dims; dims.nbDims = values.size(); @@ -293,24 +293,24 @@ nvinfer1::Dims TFAttrs::get(string key) const { } template <> -nvinfer1::DataType TFAttrs::get(string key) const { +nvinfer1::DataType TFAttrs::get(const string& key) const { nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype)); return trt_dtype; } template <> -tensorflow::DataType TFAttrs::get(string key) const { +tensorflow::DataType TFAttrs::get(const string& key) const { return this->at(key)->type(); } template <> -float TFAttrs::get(string key) const { +float TFAttrs::get(const string& key) const { return this->at(key)->f(); } template <> -bool TFAttrs::get(string key) const { +bool TFAttrs::get(const string& key) const { return this->at(key)->b(); } diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index 5603ed0ccf..3b1a7fb6f3 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -1,3 +1,9 @@ +# Description: +# Example for plugin support in TensorRT(http://developer.nvidia.com/tensorrt) +# through TensorFlow integration. Targeting TensorRT 3.0.4 +# APIs are meant to change while upgrading TRT. +# add init_py into pip package BUILD dependency to install it. + package(default_visibility = ["//tensorflow:__subpackages__"]) load( @@ -8,6 +14,7 @@ load( "tf_gen_op_wrapper_py", "tf_py_wrap_cc", "tf_copts", + "tf_py_test", ) load( "@local_config_tensorrt//:build_defs.bzl", @@ -18,19 +25,16 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library") tf_kernel_library( name = "_inc_op_plugin_kernel", - srcs = [ - "inc_op_plugin.cc", - ], - hdrs = [ - ], gpu_srcs = [ "inc_op_kernel.cu.cc", "inc_op_kernel.h", + "inc_op_plugin.cu.cc", "inc_op_plugin.h", ], - deps = if_tensorrt([ - "@local_config_tensorrt//:nv_infer", + deps = [ "//tensorflow/contrib/tensorrt:trt_plugins", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", ]), ) @@ -38,9 +42,10 @@ tf_gen_op_libs( op_lib_names = [ "inc_op", ], - deps = if_tensorrt([ - "@local_config_tensorrt//:nv_infer", + deps = [ "//tensorflow/contrib/tensorrt:trt_plugins", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", ]), ) @@ -70,9 +75,8 @@ tf_custom_op_library( srcs = ["ops/inc_op.cc"], deps = [ "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ "//tensorflow/contrib/tensorrt:trt_plugins", - ]), + ], ) tf_custom_op_py_library( @@ -97,6 +101,22 @@ py_library( ], ) +tf_py_test( + name = "plugin_test", + size = "small", + srcs = [ + "plugin_test.py", + ], + additional_deps = [ + ":init_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/contrib/tensorrt:init_py", + "//tensorflow/python:platform", + "//tensorflow/python:client_testlib", + "//tensorflow/python:tf_optimizer", + ], +) + py_library( name = "init_py", srcs = [ diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py index a61d008941..e4cd0ae8a0 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py @@ -14,11 +14,13 @@ # ============================================================================= """Import custom op for plugin and register it in plugin factory registry.""" -from ops import gen_inc_op -from plugin_wrap import inc_op_register -from inc_op import * +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tensorrt.custom_plugin_examples.ops import gen_inc_op +from tensorflow.contrib.tensorrt.custom_plugin_examples.plugin_wrap import inc_op_register +from tensorflow.contrib.tensorrt.custom_plugin_examples import inc_op as import_inc_op_so -# pylint: disable=unused-import,wildcard-import,g-import-not-at-top inc_op = gen_inc_op.inc_plugin_trt inc_op_register() -# pylint: enable=unused-import,wildcard-import,g-import-not-at-top diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index 5dd6b9bf94..38e1e01d95 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -14,10 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" -#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" #if GOOGLE_CUDA -#define EIGEN_USE_GPU #if GOOGLE_TENSORRT namespace tensorflow { diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h index ec269143e8..13156dad8f 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h @@ -17,13 +17,14 @@ limitations under the License. #define TENSORFLOW_CONTRIB_TENSORRT_INC_OP #if GOOGLE_CUDA -#define EIGEN_USE_GPU #if GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { __global__ void VecInc(float* vec, float inc, float* dest, int n); +void IncrementKernel(const float* d_input, float inc, float* d_output, + int count, cudaStream_t stream); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc similarity index 91% rename from tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc rename to tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc index 21617fa8b5..508ced587b 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" +#include +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA @@ -23,7 +24,7 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -const std::string IncOpPlugin::plugin_name_ = "IncPluginTRT"; +const string IncOpPlugin::plugin_name_ = "IncPluginTRT"; IncOpPlugin* CreateIncPlugin() { return new IncOpPlugin(); } @@ -47,7 +48,7 @@ IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) SetAttribute("inc", buffer + consumed_data, sizeof(float)); } -bool IncOpPlugin::SetAttribute(const std::string& key, const void* ptr, +bool IncOpPlugin::SetAttribute(const string& key, const void* ptr, const size_t size) { if (strcmp(key.c_str(), "inc") == 0 && size == sizeof(float)) { StoreAttribute(key, ptr, size); // save the attribute to own the data; @@ -57,7 +58,7 @@ bool IncOpPlugin::SetAttribute(const std::string& key, const void* ptr, return false; } -bool IncOpPlugin::GetAttribute(const std::string& key, const void** ptr, +bool IncOpPlugin::GetAttribute(const string& key, const void** ptr, size_t* size) const { const auto& iter = attr_map_.find(key); if (iter != attr_map_.end()) { diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index a4774d354c..87404a755c 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -18,10 +18,6 @@ limitations under the License. #include #include -#include -#include -#include -#include #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #if GOOGLE_CUDA @@ -33,14 +29,14 @@ namespace tensorrt { class IncOpPlugin : public PluginTensorRT { public: - static const std::string plugin_name_; - IncOpPlugin(){}; + static const string plugin_name_; + IncOpPlugin() {}; IncOpPlugin(const void* serialized_data, size_t length); - const std::string& GetPluginName() const override { return plugin_name_; }; + const string& GetPluginName() const override { return plugin_name_; }; bool Finalize() override { return true; }; - bool SetAttribute(const std::string& key, const void* ptr, + bool SetAttribute(const string& key, const void* ptr, const size_t size) override; - bool GetAttribute(const std::string& key, const void** ptr, + bool GetAttribute(const string& key, const void** ptr, size_t* size) const override; int getNbOutputs() const override { return 1; } @@ -56,7 +52,7 @@ class IncOpPlugin : public PluginTensorRT { void configure(const nvinfer1::Dims* inputs, int num_inputs, const nvinfer1::Dims* outputs, int num_outputs, int max_batch_size) override { - assert(nb_inputs == 1); + assert(num_inputs == 1); PluginTensorRT::configure(inputs, num_inputs, outputs, num_outputs, max_batch_size); } @@ -95,8 +91,6 @@ class IncOpPlugin : public PluginTensorRT { IncOpPlugin* CreateIncPlugin(); IncOpPlugin* CreateIncPluginDeserialize(const void*, size_t); bool RegisterIncOpPlugin(); -void IncrementKernel(const float* d_input, float inc, float* d_output, - int count, cudaStream_t stream); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc index 0dfead8f57..7466e59090 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc @@ -19,7 +19,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -using namespace tensorflow; +namespace tensorflow { REGISTER_OP("IncPluginTRT") .Attr("inc: list(float)") @@ -30,5 +30,7 @@ REGISTER_OP("IncPluginTRT") return Status::OK(); }); +} // namespace tensorflow + #endif // GOOGLE_CUDA #endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py similarity index 67% rename from tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py rename to tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py index 52f49ae00e..9f773c66a9 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py @@ -23,43 +23,44 @@ from __future__ import print_function # it looks like internal builds don't like it so # importing every module individually -from tensorflow.contrib import tensorrt as trt -from tensorflow.core.protobuf import config_pb2 as cpb2 -from tensorflow.python.client import session as csess -from tensorflow.python.framework import dtypes as dtypes -from tensorflow.python.framework import importer as importer -from tensorflow.python.framework import ops as ops -from tensorflow.python.ops import array_ops as aops -from tensorflow.python.ops import nn as nn -from tensorflow.python.ops import nn_ops as nn_ops -import numpy as np +from tensorflow.contrib import tensorrt +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.framework import errors +import numpy # import custom_op as plugin op -# the python api handles registration to the plugin factory -from tensorflow.contrib.tensorrt import custom_plugin_examples as cpe +# the python api handles registration to the plugin factory +from tensorflow.contrib.tensorrt import custom_plugin_examples def get_plugin_graph_def(): """Create a simple graph and return its graph_def.""" g = ops.Graph() with g.as_default(): - a = aops.placeholder( + a = array_ops.placeholder( dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") relu = nn.relu(a, "relu") v = nn_ops.max_pool( relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") # insert custom_op in the graph - v = cpe.inc_op(v, inc=[16.5], name="plugin_test") + v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test") v = v*2.0 v = nn.relu(v) v = nn.relu(v) - aops.squeeze(v, name="output") + array_ops.squeeze(v, name="output") return g.as_graph_def() def run_graph(gdef, dumm_inp): """Run given graphdef once.""" - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50) ops.reset_default_graph() g = ops.Graph() with g.as_default(): @@ -68,20 +69,20 @@ def run_graph(gdef, dumm_inp): inp = inp.outputs[0] out = out.outputs[0] - with csess.Session( - config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: + with session.Session( + config=config_pb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: val = sess.run(out, {inp: dumm_inp}) return val if "__main__" in __name__: inp_dims = (5, 24, 24, 2) - dummy_input = np.ones(inp_dims).astype(np.float32) + dummy_input = numpy.ones(inp_dims).astype(numpy.float32) orig_graph = get_plugin_graph_def() # graph with plugin node # trigger conversion. # plugin nodes have been registered during import, converter will be able to # create corresponding plugin layer during conversion. - trt_graph = trt.create_inference_graph( + trt_graph = tensorrt.create_inference_graph( input_graph_def=orig_graph, outputs=["output"], max_batch_size=inp_dims[0], @@ -90,4 +91,7 @@ if "__main__" in __name__: minimum_segment_size=2 ) o2 = run_graph(trt_graph, dummy_input) - print (o2) + if o2.reshape([-1])[0] == 35: + print("pass") + else: + raise RuntimeError("contrib/tensorrt/custom_plugin_examples wrong result") diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h index 7f3544f8cf..3495dc6318 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.h +++ b/tensorflow/contrib/tensorrt/log/trt_logger.h @@ -28,7 +28,7 @@ namespace tensorrt { // Logger for GIE info/warning/errors class Logger : public nvinfer1::ILogger { public: - Logger(string name = "DefaultLogger") : name_(name){}; + Logger(string name = "DefaultLogger") : name_(name) {}; void log(nvinfer1::ILogger::Severity severity, const char* msg) override; private: diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc index 82c549dbf5..062f86e8bb 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc @@ -25,7 +25,6 @@ namespace tensorflow { namespace tensorrt { PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) { - // sanity check. const char* buffer = static_cast(serialized_data); size_t op_name_char_count = *reinterpret_cast(buffer); buffer += sizeof(size_t); @@ -91,7 +90,7 @@ void PluginTensorRT::serialize(void* serialized_data) { } } -bool PluginTensorRT::StoreAttribute(const std::string& key, const void* ptr, +bool PluginTensorRT::StoreAttribute(const string& key, const void* ptr, const size_t size) { if (attr_map_.count(key) != 0) return false; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h index 772974a769..dca377c2d2 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN #include -#include #include #include +#include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -35,28 +35,28 @@ namespace tensorrt { // PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT class PluginTensorRT : public nvinfer1::IPlugin { public: - PluginTensorRT(){}; + PluginTensorRT() {}; PluginTensorRT(const void* serialized_data, size_t length); - virtual const std::string& GetPluginName() const = 0; + virtual const string& GetPluginName() const = 0; virtual bool Finalize() = 0; - virtual bool SetAttribute(const std::string& key, const void* ptr, + virtual bool SetAttribute(const string& key, const void* ptr, const size_t size) = 0; - virtual bool GetAttribute(const std::string& key, const void** ptr, + virtual bool GetAttribute(const string& key, const void** ptr, size_t* size) const = 0; void configure(const nvinfer1::Dims* inputs, int num_inputs, const nvinfer1::Dims* outputs, int num_outputs, int max_batch_size) override; - virtual bool StoreAttribute(const std::string& key, const void* ptr, + virtual bool StoreAttribute(const string& key, const void* ptr, const size_t size); virtual size_t getSerializationSize() override; virtual void serialize(void* buffer) override; protected: - std::unordered_map > attr_map_; + std::unordered_map > attr_map_; std::vector input_dim_list_; }; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index 776bce119d..736a1321fe 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -26,7 +26,7 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, size_t serial_length) { size_t parsed_byte = 0; // extract op_name from serial_data - std::string encoded_op_name = + string encoded_op_name = ExtractOpName(serial_data, serial_length, &parsed_byte); if (!IsPlugin(encoded_op_name)) { @@ -41,8 +41,7 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, return plugin_ptr; } -PluginTensorRT* PluginFactoryTensorRT::CreatePlugin( - const std::string& op_name) { +PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string& op_name) { if (!IsPlugin(op_name)) return nullptr; std::lock_guard lock(instance_m_); @@ -53,7 +52,7 @@ PluginTensorRT* PluginFactoryTensorRT::CreatePlugin( } bool PluginFactoryTensorRT::RegisterPlugin( - const std::string& op_name, PluginDeserializeFunc deserialize_func, + const string& op_name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func) { if (IsPlugin(op_name)) return false; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 08fd376844..4e4a3af4ca 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -36,7 +36,7 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { size_t serial_length) override; // plugin construction, PluginFactoryTensorRT owns the plugin; - PluginTensorRT* CreatePlugin(const std::string& op_name); + PluginTensorRT* CreatePlugin(const string& op_name); static PluginFactoryTensorRT* GetInstance() { static PluginFactoryTensorRT* factory_instance = @@ -44,11 +44,11 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { return factory_instance; } - bool RegisterPlugin(const std::string& op_name, + bool RegisterPlugin(const string& op_name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func); - bool IsPlugin(const std::string& op_name) { + bool IsPlugin(const string& op_name) { return plugin_registry_.find(op_name) != plugin_registry_.end(); } @@ -57,7 +57,7 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { void DestroyPlugins(); protected: - std::unordered_map > plugin_registry_; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc index c5d3f38280..a8f60886c0 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc @@ -30,7 +30,7 @@ string ExtractOpName(const void* serial_data, size_t serial_length, assert(serial_length >= *incremental); const char* buffer = static_cast(serial_data) + sizeof(size_t); - std::string op_name(buffer, op_name_char_count); + string op_name(buffer, op_name_char_count); return op_name; } diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc index 9ef0fce972..b834c5511f 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/test.h" @@ -31,18 +30,17 @@ namespace test { class StubPlugin : public PluginTensorRT { public: - static const std::string plugin_name_; - StubPlugin(){}; + static const string plugin_name_; + StubPlugin() {}; StubPlugin(const void* serialized_data, size_t length) - : PluginTensorRT(serialized_data, length){}; - const std::string& GetPluginName() override { return plugin_name_; }; + : PluginTensorRT(serialized_data, length) {}; + const string& GetPluginName() override { return plugin_name_; }; virtual bool Finalize() { return true; }; - virtual bool SetAttribute(const std::string& key, const void* ptr, + virtual bool SetAttribute(const string& key, const void* ptr, const size_t size) { return true; }; - virtual bool GetAttribute(const std::string& key, const void* ptr, - size_t& size) { + virtual bool GetAttribute(const string& key, const void* ptr, size_t& size) { return true; }; int getNbOutputs() const override { return 1; } @@ -59,7 +57,7 @@ class StubPlugin : public PluginTensorRT { } }; -const std::string StubPlugin::plugin_name_ = "StubPlugin"; +const string StubPlugin::plugin_name_ = "StubPlugin"; StubPlugin* CreateStubPlugin() { return new StubPlugin(); } @@ -72,8 +70,9 @@ class PluginTest : public ::testing::Test { public: bool RegisterStubPlugin() { if (PluginFactoryTensorRT::GetInstance()->IsPlugin( - StubPlugin::plugin_name_)) + StubPlugin::plugin_name_)) { return true; + } return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( StubPlugin::plugin_name_, CreateStubPluginDeserialize, CreateStubPlugin); diff --git a/tensorflow/contrib/tensorrt/plugin_test.py b/tensorflow/contrib/tensorrt/plugin_test.py new file mode 100644 index 0000000000..7c3e765bff --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin_test.py @@ -0,0 +1,88 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to show usage of TensorRT custom op & plugin.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import tensorrt +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +import numpy as np + +# import custom_op as plugin op +# the python api handles registration to the plugin factory +from tensorflow.contrib.tensorrt import custom_plugin_examples + +def get_plugin_graph_def(): + """Create a simple graph and return its graph_def.""" + g = ops.Graph() + with g.as_default(): + a = array_ops.placeholder( + dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") + relu = nn.relu(a, "relu") + v = nn_ops.max_pool( + relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") + + # insert custom_op in the graph + v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test") + + v = v*2.0 + v = nn.relu(v) + v = nn.relu(v) + array_ops.squeeze(v, name="output") + return g.as_graph_def() + +def run_graph(gdef, dumm_inp): + """Run given graphdef once.""" + gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + ops.reset_default_graph() + g = ops.Graph() + with g.as_default(): + inp, out = importer.import_graph_def( + graph_def=gdef, return_elements=["input", "output"]) + inp = inp.outputs[0] + out = out.outputs[0] + + with session.Session( + config=config_pb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: + val = sess.run(out, {inp: dumm_inp}) + return val + +if "__main__" in __name__: + inp_dims = (5, 24, 24, 2) + dummy_input = np.ones(inp_dims).astype(np.float32) + orig_graph = get_plugin_graph_def() # graph with plugin node + + # trigger conversion. + # plugin nodes have been registered during import, converter will be able to + # create corresponding plugin layer during conversion. + trt_graph = tensorrt.create_inference_graph( + input_graph_def=orig_graph, + outputs=["output"], + max_batch_size=inp_dims[0], + max_workspace_size_bytes=1 << 25, + precision_mode="FP32", + minimum_segment_size=2 + ) + o2 = run_graph(trt_graph, dummy_input) + print (o2) diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index 3c85968ae7..5164247f93 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -82,7 +82,7 @@ class TRTWeightStore : public tensorflow::ResourceBase { class TRTEngineResource : public tensorflow::ResourceBase { public: - TRTEngineResource() : runtime_(nullptr), ctx_(nullptr){}; + TRTEngineResource() : runtime_(nullptr), ctx_(nullptr) {}; string DebugString() override { return string(""); } nvinfer1::IRuntime* runtime_; nvinfer1::IExecutionContext* ctx_; -- GitLab From 2ef955b6d354378a7ca19f1f3cafccfc17f79013 Mon Sep 17 00:00:00 2001 From: Haggai Date: Fri, 20 Apr 2018 18:57:12 -0700 Subject: [PATCH 013/738] Abort on invalid fft type or rank --- tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h index 4f6b363364..0bf693edd0 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -195,6 +195,9 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, device, static_cast(out), static_cast(operand), input_batch, fft_length0, fft_length1, fft_length2); break; + default: + // Unsupported FFT type + abort(); } } @@ -219,6 +222,9 @@ void EigenFftImpl(const EigenDevice& device, void* out, void* operand, input_batch, fft_length0, fft_length1, fft_length2); break; + default: + // Unsupported FFT rank + abort(); } } -- GitLab From 364f6eae07fa8f0e2f89a9f665d0af430ea96669 Mon Sep 17 00:00:00 2001 From: Filipe Filardi Date: Sat, 21 Apr 2018 14:45:30 -0300 Subject: [PATCH 014/738] Create pull_request_template.md --- pull_request_template.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 pull_request_template.md diff --git a/pull_request_template.md b/pull_request_template.md new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/pull_request_template.md @@ -0,0 +1 @@ + -- GitLab From ea3d7ab5455f54a67e24428f159e9170be408d71 Mon Sep 17 00:00:00 2001 From: Filipe Filardi Date: Sat, 21 Apr 2018 14:57:38 -0300 Subject: [PATCH 015/738] Create Pull Request Template --- PULL_REQUEST_TEMPLATE.md | 20 ++++++++++++++++++++ pull_request_template.md | 1 - 2 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 PULL_REQUEST_TEMPLATE.md delete mode 100644 pull_request_template.md diff --git a/PULL_REQUEST_TEMPLATE.md b/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000..075bbc9945 --- /dev/null +++ b/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,20 @@ + + +##### Pull Request Checklist + +- [ ] Read [contributing guideline](CONTRIBUTING.md). +- [ ] Read [code of conduct](CODE_OF_CONDUCT.md). +- [ ] Fill [Contributor License Agreement (CLA)](https://cla.developers.google.com/). +- [ ] Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). +- [ ] Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style) +- [ ] Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). + +##### Issue Fix + +- [ ] Yes +- [ ] No + +Fixed issue: + +##### Description + diff --git a/pull_request_template.md b/pull_request_template.md deleted file mode 100644 index 8b13789179..0000000000 --- a/pull_request_template.md +++ /dev/null @@ -1 +0,0 @@ - -- GitLab From 7f78414776718a350b1beb612dd8b1c26ff3f6a4 Mon Sep 17 00:00:00 2001 From: Filipe Filardi Date: Tue, 24 Apr 2018 22:52:29 -0300 Subject: [PATCH 016/738] Merge PR Template to Contributing - Remove pull request template. - Add check list in contributing as a kind of TL;DR for that file. --- CONTRIBUTING.md | 11 +++++++++++ PULL_REQUEST_TEMPLATE.md | 20 -------------------- 2 files changed, 11 insertions(+), 20 deletions(-) delete mode 100644 PULL_REQUEST_TEMPLATE.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3dad41a88c..2e9d8c65e2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,5 +1,16 @@ # Contributing guidelines +## Pull Request Checklist + +Before sending your pull requests, make sure you followed this list. + +- [ ] Read [contributing guidelines](CONTRIBUTING.md). +- [ ] Read [Code of Conduct](CODE_OF_CONDUCT.md). +- [ ] Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). +- [ ] Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). +- [ ] Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style). +- [ ] Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). + ## How to become a contributor and submit your own code ### Contributor License Agreements diff --git a/PULL_REQUEST_TEMPLATE.md b/PULL_REQUEST_TEMPLATE.md deleted file mode 100644 index 075bbc9945..0000000000 --- a/PULL_REQUEST_TEMPLATE.md +++ /dev/null @@ -1,20 +0,0 @@ - - -##### Pull Request Checklist - -- [ ] Read [contributing guideline](CONTRIBUTING.md). -- [ ] Read [code of conduct](CODE_OF_CONDUCT.md). -- [ ] Fill [Contributor License Agreement (CLA)](https://cla.developers.google.com/). -- [ ] Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). -- [ ] Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style) -- [ ] Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). - -##### Issue Fix - -- [ ] Yes -- [ ] No - -Fixed issue: - -##### Description - -- GitLab From 7f70c7a38fc2f4aaa9ceb52240c9112886adda5c Mon Sep 17 00:00:00 2001 From: Filipe Filardi Date: Tue, 24 Apr 2018 23:00:05 -0300 Subject: [PATCH 017/738] Make more like a table of contents --- CONTRIBUTING.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2e9d8c65e2..8669c25c45 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,12 +4,12 @@ Before sending your pull requests, make sure you followed this list. -- [ ] Read [contributing guidelines](CONTRIBUTING.md). -- [ ] Read [Code of Conduct](CODE_OF_CONDUCT.md). -- [ ] Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). -- [ ] Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). -- [ ] Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style). -- [ ] Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). +- Read [contributing guidelines](CONTRIBUTING.md). +- Read [Code of Conduct](CODE_OF_CONDUCT.md). +- Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). +- Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). +- Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style). +- Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). ## How to become a contributor and submit your own code -- GitLab From 090d21c18f16303e740136e8a4e0f62c63df4e10 Mon Sep 17 00:00:00 2001 From: wangsiyu Date: Thu, 3 May 2018 18:31:29 +0800 Subject: [PATCH 018/738] fix bug of declaring regularization loss mutiple times when reusing partitioned variables in tf.layers --- tensorflow/python/layers/base.py | 13 ++++++++++++- tensorflow/python/layers/base_test.py | 15 +++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 64db49c900..c050e6be04 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -233,7 +233,8 @@ class Layer(base_layer.Layer): getter=vs.get_variable) if regularizer: - if context.executing_eagerly() or variable not in existing_variables: + if context.executing_eagerly() or _should_add_regularizer( + variable, existing_variables): self._handle_weight_regularization(name, variable, regularizer) if init_graph is not None: @@ -354,3 +355,13 @@ def _add_elements_to_collection(elements, collection_list): if element not in collection_set: collection.append(element) +def _should_add_regularizer(variable, existing_variable_set): + result = True + if isinstance(variable, tf_variables.PartitionedVariable): + for var in variable._get_variable_list(): + if var in existing_variable_set: + result = False + break + else: + result = variable not in existing_variable_set + return result diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index f08b552840..361e3de7aa 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -30,6 +30,7 @@ from tensorflow.python.layers import core as core_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope @@ -95,6 +96,20 @@ class BaseLayerTest(test.TestCase): regularizer=regularizer) self.assertEqual(len(layer.losses), 1) + def testReusePartitionedVaraiblesAndRegularizers(self): + regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 + partitioner = partitioned_variables.fixed_size_partitioner(3) + for i in xrange(2): + with variable_scope.variable_scope(variable_scope.get_variable_scope(), + partitioner=partitioner, + reuse=False if i == 0 else True): + layer = base_layers.Layer(name='my_layer') + variable = layer.add_variable( + 'reg_part_var', [4, 4], + initializer=init_ops.zeros_initializer(), + regularizer=regularizer) + self.assertEqual(len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 3) + def testNoEagerActivityRegularizer(self): with context.eager_mode(): with self.assertRaisesRegexp(ValueError, 'activity_regularizer'): -- GitLab From a2d35bddc7a1ab58b859ef396501472d7986ff0f Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Thu, 3 May 2018 07:50:13 -0700 Subject: [PATCH 019/738] Fix build dependency; add missing OpKernel; fix some formatting issues --- .../tensorrt/custom_plugin_examples/BUILD | 105 ++++++++---------- .../tensorrt/custom_plugin_examples/inc_op.py | 5 +- .../inc_op_kernel.cu.cc | 42 +++++++ .../custom_plugin_examples/inc_op_kernel.h | 2 +- .../{inc_op_plugin.cu.cc => inc_op_plugin.cc} | 13 ++- .../custom_plugin_examples/inc_op_plugin.h | 18 +-- .../custom_plugin_examples/plugin_test.py | 10 +- .../contrib/tensorrt/kernels/trt_engine_op.cc | 2 +- tensorflow/contrib/tensorrt/log/trt_logger.h | 2 +- .../tensorrt/plugin/trt_plugin_factory.cc | 6 + .../tensorrt/plugin/trt_plugin_factory.h | 12 +- .../tensorrt/plugin/trt_plugins_test.cc | 47 +++++--- 12 files changed, 162 insertions(+), 102 deletions(-) rename tensorflow/contrib/tensorrt/custom_plugin_examples/{inc_op_plugin.cu.cc => inc_op_plugin.cc} (90%) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index 3b1a7fb6f3..a45d4423bb 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -8,74 +8,39 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load( "//tensorflow:tensorflow.bzl", + "tf_copts", "tf_custom_op_library", - "tf_cuda_library", + "tf_custom_op_library_additional_deps", "tf_gen_op_libs", "tf_gen_op_wrapper_py", - "tf_py_wrap_cc", - "tf_copts", - "tf_py_test", + "tf_kernel_library", ) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load( "@local_config_tensorrt//:build_defs.bzl", "if_tensorrt", ) -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") - -tf_kernel_library( - name = "_inc_op_plugin_kernel", - gpu_srcs = [ - "inc_op_kernel.cu.cc", - "inc_op_kernel.h", - "inc_op_plugin.cu.cc", - "inc_op_plugin.h", - ], - deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) tf_gen_op_libs( op_lib_names = [ "inc_op", ], - deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), ) tf_gen_op_wrapper_py( name = "inc_op", - gen_locally = True, deps = [ ":inc_op_op_lib", ], ) -tf_py_wrap_cc( - name = "plugin_wrap", - srcs = [ - "plugin_wrap.i", - ], - copts = tf_copts(), - deps = [ - ":_inc_op_plugin_kernel", - "//tensorflow/core:framework_lite", - "//util/python:python_headers", - ], -) - tf_custom_op_library( name = "_inc_op.so", srcs = ["ops/inc_op.cc"], deps = [ "//tensorflow/core:lib_proto_parsing", - "//tensorflow/contrib/tensorrt:trt_plugins", ], ) @@ -85,6 +50,10 @@ tf_custom_op_py_library( dso = [ ":_inc_op.so", ], + kernels = [ + ":inc_op_op_lib", + ":inc_op_plugin_kernel", + ], srcs_version = "PY2AND3", deps = [ "//tensorflow/python:framework_for_generated_wrappers", @@ -101,30 +70,54 @@ py_library( ], ) -tf_py_test( - name = "plugin_test", - size = "small", - srcs = [ - "plugin_test.py", +tf_kernel_library( + name = "inc_op_plugin_kernel", + srcs = ["inc_op_plugin.cc"], + hdrs = [ + "inc_op_plugin.h", ], - additional_deps = [ - ":init_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/contrib/tensorrt:init_py", - "//tensorflow/python:platform", - "//tensorflow/python:client_testlib", - "//tensorflow/python:tf_optimizer", + gpu_srcs = [ + "inc_op_kernel.h", + "inc_op_kernel.cu.cc" + ], + deps = [ + "//tensorflow/contrib/tensorrt:trt_plugins", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]) + tf_custom_op_library_additional_deps(), +) + +tf_py_wrap_cc( + name = "plugin_wrap", + srcs = ["plugin_wrap.i"], + copts = tf_copts(), + deps = [ + ":inc_op_plugin_kernel", + "//tensorflow/core:framework_lite", + "//util/python:python_headers", ], ) py_library( name = "init_py", - srcs = [ - "__init__.py", - ], + srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ ":inc_op_py", ":plugin_wrap", ], ) + +tf_py_test( + name = "plugin_test", + size = "small", + srcs = ["plugin_test.py"], + additional_deps = [ + ":init_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/contrib/tensorrt:init_py", + "//tensorflow/python:platform", + "//tensorflow/python:client_testlib", + "//tensorflow/python:tf_optimizer", + ], +) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py index ef8e26fbde..47fd55e2f6 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py @@ -18,13 +18,14 @@ from __future__ import division from __future__ import print_function import platform -import os if platform.system() != "Windows": + # pylint: disable=g-import-not-at-top from tensorflow.contrib.util import loader from tensorflow.python.platform import resource_loader + # pylint: enable=g-import-not-at-top _inc_op = loader.load_op_library( - os.path.join(os.path.dirname(os.path.realpath(__file__)),"_inc_op.so")) + resource_loader.get_path_to_datafile("_inc_op.so")) else: raise RuntimeError("Windows not supported") diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index 38e1e01d95..ee9fbe0ea1 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -15,8 +15,14 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/stream_executor.h" + #if GOOGLE_CUDA #if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" namespace tensorflow { namespace tensorrt { @@ -35,6 +41,42 @@ void IncrementKernel(const float* d_input, float inc, float* d_output, d_output, count); } +// Note: this kernel definition is not needed in the plugin_test rule, but it is +// required for correctness of the TF program, i.e. if not using plugin or when +// run with trt optimization pass, the test should work. +class IncPluginTRT : public OpKernel { + public: + explicit IncPluginTRT(OpKernelConstruction* context) : OpKernel(context) { + std::vector inc_list; + OP_REQUIRES_OK(context, context->GetAttr("inc", &inc_list)); + OP_REQUIRES(context, inc_list.size() == 1, + errors::InvalidArgument( + "The increment list should contain single element.")); + inc_ = inc_list[0]; + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_tensor = context->input(0); + const TensorShape& input_shape = input_tensor.shape(); + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input_shape, &output_tensor)); + const cudaStream_t* stream = CHECK_NOTNULL( + reinterpret_cast(context->op_device_context() + ->stream() + ->implementation() + ->CudaStreamMemberHack())); + IncrementKernel(input_tensor.flat().data(), inc_, + output_tensor->flat().data(), + input_shape.num_elements(), *stream); + } + + private: + float inc_; +}; + +REGISTER_KERNEL_BUILDER(Name("IncPluginTRT").Device(DEVICE_GPU), IncPluginTRT); + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h index 13156dad8f..1d0ec0b6b0 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h @@ -18,11 +18,11 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" namespace tensorflow { namespace tensorrt { -__global__ void VecInc(float* vec, float inc, float* dest, int n); void IncrementKernel(const float* d_input, float inc, float* d_output, int count, cudaStream_t stream); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc similarity index 90% rename from tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc rename to tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc index 508ced587b..489bc15def 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" -#include #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA @@ -24,7 +23,7 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -const string IncOpPlugin::plugin_name_ = "IncPluginTRT"; +const char* kPluginName = "IncPluginTRT"; IncOpPlugin* CreateIncPlugin() { return new IncOpPlugin(); } @@ -33,14 +32,16 @@ IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { } bool RegisterIncOpPlugin() { - if (PluginFactoryTensorRT::GetInstance()->IsPlugin(IncOpPlugin::plugin_name_)) + if (PluginFactoryTensorRT::GetInstance()->IsPlugin(kPluginName)) return false; return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( - IncOpPlugin::plugin_name_, CreateIncPluginDeserialize, CreateIncPlugin); + kPluginName, CreateIncPluginDeserialize, CreateIncPlugin); } +IncOpPlugin::IncOpPlugin() : plugin_name_(kPluginName) {} + IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) - : PluginTensorRT(serialized_data, length) { + : PluginTensorRT(serialized_data, length), plugin_name_(kPluginName) { // account for the consumed pointer. size_t consumed_data = PluginTensorRT::getSerializationSize(); assert(length - consumed_data >= sizeof(float)); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index 87404a755c..0676abe768 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -29,13 +29,17 @@ namespace tensorrt { class IncOpPlugin : public PluginTensorRT { public: - static const string plugin_name_; - IncOpPlugin() {}; + IncOpPlugin(); + IncOpPlugin(const void* serialized_data, size_t length); + const string& GetPluginName() const override { return plugin_name_; }; + bool Finalize() override { return true; }; + bool SetAttribute(const string& key, const void* ptr, const size_t size) override; + bool GetAttribute(const string& key, const void** ptr, size_t* size) const override; @@ -71,14 +75,11 @@ class IncOpPlugin : public PluginTensorRT { } void serialize(void* buffer) override { - // serializa parent stuff - // OpName + // Serialize parent data. PluginTensorRT::serialize(buffer); - - // incremented buffer after parent serialization; + // Incremented buffer after parent serialization. buffer = static_cast(buffer) + PluginTensorRT::getSerializationSize(); - std::memcpy(buffer, &inc_, sizeof(float)); buffer = static_cast(buffer) + sizeof(float); } @@ -86,6 +87,9 @@ class IncOpPlugin : public PluginTensorRT { protected: float inc_; nvinfer1::Dims dim_; + + private: + const string plugin_name_; }; IncOpPlugin* CreateIncPlugin(); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py index 9f773c66a9..d1815fdf33 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py @@ -39,6 +39,7 @@ import numpy # the python api handles registration to the plugin factory from tensorflow.contrib.tensorrt import custom_plugin_examples + def get_plugin_graph_def(): """Create a simple graph and return its graph_def.""" g = ops.Graph() @@ -49,15 +50,16 @@ def get_plugin_graph_def(): v = nn_ops.max_pool( relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") - # insert custom_op in the graph + # insert custom_op in the graph v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test") - v = v*2.0 + v = v * 2.0 v = nn.relu(v) v = nn.relu(v) array_ops.squeeze(v, name="output") return g.as_graph_def() + def run_graph(gdef, dumm_inp): """Run given graphdef once.""" gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50) @@ -74,6 +76,7 @@ def run_graph(gdef, dumm_inp): val = sess.run(out, {inp: dumm_inp}) return val + if "__main__" in __name__: inp_dims = (5, 24, 24, 2) dummy_input = numpy.ones(inp_dims).astype(numpy.float32) @@ -88,8 +91,7 @@ if "__main__" in __name__: max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="FP32", - minimum_segment_size=2 - ) + minimum_segment_size=2) o2 = run_graph(trt_graph, dummy_input) if o2.reshape([-1])[0] == 35: print("pass") diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index c39bc12f73..71453631e2 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h index 3495dc6318..96ccacb791 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.h +++ b/tensorflow/contrib/tensorrt/log/trt_logger.h @@ -28,7 +28,7 @@ namespace tensorrt { // Logger for GIE info/warning/errors class Logger : public nvinfer1::ILogger { public: - Logger(string name = "DefaultLogger") : name_(name) {}; + Logger(string name = "DefaultLogger") : name_(name) {} void log(nvinfer1::ILogger::Severity severity, const char* msg) override; private: diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index 736a1321fe..b608e602a7 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -21,6 +21,12 @@ limitations under the License. namespace tensorflow { namespace tensorrt { +PluginFactoryTensorRT* PluginFactoryTensorRT::GetInstance() { + static PluginFactoryTensorRT* factory_instance = + new PluginFactoryTensorRT(); + return factory_instance; +} + PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, const void* serial_data, size_t serial_length) { diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 4e4a3af4ca..a088ffb842 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -31,19 +31,15 @@ namespace tensorrt { class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { public: - // deserialization method + static PluginFactoryTensorRT* GetInstance(); + + // Deserialization method PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data, size_t serial_length) override; - // plugin construction, PluginFactoryTensorRT owns the plugin; + // Plugin construction, PluginFactoryTensorRT owns the plugin. PluginTensorRT* CreatePlugin(const string& op_name); - static PluginFactoryTensorRT* GetInstance() { - static PluginFactoryTensorRT* factory_instance = - new PluginFactoryTensorRT(); - return factory_instance; - } - bool RegisterPlugin(const string& op_name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc index b834c5511f..ae5a3e8742 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/test.h" @@ -20,8 +22,6 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorrt/include/NvInfer.h" namespace tensorflow { @@ -30,34 +30,49 @@ namespace test { class StubPlugin : public PluginTensorRT { public: - static const string plugin_name_; - StubPlugin() {}; + static const char* kPluginName; + + StubPlugin() : plugin_name_(kPluginName) {} + StubPlugin(const void* serialized_data, size_t length) - : PluginTensorRT(serialized_data, length) {}; - const string& GetPluginName() override { return plugin_name_; }; - virtual bool Finalize() { return true; }; + : PluginTensorRT(serialized_data, length) {} + + const string& GetPluginName() override { return plugin_name_; } + + virtual bool Finalize() { return true; } + virtual bool SetAttribute(const string& key, const void* ptr, const size_t size) { return true; - }; + } + virtual bool GetAttribute(const string& key, const void* ptr, size_t& size) { return true; - }; + } + int getNbOutputs() const override { return 1; } + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override { return inputs[0]; } + int initialize() override { return 0; } + void terminate() override {} + size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } + int enqueue(int batch_size, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override { return 0; } + + private: + const string plugin_name_; }; -const string StubPlugin::plugin_name_ = "StubPlugin"; +const char* StubPlugin::kPluginName = "StubPlugin"; StubPlugin* CreateStubPlugin() { return new StubPlugin(); } @@ -70,32 +85,32 @@ class PluginTest : public ::testing::Test { public: bool RegisterStubPlugin() { if (PluginFactoryTensorRT::GetInstance()->IsPlugin( - StubPlugin::plugin_name_)) { + StubPlugin::kPluginName)) { return true; } return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( - StubPlugin::plugin_name_, CreateStubPluginDeserialize, + StubPlugin::kPluginName, CreateStubPluginDeserialize, CreateStubPlugin); } }; TEST_F(PluginTest, Registration) { EXPECT_FALSE( - PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); EXPECT_TRUE(RegisterStubPlugin()); ASSERT_TRUE( - PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); } TEST_F(PluginTest, CreationDeletion) { EXPECT_TRUE(RegisterStubPlugin()); ASSERT_TRUE( - PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); PluginFactoryTensorRT::GetInstance()->DestroyPlugins(); ASSERT_TRUE(PluginFactoryTensorRT::GetInstance()->CreatePlugin( - StubPlugin::plugin_name_)); + StubPlugin::kPluginName)); ASSERT_EQ(1, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins()); PluginFactoryTensorRT::GetInstance()->DestroyPlugins(); ASSERT_EQ(0, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins()); -- GitLab From 03de4a4a6cbfab49c2921d0cac5ccac31c0815f8 Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Thu, 3 May 2018 08:20:51 -0700 Subject: [PATCH 020/738] Move/rename the plugin factory test file; delete duplicate test file; fix minor formatting issues. --- tensorflow/contrib/tensorrt/BUILD | 10 ++- .../tensorrt/custom_plugin_examples/BUILD | 6 +- .../custom_plugin_examples/plugin_test.py | 5 -- .../contrib/tensorrt/plugin/trt_plugin.h | 6 +- .../tensorrt/plugin/trt_plugin_factory.h | 9 +- ...ins_test.cc => trt_plugin_factory_test.cc} | 6 +- .../tensorrt/plugin/trt_plugin_utils.h | 1 + tensorflow/contrib/tensorrt/plugin_test.py | 88 ------------------- 8 files changed, 26 insertions(+), 105 deletions(-) rename tensorflow/contrib/tensorrt/plugin/{trt_plugins_test.cc => trt_plugin_factory_test.cc} (96%) delete mode 100644 tensorflow/contrib/tensorrt/plugin_test.py diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 5fda11eccb..79e525edae 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -282,7 +282,7 @@ tf_cc_test( ], ) -# Library for the plugin factory +# Library for the plugin factory tf_cuda_library( name = "trt_plugins", srcs = [ @@ -304,9 +304,13 @@ tf_cuda_library( ) tf_cuda_cc_test( - name = "trt_plugins_test", + name = "trt_plugin_factory_test", size = "small", - srcs = ["plugin/trt_plugins_test.cc"], + srcs = ["plugin/trt_plugin_factory_test.cc"], + tags = [ + "manual", + "notap", + ], deps = [ ":trt_plugins", "//tensorflow/core:test", diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index a45d4423bb..c68e69457d 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -78,7 +78,7 @@ tf_kernel_library( ], gpu_srcs = [ "inc_op_kernel.h", - "inc_op_kernel.cu.cc" + "inc_op_kernel.cu.cc", ], deps = [ "//tensorflow/contrib/tensorrt:trt_plugins", @@ -120,4 +120,8 @@ tf_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:tf_optimizer", ], + tags = [ + "manual", + "notap", + ], ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py index d1815fdf33..cb40e08493 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py @@ -18,11 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# normally we should do import tensorflow as tf and then -# tf.placeholder, tf.constant, tf.nn.conv2d etc but -# it looks like internal builds don't like it so -# importing every module individually - from tensorflow.contrib import tensorrt from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h index dca377c2d2..d80ec44372 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include + #include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA @@ -35,9 +36,11 @@ namespace tensorrt { // PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT class PluginTensorRT : public nvinfer1::IPlugin { public: - PluginTensorRT() {}; + PluginTensorRT() {} PluginTensorRT(const void* serialized_data, size_t length); + virtual const string& GetPluginName() const = 0; + virtual bool Finalize() = 0; virtual bool SetAttribute(const string& key, const void* ptr, @@ -53,6 +56,7 @@ class PluginTensorRT : public nvinfer1::IPlugin { const size_t size); virtual size_t getSerializationSize() override; + virtual void serialize(void* buffer) override; protected: diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index a088ffb842..6d2992bbbb 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -19,8 +19,9 @@ limitations under the License. #include #include #include -#include "trt_plugin.h" -#include "trt_plugin_utils.h" + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -54,12 +55,12 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { protected: std::unordered_map > + std::pair> plugin_registry_; // TODO(jie): Owned plugin should be associated with different sessions; // should really hand ownership of plugins to resource management; - std::vector > owned_plugins_; + std::vector> owned_plugins_; std::mutex instance_m_; }; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc similarity index 96% rename from tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc rename to tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc index ae5a3e8742..c5b0e75eb1 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc @@ -81,7 +81,7 @@ StubPlugin* CreateStubPluginDeserialize(const void* serialized_data, return new StubPlugin(serialized_data, length); } -class PluginTest : public ::testing::Test { +class TrtPluginFactoryTest : public ::testing::Test { public: bool RegisterStubPlugin() { if (PluginFactoryTensorRT::GetInstance()->IsPlugin( @@ -94,7 +94,7 @@ class PluginTest : public ::testing::Test { } }; -TEST_F(PluginTest, Registration) { +TEST_F(TrtPluginFactoryTest, Registration) { EXPECT_FALSE( PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); EXPECT_TRUE(RegisterStubPlugin()); @@ -103,7 +103,7 @@ TEST_F(PluginTest, Registration) { PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); } -TEST_F(PluginTest, CreationDeletion) { +TEST_F(TrtPluginFactoryTest, CreationDeletion) { EXPECT_TRUE(RegisterStubPlugin()); ASSERT_TRUE( PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h index a94c67bba0..4ff6fbedb4 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS #include + #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/contrib/tensorrt/plugin_test.py b/tensorflow/contrib/tensorrt/plugin_test.py deleted file mode 100644 index 7c3e765bff..0000000000 --- a/tensorflow/contrib/tensorrt/plugin_test.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Script to show usage of TensorRT custom op & plugin.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib import tensorrt -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import importer -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops -import numpy as np - -# import custom_op as plugin op -# the python api handles registration to the plugin factory -from tensorflow.contrib.tensorrt import custom_plugin_examples - -def get_plugin_graph_def(): - """Create a simple graph and return its graph_def.""" - g = ops.Graph() - with g.as_default(): - a = array_ops.placeholder( - dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") - relu = nn.relu(a, "relu") - v = nn_ops.max_pool( - relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") - - # insert custom_op in the graph - v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test") - - v = v*2.0 - v = nn.relu(v) - v = nn.relu(v) - array_ops.squeeze(v, name="output") - return g.as_graph_def() - -def run_graph(gdef, dumm_inp): - """Run given graphdef once.""" - gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - ops.reset_default_graph() - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) - inp = inp.outputs[0] - out = out.outputs[0] - - with session.Session( - config=config_pb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: - val = sess.run(out, {inp: dumm_inp}) - return val - -if "__main__" in __name__: - inp_dims = (5, 24, 24, 2) - dummy_input = np.ones(inp_dims).astype(np.float32) - orig_graph = get_plugin_graph_def() # graph with plugin node - - # trigger conversion. - # plugin nodes have been registered during import, converter will be able to - # create corresponding plugin layer during conversion. - trt_graph = tensorrt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="FP32", - minimum_segment_size=2 - ) - o2 = run_graph(trt_graph, dummy_input) - print (o2) -- GitLab From eb88fd1ef5505e3f8617cc7105052fbce0e4af9e Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Thu, 3 May 2018 11:17:10 -0700 Subject: [PATCH 021/738] Add a macro for registering the plugin so we don't need to depend on swig; remove the swig file; fix build dependencies; fix tf_custom_op_library by adding GOOGLE_TENSORRT macro when gpu_srcs is not empty. --- .../tensorrt/custom_plugin_examples/BUILD | 69 +++++++++---------- .../custom_plugin_examples/__init__.py | 2 - .../inc_op_kernel.cu.cc | 1 + .../custom_plugin_examples/inc_op_plugin.cc | 7 +- .../custom_plugin_examples/inc_op_plugin.h | 5 +- .../custom_plugin_examples/plugin_wrap.i | 31 --------- .../tensorrt/plugin/trt_plugin_factory.h | 25 +++++++ tensorflow/tensorflow.bzl | 2 +- 8 files changed, 61 insertions(+), 81 deletions(-) delete mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index c68e69457d..e623b54781 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -24,24 +24,48 @@ load( ) tf_gen_op_libs( - op_lib_names = [ - "inc_op", - ], + op_lib_names = ["inc_op"], ) tf_gen_op_wrapper_py( name = "inc_op", - deps = [ - ":inc_op_op_lib", - ], + deps = [":inc_op_op_lib"], ) tf_custom_op_library( name = "_inc_op.so", - srcs = ["ops/inc_op.cc"], + srcs = [ + "inc_op_kernel.h", + "inc_op_plugin.cc", + "inc_op_plugin.h", + "ops/inc_op.cc", + ], + gpu_srcs = [ + "inc_op_kernel.h", + "inc_op_kernel.cu.cc", + ], deps = [ - "//tensorflow/core:lib_proto_parsing", + "//tensorflow/contrib/tensorrt:trt_plugins", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + +tf_kernel_library( + name = "inc_op_plugin_kernel", + srcs = ["inc_op_plugin.cc"], + hdrs = [ + "inc_op_plugin.h", + ], + gpu_srcs = [ + "inc_op_kernel.h", + "inc_op_kernel.cu.cc", ], + deps = [ + "//tensorflow/contrib/tensorrt:trt_plugins", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]) + tf_custom_op_library_additional_deps(), ) tf_custom_op_py_library( @@ -70,41 +94,12 @@ py_library( ], ) -tf_kernel_library( - name = "inc_op_plugin_kernel", - srcs = ["inc_op_plugin.cc"], - hdrs = [ - "inc_op_plugin.h", - ], - gpu_srcs = [ - "inc_op_kernel.h", - "inc_op_kernel.cu.cc", - ], - deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]) + tf_custom_op_library_additional_deps(), -) - -tf_py_wrap_cc( - name = "plugin_wrap", - srcs = ["plugin_wrap.i"], - copts = tf_copts(), - deps = [ - ":inc_op_plugin_kernel", - "//tensorflow/core:framework_lite", - "//util/python:python_headers", - ], -) - py_library( name = "init_py", srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ ":inc_op_py", - ":plugin_wrap", ], ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py index e4cd0ae8a0..e06904ab56 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py @@ -19,8 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.tensorrt.custom_plugin_examples.ops import gen_inc_op -from tensorflow.contrib.tensorrt.custom_plugin_examples.plugin_wrap import inc_op_register from tensorflow.contrib.tensorrt.custom_plugin_examples import inc_op as import_inc_op_so inc_op = gen_inc_op.inc_plugin_trt -inc_op_register() diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index ee9fbe0ea1..abbc0c5680 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -24,6 +24,7 @@ limitations under the License. #if GOOGLE_TENSORRT #include "cuda/include/cuda_runtime_api.h" + namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc index 489bc15def..d56aedc6d4 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -31,12 +31,7 @@ IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { return new IncOpPlugin(buffer, length); } -bool RegisterIncOpPlugin() { - if (PluginFactoryTensorRT::GetInstance()->IsPlugin(kPluginName)) - return false; - return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( - kPluginName, CreateIncPluginDeserialize, CreateIncPlugin); -} +REGISTER_TRT_PLUGIN(kPluginName, CreateIncPluginDeserialize, CreateIncPlugin); IncOpPlugin::IncOpPlugin() : plugin_name_(kPluginName) {} diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index 0676abe768..60153546d2 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -18,6 +18,7 @@ limitations under the License. #include #include + #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #if GOOGLE_CUDA @@ -92,10 +93,6 @@ class IncOpPlugin : public PluginTensorRT { const string plugin_name_; }; -IncOpPlugin* CreateIncPlugin(); -IncOpPlugin* CreateIncPluginDeserialize(const void*, size_t); -bool RegisterIncOpPlugin(); - } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i deleted file mode 100644 index 9882daa842..0000000000 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/* Wrap inc_op_plugin */ -%module inc_op_plugin -%{ -#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" -extern bool tensorflow::tensorrt::RegisterIncOpPlugin(); -%} - -%{ -bool inc_op_register() { - return tensorflow::tensorrt::RegisterIncOpPlugin(); -} -%} - -extern bool tensorflow::tensorrt::RegisterIncOpPlugin(); - -bool inc_op_register(); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 6d2992bbbb..54fbca5930 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -64,6 +66,29 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { std::mutex instance_m_; }; +class TrtPluginRegistrar { + public: + TrtPluginRegistrar(const string& name, + PluginDeserializeFunc deserialize_func, + PluginConstructFunc construct_func) { + auto factory = PluginFactoryTensorRT::GetInstance(); + QCHECK(factory->RegisterPlugin(name, deserialize_func, construct_func)) + << "Failed to register plugin: " << name; + } +}; + +#define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func) \ + REGISTER_TRT_PLUGIN_UNIQ_HELPER( \ + __COUNTER__, name, deserialize_func, construct_func) +#define REGISTER_TRT_PLUGIN_UNIQ_HELPER( \ + ctr, name, deserialize_func, construct_func) \ + REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) +#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \ + static ::tensorflow::tensorrt::TrtPluginRegistrar \ + trt_plugin_registrar##ctr TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::tensorrt::TrtPluginRegistrar( \ + name, deserialize_func, construct_func) + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index e5cc886b32..c27f894365 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1309,7 +1309,7 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]): native.cc_library( name=basename + "_gpu", srcs=gpu_srcs, - copts=_cuda_copts(), + copts=_cuda_copts() + if_tensorrt(["-DGOOGLE_TENSORRT=1"]), deps=deps + if_cuda(cuda_deps)) cuda_deps.extend([":" + basename + "_gpu"]) -- GitLab From e629595e8f629f2de7db225463136b0e331bd71c Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Thu, 3 May 2018 15:00:57 -0700 Subject: [PATCH 022/738] Simplify build dependencies; fix python import order; fix multiple singleton issues by inlining the singleton method. --- tensorflow/contrib/tensorrt/BUILD | 2 -- .../contrib/tensorrt/custom_plugin_examples/BUILD | 12 ++---------- .../tensorrt/custom_plugin_examples/plugin_test.py | 9 +++------ .../contrib/tensorrt/plugin/trt_plugin_factory.cc | 6 ------ .../contrib/tensorrt/plugin/trt_plugin_factory.h | 8 +++++++- 5 files changed, 12 insertions(+), 25 deletions(-) diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 79e525edae..5b56feed0f 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -259,7 +259,6 @@ cc_library( "segment/segment.h", "segment/union_find.h", ], - linkstatic = 1, deps = [ "//tensorflow/core:graph", "//tensorflow/core:lib_proto_parsing", @@ -295,7 +294,6 @@ tf_cuda_library( "plugin/trt_plugin_factory.h", "plugin/trt_plugin_utils.h", ], - linkstatic = 1, deps = [ "//tensorflow/core:platform_base", ] + if_tensorrt([ diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index e623b54781..6f81ac2b44 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -85,21 +85,13 @@ tf_custom_op_py_library( ], ) -py_library( - name = "inc_op_py", - srcs_version = "PY2AND3", - deps = [ - ":inc_op", - ":inc_op_loader", - ], -) - py_library( name = "init_py", srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ - ":inc_op_py", + ":inc_op", + ":inc_op_loader", ], ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py index cb40e08493..aedfb16211 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py @@ -18,7 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy + from tensorflow.contrib import tensorrt +from tensorflow.contrib.tensorrt import custom_plugin_examples from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import dtypes @@ -27,12 +30,6 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops -from tensorflow.python.framework import errors -import numpy - -# import custom_op as plugin op -# the python api handles registration to the plugin factory -from tensorflow.contrib.tensorrt import custom_plugin_examples def get_plugin_graph_def(): diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index b608e602a7..736a1321fe 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -21,12 +21,6 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -PluginFactoryTensorRT* PluginFactoryTensorRT::GetInstance() { - static PluginFactoryTensorRT* factory_instance = - new PluginFactoryTensorRT(); - return factory_instance; -} - PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, const void* serial_data, size_t serial_length) { diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 54fbca5930..0eee705fb9 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -34,7 +34,13 @@ namespace tensorrt { class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { public: - static PluginFactoryTensorRT* GetInstance(); + // TODO(aaroey): this static method has to be inlined to make the singleton a + // unique global symbol. Find a way to fix it. + static PluginFactoryTensorRT* GetInstance() { + static PluginFactoryTensorRT* factory_instance = + new PluginFactoryTensorRT(); + return factory_instance; + } // Deserialization method PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data, -- GitLab From aaa345f5a662aab524bbee3912c605919239bef6 Mon Sep 17 00:00:00 2001 From: wangsiyu Date: Fri, 4 May 2018 10:52:26 +0800 Subject: [PATCH 023/738] refine by using iterator of partitioned variable --- tensorflow/python/layers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index c050e6be04..f7b2e471b2 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -358,7 +358,7 @@ def _add_elements_to_collection(elements, collection_list): def _should_add_regularizer(variable, existing_variable_set): result = True if isinstance(variable, tf_variables.PartitionedVariable): - for var in variable._get_variable_list(): + for var in variable: if var in existing_variable_set: result = False break -- GitLab From a2cba4a627f880cf8160de624fc1ad947c01e973 Mon Sep 17 00:00:00 2001 From: mbhuiyan Date: Fri, 4 May 2018 12:02:28 -0700 Subject: [PATCH 024/738] if MKL is used allocation id is set to 9 and 10 --- .../direct_session_with_tracking_alloc_test.cc | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index 0ff022a8bc..29c8c8daec 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -101,18 +101,21 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { EXPECT_EQ(2, shape.dim_size()); EXPECT_EQ(2, shape.dim(0).size()); EXPECT_EQ(1, shape.dim(1).size()); -#ifndef INTEL_MKL +#ifdef INTEL_MKL // if MKL is used, it goes through various additional // graph rewrite pass. In TF, everytime a graph pass // happens, "constant" nodes are allocated // and deallocated. Each allocation calls the - // (FindChunkPtr of BFCAllocator) - // , which increments the value of AllocationId. + // (FindChunkPtr of BFCAllocator), + // which increments the value of AllocationId. // Thus AllocationId becomes more than 3 and 4 if - // MKL is used, they can be 10 and 11 or - // other numbers. If MKL is used - // following check will not hold. - // Thus, skipping the check if MKL is used. + // MKL is used. Now they are 9 and 10 for MKL. + if (node->name() == y->name()) { + EXPECT_EQ(9, cm->AllocationId(node, 0)); + } else { + EXPECT_EQ(10, cm->AllocationId(node, 0)); + } +#else if (node->name() == y->name()) { EXPECT_EQ(3, cm->AllocationId(node, 0)); } else { -- GitLab From 2314acf98fb874317dd17ef3daf438d7af87f900 Mon Sep 17 00:00:00 2001 From: Anya Petrova Date: Fri, 4 May 2018 13:22:03 -0700 Subject: [PATCH 025/738] Fix a small typo. --- SECURITY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SECURITY.md b/SECURITY.md index a5ce3a62ee..01886b613e 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -173,7 +173,7 @@ the progress being made towards a fix and announcement. In addition, please include the following information along with your report: * Your name and affiliation (if any). -* A description the technical details of the vulnerabilities. It is very +* A description of the technical details of the vulnerabilities. It is very important to let us know how we can reproduce your findings. * An explanation who can exploit this vulnerability, and what they gain when doing so -- write an attack scenario. This will help us evaluate your report -- GitLab From f368558429f5ebdbc0a187c3801dccf1ca6963c7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 11:40:01 -0700 Subject: [PATCH 026/738] [XLA] Cleanup client_library_test_base: move definition of CreateParameterAndTransferLiteral to .cc file PiperOrigin-RevId: 195446864 --- .../xla/tests/client_library_test_base.cc | 29 +++++++++++++++++++ .../xla/tests/client_library_test_base.h | 29 ------------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index c09e7eaf2b..41f9a5f666 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -565,4 +565,33 @@ XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); } +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, + const Literal& literal, + const string& name, + XlaBuilder* builder, + XlaOp* data_handle) { + return CreateParameterAndTransferLiteral(parameter_number, literal, name, + nullptr, builder, data_handle); +} + +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, XlaBuilder* builder, + XlaOp* data_handle) { + const Literal* param_literal = &literal; + std::unique_ptr converted_literal; + if (use_bfloat16_) { + converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); + param_literal = converted_literal.get(); + } + std::unique_ptr data = + client_->TransferToServer(*param_literal, device_handle) + .ConsumeValueOrDie(); + *data_handle = + builder->Parameter(parameter_number, param_literal->shape(), name); + return data; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index e58979a303..16e838e60f 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -616,35 +616,6 @@ std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( return result; } -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, - const Literal& literal, - const string& name, - XlaBuilder* builder, - XlaOp* data_handle) { - return CreateParameterAndTransferLiteral(parameter_number, literal, name, - nullptr, builder, data_handle); -} - -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, XlaBuilder* builder, - XlaOp* data_handle) { - const Literal* param_literal = &literal; - std::unique_ptr converted_literal; - if (use_bfloat16_) { - converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); - param_literal = converted_literal.get(); - } - std::unique_ptr data = - client_->TransferToServer(*param_literal, device_handle) - .ConsumeValueOrDie(); - *data_handle = - builder->Parameter(parameter_number, param_literal->shape(), name); - return data; -} - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ -- GitLab From be9b87375adecad9bd8bb12c81b2566c77a68ad7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 11:40:20 -0700 Subject: [PATCH 027/738] [XLA] Redesign: migrate the SWIG wrapped xla client. Added LocalOp that wraps XlaOp, so that it's fully visible to swig. PiperOrigin-RevId: 195446939 --- tensorflow/compiler/xla/python/BUILD | 3 +- .../xla/python/local_computation_builder.cc | 315 ++++++++------- .../xla/python/local_computation_builder.h | 206 +++++----- .../xla/python/local_computation_builder.i | 53 +-- tensorflow/compiler/xla/python/xla_client.py | 362 +++++++----------- 5 files changed, 415 insertions(+), 524 deletions(-) diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index ecb87bd889..932cce943f 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -49,9 +49,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 044458164f..df262c97bf 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/python/local_computation_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/default/thread_annotations.h" @@ -248,7 +249,7 @@ LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( return new LocalShapedBuffer(std::move(result_buffer)); } -LocalComputation::LocalComputation(Computation computation) +LocalComputation::LocalComputation(XlaComputation computation) : computation_(std::move(computation)) {} StatusOr LocalComputation::Compile( @@ -271,7 +272,7 @@ StatusOr LocalComputation::Compile( return new CompiledLocalComputation(std::move(local_executable)); } -const Computation& LocalComputation::computation() const { +const XlaComputation& LocalComputation::computation() const { return computation_; } @@ -281,8 +282,12 @@ StatusOr LocalComputation::GetReturnValueShape() const { return std::move(*program_shape.mutable_result()); } +LocalOp::LocalOp(const XlaOp& op) : op_(op) {} + +const XlaOp& LocalOp::op() const { return op_; } + LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) - : builder_(GetOrCreateLocalClient(), computation_name) {} + : builder_(computation_name) {} void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { builder_.SetOpMetadata(metadata); @@ -291,19 +296,21 @@ void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } StatusOr LocalComputationBuilder::Build() { - TF_ASSIGN_OR_RETURN(Computation computation, builder_.Build()); + TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build()); return new LocalComputation(std::move(computation)); } -ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { +LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, + const string& name) { return builder_.Parameter(parameter_number, shape, name); } std::unique_ptr LocalComputationBuilder::GetShape( - const ComputationDataHandle& operand) { - return builder_.GetShape(operand).ConsumeValueOrDie(); + const LocalOp& operand) { + auto result = MakeUnique(); + *result = builder_.GetShape(operand.op()).ValueOrDie(); + return result; } StatusOr LocalComputationBuilder::GetReturnValueShape() { @@ -311,222 +318,236 @@ StatusOr LocalComputationBuilder::GetReturnValueShape() { return program_shape.result(); } -ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) { +LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { return builder_.Infeed(shape); } -void LocalComputationBuilder::Outfeed(const ComputationDataHandle& operand, +void LocalComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config) { - builder_.Outfeed(operand, shape, outfeed_config); + builder_.Outfeed(operand.op(), shape, outfeed_config); } -ComputationDataHandle LocalComputationBuilder::ConstantLiteral( - const Literal& literal) { +LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { return builder_.ConstantLiteral(literal); } -ComputationDataHandle LocalComputationBuilder::Broadcast( - const ComputationDataHandle& operand, +LocalOp LocalComputationBuilder::Broadcast( + const LocalOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - return builder_.Broadcast(operand, broadcast_sizes); + return builder_.Broadcast(operand.op(), broadcast_sizes); } -ComputationDataHandle LocalComputationBuilder::Pad( - const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config) { - return builder_.Pad(operand, padding_value, padding_config); +LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, + const LocalOp& padding_value, + const PaddingConfig& padding_config) { + return builder_.Pad(operand.op(), padding_value.op(), padding_config); } -ComputationDataHandle LocalComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, +LocalOp LocalComputationBuilder::Reshape( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - return builder_.Reshape(operand, dimensions, new_sizes); + return builder_.Reshape(operand.op(), dimensions, new_sizes); } -ComputationDataHandle LocalComputationBuilder::Collapse( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - return builder_.Collapse(operand, dimensions); +LocalOp LocalComputationBuilder::Collapse( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return builder_.Collapse(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::CrossReplicaSum( - const ComputationDataHandle& operand) { - return builder_.CrossReplicaSum(operand); +LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { + return builder_.CrossReplicaSum(operand.op()); } -ComputationDataHandle LocalComputationBuilder::Slice( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, +LocalOp LocalComputationBuilder::Slice( + const LocalOp& operand, tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - return builder_.Slice(operand, start_indices, limit_indices, strides); + return builder_.Slice(operand.op(), start_indices, limit_indices, strides); } -ComputationDataHandle LocalComputationBuilder::SliceInDim( - const ComputationDataHandle& operand, int64 start_index, int64 limit_index, - int64 stride, int64 dimno) { - return builder_.SliceInDim(operand, start_index, limit_index, stride, dimno); +LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, + int64 start_index, + int64 limit_index, int64 stride, + int64 dimno) { + return builder_.SliceInDim(operand.op(), start_index, limit_index, stride, + dimno); } -ComputationDataHandle LocalComputationBuilder::DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, +LocalOp LocalComputationBuilder::DynamicSlice( + const LocalOp& operand, const LocalOp& start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - return builder_.DynamicSlice(operand, start_indices, slice_sizes); + return builder_.DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } -ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices) { - return builder_.DynamicUpdateSlice(operand, update, start_indices); +LocalOp LocalComputationBuilder::DynamicUpdateSlice( + const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices) { + return builder_.DynamicUpdateSlice(operand.op(), update.op(), + start_indices.op()); } -ComputationDataHandle LocalComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension) { - return builder_.ConcatInDim(operands, dimension); +LocalOp LocalComputationBuilder::ConcatInDim( + tensorflow::gtl::ArraySlice operands, int64 dimension) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.ConcatInDim(xla_ops, dimension); } -ComputationDataHandle -LocalComputationBuilder::SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, +LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter) { + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter) { return builder_.SelectAndScatterWithGeneralPadding( - operand, select.computation(), window_dimensions, window_strides, padding, - source, init_value, scatter.computation()); + operand.op(), select.computation(), window_dimensions, window_strides, + padding, source.op(), init_value.op(), scatter.computation()); } -ComputationDataHandle LocalComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { - return builder_.Tuple(elements); +LocalOp LocalComputationBuilder::Tuple( + tensorflow::gtl::ArraySlice elements) { + std::vector xla_ops; + xla_ops.reserve(elements.size()); + for (const auto& op : elements) { + xla_ops.push_back(op.op()); + } + + return builder_.Tuple(xla_ops); } -ComputationDataHandle LocalComputationBuilder::GetTupleElement( - const ComputationDataHandle& tuple_data, int64 index) { - return builder_.GetTupleElement(tuple_data, index); +LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, + int64 index) { + return builder_.GetTupleElement(tuple_data.op(), index); } -ComputationDataHandle LocalComputationBuilder::Dot( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - return builder_.Dot(lhs, rhs); +LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { + return builder_.Dot(lhs.op(), rhs.op()); } -ComputationDataHandle LocalComputationBuilder::DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::DotGeneral( + const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return builder_.DotGeneral(lhs, rhs, dimension_numbers); + return builder_.DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, + return builder_.ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, + padding, lhs_dilation, rhs_dilation, dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvertElementType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - return builder_.ConvertElementType(operand, new_element_type); +LocalOp LocalComputationBuilder::ConvertElementType( + const LocalOp& operand, PrimitiveType new_element_type) { + return builder_.ConvertElementType(operand.op(), new_element_type); } -ComputationDataHandle LocalComputationBuilder::Call( +LocalOp LocalComputationBuilder::Call( const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands) { - return builder_.Call(local_computation.computation(), operands); + tensorflow::gtl::ArraySlice operands) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.Call(local_computation.computation(), xla_ops); } -ComputationDataHandle LocalComputationBuilder::Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation) { - return builder_.Transpose(operand, permutation); +LocalOp LocalComputationBuilder::Transpose( + const LocalOp& operand, tensorflow::gtl::ArraySlice permutation) { + return builder_.Transpose(operand.op(), permutation); } -ComputationDataHandle LocalComputationBuilder::Rev( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - return builder_.Rev(operand, dimensions); +LocalOp LocalComputationBuilder::Rev( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return builder_.Rev(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, +LocalOp LocalComputationBuilder::Map( + tensorflow::gtl::ArraySlice operands, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { - return builder_.Map(operands, local_computation.computation(), dimensions, - static_operands); + tensorflow::gtl::ArraySlice static_operands) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + + std::vector static_xla_ops; + static_xla_ops.reserve(static_operands.size()); + for (const auto& op : static_operands) { + static_xla_ops.push_back(op.op()); + } + + return builder_.Map(xla_ops, local_computation.computation(), dimensions, + static_xla_ops); } -ComputationDataHandle LocalComputationBuilder::Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::Reduce( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions_to_reduce) { - return builder_.Reduce(operand, init_value, local_computation.computation(), - dimensions_to_reduce); + return builder_.Reduce(operand.op(), init_value.op(), + local_computation.computation(), dimensions_to_reduce); } -ComputationDataHandle LocalComputationBuilder::ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { return builder_.ReduceWindowWithGeneralPadding( - operand, init_value, local_computation.computation(), window_dimensions, - window_strides, padding); + operand.op(), init_value.op(), local_computation.computation(), + window_dimensions, window_strides, padding); } -ComputationDataHandle LocalComputationBuilder::RngNormal( - const ComputationDataHandle& mu, const ComputationDataHandle& sigma, - const Shape& shape) { - return builder_.RngNormal(mu, sigma, shape); +LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, + const LocalOp& sigma, + const Shape& shape) { + return builder_.RngNormal(mu.op(), sigma.op(), shape); } -ComputationDataHandle LocalComputationBuilder::RngUniform( - const ComputationDataHandle& a, const ComputationDataHandle& b, - const Shape& shape) { - return builder_.RngUniform(a, b, shape); +LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, + const Shape& shape) { + return builder_.RngUniform(a.op(), b.op(), shape); } -ComputationDataHandle LocalComputationBuilder::While( - const LocalComputation& condition, const LocalComputation& body, - const ComputationDataHandle& init) { - return builder_.While(condition.computation(), body.computation(), init); +LocalOp LocalComputationBuilder::While(const LocalComputation& condition, + const LocalComputation& body, + const LocalOp& init) { + return builder_.While(condition.computation(), body.computation(), init.op()); } -ComputationDataHandle LocalComputationBuilder::Conditional( - const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, +LocalOp LocalComputationBuilder::Conditional( + const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, const LocalOp& false_operand, const LocalComputation& false_computation) { - return builder_.Conditional(predicate, true_operand, - true_computation.computation(), false_operand, - false_computation.computation()); + return builder_.Conditional( + predicate.op(), true_operand.op(), true_computation.computation(), + false_operand.op(), false_computation.computation()); } -StatusOr LocalComputationBuilder::IsConstant( - const ComputationDataHandle& operand, int64 num_parameters) { - return builder_.IsConstant(operand, num_parameters); +StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { + return builder_.IsConstant(operand.op()); } -StatusOr> LocalComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters) { - return builder_.ComputeConstant(operand, output_layout, parameters); +StatusOr LocalComputationBuilder::BuildConstantSubGraph( + const LocalOp& operand) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, + builder_.BuildConstantSubGraph(operand.op())); + return new LocalComputation(std::move(computation)); } #define _FORWARD(method_name, return_sig, args_sig, args) \ @@ -534,23 +555,19 @@ StatusOr> LocalComputationBuilder::ComputeConstant( return builder_.method_name args; \ } -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand), (operand)) - -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions), \ - (lhs, rhs, broadcast_dimensions)) - -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs), \ - (lhs, rhs, ehs)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op())) + +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions), \ + (lhs.op(), rhs.op(), broadcast_dimensions)) + +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs), \ + (lhs.op(), rhs.op(), ehs.op())) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 5ec097846a..a06b85b4ea 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -17,9 +17,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -97,25 +98,37 @@ class CompiledLocalComputation { std::unique_ptr executable_; }; -// Wraps a Computation produced by a LocalComputationBuilder. The +// Wraps a XlaComputation produced by a LocalComputationBuilder. The // Compile method compiles the computation to a (local) executable via // the client library's local client. This class is intended to be // made available to Python via SWIG. class LocalComputation { public: - LocalComputation(Computation computation); + LocalComputation(XlaComputation computation); StatusOr Compile( const std::vector& argument_shapes, const ExecutableBuildOptions* build_options); - const Computation& computation() const; + const XlaComputation& computation() const; // Returns the return-value shape for this computation. StatusOr GetReturnValueShape() const; private: - Computation computation_; + XlaComputation computation_; +}; + +// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended +// to be made available to Python via SWIG. +class LocalOp { + public: + LocalOp(const XlaOp& op); + + const XlaOp& op() const; + + private: + XlaOp op_; }; // Wraps the ComputationBuilder API in order to: @@ -135,166 +148,137 @@ class LocalComputationBuilder { // Returns an owned LocalComputation to the caller on success. StatusOr Build(); - ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, - const string& name); + LocalOp Parameter(int64 parameter_number, const Shape& shape, + const string& name); - std::unique_ptr GetShape(const ComputationDataHandle& operand); + std::unique_ptr GetShape(const LocalOp& operand); // Returns the shape of the current return value for the computation. StatusOr GetReturnValueShape(); - ComputationDataHandle Infeed(const Shape& shape); + LocalOp Infeed(const Shape& shape); - void Outfeed(const ComputationDataHandle& operand, const Shape& shape, + void Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config); - ComputationDataHandle ConstantLiteral(const Literal& literal); + LocalOp ConstantLiteral(const Literal& literal); - ComputationDataHandle Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + LocalOp Broadcast(const LocalOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); - ComputationDataHandle Pad(const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config); + LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, + const PaddingConfig& padding_config); - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + LocalOp Reshape(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); - ComputationDataHandle Collapse(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Collapse(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions); - ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); + LocalOp CrossReplicaSum(const LocalOp& operand); - ComputationDataHandle Slice(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + LocalOp Slice(const LocalOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); - ComputationDataHandle SliceInDim(const ComputationDataHandle& operand, - int64 start_index, int64 limit_index, - int64 stride, int64 dimno); + LocalOp SliceInDim(const LocalOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno); - ComputationDataHandle DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); - ComputationDataHandle DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices); + LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices); - ComputationDataHandle ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension); + LocalOp ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension); - ComputationDataHandle SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, + LocalOp SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter); + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter); - ComputationDataHandle Tuple( - tensorflow::gtl::ArraySlice elements); + LocalOp Tuple(tensorflow::gtl::ArraySlice elements); - ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, - int64 index); + LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index); - ComputationDataHandle Dot(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); + LocalOp Dot(const LocalOp& lhs, const LocalOp& rhs); - ComputationDataHandle DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers); + LocalOp DotGeneral(const LocalOp& lhs, const LocalOp& rhs, + const DotDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + LocalOp ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); + LocalOp ConvertElementType(const LocalOp& operand, + PrimitiveType new_element_type); - ComputationDataHandle Call( - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands); + LocalOp Call(const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice operands); - ComputationDataHandle Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation); + LocalOp Transpose(const LocalOp& operand, + tensorflow::gtl::ArraySlice permutation); - ComputationDataHandle Rev(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Rev(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions); - ComputationDataHandle Map( - tensorflow::gtl::ArraySlice operands, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands); + LocalOp Map(tensorflow::gtl::ArraySlice operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands); - ComputationDataHandle Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); - ComputationDataHandle ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, + LocalOp ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding); - ComputationDataHandle RngNormal(const ComputationDataHandle& mu, - const ComputationDataHandle& sigma, - const Shape& shape); + LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, + const Shape& shape); - ComputationDataHandle RngUniform(const ComputationDataHandle& a, - const ComputationDataHandle& b, - const Shape& shape); + LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); - ComputationDataHandle While(const LocalComputation& condition, - const LocalComputation& body, - const ComputationDataHandle& init); + LocalOp While(const LocalComputation& condition, const LocalComputation& body, + const LocalOp& init); - ComputationDataHandle Conditional(const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, - const LocalComputation& false_computation); + LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, + const LocalOp& false_operand, + const LocalComputation& false_computation); - StatusOr IsConstant(const ComputationDataHandle& operand, - int64 num_parameters); + StatusOr IsConstant(const LocalOp& operand); - StatusOr > ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand)) -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions)) +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions)) -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs)) +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs)) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) @@ -338,7 +322,7 @@ class LocalComputationBuilder { #undef _FORWARD_TRIOP private: - ComputationBuilder builder_; + XlaBuilder builder_; }; // Functions for freeing resources from the Python side. diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index b8cce5a5f7..04c56bbba9 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -22,9 +22,8 @@ limitations under the License. // // C++ Python // -------------------------------------+--------------------------------------- -// ComputationDataHandle <-> int // ArraySlice <- sequence of int -// ArraySlice <- sequence of int +// ArraySlice <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray // Shape -> pair holding (dtype, dimensions) @@ -91,12 +90,9 @@ limitations under the License. // One central reason for the Python-side indirection is that the // Python-side objects produced by the typemaps in this file are // further packaged up by xla_client before being passed on. For -// instance, xla_client wraps the long produced for a C++ -// ComputationDataHandle in a Python ComputationDataHandle proto, -// rather than exposing a raw long outside of the client. Similarly, -// the Python pair produced for a C++ Shape is further wrapped in a -// Python class (xla_client.Shape) so as not to expose the raw pair -// externally. +// instance, the Python pair produced for a C++ Shape is further +// wrapped in a Python class (xla_client.Shape) so as not to expose +// the raw pair externally. // // Other SWIG object wrappers (e.g. of LocalComputation) are further // wrapped by xla_client in order to set up a custom destructor that @@ -124,6 +120,7 @@ using namespace xla; using namespace xla::swig; namespace xla { + namespace swig { bool GetIntAttr(PyObject* o, const char* field, int64* result) { @@ -177,21 +174,6 @@ bool HandleStringAttribute(PyObject* o, tensorflow::ImportNumpy(); %} -// ComputationDataHandle - -%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { - const int64 handle = numpy::PyIntOrPyLongToLong($input); - if (handle == -1 && PyErr_Occurred()) { - SWIG_fail; - } - temp.set_handle(handle); - $1 = &temp; -} - -%typemap(out) ComputationDataHandle { - $result = numpy::LongToPyIntOrPyLong($1.handle()); -} - %typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); @@ -301,33 +283,23 @@ tensorflow::ImportNumpy(); $1 = temps; } -// ComputationDataHandle +// ArraySlice -%typemap(in) tensorflow::gtl::ArraySlice - (std::vector temps) { +%typemap(in) tensorflow::gtl::ArraySlice( + std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); SWIG_fail; } const int size = PySequence_Size($input); - temps.resize(size); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - PyObject* py_int = numpy::PyNumberToPyInt(o); - if (!py_int) { - PyErr_SetString( - PyExc_TypeError, - "Argument sequence element cannot be converted to int"); - SWIG_fail; - } - const int64 handle = numpy::PyIntOrPyLongToLong(py_int); - if (handle == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - Py_DECREF(o); + LocalOp* op; + if ((SWIG_ConvertPtr(o, (void**)&op, $descriptor(xla::swig::LocalOp*), + SWIG_POINTER_EXCEPTION)) == -1) { SWIG_fail; } - temps[i].set_handle(handle); - Py_DECREF(py_int); + temps.push_back(*op); Py_DECREF(o); } $1 = temps; @@ -934,6 +906,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputation; %unignore xla::swig::LocalComputation::Compile; %unignore xla::swig::LocalComputation::GetReturnValueShape; +%unignore xla::swig::LocalOp; %unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::Build; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index f6809b6b87..1d5b75d1be 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -335,20 +335,6 @@ def _wrap_shape(shape_info): return Shape.array_shape(dtype, dims) -def _wrap_data_handle(handle): - cdh = xla_data_pb2.ComputationDataHandle() - cdh.handle = handle - return cdh - - -def _unwrap_data_handle(handle_proto): - return handle_proto.handle - - -def _unwrap_data_handles(handle_protos): - return [_unwrap_data_handle(cdh) for cdh in handle_protos] - - def require_numpy_array_layout(value): if isinstance(value, tuple): return tuple(require_numpy_array_layout(x) for x in value) @@ -535,9 +521,9 @@ class ComputationBuilder(object): queue for subsequent use in the computation. Returns: - A ComputationDataHandle message. + A LocalOp. """ - return _wrap_data_handle(self._client.Infeed(shape)) + return self._client.Infeed(shape) def Outfeed(self, operand): """Enqueues an outfeed op onto the computation. @@ -545,9 +531,7 @@ class ComputationBuilder(object): Outfeed operations enqueue data, using the given operand, onto the XLA outfeed queue for subsequent dequeue via the client API. """ - self._client.Outfeed( - _unwrap_data_handle(operand), self.GetShape(operand), - ''.encode('utf-8')) + self._client.Outfeed(operand, self.GetShape(operand), ''.encode('utf-8')) def Constant(self, value): """Enqueues a constant op onto the computation. @@ -557,10 +541,10 @@ class ComputationBuilder(object): to one of the supported types. Returns: - A ComputationDataHandle message. + A LocalOp. """ value = require_numpy_array_layout(value) - return _wrap_data_handle(self._client.ConstantLiteral(value)) + return self._client.ConstantLiteral(value) def ConstantF32Scalar(self, value): """Convenience method to enqueue a scalar F32 constant op. @@ -569,7 +553,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float32)) @@ -580,7 +564,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float64)) @@ -591,7 +575,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int32)) @@ -602,7 +586,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int64)) @@ -613,7 +597,7 @@ class ComputationBuilder(object): value: a boolean value. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.bool)) @@ -629,15 +613,14 @@ class ComputationBuilder(object): parameters, use it for *all* parameters to avoid clashes. Returns: - A ComputationDataHandle message. + A LocalOp. """ if name is None: name = '' if parameter_num is None: parameter_num = next(self._parameter_numbering) - return _wrap_data_handle( - self._client.Parameter(parameter_num, shape, name.encode('utf8'))) + return self._client.Parameter(parameter_num, shape, name.encode('utf8')) def ParameterFromNumpy(self, value, name=None, parameter_num=None): """Enqueues a Parameter op onto the computation. @@ -649,7 +632,7 @@ class ComputationBuilder(object): parameter_num: as in ParameterWithShape. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.ParameterWithShape( Shape.from_pyval(value), name=name, parameter_num=parameter_num) @@ -658,14 +641,13 @@ class ComputationBuilder(object): """Enqueues a broadcast operation onto the computation. Args: - operand: the operand ComputationDataHandle to broadcast. + operand: the operand LocalOp to broadcast. sizes: an iterable of broadcast sizes. Returns: - A ComputationDataHandle representing the added broadcast op. + A LocalOp representing the added broadcast op. """ - return _wrap_data_handle( - self._client.Broadcast(_unwrap_data_handle(operand), sizes)) + return self._client.Broadcast(operand, sizes) def Concatenate(self, operands, dimension): """Enqueues a concatenate operation onto the computation. @@ -675,10 +657,9 @@ class ComputationBuilder(object): dimension: the dimension in which to perform the concatenation. Returns: - A ComputationDataHandle representing the added concatenate op. + A LocalOp representing the added concatenate op. """ - return _wrap_data_handle( - self._client.ConcatInDim(_unwrap_data_handles(operands), dimension)) + return self._client.ConcatInDim(operands, dimension) def ConvertElementType(self, operand, new_element_type): """Enqueues an element type conversion operation onto the computation. @@ -688,14 +669,12 @@ class ComputationBuilder(object): new_element_type: the target primitive type. Returns: - A ComputationDataHandle representing the added conversion op. + A LocalOp representing the added conversion op. """ - return _wrap_data_handle( - self._client.ConvertElementType( - _unwrap_data_handle(operand), new_element_type)) + return self._client.ConvertElementType(operand, new_element_type) def GetShape(self, operand): - return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) + return _wrap_shape(self._client.GetShape(operand)) def GetReturnValueShape(self): return _wrap_shape(self._client.GetReturnValueShape()) @@ -707,40 +686,35 @@ class ComputationBuilder(object): """Enqueues a Pad operation onto the computation. Args: - operand: ComputationDataHandle representing the array to pad. - padding_value: ComputationDataHandle representing the scalar pad value. + operand: LocalOp representing the array to pad. + padding_value: LocalOp representing the scalar pad value. padding_config: either an xla_data_pb2.PaddingConfig or a list of integer triples (edge_padding_low, edge_padding_high, interior_padding) representing the configuration of the padding operation. Returns: - A ComputationDataHandle representing the added Pad op. + A LocalOp representing the added Pad op. """ if not isinstance(padding_config, xla_data_pb2.PaddingConfig): padding_config = GetPaddingConfigFromTriples(padding_config) - return _wrap_data_handle( - self._client.Pad(_unwrap_data_handle(operand), - _unwrap_data_handle(padding_value), - padding_config)) + return self._client.Pad(operand, padding_value, padding_config) def Reshape(self, operand, dimensions, new_sizes): """Enqueues a reshape op onto the computation. Args: - operand: ComputationDataHandle representing the array to be reshaped. + operand: LocalOp representing the array to be reshaped. dimensions: sequence of integers encoding the order in which dimensions are collapsed or None, in which case dimensions are flattened in order. new_sizes: sequence of integers encoding the new dimension sizes (shape). Returns: - A ComputationDataHandle representing the added Reshape op. + A LocalOp representing the added Reshape op. """ if dimensions is None: ndim = len(self.GetShape(operand).dimensions()) dimensions = tuple(range(ndim)) - return _wrap_data_handle( - self._client.Reshape( - _unwrap_data_handle(operand), dimensions, new_sizes)) + return self._client.Reshape(operand, dimensions, new_sizes) def CrossReplicaSum(self, operand): """CrossReplicaSum op. @@ -749,67 +723,56 @@ class ComputationBuilder(object): operand: the operand to sum across replica instances. Returns: - A ComputationDataHandle that has the sum of the value among all replicas. + A LocalOp that has the sum of the value among all replicas. """ - return _wrap_data_handle( - self._client.CrossReplicaSum(_unwrap_data_handle(operand))) + return self._client.CrossReplicaSum(operand) def Collapse(self, operand, dimensions): """Collapse op.""" - return _wrap_data_handle( - self._client.Collapse(_unwrap_data_handle(operand), dimensions)) + return self._client.Collapse(operand, dimensions) def Trans(self, operand): """Specialized matrix transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), [1, 0])) + return self._client.Transpose(operand, [1, 0]) def Transpose(self, operand, permutation): """Transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), permutation)) + return self._client.Transpose(operand, permutation) def Rev(self, operand, dimensions): """Rev op.""" - return _wrap_data_handle( - self._client.Rev(_unwrap_data_handle(operand), dimensions)) + return self._client.Rev(operand, dimensions) def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin """Clamp op.""" - return _wrap_data_handle( - self._client.Clamp(_unwrap_data_handle(min), - _unwrap_data_handle(operand), - _unwrap_data_handle(max))) + return self._client.Clamp(min, operand, max) def SelectAndScatter(self, operand, select, window_dimensions, window_strides, padding, source, init_value, scatter): """Select and scatter op, used by the gradient of ReduceWindow. Args: - operand: ComputationDataHandle for array of dimension N and type T over + operand: LocalOp for array of dimension N and type T over which the windows slide. select: Computation of type (T, T) -> Pred to apply to the elements of each window to indicate which element is selected. window_dimensions: sequence of N integers for dimensions of the window. window_strides: sequence of N integers for the strides of the window. padding: PaddingType representing either 'SAME' or 'VALID ' padding. - source: ComputationDataHandle for array of type T with values to scatter. - init_value: ComputationDataHandle of scalar type T for initial out value. + source: LocalOp for array of type T with values to scatter. + init_value: LocalOp of scalar type T for initial out value. scatter: Computation of type (T, T) -> T to apply to each scatter source element with its destination element. Returns: - A ComputationDataHandle representing the added SelectAndScatter op. + A LocalOp representing the added SelectAndScatter op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.SelectAndScatterWithGeneralPadding( - _unwrap_data_handle(operand), select.c_local_computation, - window_dimensions, window_strides, pads, - _unwrap_data_handle(source), _unwrap_data_handle(init_value), - scatter.c_local_computation)) + return self._client.SelectAndScatterWithGeneralPadding( + operand, select.c_local_computation, window_dimensions, window_strides, + pads, source, init_value, scatter.c_local_computation) def Select(self, pred, on_true, on_false): """Element-wise selection op. @@ -817,17 +780,13 @@ class ComputationBuilder(object): Constructs an output array from elements of two input arrays, based on the values of a predicate array. """ - return _wrap_data_handle( - self._client.Select( - _unwrap_data_handle(pred), - _unwrap_data_handle(on_true), - _unwrap_data_handle(on_false))) + return self._client.Select(pred, on_true, on_false) def Slice(self, operand, start_indices, limit_indices, strides=None): """Enqueues a slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. + operand: LocalOp for the N dimensional array to be sliced. start_indices: iterable of N integers containing the starting indices of the slice for each dimension. limit_indices: iterable of N integers containing the ending indices @@ -836,207 +795,177 @@ class ComputationBuilder(object): each dimension. Returns: - A ComputationDataHandle representing the added Slice op. + A LocalOp representing the added Slice op. """ if strides is None: start_indices = list(start_indices) strides = [1] * len(start_indices) - return _wrap_data_handle( - self._client.Slice( - _unwrap_data_handle(operand), start_indices, limit_indices, - strides)) + return self._client.Slice(operand, start_indices, limit_indices, strides) def SliceInDim(self, operand, start_index, limit_index, stride, dimno): """Enqueues a slice-in-dimension operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. + operand: LocalOp for the N dimensional array to be sliced. start_index: an integer containing the start index of the slice. limit_index: an integer containing the end index of the slice. stride: an integer containing the stride size for the slice. dimno: an integer indicating the dimension along which to slice. Returns: - A ComputationDataHandle representing the added Slice op. + A LocalOp representing the added Slice op. """ - return _wrap_data_handle( - self._client.SliceInDim( - _unwrap_data_handle(operand), start_index, limit_index, stride, - dimno)) + return self._client.SliceInDim(operand, start_index, limit_index, stride, + dimno) def DynamicSlice(self, operand, start_indices, slice_sizes): """Enqueues a slice op with dynamic start indices onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. - start_indices: ComputationDataHandle for the 1D array of N integers + operand: LocalOp for the N dimensional array to be sliced. + start_indices: LocalOp for the 1D array of N integers containing the starting indices of the slice. slice_sizes: iterable of N integers containing the slice sizes in each dimension. Returns: - A ComputationDataHandle representing the added DynamicSlice op. + A LocalOp representing the added DynamicSlice op. """ - return _wrap_data_handle( - self._client.DynamicSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(start_indices), - slice_sizes)) + return self._client.DynamicSlice(operand, start_indices, slice_sizes) def DynamicUpdateSlice(self, operand, update, start_indices): """Enqueues a dynamic update slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be updated. + operand: LocalOp for the N dimensional array to be updated. update: N dimensional array comprising the slice update. start_indices: Rank-1 array of N integers comprising the starting indices of the slice along each dimension. Returns: - A ComputationDataHandle representing the added DynamicUpdateSlice op. + A LocalOp representing the added DynamicUpdateSlice op. """ - return _wrap_data_handle( - self._client.DynamicUpdateSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(update), - _unwrap_data_handle(start_indices))) + return self._client.DynamicUpdateSlice(operand, update, start_indices) def Tuple(self, *ops): """Enqueues a tuple operation onto the computation. Args: - ops: a sequence of tuple operands (each a ComputationDataHandle). + ops: a sequence of tuple operands (each a LocalOp). Returns: - A ComputationDataHandle representing the added Tuple op. + A LocalOp representing the added Tuple op. """ - return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops))) + return self._client.Tuple(ops) def GetTupleElement(self, tup, index): """Enqueues a 'get tuple element' operation onto the computation. Args: - tup: the tuple operand (a ComputationDataHandle). + tup: the tuple operand (a LocalOp). index: numeric index to select from the tuple. Returns: - A ComputationDataHandle representing the added GetTupleElement op. + A LocalOp representing the added GetTupleElement op. """ - return _wrap_data_handle( - self._client.GetTupleElement(_unwrap_data_handle(tup), index)) + return self._client.GetTupleElement(tup, index) def Call(self, computation_to_apply, operands): """Enqueues a call operation onto the computation. Args: computation_to_apply: a Computation object. - operands: an iterable of ComputationDataHandle. The number and types of + operands: an iterable of LocalOp. The number and types of operands must match the arity of computation_to_apply. Returns: - A ComputationDataHandle representing the added call op. + A LocalOp representing the added call op. """ - return _wrap_data_handle( - self._client.Call(computation_to_apply.c_local_computation, - _unwrap_data_handles(operands))) + return self._client.Call(computation_to_apply.c_local_computation, operands) def Map(self, operands, computation_to_apply, dimensions, static_operands=()): """Enqueues a map operation onto the computation. Args: - operands: an iterable of ComputationDataHandle. + operands: an iterable of LocalOp. computation_to_apply: a Computation object. dimensions: dimensions over which to apply map the function. static_operands: auxiliary arguments passed to the applied computation. Returns: - A ComputationDataHandle representing the added Map op. + A LocalOp representing the added Map op. """ - return _wrap_data_handle( - self._client.Map( - _unwrap_data_handles(operands), - computation_to_apply.c_local_computation, - dimensions, - _unwrap_data_handles(static_operands))) + return self._client.Map(operands, computation_to_apply.c_local_computation, + dimensions, static_operands) def Reduce(self, operand, init_value, computation_to_apply, dimensions): """Enqueues a reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a Computation object - binary reduction function. dimensions: sequence of dimensions (integers) to reduce on. Returns: - A ComputationDataHandle representing the added Reduce op. + A LocalOp representing the added Reduce op. """ - return _wrap_data_handle( - self._client.Reduce( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - dimensions)) + return self._client.Reduce(operand, init_value, + computation_to_apply.c_local_computation, + dimensions) def ReduceWindow(self, operand, init_value, computation_to_apply, window_dimensions, window_strides, padding): """Enqueues a windowed reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a binary reduction function (Computation). window_dimensions: dimensions of window (sequence of integers). window_strides: strides for window (sequence of integers). padding: PaddingType representing either 'SAME' or 'VALID' padding. Returns: - A ComputationDataHandle representing the added ReduceWindow op. + A LocalOp representing the added ReduceWindow op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.ReduceWindowWithGeneralPadding( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - window_dimensions, window_strides, pads)) + return self._client.ReduceWindowWithGeneralPadding( + operand, init_value, computation_to_apply.c_local_computation, + window_dimensions, window_strides, pads) def RngNormal(self, mu, sigma, dims): """Enqueues an RngNormal operation onto the computation. Args: - mu: A ComputationDataHandle to an F32 scalar specifying the mean. - sigma: A ComputationDataHandle to an F32 scalar specifying the standard + mu: A LocalOp to an F32 scalar specifying the mean. + sigma: A LocalOp to an F32 scalar specifying the standard deviation. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of F32 values. + Returns: a LocalOp to the generated array of F32 values. """ shape = Shape.array_shape(self.GetShape(mu).element_type(), dims) - return _wrap_data_handle( - self._client.RngNormal( - _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape)) + return self._client.RngNormal(mu, sigma, shape) def RngUniform(self, a, b, dims): """Enqueues an RngUniform operation onto the computation. Args: - a: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + a: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of b) specifying the low end of the interval [a, b) over which values are generated. - b: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + b: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of a) specifying the high end of the interval [a, b) over which values are generated. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of values with the + Returns: a LocalOp to the generated array of values with the same numeric type (F32, S32, or U32) as the arguments a and b. """ shape = Shape.array_shape(self.GetShape(a).element_type(), dims) - return _wrap_data_handle( - self._client.RngUniform( - _unwrap_data_handle(a), _unwrap_data_handle(b), shape)) + return self._client.RngUniform(a, b, shape) def While(self, cond, body, init): """Enqueues a While operation onto the computation. @@ -1044,112 +973,105 @@ class ComputationBuilder(object): Args: cond: a Computation for the loop condition, which has type T -> PRED body: a Computation for the loop body, which has type T -> T - init: a ComputationDataHandle for the initial parameter, which has type T + init: a LocalOp for the initial parameter, which has type T - Returns: a ComputationDataHandle representing the While operation. + Returns: a LocalOp representing the While operation. """ - return _wrap_data_handle( - self._client.While(cond.c_local_computation, - body.c_local_computation, - _unwrap_data_handle(init))) + return self._client.While(cond.c_local_computation, + body.c_local_computation, init) def Conditional(self, pred, true_operand, true_computation, false_operand, false_computation): """Enqueues a Conditional operation onto the computation. Args: - predicate: a ComputationDataHandle to test, which has scalar type PRED - true_operand: a ComputationDataHandle of type T_0 + predicate: a LocalOp to test, which has scalar type PRED + true_operand: a LocalOp of type T_0 true_computation: a Computation to apply to true_operand, type T_0 -> S false_operand: a ComputationDatahandle of type T_1 false_computation: a Computation to apply to false_operand, type T_1 -> S - Returns: a ComputationDataHandle representing the Conditional operation. + Returns: a LocalOp representing the Conditional operation. """ - return _wrap_data_handle( - self._client.Conditional( - _unwrap_data_handle(pred), _unwrap_data_handle(true_operand), - true_computation.c_local_computation, - _unwrap_data_handle(false_operand), - false_computation.c_local_computation)) + return self._client.Conditional( + pred, true_operand, true_computation.c_local_computation, false_operand, + false_computation.c_local_computation) - def IsConstant(self, operand, num_parameters=0): - """Enqueues an IsConstant operation onto the computation. + def IsConstant(self, operand): + """Checks whether the given operand is a compile-time constant. Args: operand: a ComputationDataHandle to test. - num_parameters: optional int, number of computation parameters to treat as - constant (default 0). Returns: bool indicating whether `operand` is a compile-time constant, - meaning its value does not depend on parameters with index greater than or - equal to `num_parameters`. + meaning its value does not depend on any parametersor, or on stateful + operators such as `RngNormal` or `Infeed`. + """ + return self._client.IsConstant(operand) + + def BuildConstantSubGraph(self, operand): + """Builds a constant sub graph. + + Args: + operand: a LocalOp to test. + Returns: a LocalComputation that is rooted on the given `operand` which is a + compile-time constant. """ - return self._client.IsConstant(_unwrap_data_handle(operand), num_parameters) + return self._client.BuildConstantSubGraph(operand) def Dot(self, lhs, rhs): """Enqueues a dot operation onto the computation. Args: - lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array. - rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array. + lhs: LocalOp for the rank 1 or rank 2 left-hand-side array. + rhs: LocalOp for the rank 1 or rank 2 right-hand-side array. - Returns: a ComputationDataHandle representing the Dot operation. + Returns: a LocalOp representing the Dot operation. """ - return _wrap_data_handle( - self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) + return self._client.Dot(lhs, rhs) def DotGeneral(self, lhs, rhs, dimension_numbers): """Enqueues a general dot operation onto the computation. Args: - lhs: ComputationDataHandle for the left-hand-side array. - rhs: ComputationDataHandle for the right-hand-side array. + lhs: LocalOp for the left-hand-side array. + rhs: LocalOp for the right-hand-side array. dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of integers representing the dimensions to treat as contracting dimensions and batch dimensions on each input operand. - Returns: a ComputationDataHandle representing the DotGeneral operation. + Returns: a LocalOp representing the DotGeneral operation. """ if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers): dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) - return _wrap_data_handle( - self._client.DotGeneral( - _unwrap_data_handle(lhs), _unwrap_data_handle(rhs), - dimension_numbers)) + return self._client.DotGeneral(lhs, rhs, dimension_numbers) def Conv(self, lhs, rhs, window_strides, padding): """Enqueues a Conv operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of integer kernel strides. padding: PaddingType representing either 'SAME' or 'VALID' padding. - Returns: a ComputationDataHandle representing the Conv operation. + Returns: a LocalOp representing the Conv operation. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(lhs).dimensions()[2:], self.GetShape(rhs).dimensions()[2:], window_strides) dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - pads, - (), - (), - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (), + (), dimension_numbers) def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation): """Enqueues a ConvWithGeneralPadding operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of kernel strides. padding: length-N array-like of pairs of integers of (low, high) padding. lhs_dilation: length-N array-like of dilation factors. @@ -1159,14 +1081,9 @@ class ComputationBuilder(object): A ComputationdataHandle representing the added ConvWithGeneralPadding op. """ dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - padding, - lhs_dilation, - rhs_dilation, - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers) def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" @@ -1196,15 +1113,14 @@ def _forward_methods_to_local_builder(): """Generate a forwarding method that wraps/unwraps data handles.""" def forward(self, *args, **kwargs): - unwrapped_args = [_unwrap_data_handle(arg) for arg in args] + arg_list = list(args) - if is_binop and len(unwrapped_args) < 3: - unwrapped_args.append(kwargs.get('broadcast_dimensions', ())) + if is_binop and len(arg_list) < 3: + arg_list.append(kwargs.get('broadcast_dimensions', ())) - return _wrap_data_handle( - target_method( - self._client, # pylint: disable=protected-access - *unwrapped_args)) + return target_method( + self._client, # pylint: disable=protected-access + *arg_list) return forward -- GitLab From 5ca373b4b64167f8b0fcab96d7d2e7886ea31b6a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 12:28:42 -0700 Subject: [PATCH 028/738] Some fixes to support another TF graph: 1. Fix ResolveBatchNormalization to avoid deleting arrays that may still be used. 2. Correctly count the number of ops using a given array, even when some ops use the same array as more than one of their inputs. 3. In PropagateFixedSizes for Concatenation ops, when resolving a -1 wildcard to a fixed value, we were doing so in a local 'axis' variable without actually updating op->axis! The resulting -1 value still in op->axis tripped runtime code, causing the concatenation to misbehave during inference. PiperOrigin-RevId: 195454037 --- .../graph_transformations/propagate_fixed_sizes.cc | 11 +++++------ .../resolve_batch_normalization.cc | 6 +++--- tensorflow/contrib/lite/toco/tooling_util.cc | 4 ++++ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 4923f83d91..b02b02c5be 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -670,8 +670,7 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { const auto& first_input_array = model->GetArray(op->inputs[0]); output_array.copy_shape(first_input_array.shape()); // Negative axis means the count starts at the back of the dims(). - int axis = op->axis; - if (axis < 0) axis += first_input_array.shape().dims().size(); + if (op->axis < 0) op->axis += first_input_array.shape().dims().size(); // Determine the concat size, and enfore that all inputs have // the same dimensions count. int concat_size = 0; @@ -684,14 +683,14 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { CHECK_EQ(input_array.shape().dimensions_count(), output_array.shape().dimensions_count()); const std::vector& input_dims = input_array.shape().dims(); - CHECK_LT(axis, input_dims.size()); - concat_size += input_dims[axis]; + CHECK_LT(op->axis, input_dims.size()); + concat_size += input_dims[op->axis]; } // Write out the concat_size on the output array shape. auto& output_shape = *output_array.mutable_shape(); auto& output_dims = *output_shape.mutable_dims(); - CHECK_LT(axis, output_shape.dimensions_count()); - output_dims[axis] = concat_size; + CHECK_LT(op->axis, output_shape.dimensions_count()); + output_dims[op->axis] = concat_size; } void ProcessRangeOperator(Model* model, RangeOperator* op) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc index 2b3ee36ad1..8f2c1f8162 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc @@ -134,9 +134,9 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { } // Remove the old param arrays - model->EraseArray(bn_op->inputs[1]); - model->EraseArray(bn_op->inputs[2]); - model->EraseArray(bn_op->inputs[3]); + DeleteArrayIfUsedOnce(bn_op->inputs[1], model); + DeleteArrayIfUsedOnce(bn_op->inputs[2], model); + DeleteArrayIfUsedOnce(bn_op->inputs[3], model); // Remove the old operator DCHECK_EQ(bn_it->get(), bn_op); diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 86ee1f3761..341d45e753 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -143,6 +143,10 @@ int CountOpsWithInput(const Model& model, const string& array_name) { for (auto& input : op->inputs) { if (input == array_name) { count++; + // Breaking here is important: some graphs have ops that use the + // same array as more than one of their inputs, and in that case + // we want it counted only once. + break; } } } -- GitLab From 67b5e724121c5874425936fe01318642508d9975 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 4 May 2018 14:40:02 -0700 Subject: [PATCH 029/738] [XLA:GPU] Mark floating-point division as an inexpensive op. "Expensive" really means "so expensive you'd choose not to fuse in order to avoid doing it twice". FP division definitely isn't that expensive. PiperOrigin-RevId: 195473524 --- .../xla/service/gpu/instruction_fusion.cc | 13 +++++ .../xla/service/gpu/instruction_fusion.h | 2 + .../service/gpu/instruction_fusion_test.cc | 56 +++++++++++++++++++ 3 files changed, 71 insertions(+) diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 85ecbe8fdb..c5eb721185 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -48,6 +48,19 @@ bool IsFusile(const HloInstruction& hlo) { } // namespace +/*static*/ bool GpuInstructionFusion::IsExpensive( + const HloInstruction& instruction) { + switch (instruction.opcode()) { + // We say that floating-point division is cheap on the GPU. + case HloOpcode::kDivide: + return !ShapeUtil::ElementIsFloating(instruction.shape()) && + InstructionFusion::IsExpensive(instruction); + + default: + return InstructionFusion::IsExpensive(instruction); + } +} + bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index bb2990e6df..9fb06b0a24 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -27,6 +27,8 @@ class GpuInstructionFusion : public InstructionFusion { explicit GpuInstructionFusion(bool may_duplicate) : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {} + static bool IsExpensive(const HloInstruction& instruction); + bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; HloInstruction::FusionKind ChooseKind( diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 4b231c449f..6c9a805ad6 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -253,5 +253,61 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { op::Dot(op::Parameter(), op::Transpose(op::Parameter())))); } +// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is +// duplicated and fused into both reduces. +TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { + auto module = tools::Parse(R"( + HloModule test_module + Add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = f32[] constant(0) + one = f32[] constant(1) + p0 = f32[100] parameter(0) + recip = f32[100] divide(one, p0) + sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT root = (f32[], f32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion())); +} + +// Compute sum(100/p0), where p0 has type s32, twice. Check that the division +// is *not* duplicated and fused into both reduces, because we say that integer +// division is not cheap. +TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { + auto module = tools::Parse(R"( + HloModule test_module + Add { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = s32[] constant(0) + one_hundred = s32[] constant(100) + p0 = s32[100] parameter(0) + recip = s32[100] divide(one_hundred, p0) + sum1 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT mul = (s32[], s32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + } // namespace gpu } // namespace xla -- GitLab From 4d0388d22060a61f40965127c153c681b2412c50 Mon Sep 17 00:00:00 2001 From: James Qin Date: Fri, 4 May 2018 14:53:58 -0700 Subject: [PATCH 030/738] Fix build failure for macos py3 PiperOrigin-RevId: 195475780 --- tensorflow/python/debug/examples/debug_tflearn_iris.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/debug/examples/debug_tflearn_iris.py b/tensorflow/python/debug/examples/debug_tflearn_iris.py index 00090b21fe..7cbaae46b4 100644 --- a/tensorflow/python/debug/examples/debug_tflearn_iris.py +++ b/tensorflow/python/debug/examples/debug_tflearn_iris.py @@ -140,7 +140,7 @@ def main(_): # Make predictions, using tfdbg hook. predict_results = classifier.predict(test_input_fn, hooks=hooks) - print("A prediction result: %s" % predict_results.next()) + print("A prediction result: %s" % next(predict_results)) if __name__ == "__main__": -- GitLab From cb1775e9525ae621d23708a3d64a6cad897be95e Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 4 May 2018 15:14:00 -0700 Subject: [PATCH 031/738] Identify and prune nodes that can never be executed PiperOrigin-RevId: 195478951 --- tensorflow/core/grappler/optimizers/BUILD | 1 + .../grappler/optimizers/loop_optimizer.cc | 140 ++++++++++++++++++ .../core/grappler/optimizers/loop_optimizer.h | 1 + .../optimizers/loop_optimizer_test.cc | 107 +++++++++++++ tensorflow/core/grappler/utils.h | 4 +- 5 files changed, 251 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 5b5e1e024e..900dfa95c5 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -604,6 +604,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index 5adc5b9227..7d3520febc 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" @@ -504,6 +505,140 @@ Status RemoveStackOps(const std::unordered_set& nodes_to_preserve, return Status::OK(); } +Status RemoveDeadBranches(const std::unordered_set& nodes_to_preserve, + GraphDef* optimized_graph) { + std::unordered_set dead_nodes; + std::unordered_map> dead_merge_inputs; + // TODO(bsteiner): also rewrite switches as identity. For now we just record + // them + std::unordered_set + identity_switches; + + GraphView view(optimized_graph); + for (const NodeDef& node : optimized_graph->node()) { + if (!IsSwitch(node)) { + continue; + } + if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) { + continue; + } + GraphView::InputPort ctrl_port(&node, 1); + GraphView::OutputPort ctrl_node = view.GetRegularFanin(ctrl_port); + if (!IsConstant(*ctrl_node.node)) { + continue; + } + Tensor selector; + CHECK(selector.FromProto(ctrl_node.node->attr().at("value").tensor())); + const int dead_fanout = selector.scalar()() ? 0 : 1; + GraphView::OutputPort dead(const_cast(&node), dead_fanout); + identity_switches.insert(dead); + + SetVector zombie_inputs; + for (const GraphView::InputPort& port : view.GetFanout(dead)) { + if (dead_nodes.find(port.node) == dead_nodes.end()) { + zombie_inputs.PushBack(port); + } + } + // If we encounter a single node that must be preserved in the fanout of the + // switch node we need to preserve the entire switch fanout: we therefore + // work on a local copy that only gets committed to the master copy once the + // whole fanout has been explored. + std::unordered_set local_dead_nodes = dead_nodes; + std::unordered_map> local_dead_merge_inputs = + dead_merge_inputs; + bool found_node_to_preserve = false; + while (!found_node_to_preserve && !zombie_inputs.Empty()) { + GraphView::InputPort dead = zombie_inputs.PopBack(); + if (nodes_to_preserve.find(dead.node->name()) != + nodes_to_preserve.end()) { + found_node_to_preserve = true; + break; + } + + if (local_dead_nodes.find(dead.node) != local_dead_nodes.end()) { + continue; + } + + if (IsMerge(*dead.node)) { + const int fanout = dead.node->attr().at("N").i(); + if (fanout > 2) { + // This never happens in practice, so we'll just skip these to + // simplify the code for now. + found_node_to_preserve = true; + break; + } + GraphView::OutputPort value_index(dead.node, 1); + const std::unordered_set& + index_fanout = view.GetFanout(value_index); + if (!index_fanout.empty()) { + // The 2nd output (that indicates which input is propagated) is + // connected. This never happens in practice, so we'll just skip this + // case to simplify the code for now. + found_node_to_preserve = true; + break; + } + + bool fully_dead = false; + if (dead.port_id < 0) { + // If the control dependency never gets triggered the merge will also + // never get triggered. + local_dead_nodes.insert(dead.node); + fully_dead = true; + } else { + local_dead_merge_inputs[dead.node].insert(dead.port_id); + if (local_dead_merge_inputs[dead.node].size() == + dead.node->attr().at("N").i()) { + fully_dead = true; + } + if (fully_dead) { + local_dead_nodes.insert(dead.node); + for (const GraphView::InputPort& port : + view.GetFanouts(*dead.node, true)) { + zombie_inputs.PushBack(port); + } + } + } + } else { + if (local_dead_nodes.insert(dead.node).second) { + for (const GraphView::InputPort& dead_fanout : + view.GetFanouts(*dead.node, true)) { + zombie_inputs.PushBack(dead_fanout); + } + } + } + } + if (!found_node_to_preserve) { + std::swap(dead_nodes, local_dead_nodes); + std::swap(dead_merge_inputs, local_dead_merge_inputs); + } + } + + int last = optimized_graph->node_size() - 1; + for (int i = optimized_graph->node_size() - 1; i >= 0; --i) { + NodeDef* node = optimized_graph->mutable_node(i); + if (dead_nodes.find(node) != dead_nodes.end()) { + optimized_graph->mutable_node()->SwapElements(i, last); + last--; + } + } + optimized_graph->mutable_node()->DeleteSubrange(last + 1, dead_nodes.size()); + + for (const auto& itr : dead_merge_inputs) { + NodeDef* dead_node = itr.first; + if (dead_nodes.find(dead_node) != dead_nodes.end()) { + // The node has been pruned since all its inputs are dead. + continue; + } + const std::set& dead_inputs = itr.second; + for (int index : dead_inputs) { + dead_node->mutable_input()->DeleteSubrange(index, 1); + } + dead_node->set_op("Identity"); + dead_node->mutable_attr()->erase("N"); + } + return Status::OK(); +} + } // namespace Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, @@ -517,6 +652,11 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, if (options_.enable_stack_push_removal) { TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph)); } + if (opt_level_ == RewriterConfig::AGGRESSIVE && + options_.enable_dead_branch_removal) { + TF_RETURN_IF_ERROR( + RemoveDeadBranches(item.NodesToPreserve(), optimized_graph)); + } return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.h b/tensorflow/core/grappler/optimizers/loop_optimizer.h index 764506f7c1..85b8e65543 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.h +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.h @@ -54,6 +54,7 @@ class LoopOptimizer : public GraphOptimizer { struct LoopOptimizerOptions { bool enable_loop_invariant_node_motion = false; bool enable_stack_push_removal = true; + bool enable_dead_branch_removal = true; static LoopOptimizerOptions Default(RewriterConfig::Toggle opt_level) { LoopOptimizerOptions options; diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc index 10ec544424..6fd177b710 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc @@ -589,5 +589,112 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) { } } +TEST_F(LoopOptimizerTest, RemoveDeadBranches) { + Scope scope = Scope::NewRootScope(); + Output v_in = ops::Variable(scope.WithOpName("v_in"), {3}, DT_FLOAT); + + Output ctrl1 = ops::Const(scope.WithOpName("ctrl1"), false, TensorShape({})); + ops::Switch s1(scope.WithOpName("switch1"), v_in, ctrl1); + Output square1 = ops::Square(scope.WithOpName("square1"), s1.output_false); + Output sqrt1 = ops::Sqrt(scope.WithOpName("sqrt1"), s1.output_true); + + Output ctrl2 = ops::Const(scope.WithOpName("ctrl2"), true, TensorShape({})); + ops::Switch s2(scope.WithOpName("switch2"), v_in, ctrl2); + Output square2 = ops::Square(scope.WithOpName("square2"), s2.output_false); + Output sqrt2 = ops::Sqrt(scope.WithOpName("sqrt2"), s2.output_true); + + Output ctrl3 = ops::Const(scope.WithOpName("ctrl3"), false, TensorShape({})); + ops::Switch s3(scope.WithOpName("switch3"), v_in, ctrl3); + Output square3 = ops::Square(scope.WithOpName("square3"), s3.output_false); + Output sqrt3 = ops::Sqrt(scope.WithOpName("sqrt3"), s3.output_true); + + Output ctrl4 = ops::Const(scope.WithOpName("ctrl4"), false, TensorShape({})); + ops::Switch s4(scope.WithOpName("switch4"), v_in, ctrl4); + Output square4 = ops::Square(scope.WithOpName("square4"), s4.output_false); + Output sqrt4 = ops::Sqrt(scope.WithOpName("sqrt4"), s4.output_true); + + ops::Merge m1(scope.WithOpName("m1"), {square1, sqrt1}); + ops::Merge m2(scope.WithOpName("m2"), {v_in, square1}); + ops::Merge m3(scope.WithOpName("m3"), {v_in, sqrt1}); + ops::Merge m4(scope.WithOpName("m4"), {square1, sqrt2}); + ops::Merge m5(scope.WithOpName("m5"), {square2, sqrt1}); + ops::Merge m6(scope.WithOpName("m6").WithControlDependencies(sqrt2), + {v_in, square1}); + ops::Merge m7(scope.WithOpName("m7").WithControlDependencies(sqrt1), + {v_in, square1}); + + ops::Switch s5(scope.WithOpName("switch5"), v_in, ctrl1); + Output id1 = ops::Identity(scope.WithOpName("id1"), s5.output_false); + Output id2 = ops::Identity(scope.WithOpName("id2"), s5.output_true); + ops::Merge m8(scope.WithOpName("m8"), {id1, id2}); + + ops::Switch s6(scope.WithOpName("switch6"), v_in, ctrl1); + Output id3 = ops::Identity(scope.WithOpName("id3"), s6.output_false); + Output id4 = ops::Identity(scope.WithOpName("id4"), s6.output_true); + ops::Merge m9(scope.WithOpName("m9"), {id3, id4}); + + GrapplerItem item; + item.fetch.push_back("m8"); + item.fetch.push_back("id4"); + + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_CHECK_OK(status); + + for (const NodeDef& node : output.node()) { + // These nodes should have been pruned + EXPECT_NE("Square1", node.name()); + EXPECT_NE("Sqrt2", node.name()); + EXPECT_NE("m5", node.name()); + EXPECT_NE("m7", node.name()); + + if (node.name() == "m1") { + // sqrt1 is dead + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("square1", node.input(0)); + } else if (node.name() == "m2") { + // both inputs are alive + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("v_in", node.input(0)); + EXPECT_EQ("square1", node.input(1)); + } else if (node.name() == "m3") { + // sqrt1 is dead + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("v_in", node.input(0)); + } else if (node.name() == "m4") { + // both inputs are alive + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("square1", node.input(0)); + EXPECT_EQ("sqrt2", node.input(1)); + } else if (node.name() == "m6") { + // both inputs are alive and the control dependency can get triggered + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("v_in", node.input(0)); + EXPECT_EQ("square1", node.input(1)); + EXPECT_EQ("^sqrt2", node.input(2)); + } else if (node.name() == "m8") { + // The node is to be preserved because of a fetch + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("id1", node.input(0)); + EXPECT_EQ("id2", node.input(1)); + } else if (node.name() == "m9") { + // The node is to be preserved because of a fetch + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("id3", node.input(0)); + EXPECT_EQ("id4", node.input(1)); + } + } +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index b87ae05546..1c6fef59ea 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -65,7 +65,7 @@ class NodeMap { // A vector with a set. The set stores the same elements as the vector, and // quickly answers whether a value is in the vector. Duplicated elements are not // allowed for now. -template +template > class SetVector { public: // Returns false if value already existed in the set, true otherwise. @@ -91,7 +91,7 @@ class SetVector { void Reserve(int64 size) { vector_.reserve(size); } private: - std::unordered_set set_; + std::unordered_set set_; std::vector vector_; }; -- GitLab From 77a866ced3ca76c96b74af2759e432bfe250566f Mon Sep 17 00:00:00 2001 From: manhyuk Date: Sat, 5 May 2018 21:01:01 +0900 Subject: [PATCH 032/738] fix typo --- .../hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc index 60281951dd..66939fbb0f 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc @@ -115,7 +115,7 @@ static void CheckOpsSupport(const GraphDef& graph_def, HexagonOpsDefinitions::getInstance(); LOG(INFO) << "Checking " << graph_def.node_size() << " nodes"; LOG(INFO) << "dump_all_nodes = " << dump_all_nodes - << ", dump_shape_and_tpye = " << dump_shape_and_type; + << ", dump_shape_and_type = " << dump_shape_and_type; std::unordered_set unsupported_ops; bool all_supported = true; -- GitLab From cad5d6694aced77ab3c9141be2eea121bc6c9cb7 Mon Sep 17 00:00:00 2001 From: manhyuk Date: Sat, 5 May 2018 21:02:58 +0900 Subject: [PATCH 033/738] fix typo --- tensorflow/compiler/xla/shape_util.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index cb8bf5a2b9..82c75f85d8 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -231,7 +231,7 @@ class ShapeUtil { } // Returns the higher-precision element type if a and b are both floating - // point types; otherwise, checks that that they have the same element type + // point types; otherwise, checks that they have the same element type // and returns it. static PrimitiveType HigherPrecisionElementType(const Shape& a, const Shape& b) { -- GitLab From ab48fb528221152299fb08da8116d2eca54b8423 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 4 May 2018 15:40:07 -0700 Subject: [PATCH 034/738] [XLA] Print allowed attributes when the user specifies an invalid attr. PiperOrigin-RevId: 195482974 --- .../compiler/xla/tools/parser/hlo_parser.cc | 30 +++++++++++++------ .../xla/tools/parser/hlo_parser_test.cc | 2 +- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 3a945fb3b1..40dc0730ce 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -30,6 +30,7 @@ namespace { using tensorflow::StringPiece; using tensorflow::gtl::optional; +using tensorflow::str_util::Join; using tensorflow::str_util::Split; using tensorflow::str_util::SplitAndParseAsInts; using tensorflow::strings::Printf; @@ -53,7 +54,7 @@ class HloParser { std::unique_ptr ConsumeHloModule() { return std::move(module_); } // Returns the error information. - string GetError() const { return tensorflow::str_util::Join(error_, "\n"); } + string GetError() const { return Join(error_, "\n"); } private: // ParseXXX returns false if an error occurred. @@ -245,7 +246,7 @@ bool HloParser::Error(LocTy loc, StringPiece msg) { error_lines.push_back(std::string(lexer_.GetLine(loc))); error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); - error_.push_back(tensorflow::str_util::Join(error_lines, "\n")); + error_.push_back(Join(error_lines, "\n")); VLOG(1) << "Error: " << error_.back(); return false; } @@ -1488,11 +1489,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", - tensorflow::str_util::Join( - elems_seen_until_dim, ",", - [](string* out, const int64& num_elems) { - tensorflow::strings::StrAppend(out, num_elems - 1); - }), + Join(elems_seen_until_dim, ",", + [](string* out, const int64& num_elems) { + tensorflow::strings::StrAppend(out, num_elems - 1); + }), "]"); }; do { @@ -1680,7 +1680,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, return Error( index_loc, StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", tensorflow::str_util::Join(index, ", "), "]")); + ": [", Join(index, ", "), "]")); } } if (!ParseToken(TokKind::kColon, @@ -1848,7 +1848,19 @@ bool HloParser::ParseAttributeHelper( } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { - return Error(loc, Printf("unexpected attribute %s", name.c_str())); + string allowed_attrs; + if (attrs.empty()) { + allowed_attrs = "No attributes are allowed here."; + } else { + allowed_attrs = StrCat( + "Allowed attributes: ", + Join(attrs, ", ", + [&](string* out, const std::pair& kv) { + StrAppend(out, kv.first); + })); + } + return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(), + allowed_attrs.c_str())); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 4e085bc89c..d38d8907a6 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -1138,7 +1138,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { )"; ExpectHasSubstr(Parse(original).status().error_message(), - "unexpected attribute calls"); + "unexpected attribute \"calls\""); } TEST_F(HloParserTest, MissingAttribute) { -- GitLab From 008a3b69a601dc68fd940eb8a03b0c445714a339 Mon Sep 17 00:00:00 2001 From: Karmel Allison Date: Fri, 4 May 2018 16:01:02 -0700 Subject: [PATCH 035/738] Add the ability to export separate SavedModels for train and eval mode to Estimator with two new methods, available in tf.contrib: export_all_saved_models and export_saved_model_for_mode. PiperOrigin-RevId: 195485922 --- tensorflow/contrib/estimator/BUILD | 38 ++ tensorflow/contrib/estimator/__init__.py | 3 + .../estimator/python/estimator/export.py | 216 ++++++++++ .../estimator/python/estimator/export_test.py | 391 ++++++++++++++++++ tensorflow/python/estimator/BUILD | 1 + tensorflow/python/estimator/estimator.py | 346 +++++++++++++--- tensorflow/python/estimator/estimator_test.py | 336 ++++++++++++++- tensorflow/python/estimator/export/export.py | 325 +++++++++++---- .../python/estimator/export/export_output.py | 223 +++++++++- .../estimator/export/export_output_test.py | 110 +++++ .../python/estimator/export/export_test.py | 253 +++++++++++- tensorflow/python/estimator/model_fn.py | 8 + tensorflow/python/saved_model/builder_impl.py | 54 ++- tensorflow/python/saved_model/constants.py | 6 + .../python/saved_model/saved_model_test.py | 90 ++++ .../python/saved_model/signature_constants.py | 6 + .../python/saved_model/signature_def_utils.py | 2 + .../saved_model/signature_def_utils_impl.py | 56 +++ .../saved_model/signature_def_utils_test.py | 95 +++++ .../python/saved_model/tag_constants.py | 5 + 20 files changed, 2373 insertions(+), 191 deletions(-) create mode 100644 tensorflow/contrib/estimator/python/estimator/export.py create mode 100644 tensorflow/contrib/estimator/python/estimator/export_test.py diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 571e2e3a5d..e9a68801ef 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -17,6 +17,7 @@ py_library( ":boosted_trees", ":dnn", ":dnn_linear_combined", + ":export", ":extenders", ":head", ":linear", @@ -180,6 +181,43 @@ py_test( ], ) +py_library( + name = "export", + srcs = [ + "python/estimator/export.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( + name = "export_test", + size = "medium", + srcs = ["python/estimator/export_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], # b/62863147 + deps = [ + ":export", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:metrics", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:session", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:tag_constants", + ], +) + py_library( name = "head", srcs = [ diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index d43b3ea6bf..ec502f86dd 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * +from tensorflow.contrib.estimator.python.estimator.export import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * from tensorflow.contrib.estimator.python.estimator.linear import * @@ -56,6 +57,8 @@ _allowed_symbols = [ 'TowerOptimizer', 'RNNClassifier', 'RNNEstimator', + 'export_saved_model_for_mode', + 'export_all_saved_models', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/export.py b/tensorflow/contrib/estimator/python/estimator/export.py new file mode 100644 index 0000000000..e7e366a3f2 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export.py @@ -0,0 +1,216 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrapper for methods to export train/eval graphs from Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import model_fn as model_fn_lib + + +def export_saved_model_for_mode( + estimator, export_dir_base, input_receiver_fn, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Exports a single train/eval/predict graph as a SavedModel. + + For a detailed guide, see + @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn, steps=1000) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + export_dir = tf.contrib.estimator.export_saved_model_for_mode( + classifier, + export_dir_base='my_model/', + input_receiver_fn=train_rcvr_fn, + mode=model_fn_lib.ModeKeys.TRAIN) + + # export_dir is a timestamped directory with the SavedModel, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + ... + ``` + + This method takes an input_receiver_fn and mode. For the mode passed in, + this method builds a new graph by calling the input_receiver_fn to obtain + feature and label `Tensor`s. Next, this method calls the `Estimator`'s + model_fn in the passed mode to generate the model graph based on + those features and labels, and restores the given checkpoint + (or, lacking that, the most recent checkpoint) into the graph. + Finally, it creates a timestamped export directory below the + export_dir_base, and writes a `SavedModel` into it containing + the `MetaGraphDef` for the given mode and its associated signatures. + + For prediction, the exported `MetaGraphDef` will provide one `SignatureDef` + for each element of the export_outputs dict returned from the model_fn, + named using the same keys. One of these keys is always + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which + signature will be served when a serving request does not specify one. + For each signature, the outputs are provided by the corresponding + `ExportOutput`s, and the inputs are always the input receivers provided by + the serving_input_receiver_fn. + + For training and evaluation, the train_op is stored in an extra collection, + and loss, metrics, and predictions are included in a SignatureDef for the + mode in question. + + Extra assets may be written into the SavedModel via the assets_extra + argument. This should be a dict, where each key gives a destination path + (including the filename) relative to the assets.extra directory. The + corresponding value gives the full path of the source file to be copied. + For example, the simple case of copying a single file without renaming it + is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating with mode will be exported. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_saved_model_for_mode( + export_dir_base, input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=mode) + # pylint: enable=protected-access + + +def export_all_saved_models( + estimator, export_dir_base, input_receiver_fn_map, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False): + # pylint: disable=line-too-long + """Exports requested train/eval/predict graphs as separate SavedModels. + + This is a wrapper around export_saved_model_for_mode that accepts + multiple modes simultaneously and creates directories for each under + export_dir_base. See `Estimator.export_saved_model_for_mode` for + further details as to how the export works for each mode. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + serve_rcvr_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( + feature_spec) + + rcvr_fn_map = { + model_fn_lib.ModeKeys.TRAIN: train_rcvr_fn, + model_fn_lib.ModeKeys.PREDICT: serve_rcvr_fn, + } + + export_dirs = tf.contrib.estimator.export_all_saved_models( + classifier, + export_dir_base='my_model/', + input_receiver_fn_map=rcvr_fn_map) + + # export_dirs is a dict of directories with SavedModels, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], + export_dirs[tf.estimator.ModeKeys.TRAIN]) + ... + ``` + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn + mappings, where the input_receiver_fn is a function that takes no + argument and returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + + Returns: + A dict of tf.estimator.ModeKeys value to string path for each exported + directory. + + Raises: + ValueError: if any input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_all_saved_models( + export_dir_base, input_receiver_fn_map, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs) + # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/export_test.py b/tensorflow/contrib/estimator/python/estimator/export_test.py new file mode 100644 index 0000000000..89d02582e1 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export_test.py @@ -0,0 +1,391 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib wrapping of export_saved_model_for_mode functionality. + +These are direct copies of the tests included in core, with import locations +changed. These should be removed when the functionality in core is part of the +public API. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from tensorflow.contrib.estimator.python.estimator import export as contrib_export +from tensorflow.python.client import session +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.training import training +from tensorflow.python.util import compat + + +def _model_fn_for_export_tests(features, labels, mode): + _, _ = features, labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + update_global_step = state_ops.assign_add(training.get_global_step(), 1) + with ops.control_dependencies([update_global_step]): + train_op = constant_op.constant(2.) + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + loss=constant_op.constant(1.), + train_op=train_op, + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + + +def _x_y_input_fn(): + return ({'x': constant_op.constant([[1], [1]]), + 'y': constant_op.constant([[2], [2]])}, + constant_op.constant([[1], [1]])) + + +def _model_fn_with_x_y(features, labels, mode): + _ = labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + if mode == model_fn_lib.ModeKeys.PREDICT: + variables.Variable(36., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + else: + prefix = 'eval_' if mode == model_fn_lib.ModeKeys.EVAL else '' + + multiplied = math_ops.multiply( + features['x'], features['y'], name='{}multiplied'.format(prefix)) + metrics = {'mean': metrics_lib.mean(features['x'] - features['y'], + name='{}mean'.format(prefix))} + variables.Variable(1., name='later_var') + variables.Variable(3., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=multiplied, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + eval_metric_ops=metrics) + + +def _get_serving_input_receiver_fn(): + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + return export.build_parsing_serving_input_receiver_fn(feature_spec) + + +def _get_supervised_input_receiver_fn(): + feature_spec = { + 'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'), + 'y': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_y') + } + label_spec = array_ops.placeholder( + dtype=dtypes.float32, shape=[1], name='truth') + + return export.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + +class EstimatorExportTest(test.TestCase): + + def test_export_saved_model_train(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.TRAIN) + + def test_export_saved_model_eval(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL) + + def test_export_saved_model_predict(self): + self._test_export_saved_model_for_mode( + _get_serving_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT) + + def _test_export_saved_model_for_mode(self, input_receiver_fn, mode): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = contrib_export.export_saved_model_for_mode( + est, export_dir_base, input_receiver_fn, mode=mode) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + tag_set = model_fn_lib.EXPORT_TAG_MAP[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('name_collision_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_receiver_map(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertFalse('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_train_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertTrue('mean/update_op' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_eval_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertTrue('eval_mean/value' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_no_serving(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 2) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + # TODO(karmel): is this the desired behavior when names are shared? + self.assertTrue('feature_x_1' in graph_ops) + self.assertTrue('feature_y_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_three_defs(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + # Restore, to validate that the export was well-formed. + for mode, tag_set in model_fn_lib.EXPORT_TAG_MAP.items(): + export_dir = export_dirs[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('global_step/Assign' in graph_ops) + self.assertTrue('global_step/Initializer/zeros' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_all_vars(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_name_collision(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertEqual(3, collection_vars[-1].eval()) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # This is a non-obvious detail: when we load the estimator spec + # for predict, name_collision gets set to 36. However, we then restore + # from checkpoint, which should overwrite that var and make it the 3 + # from training. In practice, this would not be a good way to write + # a model_fn, but leaving this check in for now to ensure consistency + # with what would happen given our current order of spec, then + # checkpoint. + self.assertEqual(3, collection_vars[-1].eval()) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def _test_export_all_saved_models(self, input_receiver_fn_map): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_with_x_y) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dirs = contrib_export.export_all_saved_models( + est, export_dir_base, input_receiver_fn_map) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + + for _, export_dir in export_dirs.items(): + self._validate_exported_files(export_dir) + + return export_dirs, tmpdir + + def _validate_exported_files(self, export_dir): + self.assertTrue(gfile.Exists(export_dir)) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.index')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.data-00000-of-00001')))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 56dec1eaa1..b25cc7aa26 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -91,6 +91,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/saved_model:signature_constants", + "//tensorflow/python/saved_model:tag_constants", "@six_archive//:six", ], ) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 530a4a24ef..9ae64d230e 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -37,9 +37,8 @@ from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config from tensorflow.python.estimator import util -from tensorflow.python.estimator.export.export import build_all_signature_defs -from tensorflow.python.estimator.export.export import get_temp_export_dir -from tensorflow.python.estimator.export.export import get_timestamped_export_dir +from tensorflow.python.estimator.export import export as export_helpers +from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops @@ -51,7 +50,6 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import constants -from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import device_setter @@ -609,73 +607,283 @@ class Estimator(object): are provided, or no checkpoint can be found. """ # pylint: enable=line-too-long + return self._export_saved_model_for_mode( + export_dir_base, + serving_input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=model_fn_lib.ModeKeys.PREDICT) + + def _export_all_saved_models( + self, export_dir_base, input_receiver_fn_map, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False): + # pylint: disable=line-too-long + """Exports requested train/eval/predict graphs as separate SavedModels. + + This is a wrapper around export_saved_model_for_mode that accepts + multiple modes simultaneously and creates directories for each under + export_dir_base. See `Estimator.export_saved_model_for_mode` for + further details as to how the export works for each mode. + + See tf.contrib.estimator.export_all_saved_models for the currently + exposed version of this function. + + Args: + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn + mappings, where the input_receiver_fn is a function that takes no + argument and returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + + Returns: + A dict of tf.estimator.ModeKeys value to string path for each exported + directory. + + Raises: + ValueError: if any input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + # TODO(b/65561022): Consider allowing multiple input_receiver_fns per mode. + exported = {} + for mode, input_receiver_fn in input_receiver_fn_map.items(): + export_mode_dir = os.path.join( + compat.as_bytes(export_dir_base), + compat.as_bytes(mode)) + gfile.MakeDirs(export_mode_dir) + + exported_path = self._export_saved_model_for_mode( + export_mode_dir, + input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=mode) + + exported[mode] = exported_path + + return exported + + def _export_saved_model_for_mode( + self, export_dir_base, input_receiver_fn, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Exports a single train/eval/predict graph as a SavedModel. + + For a detailed guide, see + @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + + See tf.contrib.estimator.export_saved_model_for_mode for the currently + exposed version of this function. + + This method takes an input_receiver_fn and mode. For the mode passed in, + this method builds a new graph by calling the input_receiver_fn to obtain + feature and label `Tensor`s. Next, this method calls the `Estimator`'s + model_fn in the passed mode to generate the model graph based on + those features and labels, and restores the given checkpoint + (or, lacking that, the most recent checkpoint) into the graph. + Finally, it creates a timestamped export directory below the + export_dir_base, and writes a `SavedModel` into it containing + the `MetaGraphDef` for the given mode and its associated signatures. + + For prediction, the exported `MetaGraphDef` will provide one `SignatureDef` + for each element of the export_outputs dict returned from the model_fn, + named using the same keys. One of these keys is always + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which + signature will be served when a serving request does not specify one. + For each signature, the outputs are provided by the corresponding + `ExportOutput`s, and the inputs are always the input receivers provided by + the serving_input_receiver_fn. + + For training and evaluation, the train_op is stored in an extra collection, + and loss, metrics, and predictions are included in a SignatureDef for the + mode in question. + + Extra assets may be written into the SavedModel via the assets_extra + argument. This should be a dict, where each key gives a destination path + (including the filename) relative to the assets.extra directory. The + corresponding value gives the full path of the source file to be copied. + For example, the simple case of copying a single file without renaming it + is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + + Args: + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating with mode will be exported. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long with context.graph_mode(): - if serving_input_receiver_fn is None: - raise ValueError('serving_input_receiver_fn must be defined.') + if not input_receiver_fn: + raise ValueError('An input_receiver_fn must be defined.') - with ops.Graph().as_default() as g: - self._create_and_assert_global_step(g) - random_seed.set_random_seed(self._config.tf_random_seed) - serving_input_receiver = serving_input_receiver_fn() + if not checkpoint_path: + # Locate the latest checkpoint + checkpoint_path = saver.latest_checkpoint(self._model_dir) + if not checkpoint_path: + raise ValueError("Couldn't find trained model at %s." % self._model_dir) - # Call the model_fn and collect the export_outputs. - estimator_spec = self._call_model_fn( - features=serving_input_receiver.features, - labels=None, - mode=model_fn_lib.ModeKeys.PREDICT, - config=self.config) - - # Build the SignatureDefs from receivers and all outputs - signature_def_map = build_all_signature_defs( - serving_input_receiver.receiver_tensors, - estimator_spec.export_outputs, - serving_input_receiver.receiver_tensors_alternatives) - - if not checkpoint_path: - # Locate the latest checkpoint - checkpoint_path = saver.latest_checkpoint(self._model_dir) - if not checkpoint_path: - raise ValueError( - "Couldn't find trained model at %s." % self._model_dir) - - export_dir = get_timestamped_export_dir(export_dir_base) - temp_export_dir = get_temp_export_dir(export_dir) - - # TODO(soergel): Consider whether MonitoredSession makes sense here - with tf_session.Session(config=self._session_config) as session: - - saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( - sharded=True) - saver_for_restore.restore(session, checkpoint_path) - - local_init_op = ( - estimator_spec.scaffold.local_init_op or - monitored_session.Scaffold.default_local_init_op()) - - # Perform the export - builder = saved_model_builder.SavedModelBuilder(temp_export_dir) - builder.add_meta_graph_and_variables( - session, [tag_constants.SERVING], - signature_def_map=signature_def_map, - assets_collection=ops.get_collection( - ops.GraphKeys.ASSET_FILEPATHS), - legacy_init_op=local_init_op, - strip_default_attrs=strip_default_attrs) - builder.save(as_text) - - # Add the extra assets - if assets_extra: - assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir), - compat.as_bytes('assets.extra')) - for dest_relative, source in assets_extra.items(): - dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), - compat.as_bytes(dest_relative)) - dest_path = os.path.dirname(dest_absolute) - gfile.MakeDirs(dest_path) - gfile.Copy(source, dest_absolute) - - gfile.Rename(temp_export_dir, export_dir) - return export_dir + export_dir = export_helpers.get_timestamped_export_dir(export_dir_base) + temp_export_dir = export_helpers.get_temp_export_dir(export_dir) + + builder = saved_model_builder.SavedModelBuilder(temp_export_dir) + + self._add_meta_graph_and_variables_for_mode( + builder, input_receiver_fn, checkpoint_path, + strip_default_attrs, mode) + + builder.save(as_text) + + # Add the extra assets + if assets_extra: + assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir), + compat.as_bytes('assets.extra')) + for dest_relative, source in assets_extra.items(): + dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), + compat.as_bytes(dest_relative)) + dest_path = os.path.dirname(dest_absolute) + gfile.MakeDirs(dest_path) + gfile.Copy(source, dest_absolute) + + gfile.Rename(temp_export_dir, export_dir) + return export_dir + + def _add_meta_graph_and_variables_for_mode( + self, builder, input_receiver_fn, checkpoint_path, strip_default_attrs, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Loads variables and adds them along with a MetaGraphDef for saving. + + Args: + builder: instance of SavedModelBuilder that will be used for saving. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating which mode will be exported. + """ + # pylint: enable=line-too-long + with ops.Graph().as_default() as g: + self._create_and_assert_global_step(g) + random_seed.set_random_seed(self._config.tf_random_seed) + + input_receiver = input_receiver_fn() + + # Call the model_fn and collect the export_outputs. + estimator_spec = self._call_model_fn( + features=input_receiver.features, + labels=getattr(input_receiver, 'labels', None), + mode=mode, + config=self.config) + + export_outputs = self._get_export_outputs_for_spec(estimator_spec) + + # Build the SignatureDefs from receivers and all outputs + signature_def_map = export_helpers.build_all_signature_defs( + input_receiver.receiver_tensors, + export_outputs, + getattr(input_receiver, 'receiver_tensors_alternatives', None), + serving_only=(mode == model_fn_lib.ModeKeys.PREDICT)) + + with tf_session.Session(config=self._session_config) as session: + + export_tags = model_fn_lib.EXPORT_TAG_MAP[mode] + + local_init_op = ( + estimator_spec.scaffold.local_init_op or + monitored_session.Scaffold.default_local_init_op()) + + saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( + sharded=True) + saver_for_restore.restore(session, checkpoint_path) + + # We add the train op explicitly for now, so that we don't have to + # change the Builder public interface. Note that this is a no-op + # for prediction, where train_op is None. + builder._add_train_op(estimator_spec.train_op) # pylint: disable=protected-access + + builder.add_meta_graph_and_variables( + session, + tags=export_tags, + signature_def_map=signature_def_map, + assets_collection=ops.get_collection( + ops.GraphKeys.ASSET_FILEPATHS), + strip_default_attrs=strip_default_attrs, + legacy_init_op=local_init_op) + + def _get_export_outputs_for_spec(self, estimator_spec): + """Given an EstimatorSpec, determine what our export outputs should be. + + EstimatorSpecs contain export_outputs that are used for serving, but for + training and eval graphs, we must wrap the tensors of interest in + appropriate ExportOutput objects. + + Args: + estimator_spec: EstimatorSpec object that will be exported. + + Returns: + a dict mapping export_output_name to ExportOutput object. + + Raises: + ValueError: if an appropriate ExportOutput cannot be found for the + passed EstimatorSpec.mode + """ + mode = estimator_spec.mode + if mode == model_fn_lib.ModeKeys.PREDICT: + outputs = estimator_spec.export_outputs + else: + if mode == model_fn_lib.ModeKeys.TRAIN: + output_class = export_output.TrainOutput + elif mode == model_fn_lib.ModeKeys.EVAL: + output_class = export_output.EvalOutput + else: + raise ValueError( + 'Export output type not found for mode: {}'.format(mode)) + + export_out = output_class( + loss=estimator_spec.loss, + predictions=estimator_spec.predictions, + metrics=estimator_spec.eval_metric_ops) + outputs = {mode: export_out} + + return outputs def _get_features_from_input_fn(self, input_fn, mode): """Extracts the `features` from return values of `input_fn`.""" @@ -1544,3 +1752,5 @@ def _get_default_warm_start_settings(warm_start_from): else: raise ValueError('warm_start_from must be a string or a WarmStartSettings, ' 'instead got {}'.format(type(warm_start_from))) + + diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 76b45b7f57..02088e5134 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -1865,6 +1865,41 @@ def _model_fn_for_export_tests(features, labels, mode): 'test': export_output.ClassificationOutput(scores, classes)}) +def _x_y_input_fn(): + return ({'x': constant_op.constant([[1], [1]]), + 'y': constant_op.constant([[2], [2]])}, + constant_op.constant([[1], [1]])) + + +def _model_fn_with_x_y(features, labels, mode): + _ = labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + if mode == model_fn_lib.ModeKeys.PREDICT: + variables.Variable(36., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + else: + prefix = 'eval_' if mode == model_fn_lib.ModeKeys.EVAL else '' + + multiplied = math_ops.multiply( + features['x'], features['y'], name='{}multiplied'.format(prefix)) + metrics = {'mean': metrics_lib.mean(features['x'] - features['y'], + name='{}mean'.format(prefix))} + variables.Variable(1., name='later_var') + variables.Variable(3., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=multiplied, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + eval_metric_ops=metrics) + + def _model_fn_with_saveables_for_export_tests(features, labels, mode): _, _ = features, labels table = saver_test_utils.CheckpointedOp(name='v2') @@ -1881,21 +1916,41 @@ def _model_fn_with_saveables_for_export_tests(features, labels, mode): 'test': export_output.PredictOutput({'prediction': prediction})}) +def _get_serving_input_receiver_fn(): + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + return export.build_parsing_serving_input_receiver_fn(feature_spec) + + +def _get_supervised_input_receiver_fn(): + feature_spec = { + 'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'), + 'y': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_y') + } + label_spec = array_ops.placeholder( + dtype=dtypes.float32, shape=[1], name='truth') + + return export.build_raw_supervised_input_receiver_fn(feature_spec, label_spec) + + _VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n' _EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n' class EstimatorExportTest(test.TestCase): - def test_export_savedmodel_proto_roundtrip(self): - tmpdir = tempfile.mkdtemp() - est = estimator.Estimator(model_fn=_model_fn_for_export_tests) - est.train(input_fn=dummy_input_fn, steps=1) + def test_export_savedmodel_proto_roundtrip_raw_receiver(self): feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( feature_spec) + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=dummy_input_fn, steps=1) + # Perform the export. export_dir_base = os.path.join( compat.as_bytes(tmpdir), compat.as_bytes('export')) @@ -1904,6 +1959,266 @@ class EstimatorExportTest(test.TestCase): # Check that all the files are in the right places. self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertTrue('weight' in graph_ops) + + def test_export_saved_model_train(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.TRAIN) + + def test_export_saved_model_eval(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL) + + def test_export_saved_model_predict(self): + self._test_export_saved_model_for_mode( + _get_serving_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT) + + def _test_export_saved_model_for_mode(self, input_receiver_fn, mode): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = est._export_saved_model_for_mode( + export_dir_base, input_receiver_fn, mode=mode) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + tag_set = model_fn_lib.EXPORT_TAG_MAP[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('name_collision_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_receiver_map(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertFalse('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_train_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertTrue('mean/update_op' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_eval_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertTrue('eval_mean/value' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_no_serving(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 2) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + # TODO(karmel): is this the desired behavior when names are shared? + self.assertTrue('feature_x_1' in graph_ops) + self.assertTrue('feature_y_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_three_defs(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + # Restore, to validate that the export was well-formed. + for mode, tag_set in model_fn_lib.EXPORT_TAG_MAP.items(): + export_dir = export_dirs[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('global_step/Assign' in graph_ops) + self.assertTrue('global_step/Initializer/zeros' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_all_vars(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_name_collision(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertEqual(3, collection_vars[-1].eval()) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # This is a non-obvious detail: when we load the estimator spec + # for predict, name_collision gets set to 36. However, we then restore + # from checkpoint, which should overwrite that var and make it the 3 + # from training. In practice, this would not be a good way to write + # a model_fn, but leaving this check in for now to ensure consistency + # with what would happen given our current order of spec, then + # checkpoint. + self.assertEqual(3, collection_vars[-1].eval()) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def _test_export_all_saved_models(self, input_receiver_fn_map): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_with_x_y) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dirs = est._export_all_saved_models( + export_dir_base, input_receiver_fn_map) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + + for _, export_dir in export_dirs.items(): + self._validate_exported_files(export_dir) + + return export_dirs, tmpdir + + def _validate_exported_files(self, export_dir): self.assertTrue(gfile.Exists(export_dir)) self.assertTrue(gfile.Exists(os.path.join( compat.as_bytes(export_dir), @@ -1918,18 +2233,6 @@ class EstimatorExportTest(test.TestCase): compat.as_bytes(export_dir), compat.as_bytes('variables/variables.data-00000-of-00001')))) - # Restore, to validate that the export was well-formed. - with ops.Graph().as_default() as graph: - with session.Session(graph=graph) as sess: - loader.load(sess, [tag_constants.SERVING], export_dir) - graph_ops = [x.name for x in graph.get_operations()] - self.assertTrue('input_example_tensor' in graph_ops) - self.assertTrue('ParseExample/ParseExample' in graph_ops) - self.assertTrue('weight' in graph_ops) - - # Clean up. - gfile.DeleteRecursively(tmpdir) - def test_export_savedmodel_with_saveables_proto_roundtrip(self): tmpdir = tempfile.mkdtemp() est = estimator.Estimator( @@ -2485,5 +2788,6 @@ class EstimatorIntegrationTest(test.TestCase): serving_input_receiver_fn) self.assertTrue(gfile.Exists(export_dir)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index 41c1f5a2e2..9aafb56679 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -40,6 +40,60 @@ from tensorflow.python.util.tf_export import tf_export _SINGLE_FEATURE_DEFAULT_NAME = 'feature' _SINGLE_RECEIVER_DEFAULT_NAME = 'input' +_SINGLE_LABEL_DEFAULT_NAME = 'label' + + +def _wrap_and_check_receiver_tensors(receiver_tensors): + """Ensure that receiver_tensors is a dict of str to Tensor mappings. + + Args: + receiver_tensors: dict of str to Tensors, or a single Tensor. + + Returns: + dict of str to Tensors; this is the original dict if one was passed, or + the original tensor wrapped in a dictionary. + + Raises: + ValueError: if receiver_tensors is None, or has non-string keys, + or non-Tensor values + """ + if receiver_tensors is None: + raise ValueError('receiver_tensors must be defined.') + if not isinstance(receiver_tensors, dict): + receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} + for name, tensor in receiver_tensors.items(): + _check_tensor_key(name, error_label='receiver_tensors') + _check_tensor(tensor, name, error_label='receiver_tensor') + return receiver_tensors + + +def _check_tensor(tensor, name, error_label='feature'): + """Check that passed `tensor` is a Tensor or SparseTensor.""" + if not (isinstance(tensor, ops.Tensor) + or isinstance(tensor, sparse_tensor.SparseTensor)): + fmt_name = ' {}'.format(name) if name else '' + value_error = ValueError( + '{}{} must be a Tensor or SparseTensor.'.format(error_label, fmt_name)) + # NOTE(ericmc): This if-else block is a specific carve-out for + # LabeledTensor, which has a `.tensor` attribute and which is + # convertible to tf.Tensor via ops.convert_to_tensor. + # Allowing all types convertible to tf.Tensor is considered by soergel@ + # to be too permissive. + # TODO(soergel): accept any type convertible to Tensor, + # as in cl/193238295 snapshot #6. + if hasattr(tensor, 'tensor'): + try: + ops.convert_to_tensor(tensor) + except TypeError: + raise value_error + else: + raise value_error + + +def _check_tensor_key(name, error_label='feature'): + if not isinstance(name, six.string_types): + raise ValueError( + '{} keys must be strings: {}.'.format(error_label, name)) @tf_export('estimator.export.ServingInputReceiver') @@ -51,16 +105,18 @@ class ServingInputReceiver(collections.namedtuple( The expected return values are: features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or `SparseTensor`, specifying the features to be passed to the model. - receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying - input nodes where this receiver expects to be fed by default. Typically, - this is a single placeholder expecting serialized `tf.Example` protos. + receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` + or `SparseTensor`, specifying input nodes where this receiver expects to + be fed by default. Typically, this is a single placeholder expecting + serialized `tf.Example` protos. receiver_tensors_alternatives: a dict of string to additional - groups of receiver tensors, each of which may be a `Tensor` or a dict of - string to `Tensor`. These named receiver tensor alternatives generate - additional serving signatures, which may be used to feed inputs at - different points within the input receiver subgraph. A typical usage is - to allow feeding raw feature `Tensor`s *downstream* of the - tf.parse_example() op. Defaults to None. + groups of receiver tensors, each of which may be a `Tensor`, + `SparseTensor`, or dict of string to `Tensor` or`SparseTensor`. + These named receiver tensor alternatives generate additional serving + signatures, which may be used to feed inputs at different points within + the input receiver subgraph. A typical usage is to allow feeding raw + feature `Tensor`s *downstream* of the tf.parse_example() op. + Defaults to None. """ def __new__(cls, features, receiver_tensors, @@ -70,36 +126,10 @@ class ServingInputReceiver(collections.namedtuple( if not isinstance(features, dict): features = {_SINGLE_FEATURE_DEFAULT_NAME: features} for name, tensor in features.items(): - if not isinstance(name, six.string_types): - raise ValueError('feature keys must be strings: {}.'.format(name)) - if not (isinstance(tensor, ops.Tensor) - or isinstance(tensor, sparse_tensor.SparseTensor)): - value_error = ValueError( - 'feature {} must be a Tensor or SparseTensor.'.format(name)) - # NOTE(ericmc): This if-else block is a specific carve-out for - # LabeledTensor, which has a `.tensor` attribute and which is - # convertible to tf.Tensor via ops.convert_to_tensor. - # Allowing all types convertible to tf.Tensor is considered by soergel@ - # to be too permissive. - if hasattr(tensor, 'tensor'): - try: - ops.convert_to_tensor(tensor) - except TypeError: - raise value_error - else: - raise value_error - - if receiver_tensors is None: - raise ValueError('receiver_tensors must be defined.') - if not isinstance(receiver_tensors, dict): - receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} - for name, tensor in receiver_tensors.items(): - if not isinstance(name, six.string_types): - raise ValueError( - 'receiver_tensors keys must be strings: {}.'.format(name)) - if not isinstance(tensor, ops.Tensor): - raise ValueError( - 'receiver_tensor {} must be a Tensor.'.format(name)) + _check_tensor_key(name) + _check_tensor(tensor, name) + + receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors) if receiver_tensors_alternatives is not None: if not isinstance(receiver_tensors_alternatives, dict): @@ -115,14 +145,9 @@ class ServingInputReceiver(collections.namedtuple( receiver_tensors_alternatives[alternative_name] = ( receiver_tensors_alt) for name, tensor in receiver_tensors_alt.items(): - if not isinstance(name, six.string_types): - raise ValueError( - 'receiver_tensors keys must be strings: {}.'.format(name)) - if not (isinstance(tensor, ops.Tensor) - or isinstance(tensor, sparse_tensor.SparseTensor)): - raise ValueError( - 'receiver_tensor {} must be a Tensor or SparseTensor.'.format( - name)) + _check_tensor_key(name, error_label='receiver_tensors_alternative') + _check_tensor( + tensor, name, error_label='receiver_tensors_alternative') return super(ServingInputReceiver, cls).__new__( cls, @@ -155,25 +180,25 @@ class TensorServingInputReceiver(collections.namedtuple( The expected return values are: features: A single `Tensor` or `SparseTensor`, representing the feature to be passed to the model. - receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying - input nodes where this receiver expects to be fed by default. Typically, - this is a single placeholder expecting serialized `tf.Example` protos. + receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` + or `SparseTensor`, specifying input nodes where this receiver expects to + be fed by default. Typically, this is a single placeholder expecting + serialized `tf.Example` protos. receiver_tensors_alternatives: a dict of string to additional - groups of receiver tensors, each of which may be a `Tensor` or a dict of - string to `Tensor`. These named receiver tensor alternatives generate - additional serving signatures, which may be used to feed inputs at - different points within the input receiver subgraph. A typical usage is - to allow feeding raw feature `Tensor`s *downstream* of the - tf.parse_example() op. Defaults to None. + groups of receiver tensors, each of which may be a `Tensor`, + `SparseTensor`, or dict of string to `Tensor` or`SparseTensor`. + These named receiver tensor alternatives generate additional serving + signatures, which may be used to feed inputs at different points within + the input receiver subgraph. A typical usage is to allow feeding raw + feature `Tensor`s *downstream* of the tf.parse_example() op. + Defaults to None. """ def __new__(cls, features, receiver_tensors, receiver_tensors_alternatives=None): if features is None: raise ValueError('features must be defined.') - if not (isinstance(features, ops.Tensor) - or isinstance(features, sparse_tensor.SparseTensor)): - raise ValueError('feature must be a Tensor or SparseTensor.') + _check_tensor(features, None) receiver = ServingInputReceiver( features=features, @@ -187,6 +212,49 @@ class TensorServingInputReceiver(collections.namedtuple( receiver_tensors_alternatives=receiver.receiver_tensors_alternatives) +class SupervisedInputReceiver(collections.namedtuple( + 'SupervisedInputReceiver', + ['features', 'labels', 'receiver_tensors'])): + """A return type for a training_input_receiver_fn or eval_input_receiver_fn. + + This differs from a ServingInputReceiver in that (1) this receiver expects + a set of labels to be passed in with features, and (2) this receiver does + not support receiver_tensors_alternatives, which are primarily used for + serving. + + The expected return values are: + features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the features to be passed to the model. + labels: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the labels to be passed to the model. + receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` + or `SparseTensor`, specifying input nodes where this receiver expects to + be fed by default. Typically, this is a single placeholder expecting + serialized `tf.Example` protos. + + """ + + def __new__(cls, features, labels, receiver_tensors): + # Both features and labels can be dicts or raw tensors. + for input_vals, error_label in ((features, 'feature'), (labels, 'label')): + if input_vals is None: + raise ValueError('{}s must be defined.'.format(error_label)) + if isinstance(input_vals, dict): + for name, tensor in input_vals.items(): + _check_tensor_key(name, error_label=error_label) + _check_tensor(tensor, name, error_label=error_label) + else: + _check_tensor(input_vals, None, error_label=error_label) + + receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors) + + return super(SupervisedInputReceiver, cls).__new__( + cls, + features=features, + labels=labels, + receiver_tensors=receiver_tensors) + + @tf_export('estimator.export.build_parsing_serving_input_receiver_fn') def build_parsing_serving_input_receiver_fn(feature_spec, default_batch_size=None): @@ -216,6 +284,23 @@ def build_parsing_serving_input_receiver_fn(feature_spec, return serving_input_receiver_fn +def _placeholder_from_tensor(t, default_batch_size=None): + shape_list = t.get_shape().as_list() + shape_list[0] = default_batch_size + shape = tensor_shape.TensorShape(shape_list) + + # Reuse the feature tensor's op name (t.op.name) for the placeholder, + # excluding the index from the tensor's name (t.name): + # t.name = "%s:%d" % (t.op.name, t._value_index) + return array_ops.placeholder(dtype=t.dtype, shape=shape, name=t.op.name) + + +def _placeholders_from_receiver_tensors_dict( + input_vals, default_batch_size=None): + return {name: _placeholder_from_tensor(t, default_batch_size) + for name, t in input_vals.items()} + + @tf_export('estimator.export.build_raw_serving_input_receiver_fn') def build_raw_serving_input_receiver_fn(features, default_batch_size=None): """Build a serving_input_receiver_fn expecting feature Tensors. @@ -233,17 +318,9 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None): """ def serving_input_receiver_fn(): """A serving_input_receiver_fn that expects features to be fed directly.""" - receiver_tensors = {} - for name, t in features.items(): - shape_list = t.get_shape().as_list() - shape_list[0] = default_batch_size - shape = tensor_shape.TensorShape(shape_list) - - # Reuse the feature tensor's op name (t.op.name) for the placeholder, - # excluding the index from the tensor's name (t.name): - # t.name = "%s:%d" % (t.op.name, t._value_index) - receiver_tensors[name] = array_ops.placeholder( - dtype=t.dtype, shape=shape, name=t.op.name) + receiver_tensors = _placeholders_from_receiver_tensors_dict( + features, default_batch_size) + # TODO(b/34885899): remove the unnecessary copy # The features provided are simply the placeholders, but we defensively copy # the dict because it may be mutated. @@ -252,13 +329,100 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None): return serving_input_receiver_fn +def build_raw_supervised_input_receiver_fn( + features, labels, default_batch_size=None): + """Build a supervised_input_receiver_fn for raw features and labels. + + This function wraps tensor placeholders in a supervised_receiver_fn + with the expectation that the features and labels appear precisely as + the model_fn expects them. Features and labels can therefore be dicts of + tensors, or raw tensors. + + Args: + features: a dict of string to `Tensor` or `Tensor`. + labels: a dict of string to `Tensor` or `Tensor`. + default_batch_size: the number of query examples expected per batch. + Leave unset for variable batch size (recommended). + + Returns: + A supervised_input_receiver_fn. + + Raises: + ValueError: if features and labels have overlapping keys. + """ + # Check for overlapping keys before beginning. + try: + feat_keys = features.keys() + except AttributeError: + feat_keys = [_SINGLE_RECEIVER_DEFAULT_NAME] + try: + label_keys = labels.keys() + except AttributeError: + label_keys = [_SINGLE_LABEL_DEFAULT_NAME] + + overlap_keys = set(feat_keys) & set(label_keys) + if overlap_keys: + raise ValueError('Features and labels must have distinct keys. ' + 'Found overlapping keys: {}'.format(overlap_keys)) + + def supervised_input_receiver_fn(): + """A receiver_fn that expects pass-through features and labels.""" + if not isinstance(features, dict): + features_cp = _placeholder_from_tensor(features, default_batch_size) + receiver_features = {_SINGLE_RECEIVER_DEFAULT_NAME: features_cp} + else: + receiver_features = _placeholders_from_receiver_tensors_dict( + features, default_batch_size) + features_cp = receiver_features + + if not isinstance(labels, dict): + labels_cp = _placeholder_from_tensor(labels, default_batch_size) + receiver_labels = {_SINGLE_LABEL_DEFAULT_NAME: labels_cp} + else: + receiver_labels = _placeholders_from_receiver_tensors_dict( + labels, default_batch_size) + labels_cp = receiver_labels + + receiver_tensors = dict(receiver_features) + receiver_tensors.update(receiver_labels) + return SupervisedInputReceiver(features_cp, labels_cp, receiver_tensors) + + return supervised_input_receiver_fn + + ### Below utilities are specific to SavedModel exports. def build_all_signature_defs(receiver_tensors, export_outputs, - receiver_tensors_alternatives=None): - """Build `SignatureDef`s for all export outputs.""" + receiver_tensors_alternatives=None, + serving_only=True): + """Build `SignatureDef`s for all export outputs. + + Args: + receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying + input nodes where this receiver expects to be fed by default. Typically, + this is a single placeholder expecting serialized `tf.Example` protos. + export_outputs: a dict of ExportOutput instances, each of which has + an as_signature_def instance method that will be called to retrieve + the signature_def for all export output tensors. + receiver_tensors_alternatives: a dict of string to additional + groups of receiver tensors, each of which may be a `Tensor` or a dict of + string to `Tensor`. These named receiver tensor alternatives generate + additional serving signatures, which may be used to feed inputs at + different points within the input receiver subgraph. A typical usage is + to allow feeding raw feature `Tensor`s *downstream* of the + tf.parse_example() op. Defaults to None. + serving_only: boolean; if true, resulting signature defs will only include + valid serving signatures. If false, all requested signatures will be + returned. + + Returns: + signature_def representing all passed args. + + Raises: + ValueError: if export_outputs is not a dict + """ if not isinstance(receiver_tensors, dict): receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} if export_outputs is None or not isinstance(export_outputs, dict): @@ -293,17 +457,24 @@ def build_all_signature_defs(receiver_tensors, _log_signature_report(signature_def_map, excluded_signatures) # The above calls to export_output.as_signature_def should return only - # valid signatures; if there is a validity problem, they raise ValueError, - # which we ignore above. Consequently the call to is_valid_signature here - # should not remove anything else; it's just an extra sanity check. - return {k: v for k, v in signature_def_map.items() - if signature_def_utils.is_valid_signature(v)} + # valid signatures; if there is a validity problem, they raise a ValueError, + # in which case we exclude that signature from signature_def_map above. + # The is_valid_signature check ensures that the signatures produced are + # valid for serving, and acts as an additional sanity check for export + # signatures produced for serving. We skip this check for training and eval + # signatures, which are not intended for serving. + if serving_only: + signature_def_map = {k: v for k, v in signature_def_map.items() + if signature_def_utils.is_valid_signature(v)} + return signature_def_map _FRIENDLY_METHOD_NAMES = { signature_constants.CLASSIFY_METHOD_NAME: 'Classify', signature_constants.REGRESS_METHOD_NAME: 'Regress', signature_constants.PREDICT_METHOD_NAME: 'Predict', + signature_constants.SUPERVISED_TRAIN_METHOD_NAME: 'Train', + signature_constants.SUPERVISED_EVAL_METHOD_NAME: 'Eval', } diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index 87b964be37..d387ea2940 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -38,6 +38,8 @@ class ExportOutput(object): __metaclass__ = abc.ABCMeta + _SEPARATOR_CHAR = '/' + @abc.abstractmethod def as_signature_def(self, receiver_tensors): """Generate a SignatureDef proto for inclusion in a MetaGraphDef. @@ -51,6 +53,52 @@ class ExportOutput(object): """ pass + def _check_output_key(self, key, error_label): + # For multi-head models, the key can be a tuple. + if isinstance(key, tuple): + key = self._SEPARATOR_CHAR.join(key) + + if not isinstance(key, six.string_types): + raise ValueError( + '{} output key must be a string; got {}.'.format(error_label, key)) + return key + + def _wrap_and_check_outputs( + self, outputs, single_output_default_name, error_label=None): + """Wraps raw tensors as dicts and checks type. + + Note that we create a new dict here so that we can overwrite the keys + if necessary. + + Args: + outputs: A `Tensor` or a dict of string to `Tensor`. + single_output_default_name: A string key for use in the output dict + if the provided `outputs` is a raw tensor. + error_label: descriptive string for use in error messages. If none, + single_output_default_name will be used. + + Returns: + A dict of tensors + + Raises: + ValueError: if the outputs dict keys are not strings or tuples of strings + or the values are not Tensors. + """ + if not isinstance(outputs, dict): + outputs = {single_output_default_name: outputs} + + output_dict = {} + for key, value in outputs.items(): + error_name = error_label or single_output_default_name + key = self._check_output_key(key, error_name) + if not isinstance(value, ops.Tensor): + raise ValueError( + '{} output value must be a Tensor; got {}.'.format( + error_name, value)) + + output_dict[key] = value + return output_dict + @tf_export('estimator.export.ClassificationOutput') class ClassificationOutput(ExportOutput): @@ -154,9 +202,6 @@ class RegressionOutput(ExportOutput): return signature_def_utils.regression_signature_def(examples, self.value) -_SINGLE_OUTPUT_DEFAULT_NAME = 'output' - - @tf_export('estimator.export.PredictOutput') class PredictOutput(ExportOutput): """Represents the output of a generic prediction head. @@ -165,6 +210,7 @@ class PredictOutput(ExportOutput): Named outputs must be provided as a dict from string to `Tensor`, """ + _SINGLE_OUTPUT_DEFAULT_NAME = 'output' def __init__(self, outputs): """Constructor for PredictOutput. @@ -177,16 +223,9 @@ class PredictOutput(ExportOutput): ValueError: if the outputs is not dict, or any of its keys are not strings, or any of its values are not `Tensor`s. """ - if not isinstance(outputs, dict): - outputs = {_SINGLE_OUTPUT_DEFAULT_NAME: outputs} - for key, value in outputs.items(): - if not isinstance(key, six.string_types): - raise ValueError( - 'Prediction output key must be a string; got {}.'.format(key)) - if not isinstance(value, ops.Tensor): - raise ValueError( - 'Prediction output value must be a Tensor; got {}.'.format(value)) - self._outputs = outputs + + self._outputs = self._wrap_and_check_outputs( + outputs, self._SINGLE_OUTPUT_DEFAULT_NAME, error_label='Prediction') @property def outputs(self): @@ -195,3 +234,161 @@ class PredictOutput(ExportOutput): def as_signature_def(self, receiver_tensors): return signature_def_utils.predict_signature_def(receiver_tensors, self.outputs) + + +class _SupervisedOutput(ExportOutput): + """Represents the output of a supervised training or eval process.""" + __metaclass__ = abc.ABCMeta + + LOSS_NAME = 'loss' + PREDICTIONS_NAME = 'predictions' + METRICS_NAME = 'metrics' + + METRIC_VALUE_SUFFIX = 'value' + METRIC_UPDATE_SUFFIX = 'update_op' + + _loss = None + _predictions = None + _metrics = None + + def __init__(self, loss=None, predictions=None, metrics=None): + """Constructor for SupervisedOutput (ie, Train or Eval output). + + Args: + loss: dict of Tensors or single Tensor representing calculated loss. + predictions: dict of Tensors or single Tensor representing model + predictions. + metrics: dict of (metric_value, update_op) tuples, or a single tuple. + metric_value must be a Tensor, and update_op must be a Tensor or Op. + + Raises: + ValueError: if any of the outputs' dict keys are not strings or tuples of + strings or the values are not Tensors (or Operations in the case of + update_op). + """ + + if loss is not None: + loss_dict = self._wrap_and_check_outputs(loss, self.LOSS_NAME) + self._loss = self._prefix_output_keys(loss_dict, self.LOSS_NAME) + if predictions is not None: + pred_dict = self._wrap_and_check_outputs( + predictions, self.PREDICTIONS_NAME) + self._predictions = self._prefix_output_keys( + pred_dict, self.PREDICTIONS_NAME) + if metrics is not None: + self._metrics = self._wrap_and_check_metrics(metrics) + + def _prefix_output_keys(self, output_dict, output_name): + """Prepend output_name to the output_dict keys if it doesn't exist. + + This produces predictable prefixes for the pre-determined outputs + of SupervisedOutput. + + Args: + output_dict: dict of string to Tensor, assumed valid. + output_name: prefix string to prepend to existing keys. + + Returns: + dict with updated keys and existing values. + """ + + new_outputs = {} + for key, val in output_dict.items(): + key = self._prefix_key(key, output_name) + new_outputs[key] = val + return new_outputs + + def _prefix_key(self, key, output_name): + if key.find(output_name) != 0: + key = output_name + self._SEPARATOR_CHAR + key + return key + + def _wrap_and_check_metrics(self, metrics): + """Handle the saving of metrics. + + Metrics is either a tuple of (value, update_op), or a dict of such tuples. + Here, we separate out the tuples and create a dict with names to tensors. + + Args: + metrics: dict of (metric_value, update_op) tuples, or a single tuple. + + Returns: + dict of output_names to tensors + + Raises: + ValueError: if the dict key is not a string, or the metric values or ops + are not tensors. + """ + if not isinstance(metrics, dict): + metrics = {self.METRICS_NAME: metrics} + + outputs = {} + for key, (metric_val, metric_op) in metrics.items(): + key = self._check_output_key(key, self.METRICS_NAME) + key = self._prefix_key(key, self.METRICS_NAME) + + val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX + op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX + if not isinstance(metric_val, ops.Tensor): + raise ValueError( + '{} output value must be a Tensor; got {}.'.format( + key, metric_val)) + if (not isinstance(metric_op, ops.Tensor) and + not isinstance(metric_op, ops.Operation)): + raise ValueError( + '{} update_op must be a Tensor or Operation; got {}.'.format( + key, metric_op)) + outputs[val_name] = metric_val + outputs[op_name] = metric_op + + return outputs + + @property + def loss(self): + return self._loss + + @property + def predictions(self): + return self._predictions + + @property + def metrics(self): + return self._metrics + + @abc.abstractmethod + def _get_signature_def_fn(self): + """Returns a function that produces a SignatureDef given desired outputs.""" + pass + + def as_signature_def(self, receiver_tensors): + signature_def_fn = self._get_signature_def_fn() + return signature_def_fn( + receiver_tensors, self.loss, self.predictions, self.metrics) + + +class TrainOutput(_SupervisedOutput): + """Represents the output of a supervised training process. + + This class generates the appropriate signature def for exporting + training output by type-checking and wrapping loss, predictions, and metrics + values. + """ + + def _get_signature_def_fn(self): + return signature_def_utils.supervised_train_signature_def + + +class EvalOutput(_SupervisedOutput): + """Represents the output of a supervised eval process. + + This class generates the appropriate signature def for exporting + eval output by type-checking and wrapping loss, predictions, and metrics + values. + """ + + def _get_signature_def_fn(self): + return signature_def_utils.supervised_eval_signature_def + + + + diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py index 7090e53d80..b21ba91b0f 100644 --- a/tensorflow/python/estimator/export/export_output_test.py +++ b/tensorflow/python/estimator/export/export_output_test.py @@ -225,5 +225,115 @@ class ExportOutputTest(test.TestCase): }) +class MockSupervisedOutput(export_output_lib._SupervisedOutput): + """So that we can test the abstract class methods directly.""" + + def _get_signature_def_fn(self): + pass + + +class SupervisedOutputTest(test.TestCase): + + def test_supervised_outputs_valid(self): + """Tests that no errors are raised when provided outputs are valid.""" + loss = {"my_loss": constant_op.constant([0])} + predictions = {u"output1": constant_op.constant(["foo"])} + metrics = {"metrics": (constant_op.constant([0]), + constant_op.constant([10])), + "metrics2": (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = MockSupervisedOutput(loss, predictions, metrics) + self.assertEqual(outputter.loss["loss/my_loss"], loss["my_loss"]) + self.assertEqual( + outputter.predictions["predictions/output1"], predictions["output1"]) + self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0]) + self.assertEqual( + outputter.metrics["metrics2/update_op"], metrics["metrics2"][1]) + + # Single Tensor is OK too + outputter = MockSupervisedOutput( + loss["my_loss"], predictions["output1"], metrics["metrics"]) + self.assertEqual(outputter.loss, {"loss": loss["my_loss"]}) + self.assertEqual( + outputter.predictions, {"predictions": predictions["output1"]}) + self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0]) + + def test_supervised_outputs_none(self): + outputter = MockSupervisedOutput( + constant_op.constant([0]), None, None) + self.assertEqual(len(outputter.loss), 1) + self.assertEqual(outputter.predictions, None) + self.assertEqual(outputter.metrics, None) + + def test_supervised_outputs_invalid(self): + with self.assertRaisesRegexp(ValueError, "predictions output value must"): + MockSupervisedOutput(constant_op.constant([0]), [3], None) + with self.assertRaisesRegexp(ValueError, "loss output value must"): + MockSupervisedOutput("str", None, None) + with self.assertRaisesRegexp(ValueError, "metrics output value must"): + MockSupervisedOutput(None, None, (15.3, 4)) + with self.assertRaisesRegexp(ValueError, "loss output key must"): + MockSupervisedOutput({25: "Tensor"}, None, None) + + def test_supervised_outputs_tuples(self): + """Tests that no errors are raised when provided outputs are valid.""" + loss = {("my", "loss"): constant_op.constant([0])} + predictions = {(u"output1", "2"): constant_op.constant(["foo"])} + metrics = {("metrics", "twice"): (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = MockSupervisedOutput(loss, predictions, metrics) + self.assertEqual(set(outputter.loss.keys()), set(["loss/my/loss"])) + self.assertEqual(set(outputter.predictions.keys()), + set(["predictions/output1/2"])) + self.assertEqual(set(outputter.metrics.keys()), + set(["metrics/twice/value", "metrics/twice/update_op"])) + + def test_supervised_outputs_no_prepend(self): + """Tests that no errors are raised when provided outputs are valid.""" + loss = {"loss": constant_op.constant([0])} + predictions = {u"predictions": constant_op.constant(["foo"])} + metrics = {u"metrics": (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = MockSupervisedOutput(loss, predictions, metrics) + self.assertEqual(set(outputter.loss.keys()), set(["loss"])) + self.assertEqual(set(outputter.predictions.keys()), set(["predictions"])) + self.assertEqual(set(outputter.metrics.keys()), + set(["metrics/value", "metrics/update_op"])) + + def test_train_signature_def(self): + loss = {"my_loss": constant_op.constant([0])} + predictions = {u"output1": constant_op.constant(["foo"])} + metrics = {"metrics": (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = export_output_lib.TrainOutput(loss, predictions, metrics) + + receiver = {u"features": constant_op.constant(100, shape=(100, 2)), + "labels": constant_op.constant(100, shape=(100, 1))} + sig_def = outputter.as_signature_def(receiver) + + self.assertTrue("loss/my_loss" in sig_def.outputs) + self.assertTrue("metrics/value" in sig_def.outputs) + self.assertTrue("predictions/output1" in sig_def.outputs) + self.assertTrue("features" in sig_def.inputs) + + def test_eval_signature_def(self): + loss = {"my_loss": constant_op.constant([0])} + predictions = {u"output1": constant_op.constant(["foo"])} + + outputter = export_output_lib.EvalOutput(loss, predictions, None) + + receiver = {u"features": constant_op.constant(100, shape=(100, 2)), + "labels": constant_op.constant(100, shape=(100, 1))} + sig_def = outputter.as_signature_def(receiver) + + self.assertTrue("loss/my_loss" in sig_def.outputs) + self.assertFalse("metrics/value" in sig_def.outputs) + self.assertTrue("predictions/output1" in sig_def.outputs) + self.assertTrue("features" in sig_def.inputs) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py index c203be7dac..0af587f2a8 100644 --- a/tensorflow/python/estimator/export/export_test.py +++ b/tensorflow/python/estimator/export/export_test.py @@ -54,7 +54,7 @@ ops.register_tensor_conversion_function(LabeledTensorMock, _convert_labeled_tensor_mock_to_tensor) -class ExportTest(test_util.TensorFlowTestCase): +class ServingInputReceiverTest(test_util.TensorFlowTestCase): def test_serving_input_receiver_constructor(self): """Tests that no errors are raised when input is expected.""" @@ -161,6 +161,165 @@ class ExportTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): _ = export.ServingInputReceiver(feature, receiver_tensor) + +class SupervisedInputReceiverTest(test_util.TensorFlowTestCase): + + def test_input_receiver_constructor(self): + """Tests that no errors are raised when input is expected.""" + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + labels = { + "classes": constant_op.constant([0] * 100), + } + + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + export.SupervisedInputReceiver(features, labels, receiver_tensors) + + def test_input_receiver_raw_values(self): + """Tests that no errors are raised when input is expected.""" + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + + labels = { + "classes": constant_op.constant([0] * 100), + } + + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + rec = export.SupervisedInputReceiver( + features["feature2"], labels, receiver_tensors) + self.assertIsInstance(rec.features, sparse_tensor.SparseTensor) + + rec = export.SupervisedInputReceiver( + features, labels["classes"], receiver_tensors) + self.assertIsInstance(rec.labels, ops.Tensor) + + def test_input_receiver_features_invalid(self): + features = constant_op.constant([0] * 100) + labels = constant_op.constant([0]) + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + + with self.assertRaisesRegexp(ValueError, "features must be defined"): + export.SupervisedInputReceiver( + features=None, + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp(ValueError, "feature keys must be strings"): + export.SupervisedInputReceiver( + features={1: constant_op.constant([1])}, + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp(ValueError, "label keys must be strings"): + export.SupervisedInputReceiver( + features=features, + labels={1: constant_op.constant([1])}, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp( + ValueError, "feature feature1 must be a Tensor or SparseTensor"): + export.SupervisedInputReceiver( + features={"feature1": [1]}, + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp( + ValueError, "feature must be a Tensor or SparseTensor"): + export.SupervisedInputReceiver( + features=[1], + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp( + ValueError, "label must be a Tensor or SparseTensor"): + export.SupervisedInputReceiver( + features=features, + labels=100, + receiver_tensors=receiver_tensors) + + def test_input_receiver_receiver_tensors_invalid(self): + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + labels = constant_op.constant([0]) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors must be defined"): + export.SupervisedInputReceiver( + features=features, + labels=labels, + receiver_tensors=None) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors keys must be strings"): + export.SupervisedInputReceiver( + features=features, + labels=labels, + receiver_tensors={ + 1: array_ops.placeholder(dtypes.string, name="example0")}) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensor example1 must be a Tensor"): + export.SupervisedInputReceiver( + features=features, + labels=labels, + receiver_tensors={"example1": [1]}) + + def test_single_feature_single_receiver(self): + feature = constant_op.constant(5) + label = constant_op.constant(5) + receiver_tensor = array_ops.placeholder(dtypes.string) + input_receiver = export.SupervisedInputReceiver( + feature, label, receiver_tensor) + + # single receiver is automatically named + receiver_key, = input_receiver.receiver_tensors.keys() + self.assertEqual("input", receiver_key) + + def test_multi_feature_single_receiver(self): + features = {"foo": constant_op.constant(5), + "bar": constant_op.constant(6)} + labels = {"value": constant_op.constant(5)} + receiver_tensor = array_ops.placeholder(dtypes.string) + _ = export.SupervisedInputReceiver(features, labels, receiver_tensor) + + def test_multi_feature_multi_receiver(self): + features = {"foo": constant_op.constant(5), + "bar": constant_op.constant(6)} + labels = {"value": constant_op.constant(5)} + receiver_tensors = {"baz": array_ops.placeholder(dtypes.int64), + "qux": array_ops.placeholder(dtypes.float32)} + _ = export.SupervisedInputReceiver(features, labels, receiver_tensors) + + def test_feature_labeled_tensor(self): + feature = LabeledTensorMock() + label = constant_op.constant(5) + receiver_tensor = array_ops.placeholder(dtypes.string) + _ = export.SupervisedInputReceiver(feature, label, receiver_tensor) + + +class ExportTest(test_util.TensorFlowTestCase): + def test_build_parsing_serving_input_receiver_fn(self): feature_spec = {"int_feature": parsing_ops.VarLenFeature(dtypes.int64), "float_feature": parsing_ops.VarLenFeature(dtypes.float32)} @@ -237,6 +396,69 @@ class ExportTest(test_util.TensorFlowTestCase): dtypes.int32, serving_input_receiver.receiver_tensors["feature_2"].dtype) + def test_build_raw_supervised_input_receiver_fn(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"foo": constant_op.constant([5]), + "bar": constant_op.constant([6])} + input_receiver_fn = export.build_raw_supervised_input_receiver_fn( + features, labels) + with ops.Graph().as_default(): + input_receiver = input_receiver_fn() + self.assertEqual(set(["feature_1", "feature_2"]), + set(input_receiver.features.keys())) + self.assertEqual(set(["foo", "bar"]), + set(input_receiver.labels.keys())) + self.assertEqual(set(["feature_1", "feature_2", "foo", "bar"]), + set(input_receiver.receiver_tensors.keys())) + self.assertEqual( + dtypes.string, input_receiver.receiver_tensors["feature_1"].dtype) + self.assertEqual( + dtypes.int32, input_receiver.receiver_tensors["feature_2"].dtype) + + def test_build_raw_supervised_input_receiver_fn_raw_tensors(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"foo": constant_op.constant([5]), + "bar": constant_op.constant([6])} + input_receiver_fn1 = export.build_raw_supervised_input_receiver_fn( + features["feature_1"], labels) + input_receiver_fn2 = export.build_raw_supervised_input_receiver_fn( + features["feature_1"], labels["foo"]) + with ops.Graph().as_default(): + input_receiver = input_receiver_fn1() + self.assertIsInstance(input_receiver.features, ops.Tensor) + self.assertEqual(set(["foo", "bar"]), + set(input_receiver.labels.keys())) + self.assertEqual(set(["input", "foo", "bar"]), + set(input_receiver.receiver_tensors.keys())) + + input_receiver = input_receiver_fn2() + self.assertIsInstance(input_receiver.features, ops.Tensor) + self.assertIsInstance(input_receiver.labels, ops.Tensor) + self.assertEqual(set(["input", "label"]), + set(input_receiver.receiver_tensors.keys())) + + def test_build_raw_supervised_input_receiver_fn_batch_size(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"foo": constant_op.constant([5]), + "bar": constant_op.constant([6])} + input_receiver_fn = export.build_raw_supervised_input_receiver_fn( + features, labels, default_batch_size=10) + with ops.Graph().as_default(): + input_receiver = input_receiver_fn() + self.assertEqual([10], input_receiver.receiver_tensors["feature_1"].shape) + self.assertEqual([10], input_receiver.features["feature_1"].shape) + + def test_build_raw_supervised_input_receiver_fn_overlapping_keys(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"feature_1": constant_op.constant([5]), + "bar": constant_op.constant([6])} + with self.assertRaises(ValueError): + export.build_raw_supervised_input_receiver_fn(features, labels) + def test_build_all_signature_defs_without_receiver_alternatives(self): receiver_tensor = array_ops.placeholder(dtypes.string) output_1 = constant_op.constant([1.]) @@ -404,6 +626,35 @@ class ExportTest(test_util.TensorFlowTestCase): self.assertTrue(int(time_1) < int(time_2)) self.assertTrue(int(time_2) < int(time_3)) + def test_build_all_signature_defs_serving_only(self): + receiver_tensor = {"input": array_ops.placeholder(dtypes.string)} + output_1 = constant_op.constant([1.]) + export_outputs = { + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_output.PredictOutput(outputs=output_1), + "train": export_output.TrainOutput(loss=output_1), + } + + signature_defs = export.build_all_signature_defs( + receiver_tensor, export_outputs) + + expected_signature_defs = { + "serving_default": signature_def_utils.predict_signature_def( + receiver_tensor, {"output": output_1}) + } + + self.assertDictEqual(expected_signature_defs, signature_defs) + + signature_defs = export.build_all_signature_defs( + receiver_tensor, export_outputs, serving_only=False) + + expected_signature_defs.update({ + "train": signature_def_utils.supervised_train_signature_def( + receiver_tensor, loss={"loss": output_1}) + }) + + self.assertDictEqual(expected_signature_defs, signature_defs) + class TensorServingReceiverTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index 8111ab564c..4ab2578769 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook from tensorflow.python.util import nest @@ -53,6 +54,13 @@ class ModeKeys(object): LOSS_METRIC_KEY = 'loss' AVERAGE_LOSS_METRIC_KEY = 'average_loss' +# Mapping of the modes to appropriate tag_constants that are used for saving. +EXPORT_TAG_MAP = { + ModeKeys.PREDICT: [tag_constants.SERVING], + ModeKeys.TRAIN: [tag_constants.TRAINING], + ModeKeys.EVAL: [tag_constants.EVAL], +} + @tf_export('estimator.EstimatorSpec') class EstimatorSpec( diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 3447d917e9..071033b066 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -168,6 +168,25 @@ class SavedModelBuilder(object): raise TypeError("main_op needs to be an Operation: %r" % main_op) ops.add_to_collection(constants.MAIN_OP_KEY, main_op) + def _add_train_op(self, train_op): + """Add train op to the SavedModel. + + Note that this functionality is in development, and liable to be + moved elsewhere. + + Args: + train_op: Op or group of ops that are used for training. These are + stored as a collection with key TRAIN_OP_KEY, but not executed. + + Raises: + TypeError if Train op is not of type `Operation`. + """ + if train_op is not None: + if (not isinstance(train_op, ops.Tensor) and + not isinstance(train_op, ops.Operation)): + raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op) + ops.add_to_collection(constants.TRAIN_OP_KEY, train_op) + def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map): """Tags the meta graph def and adds it to the SavedModel. @@ -238,6 +257,20 @@ class SavedModelBuilder(object): for outputs_key in outputs: self._validate_tensor_info(outputs[outputs_key]) + def _add_collections( + self, assets_collection, legacy_init_op, main_op, train_op): + """Add asset and op collections to be saved.""" + # Save asset files and write them to disk, if any. + self._save_and_write_assets(assets_collection) + + if main_op is None: + # Add legacy init op to the SavedModel. + self._maybe_add_legacy_init_op(legacy_init_op) + else: + self._add_main_op(main_op) + + self._add_train_op(train_op) + def add_meta_graph(self, tags, signature_def_map=None, @@ -285,14 +318,8 @@ class SavedModelBuilder(object): # properly populated. self._validate_signature_def_map(signature_def_map) - # Save asset files and write them to disk, if any. - self._save_and_write_assets(assets_collection) - - if main_op is None: - # Add legacy init op to the SavedModel. - self._maybe_add_legacy_init_op(legacy_init_op) - else: - self._add_main_op(main_op) + # Add assets and ops + self._add_collections(assets_collection, legacy_init_op, main_op, None) # Initialize a saver to generate a sharded output for all saveables in the # current scope. @@ -351,6 +378,7 @@ class SavedModelBuilder(object): strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + """ # pylint: enable=line-too-long if self._has_saved_variables: @@ -362,8 +390,8 @@ class SavedModelBuilder(object): # properly populated. self._validate_signature_def_map(signature_def_map) - # Save asset files and write them to disk, if any. - self._save_and_write_assets(assets_collection) + # Add assets and ops + self._add_collections(assets_collection, legacy_init_op, main_op, None) # Create the variables sub-directory, if it does not exist. variables_dir = os.path.join( @@ -376,12 +404,6 @@ class SavedModelBuilder(object): compat.as_text(variables_dir), compat.as_text(constants.VARIABLES_FILENAME)) - if main_op is None: - # Add legacy init op to the SavedModel. - self._maybe_add_legacy_init_op(legacy_init_op) - else: - self._add_main_op(main_op) - # Initialize a saver to generate a sharded output for all saveables in the # current scope. saver = tf_saver.Saver( diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py index 34206c6f6d..61c6ffbd0d 100644 --- a/tensorflow/python/saved_model/constants.py +++ b/tensorflow/python/saved_model/constants.py @@ -41,6 +41,10 @@ MAIN_OP_KEY = "saved_model_main_op" tf_export("saved_model.constants.MAIN_OP_KEY").export_constant( __name__, "MAIN_OP_KEY") +# CollectionDef key for the SavedModel train op. +# Not exported while export_all_saved_models is in contrib. +TRAIN_OP_KEY = "saved_model_train_op" + # Schema version for SavedModel. SAVED_MODEL_SCHEMA_VERSION = 1 tf_export("saved_model.constants.SAVED_MODEL_SCHEMA_VERSION").export_constant( @@ -65,3 +69,5 @@ tf_export("saved_model.constants.VARIABLES_DIRECTORY").export_constant( VARIABLES_FILENAME = "variables" tf_export("saved_model.constants.VARIABLES_FILENAME").export_constant( __name__, "VARIABLES_FILENAME") + + diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 804255375e..a4d994fd43 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -734,6 +734,96 @@ class SavedModelTest(test.TestCase): builder.add_meta_graph_and_variables( sess, ["foo"], legacy_init_op=legacy_init_op) + def testTrainOp(self): + export_dir = self._get_export_dir("test_train_op") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + # Add `v1` and `v2` variables to the graph. + v1 = variables.Variable(1, name="v1") + ops.add_to_collection("v", v1) + v2 = variables.Variable(2, name="v2") + ops.add_to_collection("v", v2) + + sess.run(variables.global_variables_initializer()) + train_op = state_ops.assign_add(v1, v2) + + sess.run(train_op) + # TODO(karmel): remove explicit call when in the public method. + builder._add_train_op(train_op) + builder.add_meta_graph_and_variables(sess, ["foo"]) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo"], export_dir) + self.assertEqual(3, ops.get_collection("v")[0].eval()) + self.assertEqual(2, ops.get_collection("v")[1].eval()) + self.assertIsInstance( + ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor) + + def testTrainOpGroup(self): + export_dir = self._get_export_dir("test_train_op_group") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + # Add `v1` and `v2` variables to the graph. + v1 = variables.Variable(1, name="v1") + ops.add_to_collection("v", v1) + v2 = variables.Variable(2, name="v2") + ops.add_to_collection("v", v2) + + sess.run(variables.global_variables_initializer()) + train_op = control_flow_ops.group() + + sess.run(train_op) + # TODO(karmel): remove explicit call when in the public method. + builder._add_train_op(train_op) + builder.add_meta_graph_and_variables(sess, ["foo"]) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo"], export_dir) + self.assertEqual(1, ops.get_collection("v")[0].eval()) + self.assertEqual(2, ops.get_collection("v")[1].eval()) + self.assertIsInstance( + ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation) + + def testTrainOpAfterVariables(self): + export_dir = self._get_export_dir("test_train_op_after_variables") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + # Add `v1` and `v2` variables to the graph. + v1 = variables.Variable(1, name="v1") + ops.add_to_collection("v", v1) + v2 = variables.Variable(2, name="v2") + ops.add_to_collection("v", v2) + + sess.run(variables.global_variables_initializer()) + builder.add_meta_graph_and_variables(sess, ["pre_foo"]) + + train_op = state_ops.assign_add(v1, v2) + sess.run(train_op) + # TODO(karmel): remove explicit call when in the public method. + builder._add_train_op(train_op) + builder.add_meta_graph(["foo"]) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo"], export_dir) + self.assertIsInstance( + ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor) + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["pre_foo"], export_dir) + self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY)) + def testMultipleAssets(self): export_dir = self._get_export_dir("test_multiple_assets") builder = saved_model_builder.SavedModelBuilder(export_dir) diff --git a/tensorflow/python/saved_model/signature_constants.py b/tensorflow/python/saved_model/signature_constants.py index 819f351291..99007a9634 100644 --- a/tensorflow/python/saved_model/signature_constants.py +++ b/tensorflow/python/saved_model/signature_constants.py @@ -94,3 +94,9 @@ tf_export("saved_model.signature_constants.REGRESS_OUTPUTS").export_constant( __name__, "REGRESS_OUTPUTS") ################################################################################ +# Train/Eval API constants. +# Not exported while export_all_saved_models is in contrib. + +SUPERVISED_TRAIN_METHOD_NAME = "tensorflow/supervised/training" + +SUPERVISED_EVAL_METHOD_NAME = "tensorflow/supervised/eval" diff --git a/tensorflow/python/saved_model/signature_def_utils.py b/tensorflow/python/saved_model/signature_def_utils.py index ea0f52f17e..27d6b70e9d 100644 --- a/tensorflow/python/saved_model/signature_def_utils.py +++ b/tensorflow/python/saved_model/signature_def_utils.py @@ -26,6 +26,8 @@ from tensorflow.python.saved_model.signature_def_utils_impl import classificatio from tensorflow.python.saved_model.signature_def_utils_impl import is_valid_signature from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def +from tensorflow.python.saved_model.signature_def_utils_impl import supervised_eval_signature_def +from tensorflow.python.saved_model.signature_def_utils_impl import supervised_train_signature_def # pylint: enable=unused-import del absolute_import diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py index d033159188..f8ad788f77 100644 --- a/tensorflow/python/saved_model/signature_def_utils_impl.py +++ b/tensorflow/python/saved_model/signature_def_utils_impl.py @@ -185,6 +185,62 @@ def predict_signature_def(inputs, outputs): return signature_def +def supervised_train_signature_def( + inputs, loss, predictions=None, metrics=None): + return _supervised_signature_def( + signature_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss, + predictions=predictions, metrics=metrics) + + +def supervised_eval_signature_def( + inputs, loss, predictions=None, metrics=None): + return _supervised_signature_def( + signature_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss, + predictions=predictions, metrics=metrics) + + +def _supervised_signature_def( + method_name, inputs, loss=None, predictions=None, + metrics=None): + """Creates a signature for training and eval data. + + This function produces signatures that describe the inputs and outputs + of a supervised process, such as training or evaluation, that + results in loss, metrics, and the like. Note that this function only requires + inputs to be not None. + + Args: + method_name: Method name of the SignatureDef as a string. + inputs: dict of string to `Tensor`. + loss: dict of string to `Tensor` representing computed loss. + predictions: dict of string to `Tensor` representing the output predictions. + metrics: dict of string to `Tensor` representing metric ops. + + Returns: + A train- or eval-flavored signature_def. + + Raises: + ValueError: If inputs or outputs is `None`. + """ + if inputs is None or not inputs: + raise ValueError('{} inputs cannot be None or empty.'.format(method_name)) + + signature_inputs = {key: utils.build_tensor_info(tensor) + for key, tensor in inputs.items()} + + signature_outputs = {} + for output_set in (loss, predictions, metrics): + if output_set is not None: + sig_out = {key: utils.build_tensor_info(tensor) + for key, tensor in output_set.items()} + signature_outputs.update(sig_out) + + signature_def = build_signature_def( + signature_inputs, signature_outputs, method_name) + + return signature_def + + @tf_export('saved_model.signature_def_utils.is_valid_signature') def is_valid_signature(signature_def): """Determine whether a SignatureDef can be served by TensorFlow Serving.""" diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py index b2bd14db8c..ebc5450633 100644 --- a/tensorflow/python/saved_model/signature_def_utils_test.py +++ b/tensorflow/python/saved_model/signature_def_utils_test.py @@ -180,6 +180,101 @@ class SignatureDefUtilsTest(test.TestCase): self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype) self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim)) + def testTrainSignatureDef(self): + self._testSupervisedSignatureDef( + signature_def_utils_impl.supervised_train_signature_def, + signature_constants.SUPERVISED_TRAIN_METHOD_NAME) + + def testEvalSignatureDef(self): + self._testSupervisedSignatureDef( + signature_def_utils_impl.supervised_eval_signature_def, + signature_constants.SUPERVISED_EVAL_METHOD_NAME) + + def _testSupervisedSignatureDef(self, fn_to_test, method_name): + inputs = { + "input-1": constant_op.constant("a", name="input-1"), + "input-2": constant_op.constant("b", name="input-2"), + } + loss = {"loss-1": constant_op.constant(0.45, name="loss-1")} + predictions = { + "classes": constant_op.constant([100], name="classes"), + } + metrics_val = constant_op.constant(100.0, name="metrics_val") + metrics = { + "metrics/value": metrics_val, + "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"), + } + + signature_def = fn_to_test(inputs, loss, predictions, metrics) + + self.assertEqual(method_name, signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(2, len(signature_def.inputs)) + input1_tensor_info_actual = (signature_def.inputs["input-1"]) + self.assertEqual("input-1:0", input1_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype) + self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim)) + input2_tensor_info_actual = (signature_def.inputs["input-2"]) + self.assertEqual("input-2:0", input2_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype) + self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim)) + + # Check outputs in signature def. + self.assertEqual(4, len(signature_def.outputs)) + self.assertEqual("loss-1:0", signature_def.outputs["loss-1"].name) + self.assertEqual(types_pb2.DT_FLOAT, signature_def.outputs["loss-1"].dtype) + + self.assertEqual("classes:0", signature_def.outputs["classes"].name) + self.assertEqual(1, len(signature_def.outputs["classes"].tensor_shape.dim)) + + self.assertEqual( + "metrics_val:0", signature_def.outputs["metrics/value"].name) + self.assertEqual( + types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype) + + self.assertEqual( + "metrics_op:0", signature_def.outputs["metrics/update_op"].name) + self.assertEqual( + types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype) + + def testTrainSignatureDefMissingInputs(self): + self._testSupervisedSignatureDefMissingInputs( + signature_def_utils_impl.supervised_train_signature_def, + signature_constants.SUPERVISED_TRAIN_METHOD_NAME) + + def testEvalSignatureDefMissingInputs(self): + self._testSupervisedSignatureDefMissingInputs( + signature_def_utils_impl.supervised_eval_signature_def, + signature_constants.SUPERVISED_EVAL_METHOD_NAME) + + def _testSupervisedSignatureDefMissingInputs(self, fn_to_test, method_name): + inputs = { + "input-1": constant_op.constant("a", name="input-1"), + "input-2": constant_op.constant("b", name="input-2"), + } + loss = {"loss-1": constant_op.constant(0.45, name="loss-1")} + predictions = { + "classes": constant_op.constant([100], name="classes"), + } + metrics_val = constant_op.constant(100, name="metrics_val") + metrics = { + "metrics/value": metrics_val, + "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"), + } + + with self.assertRaises(ValueError): + signature_def = fn_to_test( + {}, loss=loss, predictions=predictions, metrics=metrics) + + signature_def = fn_to_test(inputs, loss=loss) + self.assertEqual(method_name, signature_def.method_name) + self.assertEqual(1, len(signature_def.outputs)) + + signature_def = fn_to_test(inputs, metrics=metrics, loss=loss) + self.assertEqual(method_name, signature_def.method_name) + self.assertEqual(3, len(signature_def.outputs)) + def testGetShapeAndTypes(self): inputs = { "input-1": constant_op.constant(["a", "b"]), diff --git a/tensorflow/python/saved_model/tag_constants.py b/tensorflow/python/saved_model/tag_constants.py index 5a797da791..c82154e7b9 100644 --- a/tensorflow/python/saved_model/tag_constants.py +++ b/tensorflow/python/saved_model/tag_constants.py @@ -32,6 +32,9 @@ TRAINING = "train" tf_export("saved_model.tag_constants.TRAINING").export_constant( __name__, "TRAINING") +# Tag for the `eval` graph. Not exported while the export logic is in contrib. +EVAL = "eval" + # Tag for the `gpu` graph. GPU = "gpu" tf_export("saved_model.tag_constants.GPU").export_constant(__name__, "GPU") @@ -39,3 +42,5 @@ tf_export("saved_model.tag_constants.GPU").export_constant(__name__, "GPU") # Tag for the `tpu` graph. TPU = "tpu" tf_export("saved_model.tag_constants.TPU").export_constant(__name__, "TPU") + + -- GitLab From 037e52e20157985d3f385f8e0426cdde3f5aae2b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 16:37:27 -0700 Subject: [PATCH 036/738] Expose read-only versions of tensors in tflite. PiperOrigin-RevId: 195491701 --- tensorflow/contrib/lite/interpreter.h | 37 ++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 1074f64263..0450e86ae7 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -201,7 +201,7 @@ class Interpreter { // Overrides execution plan. This bounds checks indices sent in. TfLiteStatus SetExecutionPlan(const std::vector& new_plan); - // Get a tensor data structure. + // Get a mutable tensor data structure. // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this // read/write access to structure TfLiteTensor* tensor(int tensor_index) { @@ -210,9 +210,14 @@ class Interpreter { return &context_.tensors[tensor_index]; } + // Get an immutable tensor data structure. + const TfLiteTensor* tensor(int tensor_index) const { + if (tensor_index >= context_.tensors_size || tensor_index < 0) + return nullptr; + return &context_.tensors[tensor_index]; + } + // Get a pointer to an operation and registration data structure if in bounds. - // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this - // read/write access to structure const std::pair* node_and_registration( int node_index) const { if (node_index >= nodes_and_registration_.size() || node_index < 0) @@ -220,7 +225,8 @@ class Interpreter { return &nodes_and_registration_[node_index]; } - // Perform a checked cast to the appropriate tensor type. + // Perform a checked cast to the appropriate tensor type (mutable pointer + // version). template T* typed_tensor(int tensor_index) { if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { @@ -231,6 +237,18 @@ class Interpreter { return nullptr; } + // Perform a checked cast to the appropriate tensor type (immutable pointer + // version). + template + const T* typed_tensor(int tensor_index) const { + if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) { + if (tensor_ptr->type == typeToTfLiteType()) { + return reinterpret_cast(tensor_ptr->data.raw); + } + } + return nullptr; + } + // Return a pointer into the data of a given input tensor. The given index // must be between 0 and inputs().size(). template @@ -238,13 +256,20 @@ class Interpreter { return typed_tensor(inputs_[index]); } - // Return a pointer into the data of a given output tensor. The given index - // must be between 0 and outputs().size(). + // Return a mutable pointer into the data of a given output tensor. The given + // index must be between 0 and outputs().size(). template T* typed_output_tensor(int index) { return typed_tensor(outputs_[index]); } + // Return an immutable pointer into the data of a given output tensor. The + // given index must be between 0 and outputs().size(). + template + const T* typed_output_tensor(int index) const { + return typed_tensor(outputs_[index]); + } + // Change the dimensionality of a given tensor. Note, this is only acceptable // for tensor indices that are inputs. // Returns status of failure or success. -- GitLab From fa1d92f70adf52d9258384e8528f9a7203a141dd Mon Sep 17 00:00:00 2001 From: Bjarke Hammersholt Roune Date: Fri, 4 May 2018 16:51:06 -0700 Subject: [PATCH 037/738] Add infrastructure for a backend-specific configuration for each op. This is intentionally not exposed in ComputationBuilder and is not intended for use or to be set at all prior to the last backend-specific part of compilation. PiperOrigin-RevId: 195493500 --- tensorflow/compiler/xla/service/hlo.proto | 3 + .../compiler/xla/service/hlo_computation.cc | 52 ++++----- .../compiler/xla/service/hlo_computation.h | 20 ++-- .../compiler/xla/service/hlo_graph_dumper.cc | 43 +++++--- .../compiler/xla/service/hlo_graph_dumper.h | 5 +- .../compiler/xla/service/hlo_instruction.cc | 100 +++++------------- .../compiler/xla/service/hlo_instruction.h | 59 +++++++++-- tensorflow/compiler/xla/service/hlo_module.cc | 12 +++ tensorflow/compiler/xla/service/hlo_module.h | 19 ++++ .../compiler/xla/service/hlo_verifier.cc | 71 ++++++------- tensorflow/compiler/xla/statusor.h | 11 +- tensorflow/compiler/xla/statusor_test.cc | 8 ++ .../compiler/xla/tools/parser/hlo_parser.cc | 10 +- .../xla/tools/parser/hlo_parser_test.cc | 22 +++- 14 files changed, 259 insertions(+), 176 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index aa6860880b..1f7c1cffd3 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -147,6 +147,9 @@ message HloInstructionProto { repeated int64 called_computation_ids = 38; xla.OpSharding sharding = 40; + + // Backend configuration for the instruction. Has backend-specific meaning. + string backend_config = 43; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 594413e88f..17e43c3cb8 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -347,6 +347,11 @@ std::list HloComputation::MakeEmbeddedComputationsList() // To avoid special handling of this computation, cast away const of // 'this'. 'this' is immediately removed from the post order after // construction. + // + // TODO(b/78350259): This violates const-correctness, since while the original + // computation is not returned, we still retrieve non-const computations from + // a const one. Consider also avoiding const for HloComputation, or review XLA + // for const-correctness of non-HloInstruction* types like this. ComputeComputationPostOrder(const_cast(this), &visited, &post_order); @@ -723,18 +728,25 @@ Status HloComputation::Accept( return this->Accept(&visitor); } -std::unique_ptr HloComputation::Clone(const string& suffix, - HloModule* module) { +std::unique_ptr HloComputation::Clone( + const string& suffix, HloModule* module, + HloInstruction::CloneMap* clone_map) { return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - module, suffix); + module, clone_map, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - HloModule* module, const string& suffix) { + HloModule* module, HloInstruction::CloneMap* clone_map, + const string& suffix) { + HloInstruction::CloneMap local_clone_map; + if (clone_map == nullptr) { + clone_map = &local_clone_map; + } + // Look up instr in the replacements map, and return either the replacement, // or instr, if the replacement isn't present. // @@ -756,24 +768,19 @@ std::unique_ptr HloComputation::CloneWithReplacements( } } - std::unordered_map clone_map; std::vector> instructions; std::unique_ptr new_instr = nullptr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { auto replaced_operand = replace(operand); - // If replaced_operand is null, that means 'replacements' asked us not to - // include operand in the new computation. But we can't do that, because - // operand is used by instr. CHECK_NE(replaced_operand, nullptr) - << "replacements map tried to eliminate a used instruction " - << operand->ToString() << ", used by " << instr->ToString(); - new_operands.push_back(FindOrDie(clone_map, replaced_operand)); + << "Replacements map specifies to leave out " << operand->ToString() + << ", but it is used by " << instr->ToString() << "."; + new_operands.push_back(FindOrDie(*clone_map, replaced_operand)); } - new_instr = - instr->CloneWithNewOperands(instr->shape(), new_operands, module); - InsertOrDie(&clone_map, instr, new_instr.get()); + new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands, + module, clone_map); instructions.push_back(std::move(new_instr)); } Builder builder(name() + "." + suffix); @@ -781,27 +788,24 @@ std::unique_ptr HloComputation::CloneWithReplacements( builder.AddInstruction(std::move(instr)); } auto result = builder.Build( - /*root_instruction=*/FindOrDie(clone_map, replace(root_instruction()))); + /*root_instruction=*/FindOrDie(*clone_map, replace(root_instruction()))); // Clone control dependencies. for (auto instr : postorder) { - HloInstruction* new_instr = FindOrDie(clone_map, instr); + HloInstruction* new_instr = FindOrDie(*clone_map, instr); for (auto successor : instr->control_successors()) { auto replaced_successor = replace(successor); - - // successor may not be in clone_map, because it might have been - // removed by the replacements map. - if (replaced_successor == nullptr) { - continue; - } + CHECK_NE(replaced_successor, nullptr) + << "Replacements map specifies to leave out " << successor->ToString() + << ", but it is control-depended-on by " << instr->ToString() << "."; TF_CHECK_OK(new_instr->AddControlDependencyTo( - FindOrDie(clone_map, replaced_successor))); + FindOrDie(*clone_map, replaced_successor))); } } // We cloned the elements of 'replacements', so they're all going to be - // destroyed. HloInstructions need to be detached from their operands before + // destroyed. HloInstructions need to be detached from their operands before // they're destroyed, otherwise they stick around in the operands' users lists // and cause use-after-frees. for (auto& kv : replacements) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 9d3f6e9a2c..9898355625 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -291,11 +291,17 @@ class HloComputation { const std::function& visitor_func) const; // Returns a deep copy of this computation including all instructions. - // If the module pointer is not nullptr, it will be the module where - // the cloned computations will be added to (in order to support deep - // cloning). - std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr); + // + // If the module pointer is not nullptr, then the cloned computations will be + // added to this module in order to support deep cloning. Otherwise the module + // of the computation is used. + // + // If clone_map is not nullptr, then each original instruction that is cloned + // will be inserted and map to its clone. clone_map should not already contain + // any of the instructions to clone. + std::unique_ptr Clone( + const string& suffix = "clone", HloModule* module = nullptr, + HloInstruction::CloneMap* clone_map = nullptr); // Like Clone(), but if an instruction is present in replacement_map, we use // the map's value to replace that instruction in the cloned computation. @@ -305,7 +311,9 @@ class HloComputation { std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - HloModule* module = nullptr, const string& suffix = "clone"); + HloModule* module = nullptr, + HloInstruction::CloneMap* clone_map = nullptr, + const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index bb4db89f0a..794f1b4682 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -322,11 +322,13 @@ class HloDotDumper { public: HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, const DebugOptions& debug_options, bool show_metadata, - const HloExecutionProfile* profile, NodeFilter filter) + bool show_backend_config, const HloExecutionProfile* profile, + NodeFilter filter) : computation_(computation), label_(label.ToString()), debug_options_(debug_options), show_metadata_(show_metadata), + show_backend_config_(show_backend_config), profile_(profile), filter_(std::move(filter)) {} @@ -365,6 +367,7 @@ class HloDotDumper { string GetInstructionNodeShape(const HloInstruction* instr); string GetInstructionNodeLabel(const HloInstruction* instr); string GetInstructionNodeMetadata(const HloInstruction* instr); + string GetInstructionNodeBackendConfig(const HloInstruction* instr); string GetInstructionNodeExtraInfo(const HloInstruction* instr); string GetInstructionNodeInlinedOperands(const HloInstruction* instr); void AddInstructionIncomingEdges(const HloInstruction* instr); @@ -393,6 +396,7 @@ class HloDotDumper { const string label_; // overall name for the graph const DebugOptions& debug_options_; const bool show_metadata_; + const bool show_backend_config_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -611,6 +615,10 @@ tooltip = " "; if (!extra_info.empty()) { StrAppend(&subcomp_label, "
", extra_info); } + string node_backend_config = GetInstructionNodeBackendConfig(parent_instr); + if (!node_backend_config.empty()) { + StrAppend(&subcomp_label, "
", node_backend_config); + } bool highlight = filter_.Highlight(parent_instr); const char* fillcolor; @@ -765,6 +773,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string node_shape = GetInstructionNodeShape(instr); string node_label = GetInstructionNodeLabel(instr); string node_metadata = GetInstructionNodeMetadata(instr); + string node_backend_config = GetInstructionNodeBackendConfig(instr); string extra_info = GetInstructionNodeExtraInfo(instr); string inlined_constants = GetInstructionNodeInlinedOperands(instr); string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); @@ -782,8 +791,8 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } // Build the text that will be displayed inside the node. string node_body = node_label; - for (const string& s : - {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) { + for (const string& s : {trivial_subcomputation, node_metadata, + node_backend_config, extra_info, inlined_constants}) { if (!s.empty()) { StrAppend(&node_body, "
", s); } @@ -1078,6 +1087,15 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { return Join(lines, "
"); } +string HloDotDumper::GetInstructionNodeBackendConfig( + const HloInstruction* instr) { + if (!show_backend_config_ || instr->backend_config().empty()) { + return ""; + } + + return StrCat("backend_config=\"", instr->backend_config(), "\""); +} + string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { std::vector lines; @@ -1404,7 +1422,7 @@ string ExportGraph(const string& graph, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, - bool show_metadata) { + bool show_metadata, bool show_backend_config) { GraphRendererInterface::GraphKind graph_kind; string graph; if (debug_options.xla_hlo_dump_as_graphdef()) { @@ -1414,9 +1432,10 @@ string DumpGraph(const HloComputation& computation, const string& label, &graph)); graph_kind = GraphRendererInterface::TF_GRAPHDEF; } else { - graph = HloDotDumper(&computation, label, debug_options, show_metadata, - hlo_execution_profile, NodeFilter()) - .Dump(); + graph = + HloDotDumper(&computation, label, debug_options, show_metadata, + show_backend_config, hlo_execution_profile, NodeFilter()) + .Dump(); graph_kind = GraphRendererInterface::DOT_GRAPH; } @@ -1427,15 +1446,15 @@ string DumpGraph(const HloComputation& computation, const string& label, } string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata) { + bool show_metadata, bool show_backend_config) { auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); NodeFilter filter = MakeNodeFilter(&node, radius); - string graph = - HloDotDumper(node.parent(), label, debug_options, show_metadata, - /*profile=*/nullptr, filter) - .Dump(); + string graph = HloDotDumper(node.parent(), label, debug_options, + show_metadata, show_backend_config, + /*profile=*/nullptr, filter) + .Dump(); return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 2704aae1e3..fc8e1468ac 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -56,7 +56,7 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile = nullptr, - bool show_metadata = false); + bool show_metadata = false, bool show_backend_config = false); // Like DumpGraph, but renders only nodes "near" the given node in the graph. // @@ -64,7 +64,8 @@ string DumpGraph(const HloComputation& computation, const string& label, // (roughly) corresponds to the max distance a node may be from the primary node // before it's omitted from the graph. string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata = false); + bool show_metadata = false, + bool show_backend_config = false); // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a714d0e114..2c733726a6 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -109,6 +109,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->name_ = proto.name(); instruction->metadata_ = proto.metadata(); + instruction->set_backend_config(proto.backend_config()); if (proto.has_literal()) { TF_ASSIGN_OR_RETURN(instruction->literal_, Literal::CreateFromProto(proto.literal())); @@ -1231,12 +1232,15 @@ bool HloInstruction::HasSideEffect() const { std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, - HloModule* module) const { + HloModule* module, CloneMap* clone_map) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { VLOG(3) << " %" << new_operand->name(); } + if (module == nullptr) { + module = GetModule(); + } std::unique_ptr clone; @@ -1342,7 +1346,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( break; case HloOpcode::kFft: CHECK_EQ(new_operands.size(), 1); - return CreateFft(shape, new_operands[0], fft_type_, fft_length_); + clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_); + break; case HloOpcode::kCrossReplicaSum: clone = CreateCrossReplicaSum(shape, new_operands); break; @@ -1415,9 +1420,15 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kConstant: clone = CreateConstant(literal_->CloneToUnique()); break; - case HloOpcode::kFusion: - clone = CloneFusionWithNewOperands(shape, new_operands, module); + case HloOpcode::kFusion: { + CHECK_NE(module, nullptr); + auto new_fused_computation = module->AddEmbeddedComputation( + fused_instructions_computation()->Clone("clone", module, clone_map)); + clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(), + /*operands=*/new_operands, + /*fusion_computation=*/new_fused_computation); break; + } case HloOpcode::kParameter: clone = CreateParameter(parameter_number_, shape, name_); break; @@ -1481,15 +1492,19 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); + clone->set_backend_config(backend_config()); + if (clone_map != nullptr) { + InsertOrDie(clone_map, this, clone.get()); + } return clone; } HloInstruction::~HloInstruction() {} -std::unique_ptr HloInstruction::Clone(const string& suffix, - HloModule* module) const { +std::unique_ptr HloInstruction::Clone( + const string& suffix, HloModule* module, CloneMap* clone_map) const { std::unique_ptr clone = - CloneWithNewOperands(shape_, operands_, module); + CloneWithNewOperands(shape_, operands_, module, clone_map); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1526,71 +1541,6 @@ std::unique_ptr HloInstruction::Clone(const string& suffix, return clone; } -std::unique_ptr HloInstruction::CloneFusionWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module) const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(parent() != nullptr); - - auto new_instruction = - WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - // Add the operands to our new fusion instruction. - for (HloInstruction* new_operand : operands) { - new_instruction->AppendOperand(new_operand); - } - // Clone all the fused instructions for the new fusion instruction. - HloInstructionMap old_to_new; - std::list> new_fused_instructions; - // Create the list of fused parameters by mapping through the cloned, - // fused instructions. - for (HloInstruction* old_fused_parameter : - fused_instructions_computation()->parameter_instructions()) { - new_fused_instructions.push_back( - old_fused_parameter->Clone("clone", module)); - HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); - InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter); - } - for (auto old_fused_instruction : - fused_instructions_computation()->MakeInstructionPostOrder()) { - if (old_fused_instruction->opcode() == HloOpcode::kParameter) { - FindOrDie(old_to_new, old_fused_instruction); - continue; - } - std::vector new_operands; - for (int64 operand_idx = 0; - operand_idx < old_fused_instruction->operand_count(); ++operand_idx) { - HloInstruction* old_operand = - old_fused_instruction->mutable_operand(operand_idx); - new_operands.push_back(FindOrDie(old_to_new, old_operand)); - } - new_fused_instructions.push_back( - old_fused_instruction->CloneWithNewOperands( - old_fused_instruction->shape(), new_operands, module)); - HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); - new_fused_instruction->set_parent(parent_); - InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); - } - new_instruction->fusion_kind_ = fusion_kind_; - auto computation_builder = HloComputation::Builder( - fused_instructions_computation()->name() + ".clone", - new_instruction.get()); - // We iterated the fusion instructions in reverse post order which means - // that we must reverse our new list of fusion instructions. - for (auto new_fused_instruction_iter = new_fused_instructions.rbegin(); - new_fused_instruction_iter != new_fused_instructions.rend(); - ++new_fused_instruction_iter) { - computation_builder.AddInstruction(std::move(*new_fused_instruction_iter)); - } - if (module == nullptr) { - module = GetModule(); - } - auto fused_root_ = fused_expression_root(); - new_instruction->called_computations_.push_back( - CHECK_NOTNULL(module)->AddEmbeddedComputation( - computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); - return new_instruction; -} - std::pair HloInstruction::LatestNonGteAncestorAndIndex() const { const HloInstruction* hlo = this; @@ -2172,6 +2122,9 @@ string HloInstruction::ToString(const HloPrintOptions& options) const { !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } + if (options.print_backend_config() && !backend_config().empty()) { + StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\""); + } return result; } @@ -2357,6 +2310,7 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); } + return extra; } @@ -2386,6 +2340,7 @@ HloInstructionProto HloInstruction::ToProto() const { } *proto.mutable_metadata() = metadata_; + proto.set_backend_config(backend_config()); if (literal_ != nullptr) { *proto.mutable_literal() = literal_->ToProto(); } @@ -2971,6 +2926,7 @@ Status HloInstruction::AcceptOrdered( continue; } + // TODO(b/78350259): Eliminate const laundering. HloInstruction* instruction = const_cast(const_instruction); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a5e9aecb9e..19c8c11453 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -66,6 +66,7 @@ class HloPrintOptions { : print_large_constants_(false), print_subcomputation_references_(true), print_metadata_(true), + print_backend_config_(true), compact_operands_(false), print_operand_shape_(true), print_program_shape_(true), @@ -77,6 +78,7 @@ class HloPrintOptions { .set_print_large_constants(true) .set_print_subcomputation_references(true) .set_print_metadata(false) + .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_program_shape(false) .set_print_percent(false); @@ -99,12 +101,18 @@ class HloPrintOptions { return *this; } - // If true, metatdata will be printed. + // If true, metadata will be printed. HloPrintOptions& set_print_metadata(bool value) { print_metadata_ = value; return *this; } + // If true, backend_config will be printed. + HloPrintOptions& set_print_backend_config(bool value) { + print_backend_config_ = value; + return *this; + } + // If true, operands' shapes will be printed. HloPrintOptions& set_print_operand_shape(bool value) { print_operand_shape_ = value; @@ -141,6 +149,7 @@ class HloPrintOptions { return print_subcomputation_references_; } bool print_metadata() const { return print_metadata_; } + bool print_backend_config() const { return print_metadata_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } @@ -151,6 +160,7 @@ class HloPrintOptions { bool print_large_constants_; bool print_subcomputation_references_; bool print_metadata_; + bool print_backend_config_; bool compact_operands_; bool print_operand_shape_; bool print_program_shape_; @@ -643,6 +653,8 @@ class HloInstruction { // Detaches an instruction from its operands. That is, remove the instruction // from each operand's user set. This should only be called prior to // deallocating the instruction. + // + // TODO(b/78305363): Make this automatic when deleting an instruction. void DetachFromOperands(); // Performs a postorder DFS visit using this node as the root. If @@ -1157,23 +1169,30 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kRng RandomDistribution random_distribution() const; + // See documentation for Clone(). + using CloneMap = std::unordered_map; + // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of - // the instruction to form the name of the cloned instruction. If the module - // pointer is not nullptr, it will be the module where the cloned computations - // will be added to (in order to support deep cloning). Ignores the control - // predecessors and successors of this HLO instruction. + // the instruction to form the name of the cloned instruction. Ignores the + // control predecessors and successors of this HLO instruction. + // + // If the module pointer is not nullptr, then any cloned computations will be + // added to this module in order to support deep cloning. Otherwise the module + // of the instruction is used. + // + // If clone_map is not nullptr, then each original instruction that is cloned + // will be inserted and map to its clone. clone_map should not already contain + // any of the instructions to clone. std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr) const; + HloModule* module = nullptr, + CloneMap* clone_map = nullptr) const; - // Clones the HLO instruction as above but with new shape and operands. If - // the module pointer is not nullptr, it will be the module where the cloned - // computations will be added to (in order to support deep cloning). Ignores - // the control predecessors and successors of this HLO instruction. + // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr) const; + HloModule* module = nullptr, CloneMap* clone_map = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector& called_computations() const { @@ -1262,6 +1281,19 @@ class HloInstruction { // if no id has been assigned yet). int unique_id() const { return unique_id_; } + // Returns the backend-specific configuration for how a backend should compile + // this HLO. The meaning of the field is backend specific. Not for use before + // or during general HLO optimization, since HLO optimizations do not preserve + // this field and they cannot interpret it due to its meaning being backend + // specific. + // + // TODO(b/78194644): Introduce structured configuration format as per + // go/xla-heuristics. + const string& backend_config() const { return backend_config_; } + void set_backend_config(string backend_config) { + backend_config_ = std::move(backend_config); + } + // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } @@ -1283,6 +1315,7 @@ class HloInstruction { // Get/Set the number of partitions per outer dimension (in order, starting // with outer-most dimension first). Currently used by the parallel cpu // backend to partition HLOs into parallel tasks. + // // TODO(b/62783254) Replace these methods with a more general way to // annotate HLOs with backend-specific information. const std::vector& outer_dimension_partitions() const { @@ -1510,6 +1543,10 @@ class HloInstruction { // The string representation of the infeed configuration. string infeed_config_; + // The backend-specific configuration for how a backend should compile this + // HLO. See the documentation on backend_config(). + string backend_config_; + // String identifier for instruction. string name_; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index c7a7192867..5308fb5848 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -46,6 +46,18 @@ HloModule::HloModule(const string& name, const HloModuleConfig& config) config_(config), unique_id_(next_unique_module_id_++) {} +StatusOr HloModule::LaunderConstInstructionFromModule( + const HloInstruction* hlo) { + if (hlo == nullptr) { + return nullptr; + } + + TF_RET_CHECK(hlo->GetModule() == this); + + // TODO(b/78350259): Eliminate const laundering. + return const_cast(hlo); +} + HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, bool uniquify_names) { diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index f9674df812..1604a72612 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -217,6 +217,25 @@ class HloModule { // the lifetime of this process. int unique_id() const { return unique_id_; } + // Returns a non-const version of the passed-in const HloInstruction*. This is + // safe on the argument that if you have a non-const module, then you can + // access all instructions in the module as non-const. + // + // Returns an error if the passed-in instruction is not from this module, + // except that it is allowed to pass in a null pointer. + // + // TODO(b/78350259): Eliminate const laundering. The argument above is not + // reliable since at any time someone could add or discover a way for a + // non-const module to transitively contain a const HloInstruction. The + // reliable way to do this would be to create a const laundering map from a + // module, mapping each encountered HloInstruction to its non-const version + // and then look up each instruction in need of laundering in that map, but + // this is much more expensive and complicated. This returns a Status instead + // of doing a CHECK-failure in part to make it strongly apparent that this is + // something that can fail. + StatusOr LaunderConstInstructionFromModule( + const HloInstruction* hlo); + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 8a30cbf9cd..096ebb7946 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -116,7 +116,7 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // produces no HLO value in the graph. if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { - return InvalidArgument( + return InternalError( "Expected outfeed to have shape compatible with operand's shape %s, " "actual shape is %s:\n%s", ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), @@ -200,7 +200,7 @@ Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { transpose->operand(0)->shape(), transpose->dimensions())); } -Status ShapeVerifier::HandleParameter(HloInstruction*) { +Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { return tensorflow::Status::OK(); } @@ -410,7 +410,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { if (fp_type == PRIMITIVE_TYPE_INVALID) { fp_type = subshape.element_type(); } else if (fp_type != subshape.element_type()) { - return FailedPrecondition( + return InternalError( "Seen floating point types of different precisions in " "%s, but mixed precision is disallowed.", instruction->ToString().c_str()); @@ -490,7 +490,7 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } } if (!compatible) { - return InvalidArgument( + return InternalError( "Expected instruction to have shape compatible with %s, actual " "shape is %s:\n%s", ShapeUtil::HumanString(inferred_shape).c_str(), @@ -541,7 +541,7 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2) { if (instr1->channel_id() != instr2->channel_id()) { - return FailedPrecondition( + return InternalError( "Expected to have the same channel id, actual channel ids are: %s " "(%lld), %s (%lld)", instr1->ToString().c_str(), instr1->channel_id(), @@ -571,22 +571,22 @@ string ComputationsToString( Status VerifyHloStructure(HloModule* module) { for (const HloComputation* computation : module->computations()) { if (computation->parent() == nullptr) { - return FailedPrecondition("Computation %s has a null parent pointer", - computation->name().c_str()); + return InternalError("Computation %s has a null parent pointer", + computation->name().c_str()); } if (computation->parent() != module) { - return FailedPrecondition( + return InternalError( "Computation %s parent() does not point to parent module", computation->name().c_str()); } for (const HloInstruction* instruction : computation->instructions()) { if (instruction->parent() == nullptr) { - return FailedPrecondition("Instruction %s has a null parent pointer", - instruction->name().c_str()); + return InternalError("Instruction %s has a null parent pointer", + instruction->name().c_str()); } if (instruction->parent() != computation) { - return FailedPrecondition( + return InternalError( "Instruction %s parent() does not point to parent computation", instruction->name().c_str()); } @@ -602,7 +602,7 @@ Status VerifyHloStructure(HloModule* module) { for (int i = 0; i < instruction->operand_count(); ++i) { const HloInstruction* operand = instruction->operand(i); if (operand->parent() != instruction->parent()) { - return FailedPrecondition( + return InternalError( "Operand %d (%s) of instruction %s is in a different " "computation: %s vs %s", i, operand->name().c_str(), instruction->name().c_str(), @@ -619,7 +619,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { // The parent fusion instruction of the fusion computation must be 'fusion'. HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { - return FailedPrecondition( + return InternalError( "Instruction of fused computation does not match expected instruction " "%s.", fusion->ToString().c_str()); @@ -635,37 +635,37 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto* instruction : fused_computation->instructions()) { if (fused_root == instruction) { if (root_owned) { - return FailedPrecondition("Root appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Root appears more than once in %s.", + fusion->ToString().c_str()); } root_owned = true; } for (int i = 0; i < fused_parameters.size(); ++i) { if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { - return FailedPrecondition("Parameter appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Parameter appears more than once in %s.", + fusion->ToString().c_str()); } parameter_owned[i] = true; } } } if (!root_owned) { - return FailedPrecondition("Root not found in computation of %s.", - fusion->ToString().c_str()); + return InternalError("Root not found in computation of %s.", + fusion->ToString().c_str()); } // Make sure all the parameter_owned entries are set for (int i = 0; i < parameter_owned.size(); i++) { if (!parameter_owned[i]) { - return FailedPrecondition("Parameter %d not found in computation of %s.", - i, fusion->ToString().c_str()); + return InternalError("Parameter %d not found in computation of %s.", i, + fusion->ToString().c_str()); } } // Fused root must have no users. if (fused_root->user_count() != 0) { - return FailedPrecondition("Root of %s may not have users.", - fusion->ToString().c_str()); + return InternalError("Root of %s may not have users.", + fusion->ToString().c_str()); } // All uses of fused instructions must be in the fusion computation, and every @@ -674,13 +674,13 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { fusion->fused_instructions_computation()->instructions()) { if (instruction != fused_root) { if (instruction->user_count() == 0) { - return FailedPrecondition( - "Non-root instruction %s in %s must have users.", - instruction->ToString().c_str(), fusion->ToString().c_str()); + return InternalError("Non-root instruction %s in %s must have users.", + instruction->ToString().c_str(), + fusion->ToString().c_str()); } for (auto& user : instruction->users()) { if (fused_computation != user->parent()) { - return FailedPrecondition( + return InternalError( "Non-root instruction %s in %s may not have external users.", instruction->ToString().c_str(), fusion->ToString().c_str()); } @@ -695,34 +695,33 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); if (param_no < 0) { - return FailedPrecondition( - "Unexpected negative parameter number %lld in %s.", param_no, - fusion->ToString().c_str()); + return InternalError("Unexpected negative parameter number %lld in %s.", + param_no, fusion->ToString().c_str()); } if (param_no >= fused_parameters.size()) { - return FailedPrecondition( + return InternalError( "Unexpected parameter number %lld in %s: higher then number of " "parameters %lu.", param_no, fusion->ToString().c_str(), fused_parameters.size()); } if (parameter_numbers[param_no]) { - return FailedPrecondition( + return InternalError( "Did not expect parameter number %lld more than once in %s.", param_no, fusion->ToString().c_str()); } parameter_numbers[param_no] = true; if (!ShapeUtil::Compatible(fused_param->shape(), fusion->operand(param_no)->shape())) { - return FailedPrecondition( + return InternalError( "Shape mismatch between parameter number %lld and its operand in %s.", param_no, fusion->ToString().c_str()); } } - // Make sure all the parameter_numbers entries were seen + // Make sure all the parameter_numbers entries were seen. for (int i = 0; i < parameter_numbers.size(); i++) { if (!parameter_numbers[i]) { - return FailedPrecondition("Did not see parameter number %d in %s.", i, - fusion->ToString().c_str()); + return InternalError("Did not see parameter number %d in %s.", i, + fusion->ToString().c_str()); } } diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h index cccbce5fc8..0e1387c939 100644 --- a/tensorflow/compiler/xla/statusor.h +++ b/tensorflow/compiler/xla/statusor.h @@ -13,13 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// StatusOr is the union of a Status object and a T -// object. StatusOr models the concept of an object that is either a -// usable value, or an error Status explaining why such a value is -// not present. To this end, StatusOr does not allow its Status -// value to be Status::OK. Furthermore, the value of a StatusOr -// must not be null. This is enforced by a debug check in most cases, -// but even when it is not, clients must not set the value to null. +// StatusOr is the union of a Status object and a T object. StatusOr models +// the concept of an object that is either a value, or an error Status +// explaining why such a value is not present. To this end, StatusOr does not +// allow its Status value to be Status::OK. // // The primary use-case for StatusOr is as the return value of a // function which may fail. diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index f9d25945bc..7d76370e85 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -75,6 +75,14 @@ TEST(StatusOr, ElementType) { static_assert(std::is_same::element_type, char>(), ""); } +TEST(StatusOr, NullPointerStatusOr) { + // As a very special case, null-plain-pointer StatusOr used to be an + // error. Test that it no longer is. + StatusOr null_status(nullptr); + EXPECT_TRUE(null_status.ok()); + EXPECT_EQ(null_status.ValueOrDie(), nullptr); +} + TEST(StatusOr, TestNoDefaultConstructorInitialization) { // Explicitly initialize it with an error code. StatusOr statusor(tensorflow::errors::Cancelled("")); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 40dc0730ce..156a06c596 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -440,6 +440,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional metadata; attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; + optional backend_config; + attrs["backend_config"] = {/*required=*/false, AttrTy::kString, + &backend_config}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -1094,8 +1098,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, instruction->set_name(name); - // Add common attrs (sharding, control predecessors) to the instruction, if - // they were seen. + // Add shared attributes like metadata to the instruction, if they were seen. if (sharding) { instruction->set_sharding( HloSharding::FromProto(sharding.value()).ValueOrDie()); @@ -1112,6 +1115,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (metadata) { instruction->set_metadata(*metadata); } + if (backend_config) { + instruction->set_backend_config(std::move(*backend_config)); + } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index d38d8907a6..e100d8cda1 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -65,7 +65,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { R"(HloModule constant_pred_module ENTRY %constant_pred () -> pred[] { - ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} + ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar" } )" @@ -81,13 +81,14 @@ ENTRY %constant_s32 () -> s32[] { )" }, -// f32 constant, but the value is not a decimal +// f32 constant, but the value is not a decimal and there is a backend +// configuration { "ConstantF32", R"(HloModule ConstantF32_module ENTRY %ConstantF32.v4 () -> f32[] { - ROOT %constant = f32[] constant(42) + ROOT %constant = f32[] constant(42), backend_config="this is a configuration" } )" @@ -1013,6 +1014,19 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { // but the constant names will not be exactly the same. } +TEST_F(HloParserTest, ConfigurationField) { + const string original = R"(HloModule AModule +ENTRY %configuration_test() -> s32[] { + %constant = s32[] constant(42), backend_config="foo bar" +})"; + auto result = Parse(original); + TF_ASSERT_OK(result.status()); + EXPECT_EQ("foo bar", result.ValueOrDie() + ->entry_computation() + ->root_instruction() + ->backend_config()); +} + TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { const string original = R"(HloModule some_2_module @@ -1092,7 +1106,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} } )"; -- GitLab From bf228e1435da0032d2529de93661b742ee8a7048 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Fri, 4 May 2018 17:03:52 -0700 Subject: [PATCH 038/738] [tf.data] Adding `num_parallel_calls` to `map_and_batch`. PiperOrigin-RevId: 195495206 --- .../kernel_tests/batch_dataset_op_test.py | 44 +- .../contrib/data/python/ops/batching.py | 47 +- .../base_api/api_def_MapAndBatchDataset.pbtxt | 35 +- .../api_def_MapAndBatchDatasetV2.pbtxt | 54 ++ .../kernels/data/map_and_batch_dataset_op.cc | 773 +++++++++--------- tensorflow/core/ops/dataset_ops.cc | 13 + 6 files changed, 538 insertions(+), 428 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_MapAndBatchDatasetV2.pbtxt diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 6588fd04ac..2568b899d7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -427,7 +427,9 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def _testMapAndBatchDatasetHelper(self, num_parallel_batches=1): + def _testMapAndBatchDatasetHelper(self, + num_parallel_calls=None, + num_parallel_batches=None): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). @@ -446,6 +448,7 @@ class BatchDatasetTest(test.TestCase): batching.map_and_batch( map_func=_map_fn, batch_size=batch_size, + num_parallel_calls=num_parallel_calls, num_parallel_batches=num_parallel_batches)) .make_initializable_iterator()) init_op = iterator.initializer @@ -497,12 +500,18 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testMapAndBatchDataset(self): + def testMapAndBatch(self): return self._testMapAndBatchDatasetHelper() - def testMapAndBatchDatasetWithParallelBatching(self): + def testMapAndBatchWithParallelBatches(self): return self._testMapAndBatchDatasetHelper(num_parallel_batches=10) + def testMapAndBatchWithSequentialCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=1) + + def testMapAndBatchWithParallelCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=2) + def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False): iterator = ( dataset_ops.Dataset.range(10).apply( @@ -682,7 +691,7 @@ class UnbatchDatasetSerializationTest( class MapAndBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): - def testSerializationCore(self): + def testNumParallelBatches(self): range_size = 11 num_repeats = 2 batch_size = 5 @@ -709,6 +718,33 @@ class MapAndBatchDatasetSerializationTest( self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), num_outputs_drop_remainder) + def testNumParallelCalls(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_calls = 7 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_calls=num_parallel_calls, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + class PaddedBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 42ec2b0b01..b9393de4e9 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -466,14 +466,14 @@ def assert_element_shape(expected_shapes): class _MapAndBatchDataset(dataset_ops.MapDataset): """A `Dataset` that maps a function over a batch of elements.""" - def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches, + def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, drop_remainder): """See `Dataset.map()` for details.""" super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) self._batch_size_t = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - self._num_parallel_batches_t = ops.convert_to_tensor( - num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches") + self._num_parallel_calls_t = ops.convert_to_tensor( + num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") self._drop_remainder_t = ops.convert_to_tensor( drop_remainder, dtype=dtypes.bool, name="drop_remainder") @@ -483,12 +483,12 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def _as_variant_tensor(self): # pylint: disable=protected-access input_resource = self._input_dataset._as_variant_tensor() - return gen_dataset_ops.map_and_batch_dataset( + return gen_dataset_ops.map_and_batch_dataset_v2( input_resource, self._map_func.captured_inputs, f=self._map_func, batch_size=self._batch_size_t, - num_parallel_batches=self._num_parallel_batches_t, + num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), @@ -511,8 +511,9 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def map_and_batch(map_func, batch_size, - num_parallel_batches=1, - drop_remainder=False): + num_parallel_batches=None, + drop_remainder=False, + num_parallel_calls=None): """Fused implementation of `map` and `batch`. Maps `map_func` across `batch_size` consecutive elements of this dataset @@ -528,21 +529,37 @@ def map_and_batch(map_func, nested structure of tensors. batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of consecutive elements of this dataset to combine in a single batch. - num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the - number of batches to create in parallel. On one hand, higher values can - help mitigate the effect of stragglers. On the other hand, higher values - can increase contention if CPU is scarce. - drop_remainder: A `tf.bool` scalar `tf.Tensor`, representing whether the - last batch should be dropped in case its size is smaller than desired; - the default behavior is not to drop the smaller batch. + num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, + representing the number of batches to create in parallel. On one hand, + higher values can help mitigate the effect of stragglers. On the other + hand, higher values can increase contention if CPU is scarce. + drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing + whether the last batch should be dropped in case its size is smaller than + desired; the default behavior is not to drop the smaller batch. + num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, + representing the number of elements to process in parallel. If not + specified, `batch_size * num_parallel_batches` elements will be + processed in parallel. Returns: A `Dataset` transformation function, which can be passed to @{tf.data.Dataset.apply}. + + Raises: + ValueError: If both `num_parallel_batches` and `num_parallel_calls` are + specified. """ + if num_parallel_batches is None and num_parallel_calls is None: + num_parallel_calls = batch_size + elif num_parallel_batches is not None and num_parallel_calls is None: + num_parallel_calls = batch_size * num_parallel_batches + elif num_parallel_batches is not None and num_parallel_calls is not None: + raise ValueError("The `num_parallel_batches` and `num_parallel_calls` " + "arguments are mutually exclusive.") + def _apply_fn(dataset): return _MapAndBatchDataset(dataset, map_func, batch_size, - num_parallel_batches, drop_remainder) + num_parallel_calls, drop_remainder) return _apply_fn diff --git a/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt index bf544703de..e230c51edf 100644 --- a/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt @@ -1,5 +1,19 @@ op { graph_op_name: "MapAndBatchDataset" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: <