From 75259500f80266998f232a94853b0bc08d2925cc Mon Sep 17 00:00:00 2001 From: KB Sriram Date: Wed, 28 Feb 2018 07:16:20 -0800 Subject: [PATCH 001/755] C++ gradients: Fractional*Pool, Soft{Plus,Sign} 1. Adds gradients for four nn ops: FractionalAvgPool FractionalMaxPool SoftPlus SoftSign 2. Update randomization to allow numeric gradient checks on max pooling algorithms with more than one pool. Resolves https://github.com/tensorflow/tensorflow/issues/17330 --- tensorflow/cc/gradients/nn_grad.cc | 47 ++++++++++++++ tensorflow/cc/gradients/nn_grad_test.cc | 84 ++++++++++++++++++++----- 2 files changed, 115 insertions(+), 16 deletions(-) diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 63a67f09f6..4b89dac4c0 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -272,6 +272,53 @@ Status LRNGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("LRN", LRNGradHelper); +Status SoftplusGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftplusGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softplus", SoftplusGradHelper); + +Status SoftsignGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto dx = internal::SoftsignGrad(scope, grad_inputs[0], op.input(0)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Softsign", SoftsignGradHelper); + +Status FractionalAvgPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalAvgPoolGrad( + scope, Shape(scope, op.input(0), Shape::OutType(DT_INT64)), + grad_inputs[0], op.output(1), op.output(2), + internal::FractionalAvgPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalAvgPool", FractionalAvgPoolGradHelper); + +Status FractionalMaxPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + bool overlapping; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "overlapping", &overlapping)); + auto dx = internal::FractionalMaxPoolGrad( + scope, op.input(0), op.output(0), grad_inputs[0], op.output(1), + op.output(2), internal::FractionalMaxPoolGrad::Overlapping(overlapping)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("FractionalMaxPool", FractionalMaxPoolGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index c4eba7ecb0..b4d457a9d1 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -28,6 +28,8 @@ namespace { using ops::BiasAdd; using ops::Conv2D; using ops::Elu; +using ops::FractionalAvgPool; +using ops::FractionalMaxPool; using ops::L2Loss; using ops::LogSoftmax; using ops::LRN; @@ -41,6 +43,8 @@ using ops::Relu; using ops::Relu6; using ops::Selu; using ops::Softmax; +using ops::Softplus; +using ops::Softsign; class NNGradTest : public ::testing::Test { protected: @@ -71,22 +75,30 @@ class NNGradTest : public ::testing::Test { EXPECT_LT(max_error, 1e-3); } - // Sets tensor with random values, ensuring that the max value is largest by - // a reasonable amount. - // This is an issue for MaxPool, MaxPoolV2 and MaxPool3D, in which - // perturbations by the numeric gradient computation in the gradient checker - // can change the max value if values are too close together. + // Sets tensor with random values, ensuring that every pair of elements are at + // least a reasonable amount apart. + // This is an issue for max pooling operations, in which perturbations by the + // numeric gradient computation in the gradient checker can change the max + // value if a pool has values that are too close together. template - void SetRandomValuesWithBumpedMax(Tensor* tensor) { + void SetRandomValuesForMaxPooling(Tensor* tensor) { auto tensor_flat = tensor->flat(); - tensor_flat.setRandom(); - int32 max_index = 0; - for (size_t i = 1; i < tensor->NumElements(); i++) { - if (tensor_flat(i) > tensor_flat(max_index)) { - max_index = i; - } + // First set the array to an increasing sequence of values spaced + // a reasonable amount apart + T cur = 0; + for (size_t i = 0; i < tensor->NumElements(); i++) { + tensor_flat(i) = cur; + cur += 5e-2; + } + // Fischer-Yates shuffle the array + for (size_t i = tensor->NumElements() - 1; i >= 1; i--) { + // j <- random integer 0 <= j <= i + size_t j = random::New64() % (i + 1); + // swap values at i, j + T tmp = tensor_flat(i); + tensor_flat(i) = tensor_flat(j); + tensor_flat(j) = tmp; } - tensor_flat(max_index) += 1e-2; } Scope scope_; @@ -189,7 +201,7 @@ TEST_F(NNGradTest, MaxPoolGradHelper) { const std::vector strides{1, 2, 2, 1}; auto y = MaxPool(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -202,7 +214,7 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { Tensor strides = test::AsTensor({1, 2, 2, 1}, {4}); auto y = MaxPoolV2(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -215,7 +227,7 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) { const std::vector strides{1, 3, 3, 3, 1}; auto y = MaxPool3D(scope_, x, ksize, strides, "VALID"); Tensor x_init_value = Tensor(DT_FLOAT, x_shape); - SetRandomValuesWithBumpedMax(&x_init_value); + SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } @@ -248,5 +260,45 @@ TEST_F(NNGradTest, LRN){ RunTest(x, x_shape, y, x_shape); } +TEST_F(NNGradTest, SoftplusGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softplus(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, SoftsignGrad) { + TensorShape shape({3, 7}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + auto y = Softsign(scope_, x); + RunTest(x, shape, y, shape); +} + +TEST_F(NNGradTest, FractionalAvgPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalAvgPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalAvgPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_shape, y.output, y_shape); +} + +TEST_F(NNGradTest, FractionalMaxPoolGradHelper) { + TensorShape x_shape({1, 3, 7, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Force consistent pooling regions for unit testing. + auto y = FractionalMaxPool( + scope_, x, {1, 1.2, 1.9, 1}, + FractionalMaxPool::Deterministic(true).Overlapping(true).Seed(1).Seed2( + 2)); + Tensor x_init_value = Tensor(DT_FLOAT, x_shape); + SetRandomValuesForMaxPooling(&x_init_value); + TensorShape y_shape({1, 2, 3, 1}); + RunTest(x, x_init_value, y.output, y_shape); +} + } // namespace } // namespace tensorflow -- GitLab From 1da3a47287aa911287d6667dd837dc2a7ddaa8f1 Mon Sep 17 00:00:00 2001 From: Smit Shilu Date: Thu, 22 Mar 2018 10:58:51 -0400 Subject: [PATCH 002/755] Update BUILD exports_files(["LICENSE"]) gives error while building on Mac and Ubuntu --- tensorflow/contrib/lite/BUILD | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index dafe6f136e..1c5bc29763 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -6,8 +6,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") -exports_files(["LICENSE"]) - exports_files(glob([ "testdata/*.bin", "testdata/*.pb", -- GitLab From e7f3ed2477c7910e68573880efd2310e149ca785 Mon Sep 17 00:00:00 2001 From: mbhuiyan Date: Wed, 4 Apr 2018 10:52:49 -0700 Subject: [PATCH 003/755] Fixing a unit test failure for INTEL MKL where memeory allocation check failed because of use of INTEL MKL --- .../direct_session_with_tracking_alloc_test.cc | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index 31fb128f93..0ff022a8bc 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -101,11 +101,24 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { EXPECT_EQ(2, shape.dim_size()); EXPECT_EQ(2, shape.dim(0).size()); EXPECT_EQ(1, shape.dim(1).size()); +#ifndef INTEL_MKL + // if MKL is used, it goes through various additional + // graph rewrite pass. In TF, everytime a graph pass + // happens, "constant" nodes are allocated + // and deallocated. Each allocation calls the + // (FindChunkPtr of BFCAllocator) + // , which increments the value of AllocationId. + // Thus AllocationId becomes more than 3 and 4 if + // MKL is used, they can be 10 and 11 or + // other numbers. If MKL is used + // following check will not hold. + // Thus, skipping the check if MKL is used. if (node->name() == y->name()) { EXPECT_EQ(3, cm->AllocationId(node, 0)); } else { EXPECT_EQ(4, cm->AllocationId(node, 0)); } +#endif } EXPECT_LE(0, cm->MaxExecutionTime(node)); EXPECT_GE(run_duration_micros, cm->MaxExecutionTime(node)); -- GitLab From 9d1aa895adda8644ddbb55b5e1dbb0797ea6cbb0 Mon Sep 17 00:00:00 2001 From: Jie Date: Wed, 11 Apr 2018 14:42:15 -0700 Subject: [PATCH 004/755] [tftrt update] Added support for TRT plugin during conversion - converter & shape inference are now aware of plugin factory. - each plugin does serialization of plugin type & input dimensions - wrapper for nvinfer1::IPlugin & nvinfer1::PluginFactory * compatible with TRT 3.0.4 plugin API. * future plugin API changes willl be updated. --- tensorflow/contrib/tensorrt/BUILD | 26 ++++++ .../contrib/tensorrt/convert/convert_graph.cc | 4 +- .../contrib/tensorrt/convert/convert_nodes.cc | 84 ++++++++++++++--- .../contrib/tensorrt/kernels/trt_engine_op.cc | 4 +- .../contrib/tensorrt/plugin/trt_plugin.cc | 89 +++++++++++++++++++ .../contrib/tensorrt/plugin/trt_plugin.h | 81 +++++++++++++++++ .../tensorrt/plugin/trt_plugin_factory.cc | 81 +++++++++++++++++ .../tensorrt/plugin/trt_plugin_factory.h | 83 +++++++++++++++++ .../tensorrt/plugin/trt_plugin_utils.cc | 36 ++++++++ .../tensorrt/plugin/trt_plugin_utils.h | 51 +++++++++++ .../contrib/tensorrt/shape_fn/trt_shfn.cc | 4 +- 11 files changed, 528 insertions(+), 15 deletions(-) create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin.cc create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin.h create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 2f316767b3..98f18835b0 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -67,6 +67,7 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = [ ":trt_logging", + ":trt_plugins", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]) + tf_custom_op_library_additional_deps(), @@ -86,6 +87,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":trt_logging", + ":trt_plugins", ":trt_resources", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", @@ -222,6 +224,7 @@ tf_cuda_library( ], deps = [ ":segment", + ":trt_plugins", ":trt_logging", ":trt_resources", "//tensorflow/core/grappler:grappler_item", @@ -272,3 +275,26 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +# Library for the plugin factory +#cc_library( +tf_cuda_library( + name = "trt_plugins", + srcs = [ + "plugin/trt_plugin.cc", + "plugin/trt_plugin_factory.cc", + "plugin/trt_plugin_utils.cc", + ], + hdrs = [ + "plugin/trt_plugin.h", + "plugin/trt_plugin_factory.h", + "plugin/trt_plugin_utils.h", + ], + linkstatic = 1, + deps = [ + #"@protobuf_archive//:protobuf_headers", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index b412b296e0..899e1721e6 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_graph.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include @@ -75,7 +76,8 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(ben,jie): ... }; // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h) - return candidate_ops.count(node->type_string()); + return (candidate_ops.count(node->type_string()) || + PluginFactoryTensorRT::GetInstance().IsPlugin(&node->type_string())); } void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 567b4af88d..a03c1e224a 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include @@ -246,6 +247,15 @@ class TFAttrs { return attrs_.count(key) ? this->get(key) : default_value; } + std::vector GetAllAttrKey() { + std::vector attr_list; + for (AttrMap::iterator iter = attrs_.begin(); iter != attrs_.end(); + iter++) { + attr_list.emplace_back(iter->first); + } + return attr_list; + } + private: typedef std::map AttrMap; AttrMap attrs_; @@ -262,6 +272,12 @@ std::vector TFAttrs::get>(string key) const { return std::vector(attr.begin(), attr.end()); } +template <> +std::vector TFAttrs::get>(string key) const { + auto attr = this->at(key)->list().f(); + return std::vector(attr.begin(), attr.end()); +} + template <> std::vector TFAttrs::get>(string key) const { auto attr = this->at(key)->list().s(); @@ -424,6 +440,7 @@ using OpConverter = class Converter { std::unordered_map trt_tensors_; std::unordered_map op_registry_; + OpConverter plugin_converter_; nvinfer1::INetworkDefinition* trt_network_; std::list> temp_bufs_; tensorflow::tensorrt::TRTWeightStore* weight_store_; @@ -444,8 +461,8 @@ class Converter { * remove this and annotate the edge as a control dependency. ************************************************************************/ // skip control nodes - if (input_name[0] == '^' ) continue; - string name = input_name; + if (input_name[0] == '^') continue; + string name = input_name; auto first = name.find_first_of(':'); if (first != string::npos && first + 2 == name.size() && name[first + 1] == '0') @@ -490,13 +507,17 @@ class Converter { std::vector inputs; TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs)); string op = node_def.op(); - if (!op_registry_.count(op)) { - return tensorflow::errors::Unimplemented( - "No converter registered for op: " + op); - } - OpConverter op_converter = op_registry_.at(op); std::vector outputs; - TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs)); + if (PluginFactoryTensorRT::GetInstance().IsPlugin(&op)) { + TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs)); + } else { + if (!op_registry_.count(op)) { + return tensorflow::errors::Unimplemented( + "No converter registered for op: " + op); + } + OpConverter op_converter = op_registry_.at(op); + TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs)); + } for (size_t i = 0; i < outputs.size(); ++i) { TRT_TensorOrWeights output = outputs.at(i); // TODO(jie): tf protobuf seems to be omitting the :0 suffix @@ -1158,9 +1179,9 @@ tensorflow::Status BinaryTensorOpTensor( CHECK_EQ_TYPE(tensor_r->getType(), dtype); auto op_pair = ops.find(node_def.op()); if (op_pair == ops.end()) - return tensorflow::errors::Unimplemented( - "binary op: " + node_def.op() + - " not supported at: " + node_def.name()); + return tensorflow::errors::Unimplemented("binary op: " + node_def.op() + + " not supported at: " + + node_def.name()); nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( *const_cast(tensor_l), @@ -1173,6 +1194,43 @@ tensorflow::Status BinaryTensorOpTensor( return tensorflow::Status::OK(); } +tensorflow::Status ConvertPlugin(Converter& ctx, + const tensorflow::NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs) { + // prepare input + std::vector all_inputs; + for (auto input : inputs) { + all_inputs.emplace_back(const_cast(input.tensor())); + } + + // plugin is owned by PluginFactory + // TODO(jie): destroy plugins later (resource management) + PluginTensorRT* plugin = + PluginFactoryTensorRT::GetInstance().CreatePlugin(&node_def.op()); + + // passing attributes + // TODO(jie): support more general attribute + TFAttrs attrs(node_def); + auto attr_key_vector = attrs.GetAllAttrKey(); + for (auto attr_key : attr_key_vector) { + std::cout << attr_key << std::endl; + // TODO(jie): support only list of float for toy example here. + auto data = attrs.get>(attr_key); + size_t size_data = data.size() * sizeof(float); + plugin->SetAttribute(attr_key, static_cast(data.data()), size_data); + } + + nvinfer1::IPluginLayer* layer = + ctx.network()->addPlugin(&all_inputs[0], int(inputs.size()), *plugin); + + for (int i = 0; i < layer->getNbOutputs(); i++) { + nvinfer1::ITensor* output_tensor = layer->getOutput(i); + outputs->push_back(TRT_TensorOrWeights(output_tensor)); + } + return tensorflow::Status::OK(); +} + tensorflow::Status ConvertPlaceholder( Converter& ctx, const tensorflow::NodeDef& node_def, const std::vector& inputs, @@ -2073,6 +2131,8 @@ void Converter::register_op_converters() { op_registry_["Reshape"] = ConvertReshape; op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm; op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm; + + plugin_converter_ = ConvertPlugin; } } // namespace @@ -2511,7 +2571,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef( std::vector input_names; std::vector input_dtypes; for (const std::pair& input : s.input_inds) { - VLOG(2) << "parsing input. Node id= " << input.first ; + VLOG(2) << "parsing input. Node id= " << input.first; int node_id = input.first; int output_idx = input.second; tensorflow::Node* node = s.graph.FindNodeId(node_id); diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index b32371b642..8881c48fe6 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/core/platform/logging.h" @@ -58,7 +59,8 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { IRuntime* infer = nvinfer1::createInferRuntime(logger); trt_engine_ptr_.reset(infer->deserializeCudaEngine( - serialized_engine.c_str(), serialized_engine.size(), nullptr)); + serialized_engine.c_str(), serialized_engine.size(), + &PluginFactoryTensorRT::GetInstance())); trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); // Runtime is safe to delete after engine creation infer->destroy(); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc new file mode 100644 index 0000000000..0e4a157d79 --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc @@ -0,0 +1,89 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include +#include +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) { + // sanity check. + assert(EncodeOpName(GetPluginName()) != + *static_cast(serialized_data)); + const char* buffer = static_cast(serialized_data) + + sizeof(input_dim_list_.size()); + + size_t count = *reinterpret_cast(buffer); + buffer += sizeof(size_t); + + for (int i = 0; i < count; i++) { + nvinfer1::Dims dim; + std::memcpy(&(dim.nbDims), buffer, sizeof(dim.nbDims)); + buffer += sizeof(dim.nbDims); + std::memcpy(dim.d, buffer, sizeof(dim.d)); + buffer += sizeof(dim.d); + std::memcpy(dim.type, buffer, sizeof(dim.type)); + buffer += sizeof(dim.type); + input_dim_list_.emplace_back(dim); + } +} + +size_t PluginTensorRT::getSerializationSize() { + nvinfer1::Dims dim; + return sizeof(size_t) + sizeof(input_dim_list_.size()) + sizeof(dim.nbDims) + + sizeof(dim.d) + sizeof(dim.type); +} + +void PluginTensorRT::serialize(void* serialized_data) { + size_t encode_op_name = EncodeOpName(GetPluginName()); + char* buffer = static_cast(serialized_data); + std::memcpy(buffer, &encode_op_name, sizeof(size_t)); + buffer += sizeof(size_t); + + auto list_size = input_dim_list_.size(); + std::memcpy(buffer, &list_size, sizeof(input_dim_list_.size())); + buffer += sizeof(input_dim_list_.size()); + + for (int i = 0; i < input_dim_list_.size(); i++) { + auto dim = input_dim_list_[i]; + std::memcpy(buffer, &(dim.nbDims), sizeof(dim.nbDims)); + buffer += sizeof(dim.nbDims); + std::memcpy(buffer, dim.d, sizeof(dim.d)); + buffer += sizeof(dim.d); + std::memcpy(buffer, dim.type, sizeof(dim.type)); + buffer += sizeof(dim.type); + } +} + +bool PluginTensorRT::StoreAttribute(const string& key, const void* ptr, + const size_t size) { + if (attr_map_.count(key) != 0) return false; + + attr_map_.emplace(key, std::vector(size)); + std::memcpy(attr_map_[key].data(), ptr, size); + return true; +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h new file mode 100644 index 0000000000..1bbfe62a4e --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN +#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN + +#include +#include +#include +#include + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +using std::string; +using std::unordered_map; + +class PluginTensorRT : public nvinfer1::IPlugin { + public: + PluginTensorRT(){}; + PluginTensorRT(const void* serialized_data, size_t length); + // PluginTensorRT(const void* serialized_data, size_t length, size_t + // &incremental); + virtual string GetPluginName() = 0; + virtual bool Finalize() = 0; + + virtual bool SetAttribute(const string& key, const void* ptr, + const size_t size) = 0; + virtual bool GetAttribute(const string& key, const void* ptr, + size_t& size) = 0; + + void configure(const nvinfer1::Dims* inputs, int nbInputs, + const nvinfer1::Dims* outputs, int nbOutputs, + int maxBatchSize) override { + for (int index = 0; index < nbInputs; index++) { + nvinfer1::Dims dim; + dim.nbDims = inputs[index].nbDims; + for (int i = 0; i < dim.nbDims; i++) { + dim.d[i] = inputs[index].d[i]; + dim.type[i] = inputs[index].type[i]; + } + input_dim_list_.emplace_back(dim); + } + return; + } + + virtual bool StoreAttribute(const string& key, const void* ptr, + const size_t size); + + virtual size_t getSerializationSize() override; + virtual void serialize(void* buffer) override; + + protected: + std::unordered_map > attr_map_; + + std::vector input_dim_list_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc new file mode 100644 index 0000000000..799c609a3e --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layerName, + const void* serial_data, + size_t serial_length) { + size_t parsed_byte = 0; + // extract op_name from serial_data + size_t encoded_op_name = + ExtractOpName(serial_data, serial_length, parsed_byte); + + if (!IsPlugin(encoded_op_name)) { + return nullptr; + } + + // should I lock plugins here? + instance_m_.lock(); + auto plugin_ptr = + plugin_registry_[encoded_op_name].first(serial_data, serial_length); + // string op_name = "IncPluginTRT"; + // auto plugin_ptr = plugin_registry_[EncodeLayerName(&op_name)].second(); + // auto plugin_ptr = plugin_registry_.begin()->second.second(); + owned_plugins_.emplace_back(plugin_ptr); + instance_m_.unlock(); + + return plugin_ptr; +} + +PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string* op_name) { + if (!IsPlugin(op_name)) return nullptr; + + instance_m_.lock(); + auto plugin_ptr = plugin_registry_[EncodeLayerName(op_name)].second(); + owned_plugins_.emplace_back(plugin_ptr); + instance_m_.unlock(); + + return plugin_ptr; +} + +bool PluginFactoryTensorRT::RegisterPlugin( + const string* op_name, PluginDeserializeFunc deserialize_func, + PluginConstructFunc construct_func) { + if (IsPlugin(op_name)) return false; + + // get instance_m_ first before write to registry; + instance_m_.lock(); + auto ret = plugin_registry_.emplace( + EncodeLayerName(op_name), + std::make_pair(deserialize_func, construct_func)); + instance_m_.unlock(); + + return ret.second; +} + +void PluginFactoryTensorRT::DestroyPlugins() { return; } + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h new file mode 100644 index 0000000000..e68f4629d0 --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -0,0 +1,83 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY +#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY + +#include +#include +#include +#include "trt_plugin.h" +#include "trt_plugin_utils.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { + public: + // deserialization method + // virtual nvinfer1::IPlugin* createPlugin(const char* layerName, const void* + // serialData, size_t serialLength) override; + PluginTensorRT* createPlugin(const char* layerName, const void* serialData, + size_t serialLength) override; + + // construction + PluginTensorRT* CreatePlugin(const string* op_name); + + static PluginFactoryTensorRT& GetInstance() { + static PluginFactoryTensorRT factory_instance; + return factory_instance; + } + + bool RegisterPlugin(const string* op_name, + PluginDeserializeFunc deserialize_func, + PluginConstructFunc construct_func); + + bool IsPlugin(const size_t encode_name) { + return plugin_registry_.find(encode_name) != plugin_registry_.end(); + } + + bool IsPlugin(const string* op_name) { + return IsPlugin(EncodeLayerName(op_name)); + } + + size_t EncodeLayerName(const string* op_name) { + return EncodeOpName(*op_name); + } + + void DestroyPlugins(); + + protected: + std::unordered_map > + plugin_registry_; + + // TODO(jie): Owned plugin should be associated with different sessions; + // should really hand ownership of plugins to resource management; + std::vector > owned_plugins_; + std::mutex instance_m_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_FACTORY diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc new file mode 100644 index 0000000000..b14480cfa6 --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +size_t ExtractOpName(const void* serial_data, size_t serial_length, + size_t& incremental) { + incremental = sizeof(size_t); + if (serial_length < incremental) return 0; + size_t encoded_op_name = *static_cast(serial_data); + return encoded_op_name; +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h new file mode 100644 index 0000000000..e9675d84cd --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -0,0 +1,51 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS +#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS + +#include +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +typedef std::function + PluginDeserializeFunc; + +typedef std::function PluginConstructFunc; + +inline size_t EncodeOpName(std::string str) { + return std::hash{}(str); +} + +// TODO(jie): work on error handling here +size_t ExtractOpName(const void* serial_data, size_t serial_length, + size_t& incremental); + +// size_t Deserialize(const char* serial_data, size_t serial_length, size_t +// &incremental); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index 8b475177bc..30b5616475 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include #include @@ -33,7 +34,8 @@ tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine)); nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine( - serialized_engine.c_str(), serialized_engine.size(), nullptr); + serialized_engine.c_str(), serialized_engine.size(), + &tensorrt::PluginFactoryTensorRT::GetInstance()); int num_batch = -1; std::vector<::tensorflow::DataType> input_type; -- GitLab From 0eb443db1a5654168f396702cae39f5dc3fc7e2e Mon Sep 17 00:00:00 2001 From: imsheridan Date: Tue, 17 Apr 2018 20:25:51 +0800 Subject: [PATCH 005/755] Add deprecated_args decoration to array_ops/ sparse_ops --- tensorflow/python/ops/array_ops.py | 2 ++ tensorflow/python/ops/sparse_ops.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index ceeabe090d..06da2485c3 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2690,6 +2690,8 @@ reverse.__doc__ = gen_array_ops.reverse_v2.__doc__ # pylint: disable=redefined-builtin @tf_export("reverse_sequence") +@deprecation.deprecated_args(None, "Use the `seq_axis` argument instead", "seq_dim") +@deprecation.deprecated_args(None, "Use the `batch_axis` argument instead", "batch_dim") def reverse_sequence(input, seq_lengths, seq_axis=None, diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index c580052c32..73ab216f35 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -110,6 +110,7 @@ def _convert_to_sparse_tensors(sp_inputs): # pylint: disable=protected-access @tf_export("sparse_concat") +@deprecation.deprecated_args(None, "concat_dim is deprecated, use axis instead", "concat_dim") def sparse_concat(axis, sp_inputs, name=None, @@ -616,6 +617,7 @@ class KeywordRequired(object): @tf_export("sparse_split") +@deprecation.deprecated_args(None, "split_dim is deprecated, use axis instead", "split_dim") def sparse_split(keyword_required=KeywordRequired(), sp_input=None, num_split=None, -- GitLab From d6838f52c7daea81c57cdeab8e98c4cd617e5f8b Mon Sep 17 00:00:00 2001 From: imsheridan Date: Tue, 17 Apr 2018 20:30:13 +0800 Subject: [PATCH 006/755] fix typo to keep consistence --- tensorflow/python/ops/array_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 06da2485c3..b6a1f5a272 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -2690,8 +2690,8 @@ reverse.__doc__ = gen_array_ops.reverse_v2.__doc__ # pylint: disable=redefined-builtin @tf_export("reverse_sequence") -@deprecation.deprecated_args(None, "Use the `seq_axis` argument instead", "seq_dim") -@deprecation.deprecated_args(None, "Use the `batch_axis` argument instead", "batch_dim") +@deprecation.deprecated_args(None, "seq_dim is deprecated, use seq_axis instead", "seq_dim") +@deprecation.deprecated_args(None, "batch_dim is deprecated, use batch_axis instead", "batch_dim") def reverse_sequence(input, seq_lengths, seq_axis=None, -- GitLab From df0ce53aee6f4e14b3f1c9e0e772a1f7bd1bb95a Mon Sep 17 00:00:00 2001 From: Jie Date: Mon, 16 Apr 2018 17:47:00 -0700 Subject: [PATCH 007/755] [PR comment addressed] adding plugin test for registration updating plugin API wrapper addressing comments in the PR addressing coding style issues removing commented code --- tensorflow/contrib/tensorrt/BUILD | 15 ++- .../contrib/tensorrt/convert/convert_graph.cc | 2 +- .../contrib/tensorrt/convert/convert_nodes.cc | 10 +- .../custom_plugin_examples/inc_op_plugin.cc | 89 ++++++++++++++ .../custom_plugin_examples/inc_op_plugin.h | 114 ++++++++++++++++++ .../contrib/tensorrt/kernels/trt_engine_op.cc | 2 +- .../contrib/tensorrt/plugin/trt_plugin.cc | 37 ++++-- .../contrib/tensorrt/plugin/trt_plugin.h | 41 +++---- .../tensorrt/plugin/trt_plugin_factory.cc | 28 +++-- .../tensorrt/plugin/trt_plugin_factory.h | 33 +++-- .../tensorrt/plugin/trt_plugin_utils.cc | 18 ++- .../tensorrt/plugin/trt_plugin_utils.h | 11 +- .../tensorrt/plugin/trt_plugins_test.cc | 112 +++++++++++++++++ .../contrib/tensorrt/shape_fn/trt_shfn.cc | 2 +- 14 files changed, 423 insertions(+), 91 deletions(-) create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h create mode 100644 tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 98f18835b0..751f1d3482 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -277,7 +277,6 @@ tf_cc_test( ) # Library for the plugin factory -#cc_library( tf_cuda_library( name = "trt_plugins", srcs = [ @@ -292,9 +291,21 @@ tf_cuda_library( ], linkstatic = 1, deps = [ - #"@protobuf_archive//:protobuf_headers", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]), ) +tf_cuda_cc_test( + name = "trt_plugins_test", + size = "small", + srcs = ["plugin/trt_plugins_test.cc"], + deps = [ + ":trt_plugins", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:nv_infer", + ]), +) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 899e1721e6..91faba7e21 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -77,7 +77,7 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) { }; // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h) return (candidate_ops.count(node->type_string()) || - PluginFactoryTensorRT::GetInstance().IsPlugin(&node->type_string())); + PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string())); } void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index a03c1e224a..d02c1ebf50 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -249,9 +249,8 @@ class TFAttrs { std::vector GetAllAttrKey() { std::vector attr_list; - for (AttrMap::iterator iter = attrs_.begin(); iter != attrs_.end(); - iter++) { - attr_list.emplace_back(iter->first); + for (auto & attr_item : attrs_) { + attr_list.emplace_back(attr_item.first); } return attr_list; } @@ -508,7 +507,7 @@ class Converter { TF_RETURN_IF_ERROR(this->get_inputs(node_def, &inputs)); string op = node_def.op(); std::vector outputs; - if (PluginFactoryTensorRT::GetInstance().IsPlugin(&op)) { + if (PluginFactoryTensorRT::GetInstance()->IsPlugin(op)) { TF_RETURN_IF_ERROR(plugin_converter_(*this, node_def, inputs, &outputs)); } else { if (!op_registry_.count(op)) { @@ -1207,14 +1206,13 @@ tensorflow::Status ConvertPlugin(Converter& ctx, // plugin is owned by PluginFactory // TODO(jie): destroy plugins later (resource management) PluginTensorRT* plugin = - PluginFactoryTensorRT::GetInstance().CreatePlugin(&node_def.op()); + PluginFactoryTensorRT::GetInstance()->CreatePlugin(node_def.op()); // passing attributes // TODO(jie): support more general attribute TFAttrs attrs(node_def); auto attr_key_vector = attrs.GetAllAttrKey(); for (auto attr_key : attr_key_vector) { - std::cout << attr_key << std::endl; // TODO(jie): support only list of float for toy example here. auto data = attrs.get>(attr_key); size_t size_data = data.size() * sizeof(float); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc new file mode 100644 index 0000000000..2155079e8b --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -0,0 +1,89 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "inc_op_plugin.h" + +namespace tensorflow { +namespace tensorrt { + +const string IncOpPlugin::plugin_name_ = "IncPluginTRT"; + +IncOpPlugin* CreateIncPlugin() { + return new IncOpPlugin(); +} + + +IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { + return new IncOpPlugin(buffer, length); +} + +bool RegisterIncOpPlugin() { + if (PluginFactoryTensorRT::GetInstance()->IsPlugin(IncOpPlugin::plugin_name_)) + return false; + return PluginFactoryTensorRT::GetInstance()->RegisterPlugin(IncOpPlugin::plugin_name_, CreateIncPluginDeserialize, CreateIncPlugin); +} + + +IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) : + PluginTensorRT(serialized_data, length) +{ + // account for the consumed pointer. + size_t consumed_data = PluginTensorRT::getSerializationSize(); + assert(length-consumed_data >= sizeof(float)); + SetAttribute("inc", serialized_data+consumed_data, sizeof(float)); +} + +bool IncOpPlugin::SetAttribute(const string &key, const void *ptr, const size_t size) { + if (strcmp(key.c_str(), "inc")==0 && size == sizeof(float)) { + StoreAttribute(key, ptr, size); // save the attribute to own the data; + inc_ = *static_cast(ptr); + return true; + } + return false; +} + +bool IncOpPlugin::GetAttribute(const string &key, const void *ptr, size_t &size) { + if (attr_map_.find(key) != attr_map_.end()) { + ptr = attr_map_[key].data(); + size = attr_map_[key].size(); + return true; + } + return false; +} + +int IncOpPlugin::enqueue(int batchSize, const void*const *inputs, void** outputs, void*, cudaStream_t stream) { + int count = 1; + for (int i=0; i(inputs[0]); + float *output = reinterpret_cast(outputs[0]); + IncrementKernel(input, inc_, output, count, stream); + return 0; +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h new file mode 100644 index 0000000000..52b68487e6 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -0,0 +1,114 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN +#define TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include +#include +#include +#include +#include +#include + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +using std::string; +using std::unordered_map; + +class IncOpPlugin : public PluginTensorRT +{ +public: + static const string plugin_name_; + IncOpPlugin() {}; + IncOpPlugin(const void* serialized_data, size_t length); + const string GetPluginName() override {return plugin_name_;}; + bool Finalize() override {return true;}; + bool SetAttribute(const string &key, const void *ptr, const size_t size) override; + bool GetAttribute(const string &key, const void *ptr, size_t &size) override; + + // TRT IPlugin methods + int getNbOutputs() const override {return 1;} + + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override { + assert(index==0); + assert(nbInputDims==1); + return inputs[0]; + } + + // no configure needed + // use configure to setup input dimensions + void configure(const nvinfer1::Dims *inputs, int nbInputs, const nvinfer1::Dims *outputs, int nbOutputs, int maxBatchSize) override { + assert(nbInputs==1); + PluginTensorRT::configure(inputs, nbInputs, outputs, nbOutputs, maxBatchSize); + return; + } + + int initialize() override { + return 0; + } + + void terminate() override { + return; + } + + size_t getWorkspaceSize(int maxBatchSize) const override { + return 0; + } + + int enqueue(int batchSize, const void*const *inputs, void** outputs, void* workspace, cudaStream_t stream) override; + + size_t getSerializationSize() override { + return PluginTensorRT::getSerializationSize() + sizeof(float); + } + + void serialize(void* buffer) override { + // serializa parent stuff + // OpName + PluginTensorRT::serialize(buffer); + + // incremented buffer after parent serialization; + buffer = static_cast(buffer) + PluginTensorRT::getSerializationSize(); + + std::memcpy(buffer, &inc_, sizeof(float)); + buffer = static_cast(buffer) + sizeof(float); + return; + } + +protected: + float inc_; + nvinfer1::Dims dim_; + // std::unordered_map > attr_map_; +}; + +IncOpPlugin* CreateIncPlugin(); +IncOpPlugin* CreateIncPluginDeserialize(const void*, size_t); +bool RegisterIncOpPlugin(); +void IncrementKernel(const float* d_input, float inc, float* d_output, int count, cudaStream_t stream); + + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 8881c48fe6..162301fb52 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -60,7 +60,7 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) { IRuntime* infer = nvinfer1::createInferRuntime(logger); trt_engine_ptr_.reset(infer->deserializeCudaEngine( serialized_engine.c_str(), serialized_engine.size(), - &PluginFactoryTensorRT::GetInstance())); + PluginFactoryTensorRT::GetInstance())); trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext()); // Runtime is safe to delete after engine creation infer->destroy(); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc index 0e4a157d79..7600703775 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc @@ -26,10 +26,10 @@ namespace tensorrt { PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) { // sanity check. - assert(EncodeOpName(GetPluginName()) != - *static_cast(serialized_data)); - const char* buffer = static_cast(serialized_data) + - sizeof(input_dim_list_.size()); + const char* buffer = static_cast(serialized_data); + size_t op_name_char_count = *reinterpret_cast(buffer); + buffer += sizeof(size_t); + buffer += op_name_char_count; size_t count = *reinterpret_cast(buffer); buffer += sizeof(size_t); @@ -46,18 +46,37 @@ PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) { } } +void PluginTensorRT::configure(const nvinfer1::Dims* inputs, int num_inputs, + const nvinfer1::Dims* outputs, int num_outputs, + int max_batch_size) { + for (int index = 0; index < num_inputs; index++) { + nvinfer1::Dims dim; + dim.nbDims = inputs[index].nbDims; + for (int i = 0; i < dim.nbDims; i++) { + dim.d[i] = inputs[index].d[i]; + dim.type[i] = inputs[index].type[i]; + } + input_dim_list_.emplace_back(dim); + } + return; +} + size_t PluginTensorRT::getSerializationSize() { nvinfer1::Dims dim; - return sizeof(size_t) + sizeof(input_dim_list_.size()) + sizeof(dim.nbDims) + - sizeof(dim.d) + sizeof(dim.type); + return sizeof(size_t) + GetPluginName().size() + + sizeof(input_dim_list_.size()) + sizeof(dim.nbDims) + sizeof(dim.d) + + sizeof(dim.type); } void PluginTensorRT::serialize(void* serialized_data) { - size_t encode_op_name = EncodeOpName(GetPluginName()); + size_t op_name_size = GetPluginName().size(); char* buffer = static_cast(serialized_data); - std::memcpy(buffer, &encode_op_name, sizeof(size_t)); + std::memcpy(buffer, &op_name_size, sizeof(size_t)); buffer += sizeof(size_t); + std::memcpy(buffer, GetPluginName().data(), op_name_size); + buffer += op_name_size; + auto list_size = input_dim_list_.size(); std::memcpy(buffer, &list_size, sizeof(input_dim_list_.size())); buffer += sizeof(input_dim_list_.size()); @@ -73,7 +92,7 @@ void PluginTensorRT::serialize(void* serialized_data) { } } -bool PluginTensorRT::StoreAttribute(const string& key, const void* ptr, +bool PluginTensorRT::StoreAttribute(const std::string& key, const void* ptr, const size_t size) { if (attr_map_.count(key) != 0) return false; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h index 1bbfe62a4e..59b92657f6 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -28,46 +28,37 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -using std::string; -using std::unordered_map; - +// A wrapper class for TensorRT plugin +// User application should inherit from this class to write custom kernels. +// Allows user to insert custom op in TensorRT engine +// To register plugin in converter, user should also register custom +// tensorflow::tensorrt::PluginDeserializeFunc & +// tensorflow::tensorrt::PluginConstructFunc through +// tensorflow::tensorrt::PluginFactoryTensorRT class PluginTensorRT : public nvinfer1::IPlugin { public: PluginTensorRT(){}; PluginTensorRT(const void* serialized_data, size_t length); - // PluginTensorRT(const void* serialized_data, size_t length, size_t - // &incremental); - virtual string GetPluginName() = 0; + virtual const std::string& GetPluginName() = 0; virtual bool Finalize() = 0; - virtual bool SetAttribute(const string& key, const void* ptr, + virtual bool SetAttribute(const std::string& key, const void* ptr, const size_t size) = 0; - virtual bool GetAttribute(const string& key, const void* ptr, + virtual bool GetAttribute(const std::string& key, const void* ptr, size_t& size) = 0; - void configure(const nvinfer1::Dims* inputs, int nbInputs, - const nvinfer1::Dims* outputs, int nbOutputs, - int maxBatchSize) override { - for (int index = 0; index < nbInputs; index++) { - nvinfer1::Dims dim; - dim.nbDims = inputs[index].nbDims; - for (int i = 0; i < dim.nbDims; i++) { - dim.d[i] = inputs[index].d[i]; - dim.type[i] = inputs[index].type[i]; - } - input_dim_list_.emplace_back(dim); - } - return; - } - - virtual bool StoreAttribute(const string& key, const void* ptr, + void configure(const nvinfer1::Dims* inputs, int num_inputs, + const nvinfer1::Dims* outputs, int num_outputs, + int max_batch_size) override; + + virtual bool StoreAttribute(const std::string& key, const void* ptr, const size_t size); virtual size_t getSerializationSize() override; virtual void serialize(void* buffer) override; protected: - std::unordered_map > attr_map_; + std::unordered_map > attr_map_; std::vector input_dim_list_; }; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index 799c609a3e..44b10394c8 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -21,13 +21,13 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layerName, +PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, const void* serial_data, size_t serial_length) { size_t parsed_byte = 0; // extract op_name from serial_data - size_t encoded_op_name = - ExtractOpName(serial_data, serial_length, parsed_byte); + std::string encoded_op_name = + ExtractOpName(serial_data, serial_length, &parsed_byte); if (!IsPlugin(encoded_op_name)) { return nullptr; @@ -37,20 +37,18 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layerName, instance_m_.lock(); auto plugin_ptr = plugin_registry_[encoded_op_name].first(serial_data, serial_length); - // string op_name = "IncPluginTRT"; - // auto plugin_ptr = plugin_registry_[EncodeLayerName(&op_name)].second(); - // auto plugin_ptr = plugin_registry_.begin()->second.second(); owned_plugins_.emplace_back(plugin_ptr); instance_m_.unlock(); return plugin_ptr; } -PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string* op_name) { +PluginTensorRT* PluginFactoryTensorRT::CreatePlugin( + const std::string& op_name) { if (!IsPlugin(op_name)) return nullptr; instance_m_.lock(); - auto plugin_ptr = plugin_registry_[EncodeLayerName(op_name)].second(); + auto plugin_ptr = plugin_registry_[op_name].second(); owned_plugins_.emplace_back(plugin_ptr); instance_m_.unlock(); @@ -58,21 +56,27 @@ PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string* op_name) { } bool PluginFactoryTensorRT::RegisterPlugin( - const string* op_name, PluginDeserializeFunc deserialize_func, + const std::string& op_name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func) { if (IsPlugin(op_name)) return false; // get instance_m_ first before write to registry; instance_m_.lock(); auto ret = plugin_registry_.emplace( - EncodeLayerName(op_name), - std::make_pair(deserialize_func, construct_func)); + op_name, std::make_pair(deserialize_func, construct_func)); instance_m_.unlock(); return ret.second; } -void PluginFactoryTensorRT::DestroyPlugins() { return; } +void PluginFactoryTensorRT::DestroyPlugins() { + instance_m_.lock(); + for (auto& owned_plugin_ptr : owned_plugins_) { + owned_plugin_ptr.release(); + } + owned_plugins_.clear(); + instance_m_.unlock(); +} } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index e68f4629d0..824efcff35 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -32,39 +32,34 @@ namespace tensorrt { class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { public: // deserialization method - // virtual nvinfer1::IPlugin* createPlugin(const char* layerName, const void* - // serialData, size_t serialLength) override; - PluginTensorRT* createPlugin(const char* layerName, const void* serialData, - size_t serialLength) override; + PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data, + size_t serial_length) override; - // construction - PluginTensorRT* CreatePlugin(const string* op_name); + // plugin construction, PluginFactoryTensorRT owns the plugin; + PluginTensorRT* CreatePlugin(const std::string& op_name); - static PluginFactoryTensorRT& GetInstance() { - static PluginFactoryTensorRT factory_instance; + static PluginFactoryTensorRT* GetInstance() { + static PluginFactoryTensorRT* factory_instance = nullptr; + if (factory_instance == nullptr) { + factory_instance = new PluginFactoryTensorRT(); + } return factory_instance; } - bool RegisterPlugin(const string* op_name, + bool RegisterPlugin(const std::string& op_name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func); - bool IsPlugin(const size_t encode_name) { - return plugin_registry_.find(encode_name) != plugin_registry_.end(); + bool IsPlugin(const std::string& op_name) { + return plugin_registry_.find(op_name) != plugin_registry_.end(); } - bool IsPlugin(const string* op_name) { - return IsPlugin(EncodeLayerName(op_name)); - } - - size_t EncodeLayerName(const string* op_name) { - return EncodeOpName(*op_name); - } + size_t CountOwnedPlugins() { return owned_plugins_.size(); } void DestroyPlugins(); protected: - std::unordered_map > plugin_registry_; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc index b14480cfa6..8b65e8b41c 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -21,12 +22,17 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -size_t ExtractOpName(const void* serial_data, size_t serial_length, - size_t& incremental) { - incremental = sizeof(size_t); - if (serial_length < incremental) return 0; - size_t encoded_op_name = *static_cast(serial_data); - return encoded_op_name; +std::string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental) { + size_t op_name_char_count = *static_cast(serial_data); + *incremental = sizeof(size_t) + op_name_char_count; + + assert(serial_length >= *incremental); + + const char* buffer = static_cast(serial_data) + sizeof(size_t); + std::string op_name(buffer, op_name_char_count); + + return op_name; } } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h index e9675d84cd..d4da8b261e 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -31,16 +31,9 @@ typedef std::function typedef std::function PluginConstructFunc; -inline size_t EncodeOpName(std::string str) { - return std::hash{}(str); -} - // TODO(jie): work on error handling here -size_t ExtractOpName(const void* serial_data, size_t serial_length, - size_t& incremental); - -// size_t Deserialize(const char* serial_data, size_t serial_length, size_t -// &incremental); +std::string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc new file mode 100644 index 0000000000..2856b0f87d --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +namespace test { + +class StubPlugin : public PluginTensorRT { + public: + static const std::string plugin_name_; + StubPlugin(){}; + StubPlugin(const void* serialized_data, size_t length) + : PluginTensorRT(serialized_data, length){}; + const std::string& GetPluginName() override { return plugin_name_; }; + virtual bool Finalize() { return true; }; + virtual bool SetAttribute(const std::string& key, const void* ptr, + const size_t size) { + return true; + }; + virtual bool GetAttribute(const std::string& key, const void* ptr, + size_t& size) { + return true; + }; + int getNbOutputs() const override { return 1; } + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, + int nbInputDims) override { + return inputs[0]; + } + int initialize() override { return 0; } + void terminate() override { return; } + size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } + int enqueue(int batchSize, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) override { + return 0; + } +}; + +const std::string StubPlugin::plugin_name_ = "StubPlugin"; + +StubPlugin* CreateStubPlugin() { return new StubPlugin(); } + +StubPlugin* CreateStubPluginDeserialize(const void* serialized_data, + size_t length) { + return new StubPlugin(serialized_data, length); +} + +class PluginTest : public ::testing::Test { + public: + bool RegisterStubPlugin() { + if (PluginFactoryTensorRT::GetInstance()->IsPlugin( + StubPlugin::plugin_name_)) + return true; + return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( + StubPlugin::plugin_name_, CreateStubPluginDeserialize, + CreateStubPlugin); + } + + protected: +}; + +TEST_F(PluginTest, Registration) { + EXPECT_FALSE( + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); + EXPECT_TRUE(RegisterStubPlugin()); + + ASSERT_TRUE( + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); +} + +TEST_F(PluginTest, CreationDeletion) { + EXPECT_TRUE(RegisterStubPlugin()); + ASSERT_TRUE( + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); + + PluginFactoryTensorRT::GetInstance()->DestroyPlugins(); + ASSERT_TRUE(PluginFactoryTensorRT::GetInstance()->CreatePlugin( + StubPlugin::plugin_name_)); + ASSERT_EQ(1, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins()); + PluginFactoryTensorRT::GetInstance()->DestroyPlugins(); + ASSERT_EQ(0, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins()); +} + +} // namespace test +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index 30b5616475..f36495f6b6 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -35,7 +35,7 @@ tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine( serialized_engine.c_str(), serialized_engine.size(), - &tensorrt::PluginFactoryTensorRT::GetInstance()); + tensorrt::PluginFactoryTensorRT::GetInstance()); int num_batch = -1; std::vector<::tensorflow::DataType> input_type; -- GitLab From 758f25e8168bf1ff76c63a5b54dfd50ff54e4e27 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Tue, 17 Apr 2018 15:47:33 -0700 Subject: [PATCH 008/755] Fix calculation of the histogram buckets and writing to the tensor and add a unit test --- .../tensorboard/db/summary_db_writer.cc | 21 +++++--- .../tensorboard/db/summary_db_writer_test.cc | 49 +++++++++++++++++++ 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index 6590d6f7df..046a2d3884 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -1182,14 +1182,19 @@ class SummaryDbWriter : public SummaryWriterInterface { // See tensorboard/plugins/histogram/summary.py and data_compat.py Tensor t{DT_DOUBLE, {k, 3}}; auto data = t.flat(); - for (int i = 0; i < k; ++i) { - double left_edge = ((i - 1 >= 0) ? histo.bucket_limit(i - 1) - : std::numeric_limits::min()); - double right_edge = ((i + 1 < k) ? histo.bucket_limit(i + 1) - : std::numeric_limits::max()); - data(i + 0) = left_edge; - data(i + 1) = right_edge; - data(i + 2) = histo.bucket(i); + for (int i = 0, j = 0; i < k; ++i) { + // From summary.proto + // Parallel arrays encoding the bucket boundaries and the bucket values. + // bucket(i) is the count for the bucket i. The range for + // a bucket is: + // i == 0: -DBL_MAX .. bucket_limit(0) + // i != 0: bucket_limit(i-1) .. bucket_limit(i) + double left_edge = (i == 0) ? std::numeric_limits::min() + : histo.bucket_limit(i - 1); + + data(j++) = left_edge; + data(j++) = histo.bucket_limit(i); + data(j++) = histo.bucket(i); } int64 tag_id; PatchPluginName(s->mutable_metadata(), kHistogramPluginName); diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc index 29b8063218..cb51325d15 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -100,6 +100,55 @@ class SummaryDbWriterTest : public ::testing::Test { SummaryWriterInterface* writer_ = nullptr; }; +TEST_F(SummaryDbWriterTest, WriteHistogram_VerifyTensorValues) { + TF_ASSERT_OK(CreateSummaryDbWriter(db_, "histtest", "test1", "user1", &env_, + &writer_)); + int step = 0; + std::unique_ptr e{new Event}; + e->set_step(step); + e->set_wall_time(123); + Summary::Value* s = e->mutable_summary()->add_value(); + s->set_tag("normal/myhisto"); + + double dummy_value = 10.123; + HistogramProto* proto = s->mutable_histo(); + proto->Clear(); + proto->set_min(dummy_value); + proto->set_max(dummy_value); + proto->set_num(dummy_value); + proto->set_sum(dummy_value); + proto->set_sum_squares(dummy_value); + + int size = 3; + double bucket_limits[] = {-30.5, -10.5, -5.5}; + double bucket[] = {-10, 10, 20}; + for (int i = 0; i < size; i++) { + proto->add_bucket_limit(bucket_limits[i]); + proto->add_bucket(bucket[i]); + } + TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); + TF_ASSERT_OK(writer_->Flush()); + writer_->Unref(); + writer_ = nullptr; + + // Verify the data + string result = QueryString("SELECT data FROM Tensors"); + const double* val = reinterpret_cast(result.data()); + double histarray[] = {std::numeric_limits::min(), + -30.5, + -10, + -30.5, + -10.5, + 10, + -10.5, + -5.5, + 20}; + int histarray_size = 9; + for (int i = 0; i < histarray_size; i++) { + EXPECT_EQ(histarray[i], val[i]); + } +} + TEST_F(SummaryDbWriterTest, NothingWritten_NoRowsCreated) { TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_, &writer_)); -- GitLab From 419dbc8f44efe06612845ec291b98bb49e873639 Mon Sep 17 00:00:00 2001 From: Jie Date: Wed, 18 Apr 2018 14:42:42 -0700 Subject: [PATCH 009/755] [PR comment addressed] Added custom plugin example registered tensorflow custom op & plugin kernel python wrapper to import custom op & register plugin clang-format --- tensorflow/contrib/tensorrt/BUILD | 1 + .../contrib/tensorrt/convert/convert_nodes.cc | 2 +- .../tensorrt/custom_plugin_examples/BUILD | 110 ++++++++++++++++++ .../custom_plugin_examples/__init__.py | 24 ++++ .../tensorrt/custom_plugin_examples/inc_op.py | 30 +++++ .../inc_op_kernel.cu.cc | 44 +++++++ .../custom_plugin_examples/inc_op_kernel.h | 34 ++++++ .../custom_plugin_examples/inc_op_plugin.cc | 55 +++++---- .../custom_plugin_examples/inc_op_plugin.h | 85 +++++++------- .../custom_plugin_examples/ops/inc_op.cc | 34 ++++++ .../custom_plugin_examples/plugin_wrap.i | 31 +++++ .../test/plugin_test.py | 93 +++++++++++++++ .../contrib/tensorrt/plugin/trt_plugin.cc | 1 - .../contrib/tensorrt/plugin/trt_plugin.h | 10 +- .../tensorrt/plugin/trt_plugin_factory.cc | 14 +-- .../tensorrt/plugin/trt_plugin_factory.h | 6 +- .../tensorrt/plugin/trt_plugin_utils.cc | 4 +- .../tensorrt/plugin/trt_plugin_utils.h | 5 +- .../tensorrt/plugin/trt_plugins_test.cc | 6 +- 19 files changed, 485 insertions(+), 104 deletions(-) create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i create mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 751f1d3482..9c81c12705 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -291,6 +291,7 @@ tf_cuda_library( ], linkstatic = 1, deps = [ + "//tensorflow/core:platform_base", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]), diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index d02c1ebf50..874be96c78 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -249,7 +249,7 @@ class TFAttrs { std::vector GetAllAttrKey() { std::vector attr_list; - for (auto & attr_item : attrs_) { + for (const auto& attr_item : attrs_) { attr_list.emplace_back(attr_item.first); } return attr_list; diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD new file mode 100644 index 0000000000..5603ed0ccf --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -0,0 +1,110 @@ +package(default_visibility = ["//tensorflow:__subpackages__"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_cuda_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_py_wrap_cc", + "tf_copts", +) +load( + "@local_config_tensorrt//:build_defs.bzl", + "if_tensorrt", +) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") + +tf_kernel_library( + name = "_inc_op_plugin_kernel", + srcs = [ + "inc_op_plugin.cc", + ], + hdrs = [ + ], + gpu_srcs = [ + "inc_op_kernel.cu.cc", + "inc_op_kernel.h", + "inc_op_plugin.h", + ], + deps = if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + "//tensorflow/contrib/tensorrt:trt_plugins", + ]), +) + +tf_gen_op_libs( + op_lib_names = [ + "inc_op", + ], + deps = if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + "//tensorflow/contrib/tensorrt:trt_plugins", + ]), +) + +tf_gen_op_wrapper_py( + name = "inc_op", + gen_locally = True, + deps = [ + ":inc_op_op_lib", + ], +) + +tf_py_wrap_cc( + name = "plugin_wrap", + srcs = [ + "plugin_wrap.i", + ], + copts = tf_copts(), + deps = [ + ":_inc_op_plugin_kernel", + "//tensorflow/core:framework_lite", + "//util/python:python_headers", + ], +) + +tf_custom_op_library( + name = "_inc_op.so", + srcs = ["ops/inc_op.cc"], + deps = [ + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "//tensorflow/contrib/tensorrt:trt_plugins", + ]), +) + +tf_custom_op_py_library( + name = "inc_op_loader", + srcs = ["inc_op.py"], + dso = [ + ":_inc_op.so", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:resources", + ], +) + +py_library( + name = "inc_op_py", + srcs_version = "PY2AND3", + deps = [ + ":inc_op", + ":inc_op_loader", + ], +) + +py_library( + name = "init_py", + srcs = [ + "__init__.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":inc_op_py", + ":plugin_wrap", + ], +) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py new file mode 100644 index 0000000000..a61d008941 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Import custom op for plugin and register it in plugin factory registry.""" + +from ops import gen_inc_op +from plugin_wrap import inc_op_register +from inc_op import * + +# pylint: disable=unused-import,wildcard-import,g-import-not-at-top +inc_op = gen_inc_op.inc_plugin_trt +inc_op_register() +# pylint: enable=unused-import,wildcard-import,g-import-not-at-top diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py new file mode 100644 index 0000000000..ef8e26fbde --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py @@ -0,0 +1,30 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import platform +import os + +if platform.system() != "Windows": + from tensorflow.contrib.util import loader + from tensorflow.python.platform import resource_loader + + _inc_op = loader.load_op_library( + os.path.join(os.path.dirname(os.path.realpath(__file__)),"_inc_op.so")) +else: + raise RuntimeError("Windows not supported") diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc new file mode 100644 index 0000000000..5dd6b9bf94 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +__global__ void VecInc(const float* vec, float inc, float* dest, int n) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i < n) dest[i] = vec[i] + inc; +} + +void IncrementKernel(const float* d_input, float inc, float* d_output, + int count, cudaStream_t stream) { + int threads_per_block = 256; + int blocks_per_grid = (count + threads_per_block - 1) / threads_per_block; + + VecInc<<>>(d_input, inc, + d_output, count); +} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h new file mode 100644 index 0000000000..ec269143e8 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_INC_OP +#define TENSORFLOW_CONTRIB_TENSORRT_INC_OP + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#if GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +__global__ void VecInc(float* vec, float inc, float* dest, int n); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CONTRIB_TENSORRT_INC_OP diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc index 2155079e8b..21617fa8b5 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -13,24 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "inc_op_plugin.h" namespace tensorflow { namespace tensorrt { -const string IncOpPlugin::plugin_name_ = "IncPluginTRT"; - -IncOpPlugin* CreateIncPlugin() { - return new IncOpPlugin(); -} +const std::string IncOpPlugin::plugin_name_ = "IncPluginTRT"; +IncOpPlugin* CreateIncPlugin() { return new IncOpPlugin(); } IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { return new IncOpPlugin(buffer, length); @@ -39,45 +34,49 @@ IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { bool RegisterIncOpPlugin() { if (PluginFactoryTensorRT::GetInstance()->IsPlugin(IncOpPlugin::plugin_name_)) return false; - return PluginFactoryTensorRT::GetInstance()->RegisterPlugin(IncOpPlugin::plugin_name_, CreateIncPluginDeserialize, CreateIncPlugin); + return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( + IncOpPlugin::plugin_name_, CreateIncPluginDeserialize, CreateIncPlugin); } - -IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) : - PluginTensorRT(serialized_data, length) -{ +IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) + : PluginTensorRT(serialized_data, length) { // account for the consumed pointer. size_t consumed_data = PluginTensorRT::getSerializationSize(); - assert(length-consumed_data >= sizeof(float)); - SetAttribute("inc", serialized_data+consumed_data, sizeof(float)); + assert(length - consumed_data >= sizeof(float)); + const char* buffer = reinterpret_cast(serialized_data); + SetAttribute("inc", buffer + consumed_data, sizeof(float)); } -bool IncOpPlugin::SetAttribute(const string &key, const void *ptr, const size_t size) { - if (strcmp(key.c_str(), "inc")==0 && size == sizeof(float)) { - StoreAttribute(key, ptr, size); // save the attribute to own the data; +bool IncOpPlugin::SetAttribute(const std::string& key, const void* ptr, + const size_t size) { + if (strcmp(key.c_str(), "inc") == 0 && size == sizeof(float)) { + StoreAttribute(key, ptr, size); // save the attribute to own the data; inc_ = *static_cast(ptr); return true; } return false; } -bool IncOpPlugin::GetAttribute(const string &key, const void *ptr, size_t &size) { - if (attr_map_.find(key) != attr_map_.end()) { - ptr = attr_map_[key].data(); - size = attr_map_[key].size(); +bool IncOpPlugin::GetAttribute(const std::string& key, const void** ptr, + size_t* size) const { + const auto& iter = attr_map_.find(key); + if (iter != attr_map_.end()) { + *ptr = iter->second.data(); + *size = iter->second.size(); return true; } return false; } -int IncOpPlugin::enqueue(int batchSize, const void*const *inputs, void** outputs, void*, cudaStream_t stream) { +int IncOpPlugin::enqueue(int batch_size, const void* const* inputs, + void** outputs, void*, cudaStream_t stream) { int count = 1; - for (int i=0; i(inputs[0]); - float *output = reinterpret_cast(outputs[0]); + count *= batch_size; + const float* input = reinterpret_cast(inputs[0]); + float* output = reinterpret_cast(outputs[0]); IncrementKernel(input, inc_, output, count, stream); return 0; } diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index 52b68487e6..a4774d354c 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -16,13 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN #define TENSORFLOW_CONTRIB_TENSORRT_INC_OP_PLUGIN -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" -#include -#include -#include -#include #include +#include #include +#include +#include +#include +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -31,50 +31,44 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -using std::string; -using std::unordered_map; - -class IncOpPlugin : public PluginTensorRT -{ -public: - static const string plugin_name_; - IncOpPlugin() {}; +class IncOpPlugin : public PluginTensorRT { + public: + static const std::string plugin_name_; + IncOpPlugin(){}; IncOpPlugin(const void* serialized_data, size_t length); - const string GetPluginName() override {return plugin_name_;}; - bool Finalize() override {return true;}; - bool SetAttribute(const string &key, const void *ptr, const size_t size) override; - bool GetAttribute(const string &key, const void *ptr, size_t &size) override; - - // TRT IPlugin methods - int getNbOutputs() const override {return 1;} - - nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override { - assert(index==0); - assert(nbInputDims==1); + const std::string& GetPluginName() const override { return plugin_name_; }; + bool Finalize() override { return true; }; + bool SetAttribute(const std::string& key, const void* ptr, + const size_t size) override; + bool GetAttribute(const std::string& key, const void** ptr, + size_t* size) const override; + + int getNbOutputs() const override { return 1; } + + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, + int num_input_dims) override { + assert(index == 0); + assert(num_input_dims == 1); return inputs[0]; } - // no configure needed // use configure to setup input dimensions - void configure(const nvinfer1::Dims *inputs, int nbInputs, const nvinfer1::Dims *outputs, int nbOutputs, int maxBatchSize) override { - assert(nbInputs==1); - PluginTensorRT::configure(inputs, nbInputs, outputs, nbOutputs, maxBatchSize); - return; + void configure(const nvinfer1::Dims* inputs, int num_inputs, + const nvinfer1::Dims* outputs, int num_outputs, + int max_batch_size) override { + assert(nb_inputs == 1); + PluginTensorRT::configure(inputs, num_inputs, outputs, num_outputs, + max_batch_size); } - int initialize() override { - return 0; - } + int initialize() override { return 0; } - void terminate() override { - return; - } + void terminate() override {} - size_t getWorkspaceSize(int maxBatchSize) const override { - return 0; - } + size_t getWorkspaceSize(int max_batch_size) const override { return 0; } - int enqueue(int batchSize, const void*const *inputs, void** outputs, void* workspace, cudaStream_t stream) override; + int enqueue(int batch_size, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) override; size_t getSerializationSize() override { return PluginTensorRT::getSerializationSize() + sizeof(float); @@ -86,24 +80,23 @@ public: PluginTensorRT::serialize(buffer); // incremented buffer after parent serialization; - buffer = static_cast(buffer) + PluginTensorRT::getSerializationSize(); + buffer = + static_cast(buffer) + PluginTensorRT::getSerializationSize(); std::memcpy(buffer, &inc_, sizeof(float)); buffer = static_cast(buffer) + sizeof(float); - return; } -protected: + protected: float inc_; nvinfer1::Dims dim_; - // std::unordered_map > attr_map_; }; -IncOpPlugin* CreateIncPlugin(); +IncOpPlugin* CreateIncPlugin(); IncOpPlugin* CreateIncPluginDeserialize(const void*, size_t); bool RegisterIncOpPlugin(); -void IncrementKernel(const float* d_input, float inc, float* d_output, int count, cudaStream_t stream); - +void IncrementKernel(const float* d_input, float inc, float* d_output, + int count, cudaStream_t stream); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc new file mode 100644 index 0000000000..0dfead8f57 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT + +using namespace tensorflow; + +REGISTER_OP("IncPluginTRT") + .Attr("inc: list(float)") + .Input("input: float32") + .Output("output: float32") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); + +#endif // GOOGLE_CUDA +#endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i new file mode 100644 index 0000000000..9882daa842 --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* Wrap inc_op_plugin */ +%module inc_op_plugin +%{ +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" +extern bool tensorflow::tensorrt::RegisterIncOpPlugin(); +%} + +%{ +bool inc_op_register() { + return tensorflow::tensorrt::RegisterIncOpPlugin(); +} +%} + +extern bool tensorflow::tensorrt::RegisterIncOpPlugin(); + +bool inc_op_register(); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py new file mode 100644 index 0000000000..52f49ae00e --- /dev/null +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py @@ -0,0 +1,93 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to show usage of TensorRT custom op & plugin.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# normally we should do import tensorflow as tf and then +# tf.placeholder, tf.constant, tf.nn.conv2d etc but +# it looks like internal builds don't like it so +# importing every module individually + +from tensorflow.contrib import tensorrt as trt +from tensorflow.core.protobuf import config_pb2 as cpb2 +from tensorflow.python.client import session as csess +from tensorflow.python.framework import dtypes as dtypes +from tensorflow.python.framework import importer as importer +from tensorflow.python.framework import ops as ops +from tensorflow.python.ops import array_ops as aops +from tensorflow.python.ops import nn as nn +from tensorflow.python.ops import nn_ops as nn_ops +import numpy as np + +# import custom_op as plugin op +# the python api handles registration to the plugin factory +from tensorflow.contrib.tensorrt import custom_plugin_examples as cpe + +def get_plugin_graph_def(): + """Create a simple graph and return its graph_def.""" + g = ops.Graph() + with g.as_default(): + a = aops.placeholder( + dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") + relu = nn.relu(a, "relu") + v = nn_ops.max_pool( + relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") + + # insert custom_op in the graph + v = cpe.inc_op(v, inc=[16.5], name="plugin_test") + + v = v*2.0 + v = nn.relu(v) + v = nn.relu(v) + aops.squeeze(v, name="output") + return g.as_graph_def() + +def run_graph(gdef, dumm_inp): + """Run given graphdef once.""" + gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + ops.reset_default_graph() + g = ops.Graph() + with g.as_default(): + inp, out = importer.import_graph_def( + graph_def=gdef, return_elements=["input", "output"]) + inp = inp.outputs[0] + out = out.outputs[0] + + with csess.Session( + config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: + val = sess.run(out, {inp: dumm_inp}) + return val + +if "__main__" in __name__: + inp_dims = (5, 24, 24, 2) + dummy_input = np.ones(inp_dims).astype(np.float32) + orig_graph = get_plugin_graph_def() # graph with plugin node + + # trigger conversion. + # plugin nodes have been registered during import, converter will be able to + # create corresponding plugin layer during conversion. + trt_graph = trt.create_inference_graph( + input_graph_def=orig_graph, + outputs=["output"], + max_batch_size=inp_dims[0], + max_workspace_size_bytes=1 << 25, + precision_mode="FP32", + minimum_segment_size=2 + ) + o2 = run_graph(trt_graph, dummy_input) + print (o2) diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc index 7600703775..82c549dbf5 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc @@ -58,7 +58,6 @@ void PluginTensorRT::configure(const nvinfer1::Dims* inputs, int num_inputs, } input_dim_list_.emplace_back(dim); } - return; } size_t PluginTensorRT::getSerializationSize() { diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h index 59b92657f6..772974a769 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -32,20 +32,18 @@ namespace tensorrt { // User application should inherit from this class to write custom kernels. // Allows user to insert custom op in TensorRT engine // To register plugin in converter, user should also register custom -// tensorflow::tensorrt::PluginDeserializeFunc & -// tensorflow::tensorrt::PluginConstructFunc through -// tensorflow::tensorrt::PluginFactoryTensorRT +// PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT class PluginTensorRT : public nvinfer1::IPlugin { public: PluginTensorRT(){}; PluginTensorRT(const void* serialized_data, size_t length); - virtual const std::string& GetPluginName() = 0; + virtual const std::string& GetPluginName() const = 0; virtual bool Finalize() = 0; virtual bool SetAttribute(const std::string& key, const void* ptr, const size_t size) = 0; - virtual bool GetAttribute(const std::string& key, const void* ptr, - size_t& size) = 0; + virtual bool GetAttribute(const std::string& key, const void** ptr, + size_t* size) const = 0; void configure(const nvinfer1::Dims* inputs, int num_inputs, const nvinfer1::Dims* outputs, int num_outputs, diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index 44b10394c8..776bce119d 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -33,12 +33,10 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, return nullptr; } - // should I lock plugins here? - instance_m_.lock(); + std::lock_guard lock(instance_m_); auto plugin_ptr = plugin_registry_[encoded_op_name].first(serial_data, serial_length); owned_plugins_.emplace_back(plugin_ptr); - instance_m_.unlock(); return plugin_ptr; } @@ -47,10 +45,9 @@ PluginTensorRT* PluginFactoryTensorRT::CreatePlugin( const std::string& op_name) { if (!IsPlugin(op_name)) return nullptr; - instance_m_.lock(); + std::lock_guard lock(instance_m_); auto plugin_ptr = plugin_registry_[op_name].second(); owned_plugins_.emplace_back(plugin_ptr); - instance_m_.unlock(); return plugin_ptr; } @@ -60,22 +57,19 @@ bool PluginFactoryTensorRT::RegisterPlugin( PluginConstructFunc construct_func) { if (IsPlugin(op_name)) return false; - // get instance_m_ first before write to registry; - instance_m_.lock(); + std::lock_guard lock(instance_m_); auto ret = plugin_registry_.emplace( op_name, std::make_pair(deserialize_func, construct_func)); - instance_m_.unlock(); return ret.second; } void PluginFactoryTensorRT::DestroyPlugins() { - instance_m_.lock(); + std::lock_guard lock(instance_m_); for (auto& owned_plugin_ptr : owned_plugins_) { owned_plugin_ptr.release(); } owned_plugins_.clear(); - instance_m_.unlock(); } } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 824efcff35..08fd376844 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -39,10 +39,8 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { PluginTensorRT* CreatePlugin(const std::string& op_name); static PluginFactoryTensorRT* GetInstance() { - static PluginFactoryTensorRT* factory_instance = nullptr; - if (factory_instance == nullptr) { - factory_instance = new PluginFactoryTensorRT(); - } + static PluginFactoryTensorRT* factory_instance = + new PluginFactoryTensorRT(); return factory_instance; } diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc index 8b65e8b41c..c5d3f38280 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc @@ -22,8 +22,8 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -std::string ExtractOpName(const void* serial_data, size_t serial_length, - size_t* incremental) { +string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental) { size_t op_name_char_count = *static_cast(serial_data); *incremental = sizeof(size_t) + op_name_char_count; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h index d4da8b261e..a94c67bba0 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -32,8 +33,8 @@ typedef std::function typedef std::function PluginConstructFunc; // TODO(jie): work on error handling here -std::string ExtractOpName(const void* serial_data, size_t serial_length, - size_t* incremental); +string ExtractOpName(const void* serial_data, size_t serial_length, + size_t* incremental); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc index 2856b0f87d..9ef0fce972 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc @@ -51,9 +51,9 @@ class StubPlugin : public PluginTensorRT { return inputs[0]; } int initialize() override { return 0; } - void terminate() override { return; } + void terminate() override {} size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } - int enqueue(int batchSize, const void* const* inputs, void** outputs, + int enqueue(int batch_size, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override { return 0; } @@ -78,8 +78,6 @@ class PluginTest : public ::testing::Test { StubPlugin::plugin_name_, CreateStubPluginDeserialize, CreateStubPlugin); } - - protected: }; TEST_F(PluginTest, Registration) { -- GitLab From abfbbb86295c67eb1ac7c92235dbd5fb4b707169 Mon Sep 17 00:00:00 2001 From: Haggai Date: Wed, 18 Apr 2018 22:23:35 -0700 Subject: [PATCH 010/755] Remove reliance on TF core in XLA CPU Fft --- tensorflow/compiler/xla/service/cpu/BUILD | 1 - .../xla/service/cpu/runtime_fft_impl.h | 18 +++--------------- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 246b802861..6428ca528c 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -513,7 +513,6 @@ cc_library( deps = [ "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h index 984cb0616e..4f6b363364 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -21,8 +21,6 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/numeric_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/types.h" // 'tensorflow' namespace is used so that int64 and other types don't require @@ -71,11 +69,9 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = fft_shape[i]; out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -88,8 +84,8 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); // Compute the full FFT using a temporary tensor. - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(in_dims); + const Eigen::DSizes zero_start_indices; full_fft.device(device) = input.template fft(axes); @@ -112,11 +108,9 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; out_dims[i + 1] = fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -129,8 +123,7 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, // region we will slice from input given fft_shape. We slice input to // fft_shape on its inner-most dimensions, except the last (which we // slice to fft_shape[-1] / 2 + 1). - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(out_dims); // Calculate the starting point and range of the source of // negative frequency part. @@ -179,7 +172,6 @@ template void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, int32 fft_type, int64 input_batch, int64 fft_length0, int64 fft_length1, int64 fft_length2) { - CHECK(::xla::FftType_IsValid(fft_type)) << fft_type; switch (fft_type) { case ::xla::FftType::FFT: EigenFftC2C( @@ -203,8 +195,6 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, device, static_cast(out), static_cast(operand), input_batch, fft_length0, fft_length1, fft_length2); break; - default: - LOG(FATAL) << "Unsupported FFT type: " << fft_type; } } @@ -229,8 +219,6 @@ void EigenFftImpl(const EigenDevice& device, void* out, void* operand, input_batch, fft_length0, fft_length1, fft_length2); break; - default: - LOG(FATAL) << "Unsupported FFT rank " << fft_rank; } } -- GitLab From 6343b8dd77ba94c74acc3c04c985a5535b2b8169 Mon Sep 17 00:00:00 2001 From: Haggai Date: Wed, 18 Apr 2018 22:26:42 -0700 Subject: [PATCH 011/755] Add single-threaded support for XLA CPU Fft --- tensorflow/compiler/xla/service/cpu/BUILD | 17 ++++++++++ .../compiler/xla/service/cpu/cpu_runtime.cc | 2 ++ .../compiler/xla/service/cpu/cpu_runtime.h | 1 + .../compiler/xla/service/cpu/ir_emitter.cc | 8 ++++- .../cpu/runtime_single_threaded_fft.cc | 32 +++++++++++++++++++ .../service/cpu/runtime_single_threaded_fft.h | 31 ++++++++++++++++++ .../xla/service/cpu/simple_orc_jit.cc | 2 ++ 7 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc create mode 100644 tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 6428ca528c..4862f9e2f9 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -176,6 +176,7 @@ cc_library( ":runtime_matmul", ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", + ":runtime_single_threaded_fft", ":runtime_single_threaded_matmul", "@llvm//:execution_engine", "@llvm//:core", @@ -574,6 +575,22 @@ cc_library( ], ) +cc_library( + name = "runtime_single_threaded_fft", + srcs = [ + "runtime_fft_impl.h", + "runtime_single_threaded_fft.cc", + ], + hdrs = ["runtime_single_threaded_fft.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + cc_library( name = "runtime_single_threaded_matmul", srcs = ["runtime_single_threaded_matmul.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 872b0be1f8..4fcab483d6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -50,6 +50,8 @@ extern const char* const kEigenConvF16SymbolName = extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; +extern const char* const kEigenSingleThreadedFftSymbolName = + "__xla_cpu_runtime_EigenSingleThreadedFft"; extern const char* const kEigenSingleThreadedMatMulF16SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF16"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index e392e231b4..0cc45dac61 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -51,6 +51,7 @@ extern const char* const kMKLSingleThreadedMatMulF64SymbolName; extern const char* const kEigenConvF16SymbolName; extern const char* const kEigenConvF32SymbolName; extern const char* const kEigenFftSymbolName; +extern const char* const kEigenSingleThreadedFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF16SymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 3405277d44..8c2ca7104c 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1171,7 +1171,13 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); - const char* fn_name = runtime::kEigenFftSymbolName; + + bool multi_threaded_eigen = + hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + const char* fn_name = multi_threaded_eigen + ? runtime::kEigenFftSymbolName + : runtime::kEigenSingleThreadedFftSymbolName; + llvm::Function* fft_func = llvm::cast( module_->getOrInsertFunction(fn_name, fft_type)); fft_func->setCallingConv(llvm::CallingConv::C); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc new file mode 100644 index 0000000000..2613ddb127 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" + +#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* run_options_ptr, void* out, void* operand, int32 fft_type, + int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand, fft_type, + fft_rank, input_batch, fft_length0, fft_length1, + fft_length2); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h new file mode 100644 index 0000000000..dcd133d012 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, + void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + tensorflow::int64 input_batch, tensorflow::int64 fft_length0, + tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index b7ce5bbe47..7bd17002e3 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" #include "tensorflow/compiler/xla/types.h" @@ -190,6 +191,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); -- GitLab From 8149077125f6c2701713fef12fe0f0caac729e27 Mon Sep 17 00:00:00 2001 From: Krish Ravindranath Date: Thu, 19 Apr 2018 14:53:10 -0400 Subject: [PATCH 012/755] changes error to ValueError, notes that shuffle must be provided and should be set True for training --- tensorflow/python/estimator/inputs/numpy_io.py | 5 +++-- tensorflow/python/estimator/inputs/pandas_io.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py index a6f4712910..5b5eb41466 100644 --- a/tensorflow/python/estimator/inputs/numpy_io.py +++ b/tensorflow/python/estimator/inputs/numpy_io.py @@ -139,8 +139,9 @@ def numpy_input_fn(x, TypeError: `x` is not a dict or array, or if `shuffle` is not bool. """ if not isinstance(shuffle, bool): - raise TypeError('shuffle must be explicitly set as boolean; ' - 'got {}'.format(shuffle)) + raise ValueError('shuffle must be provided and explicitly set as boolean ' + '(it is recommended to set it as True for training); ' + 'got {}'.format(shuffle)) def input_fn(): """Numpy input function.""" diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py index bd06843021..16825e09de 100644 --- a/tensorflow/python/estimator/inputs/pandas_io.py +++ b/tensorflow/python/estimator/inputs/pandas_io.py @@ -75,8 +75,9 @@ def pandas_input_fn(x, 'pandas_input_fn should not be called without pandas installed') if not isinstance(shuffle, bool): - raise TypeError('shuffle must be explicitly set as boolean; ' - 'got {}'.format(shuffle)) + raise ValueError('shuffle must be provided and explicitly set as boolean ' + '(it is recommended to set it as True for training); ' + 'got {}'.format(shuffle)) x = x.copy() if y is not None: -- GitLab From 459d61cbe8ab9cbb86b2bb7eac602ff565d54fde Mon Sep 17 00:00:00 2001 From: Jie Date: Thu, 19 Apr 2018 13:48:14 -0700 Subject: [PATCH 013/755] [PR comment addressed] switched from std::string to TF string custom_plugin_examples python test added (bazel) style guide violation addressed --- .../contrib/tensorrt/convert/convert_nodes.cc | 22 ++--- .../tensorrt/custom_plugin_examples/BUILD | 42 ++++++--- .../custom_plugin_examples/__init__.py | 12 +-- .../inc_op_kernel.cu.cc | 2 - .../custom_plugin_examples/inc_op_kernel.h | 3 +- .../{inc_op_plugin.cc => inc_op_plugin.cu.cc} | 9 +- .../custom_plugin_examples/inc_op_plugin.h | 18 ++-- .../custom_plugin_examples/ops/inc_op.cc | 4 +- .../{test => }/plugin_test.py | 46 +++++----- tensorflow/contrib/tensorrt/log/trt_logger.h | 2 +- .../contrib/tensorrt/plugin/trt_plugin.cc | 3 +- .../contrib/tensorrt/plugin/trt_plugin.h | 14 +-- .../tensorrt/plugin/trt_plugin_factory.cc | 7 +- .../tensorrt/plugin/trt_plugin_factory.h | 8 +- .../tensorrt/plugin/trt_plugin_utils.cc | 2 +- .../tensorrt/plugin/trt_plugins_test.cc | 19 ++-- tensorflow/contrib/tensorrt/plugin_test.py | 88 +++++++++++++++++++ .../tensorrt/resources/trt_resources.h | 2 +- 18 files changed, 205 insertions(+), 98 deletions(-) rename tensorflow/contrib/tensorrt/custom_plugin_examples/{inc_op_plugin.cc => inc_op_plugin.cu.cc} (91%) rename tensorflow/contrib/tensorrt/custom_plugin_examples/{test => }/plugin_test.py (67%) create mode 100644 tensorflow/contrib/tensorrt/plugin_test.py diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 874be96c78..c8a96e5dba 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -241,9 +241,9 @@ class TFAttrs { return attrs_.at(key); } template - T get(string key) const; + T get(const string& key) const; template - T get(string key, const T& default_value) const { + T get(const string& key, const T& default_value) const { return attrs_.count(key) ? this->get(key) : default_value; } @@ -261,29 +261,29 @@ class TFAttrs { }; template <> -string TFAttrs::get(string key) const { +string TFAttrs::get(const string& key) const { return this->at(key)->s(); } template <> -std::vector TFAttrs::get>(string key) const { +std::vector TFAttrs::get>(const string& key) const { auto attr = this->at(key)->list().i(); return std::vector(attr.begin(), attr.end()); } template <> -std::vector TFAttrs::get>(string key) const { +std::vector TFAttrs::get>(const string& key) const { auto attr = this->at(key)->list().f(); return std::vector(attr.begin(), attr.end()); } template <> -std::vector TFAttrs::get>(string key) const { +std::vector TFAttrs::get>(const string& key) const { auto attr = this->at(key)->list().s(); return std::vector(attr.begin(), attr.end()); } template <> -nvinfer1::Dims TFAttrs::get(string key) const { +nvinfer1::Dims TFAttrs::get(const string& key) const { auto values = this->get>(key); nvinfer1::Dims dims; dims.nbDims = values.size(); @@ -293,24 +293,24 @@ nvinfer1::Dims TFAttrs::get(string key) const { } template <> -nvinfer1::DataType TFAttrs::get(string key) const { +nvinfer1::DataType TFAttrs::get(const string& key) const { nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype)); return trt_dtype; } template <> -tensorflow::DataType TFAttrs::get(string key) const { +tensorflow::DataType TFAttrs::get(const string& key) const { return this->at(key)->type(); } template <> -float TFAttrs::get(string key) const { +float TFAttrs::get(const string& key) const { return this->at(key)->f(); } template <> -bool TFAttrs::get(string key) const { +bool TFAttrs::get(const string& key) const { return this->at(key)->b(); } diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index 5603ed0ccf..3b1a7fb6f3 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -1,3 +1,9 @@ +# Description: +# Example for plugin support in TensorRT(http://developer.nvidia.com/tensorrt) +# through TensorFlow integration. Targeting TensorRT 3.0.4 +# APIs are meant to change while upgrading TRT. +# add init_py into pip package BUILD dependency to install it. + package(default_visibility = ["//tensorflow:__subpackages__"]) load( @@ -8,6 +14,7 @@ load( "tf_gen_op_wrapper_py", "tf_py_wrap_cc", "tf_copts", + "tf_py_test", ) load( "@local_config_tensorrt//:build_defs.bzl", @@ -18,19 +25,16 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library") tf_kernel_library( name = "_inc_op_plugin_kernel", - srcs = [ - "inc_op_plugin.cc", - ], - hdrs = [ - ], gpu_srcs = [ "inc_op_kernel.cu.cc", "inc_op_kernel.h", + "inc_op_plugin.cu.cc", "inc_op_plugin.h", ], - deps = if_tensorrt([ - "@local_config_tensorrt//:nv_infer", + deps = [ "//tensorflow/contrib/tensorrt:trt_plugins", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", ]), ) @@ -38,9 +42,10 @@ tf_gen_op_libs( op_lib_names = [ "inc_op", ], - deps = if_tensorrt([ - "@local_config_tensorrt//:nv_infer", + deps = [ "//tensorflow/contrib/tensorrt:trt_plugins", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", ]), ) @@ -70,9 +75,8 @@ tf_custom_op_library( srcs = ["ops/inc_op.cc"], deps = [ "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ "//tensorflow/contrib/tensorrt:trt_plugins", - ]), + ], ) tf_custom_op_py_library( @@ -97,6 +101,22 @@ py_library( ], ) +tf_py_test( + name = "plugin_test", + size = "small", + srcs = [ + "plugin_test.py", + ], + additional_deps = [ + ":init_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/contrib/tensorrt:init_py", + "//tensorflow/python:platform", + "//tensorflow/python:client_testlib", + "//tensorflow/python:tf_optimizer", + ], +) + py_library( name = "init_py", srcs = [ diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py index a61d008941..e4cd0ae8a0 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py @@ -14,11 +14,13 @@ # ============================================================================= """Import custom op for plugin and register it in plugin factory registry.""" -from ops import gen_inc_op -from plugin_wrap import inc_op_register -from inc_op import * +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tensorrt.custom_plugin_examples.ops import gen_inc_op +from tensorflow.contrib.tensorrt.custom_plugin_examples.plugin_wrap import inc_op_register +from tensorflow.contrib.tensorrt.custom_plugin_examples import inc_op as import_inc_op_so -# pylint: disable=unused-import,wildcard-import,g-import-not-at-top inc_op = gen_inc_op.inc_plugin_trt inc_op_register() -# pylint: enable=unused-import,wildcard-import,g-import-not-at-top diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index 5dd6b9bf94..38e1e01d95 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -14,10 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" -#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" #if GOOGLE_CUDA -#define EIGEN_USE_GPU #if GOOGLE_TENSORRT namespace tensorflow { diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h index ec269143e8..13156dad8f 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h @@ -17,13 +17,14 @@ limitations under the License. #define TENSORFLOW_CONTRIB_TENSORRT_INC_OP #if GOOGLE_CUDA -#define EIGEN_USE_GPU #if GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { __global__ void VecInc(float* vec, float inc, float* dest, int n); +void IncrementKernel(const float* d_input, float inc, float* d_output, + int count, cudaStream_t stream); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc similarity index 91% rename from tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc rename to tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc index 21617fa8b5..508ced587b 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" +#include +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA @@ -23,7 +24,7 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -const std::string IncOpPlugin::plugin_name_ = "IncPluginTRT"; +const string IncOpPlugin::plugin_name_ = "IncPluginTRT"; IncOpPlugin* CreateIncPlugin() { return new IncOpPlugin(); } @@ -47,7 +48,7 @@ IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) SetAttribute("inc", buffer + consumed_data, sizeof(float)); } -bool IncOpPlugin::SetAttribute(const std::string& key, const void* ptr, +bool IncOpPlugin::SetAttribute(const string& key, const void* ptr, const size_t size) { if (strcmp(key.c_str(), "inc") == 0 && size == sizeof(float)) { StoreAttribute(key, ptr, size); // save the attribute to own the data; @@ -57,7 +58,7 @@ bool IncOpPlugin::SetAttribute(const std::string& key, const void* ptr, return false; } -bool IncOpPlugin::GetAttribute(const std::string& key, const void** ptr, +bool IncOpPlugin::GetAttribute(const string& key, const void** ptr, size_t* size) const { const auto& iter = attr_map_.find(key); if (iter != attr_map_.end()) { diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index a4774d354c..87404a755c 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -18,10 +18,6 @@ limitations under the License. #include #include -#include -#include -#include -#include #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #if GOOGLE_CUDA @@ -33,14 +29,14 @@ namespace tensorrt { class IncOpPlugin : public PluginTensorRT { public: - static const std::string plugin_name_; - IncOpPlugin(){}; + static const string plugin_name_; + IncOpPlugin() {}; IncOpPlugin(const void* serialized_data, size_t length); - const std::string& GetPluginName() const override { return plugin_name_; }; + const string& GetPluginName() const override { return plugin_name_; }; bool Finalize() override { return true; }; - bool SetAttribute(const std::string& key, const void* ptr, + bool SetAttribute(const string& key, const void* ptr, const size_t size) override; - bool GetAttribute(const std::string& key, const void** ptr, + bool GetAttribute(const string& key, const void** ptr, size_t* size) const override; int getNbOutputs() const override { return 1; } @@ -56,7 +52,7 @@ class IncOpPlugin : public PluginTensorRT { void configure(const nvinfer1::Dims* inputs, int num_inputs, const nvinfer1::Dims* outputs, int num_outputs, int max_batch_size) override { - assert(nb_inputs == 1); + assert(num_inputs == 1); PluginTensorRT::configure(inputs, num_inputs, outputs, num_outputs, max_batch_size); } @@ -95,8 +91,6 @@ class IncOpPlugin : public PluginTensorRT { IncOpPlugin* CreateIncPlugin(); IncOpPlugin* CreateIncPluginDeserialize(const void*, size_t); bool RegisterIncOpPlugin(); -void IncrementKernel(const float* d_input, float inc, float* d_output, - int count, cudaStream_t stream); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc index 0dfead8f57..7466e59090 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/ops/inc_op.cc @@ -19,7 +19,7 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -using namespace tensorflow; +namespace tensorflow { REGISTER_OP("IncPluginTRT") .Attr("inc: list(float)") @@ -30,5 +30,7 @@ REGISTER_OP("IncPluginTRT") return Status::OK(); }); +} // namespace tensorflow + #endif // GOOGLE_CUDA #endif // GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py similarity index 67% rename from tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py rename to tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py index 52f49ae00e..9f773c66a9 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/test/plugin_test.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py @@ -23,43 +23,44 @@ from __future__ import print_function # it looks like internal builds don't like it so # importing every module individually -from tensorflow.contrib import tensorrt as trt -from tensorflow.core.protobuf import config_pb2 as cpb2 -from tensorflow.python.client import session as csess -from tensorflow.python.framework import dtypes as dtypes -from tensorflow.python.framework import importer as importer -from tensorflow.python.framework import ops as ops -from tensorflow.python.ops import array_ops as aops -from tensorflow.python.ops import nn as nn -from tensorflow.python.ops import nn_ops as nn_ops -import numpy as np +from tensorflow.contrib import tensorrt +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.framework import errors +import numpy # import custom_op as plugin op -# the python api handles registration to the plugin factory -from tensorflow.contrib.tensorrt import custom_plugin_examples as cpe +# the python api handles registration to the plugin factory +from tensorflow.contrib.tensorrt import custom_plugin_examples def get_plugin_graph_def(): """Create a simple graph and return its graph_def.""" g = ops.Graph() with g.as_default(): - a = aops.placeholder( + a = array_ops.placeholder( dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") relu = nn.relu(a, "relu") v = nn_ops.max_pool( relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") # insert custom_op in the graph - v = cpe.inc_op(v, inc=[16.5], name="plugin_test") + v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test") v = v*2.0 v = nn.relu(v) v = nn.relu(v) - aops.squeeze(v, name="output") + array_ops.squeeze(v, name="output") return g.as_graph_def() def run_graph(gdef, dumm_inp): """Run given graphdef once.""" - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50) ops.reset_default_graph() g = ops.Graph() with g.as_default(): @@ -68,20 +69,20 @@ def run_graph(gdef, dumm_inp): inp = inp.outputs[0] out = out.outputs[0] - with csess.Session( - config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: + with session.Session( + config=config_pb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: val = sess.run(out, {inp: dumm_inp}) return val if "__main__" in __name__: inp_dims = (5, 24, 24, 2) - dummy_input = np.ones(inp_dims).astype(np.float32) + dummy_input = numpy.ones(inp_dims).astype(numpy.float32) orig_graph = get_plugin_graph_def() # graph with plugin node # trigger conversion. # plugin nodes have been registered during import, converter will be able to # create corresponding plugin layer during conversion. - trt_graph = trt.create_inference_graph( + trt_graph = tensorrt.create_inference_graph( input_graph_def=orig_graph, outputs=["output"], max_batch_size=inp_dims[0], @@ -90,4 +91,7 @@ if "__main__" in __name__: minimum_segment_size=2 ) o2 = run_graph(trt_graph, dummy_input) - print (o2) + if o2.reshape([-1])[0] == 35: + print("pass") + else: + raise RuntimeError("contrib/tensorrt/custom_plugin_examples wrong result") diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h index 7f3544f8cf..3495dc6318 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.h +++ b/tensorflow/contrib/tensorrt/log/trt_logger.h @@ -28,7 +28,7 @@ namespace tensorrt { // Logger for GIE info/warning/errors class Logger : public nvinfer1::ILogger { public: - Logger(string name = "DefaultLogger") : name_(name){}; + Logger(string name = "DefaultLogger") : name_(name) {}; void log(nvinfer1::ILogger::Severity severity, const char* msg) override; private: diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc index 82c549dbf5..062f86e8bb 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.cc @@ -25,7 +25,6 @@ namespace tensorflow { namespace tensorrt { PluginTensorRT::PluginTensorRT(const void* serialized_data, size_t length) { - // sanity check. const char* buffer = static_cast(serialized_data); size_t op_name_char_count = *reinterpret_cast(buffer); buffer += sizeof(size_t); @@ -91,7 +90,7 @@ void PluginTensorRT::serialize(void* serialized_data) { } } -bool PluginTensorRT::StoreAttribute(const std::string& key, const void* ptr, +bool PluginTensorRT::StoreAttribute(const string& key, const void* ptr, const size_t size) { if (attr_map_.count(key) != 0) return false; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h index 772974a769..dca377c2d2 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -17,9 +17,9 @@ limitations under the License. #define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN #include -#include #include #include +#include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -35,28 +35,28 @@ namespace tensorrt { // PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT class PluginTensorRT : public nvinfer1::IPlugin { public: - PluginTensorRT(){}; + PluginTensorRT() {}; PluginTensorRT(const void* serialized_data, size_t length); - virtual const std::string& GetPluginName() const = 0; + virtual const string& GetPluginName() const = 0; virtual bool Finalize() = 0; - virtual bool SetAttribute(const std::string& key, const void* ptr, + virtual bool SetAttribute(const string& key, const void* ptr, const size_t size) = 0; - virtual bool GetAttribute(const std::string& key, const void** ptr, + virtual bool GetAttribute(const string& key, const void** ptr, size_t* size) const = 0; void configure(const nvinfer1::Dims* inputs, int num_inputs, const nvinfer1::Dims* outputs, int num_outputs, int max_batch_size) override; - virtual bool StoreAttribute(const std::string& key, const void* ptr, + virtual bool StoreAttribute(const string& key, const void* ptr, const size_t size); virtual size_t getSerializationSize() override; virtual void serialize(void* buffer) override; protected: - std::unordered_map > attr_map_; + std::unordered_map > attr_map_; std::vector input_dim_list_; }; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index 776bce119d..736a1321fe 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -26,7 +26,7 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, size_t serial_length) { size_t parsed_byte = 0; // extract op_name from serial_data - std::string encoded_op_name = + string encoded_op_name = ExtractOpName(serial_data, serial_length, &parsed_byte); if (!IsPlugin(encoded_op_name)) { @@ -41,8 +41,7 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, return plugin_ptr; } -PluginTensorRT* PluginFactoryTensorRT::CreatePlugin( - const std::string& op_name) { +PluginTensorRT* PluginFactoryTensorRT::CreatePlugin(const string& op_name) { if (!IsPlugin(op_name)) return nullptr; std::lock_guard lock(instance_m_); @@ -53,7 +52,7 @@ PluginTensorRT* PluginFactoryTensorRT::CreatePlugin( } bool PluginFactoryTensorRT::RegisterPlugin( - const std::string& op_name, PluginDeserializeFunc deserialize_func, + const string& op_name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func) { if (IsPlugin(op_name)) return false; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 08fd376844..4e4a3af4ca 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -36,7 +36,7 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { size_t serial_length) override; // plugin construction, PluginFactoryTensorRT owns the plugin; - PluginTensorRT* CreatePlugin(const std::string& op_name); + PluginTensorRT* CreatePlugin(const string& op_name); static PluginFactoryTensorRT* GetInstance() { static PluginFactoryTensorRT* factory_instance = @@ -44,11 +44,11 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { return factory_instance; } - bool RegisterPlugin(const std::string& op_name, + bool RegisterPlugin(const string& op_name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func); - bool IsPlugin(const std::string& op_name) { + bool IsPlugin(const string& op_name) { return plugin_registry_.find(op_name) != plugin_registry_.end(); } @@ -57,7 +57,7 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { void DestroyPlugins(); protected: - std::unordered_map > plugin_registry_; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc index c5d3f38280..a8f60886c0 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.cc @@ -30,7 +30,7 @@ string ExtractOpName(const void* serial_data, size_t serial_length, assert(serial_length >= *incremental); const char* buffer = static_cast(serial_data) + sizeof(size_t); - std::string op_name(buffer, op_name_char_count); + string op_name(buffer, op_name_char_count); return op_name; } diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc index 9ef0fce972..b834c5511f 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/test.h" @@ -31,18 +30,17 @@ namespace test { class StubPlugin : public PluginTensorRT { public: - static const std::string plugin_name_; - StubPlugin(){}; + static const string plugin_name_; + StubPlugin() {}; StubPlugin(const void* serialized_data, size_t length) - : PluginTensorRT(serialized_data, length){}; - const std::string& GetPluginName() override { return plugin_name_; }; + : PluginTensorRT(serialized_data, length) {}; + const string& GetPluginName() override { return plugin_name_; }; virtual bool Finalize() { return true; }; - virtual bool SetAttribute(const std::string& key, const void* ptr, + virtual bool SetAttribute(const string& key, const void* ptr, const size_t size) { return true; }; - virtual bool GetAttribute(const std::string& key, const void* ptr, - size_t& size) { + virtual bool GetAttribute(const string& key, const void* ptr, size_t& size) { return true; }; int getNbOutputs() const override { return 1; } @@ -59,7 +57,7 @@ class StubPlugin : public PluginTensorRT { } }; -const std::string StubPlugin::plugin_name_ = "StubPlugin"; +const string StubPlugin::plugin_name_ = "StubPlugin"; StubPlugin* CreateStubPlugin() { return new StubPlugin(); } @@ -72,8 +70,9 @@ class PluginTest : public ::testing::Test { public: bool RegisterStubPlugin() { if (PluginFactoryTensorRT::GetInstance()->IsPlugin( - StubPlugin::plugin_name_)) + StubPlugin::plugin_name_)) { return true; + } return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( StubPlugin::plugin_name_, CreateStubPluginDeserialize, CreateStubPlugin); diff --git a/tensorflow/contrib/tensorrt/plugin_test.py b/tensorflow/contrib/tensorrt/plugin_test.py new file mode 100644 index 0000000000..7c3e765bff --- /dev/null +++ b/tensorflow/contrib/tensorrt/plugin_test.py @@ -0,0 +1,88 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to show usage of TensorRT custom op & plugin.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import tensorrt +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +import numpy as np + +# import custom_op as plugin op +# the python api handles registration to the plugin factory +from tensorflow.contrib.tensorrt import custom_plugin_examples + +def get_plugin_graph_def(): + """Create a simple graph and return its graph_def.""" + g = ops.Graph() + with g.as_default(): + a = array_ops.placeholder( + dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") + relu = nn.relu(a, "relu") + v = nn_ops.max_pool( + relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") + + # insert custom_op in the graph + v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test") + + v = v*2.0 + v = nn.relu(v) + v = nn.relu(v) + array_ops.squeeze(v, name="output") + return g.as_graph_def() + +def run_graph(gdef, dumm_inp): + """Run given graphdef once.""" + gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50) + ops.reset_default_graph() + g = ops.Graph() + with g.as_default(): + inp, out = importer.import_graph_def( + graph_def=gdef, return_elements=["input", "output"]) + inp = inp.outputs[0] + out = out.outputs[0] + + with session.Session( + config=config_pb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: + val = sess.run(out, {inp: dumm_inp}) + return val + +if "__main__" in __name__: + inp_dims = (5, 24, 24, 2) + dummy_input = np.ones(inp_dims).astype(np.float32) + orig_graph = get_plugin_graph_def() # graph with plugin node + + # trigger conversion. + # plugin nodes have been registered during import, converter will be able to + # create corresponding plugin layer during conversion. + trt_graph = tensorrt.create_inference_graph( + input_graph_def=orig_graph, + outputs=["output"], + max_batch_size=inp_dims[0], + max_workspace_size_bytes=1 << 25, + precision_mode="FP32", + minimum_segment_size=2 + ) + o2 = run_graph(trt_graph, dummy_input) + print (o2) diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index 3c85968ae7..5164247f93 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -82,7 +82,7 @@ class TRTWeightStore : public tensorflow::ResourceBase { class TRTEngineResource : public tensorflow::ResourceBase { public: - TRTEngineResource() : runtime_(nullptr), ctx_(nullptr){}; + TRTEngineResource() : runtime_(nullptr), ctx_(nullptr) {}; string DebugString() override { return string(""); } nvinfer1::IRuntime* runtime_; nvinfer1::IExecutionContext* ctx_; -- GitLab From 2ef955b6d354378a7ca19f1f3cafccfc17f79013 Mon Sep 17 00:00:00 2001 From: Haggai Date: Fri, 20 Apr 2018 18:57:12 -0700 Subject: [PATCH 014/755] Abort on invalid fft type or rank --- tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h index 4f6b363364..0bf693edd0 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -195,6 +195,9 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, device, static_cast(out), static_cast(operand), input_batch, fft_length0, fft_length1, fft_length2); break; + default: + // Unsupported FFT type + abort(); } } @@ -219,6 +222,9 @@ void EigenFftImpl(const EigenDevice& device, void* out, void* operand, input_batch, fft_length0, fft_length1, fft_length2); break; + default: + // Unsupported FFT rank + abort(); } } -- GitLab From 364f6eae07fa8f0e2f89a9f665d0af430ea96669 Mon Sep 17 00:00:00 2001 From: Filipe Filardi Date: Sat, 21 Apr 2018 14:45:30 -0300 Subject: [PATCH 015/755] Create pull_request_template.md --- pull_request_template.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 pull_request_template.md diff --git a/pull_request_template.md b/pull_request_template.md new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/pull_request_template.md @@ -0,0 +1 @@ + -- GitLab From ea3d7ab5455f54a67e24428f159e9170be408d71 Mon Sep 17 00:00:00 2001 From: Filipe Filardi Date: Sat, 21 Apr 2018 14:57:38 -0300 Subject: [PATCH 016/755] Create Pull Request Template --- PULL_REQUEST_TEMPLATE.md | 20 ++++++++++++++++++++ pull_request_template.md | 1 - 2 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 PULL_REQUEST_TEMPLATE.md delete mode 100644 pull_request_template.md diff --git a/PULL_REQUEST_TEMPLATE.md b/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000..075bbc9945 --- /dev/null +++ b/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,20 @@ + + +##### Pull Request Checklist + +- [ ] Read [contributing guideline](CONTRIBUTING.md). +- [ ] Read [code of conduct](CODE_OF_CONDUCT.md). +- [ ] Fill [Contributor License Agreement (CLA)](https://cla.developers.google.com/). +- [ ] Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). +- [ ] Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style) +- [ ] Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). + +##### Issue Fix + +- [ ] Yes +- [ ] No + +Fixed issue: + +##### Description + diff --git a/pull_request_template.md b/pull_request_template.md deleted file mode 100644 index 8b13789179..0000000000 --- a/pull_request_template.md +++ /dev/null @@ -1 +0,0 @@ - -- GitLab From 7f78414776718a350b1beb612dd8b1c26ff3f6a4 Mon Sep 17 00:00:00 2001 From: Filipe Filardi Date: Tue, 24 Apr 2018 22:52:29 -0300 Subject: [PATCH 017/755] Merge PR Template to Contributing - Remove pull request template. - Add check list in contributing as a kind of TL;DR for that file. --- CONTRIBUTING.md | 11 +++++++++++ PULL_REQUEST_TEMPLATE.md | 20 -------------------- 2 files changed, 11 insertions(+), 20 deletions(-) delete mode 100644 PULL_REQUEST_TEMPLATE.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3dad41a88c..2e9d8c65e2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,5 +1,16 @@ # Contributing guidelines +## Pull Request Checklist + +Before sending your pull requests, make sure you followed this list. + +- [ ] Read [contributing guidelines](CONTRIBUTING.md). +- [ ] Read [Code of Conduct](CODE_OF_CONDUCT.md). +- [ ] Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). +- [ ] Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). +- [ ] Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style). +- [ ] Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). + ## How to become a contributor and submit your own code ### Contributor License Agreements diff --git a/PULL_REQUEST_TEMPLATE.md b/PULL_REQUEST_TEMPLATE.md deleted file mode 100644 index 075bbc9945..0000000000 --- a/PULL_REQUEST_TEMPLATE.md +++ /dev/null @@ -1,20 +0,0 @@ - - -##### Pull Request Checklist - -- [ ] Read [contributing guideline](CONTRIBUTING.md). -- [ ] Read [code of conduct](CODE_OF_CONDUCT.md). -- [ ] Fill [Contributor License Agreement (CLA)](https://cla.developers.google.com/). -- [ ] Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). -- [ ] Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style) -- [ ] Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). - -##### Issue Fix - -- [ ] Yes -- [ ] No - -Fixed issue: - -##### Description - -- GitLab From 7f70c7a38fc2f4aaa9ceb52240c9112886adda5c Mon Sep 17 00:00:00 2001 From: Filipe Filardi Date: Tue, 24 Apr 2018 23:00:05 -0300 Subject: [PATCH 018/755] Make more like a table of contents --- CONTRIBUTING.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2e9d8c65e2..8669c25c45 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,12 +4,12 @@ Before sending your pull requests, make sure you followed this list. -- [ ] Read [contributing guidelines](CONTRIBUTING.md). -- [ ] Read [Code of Conduct](CODE_OF_CONDUCT.md). -- [ ] Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). -- [ ] Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). -- [ ] Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style). -- [ ] Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). +- Read [contributing guidelines](CONTRIBUTING.md). +- Read [Code of Conduct](CODE_OF_CONDUCT.md). +- Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). +- Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). +- Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style). +- Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). ## How to become a contributor and submit your own code -- GitLab From df5ae5ac2a58131737a11e417ac34a663efb3574 Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Wed, 2 May 2018 17:52:38 -0700 Subject: [PATCH 019/755] Add some todo's --- tensorflow/contrib/tensorboard/db/summary_db_writer.cc | 1 + tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index 046a2d3884..630c0607ae 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -1183,6 +1183,7 @@ class SummaryDbWriter : public SummaryWriterInterface { Tensor t{DT_DOUBLE, {k, 3}}; auto data = t.flat(); for (int i = 0, j = 0; i < k; ++i) { + // TODO(nickfelt): reconcile with TensorBoard's data_compat.py // From summary.proto // Parallel arrays encoding the bucket boundaries and the bucket values. // bucket(i) is the count for the bucket i. The range for diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc index cb51325d15..2044692b6e 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -131,6 +131,7 @@ TEST_F(SummaryDbWriterTest, WriteHistogram_VerifyTensorValues) { writer_->Unref(); writer_ = nullptr; + // TODO(nickfelt): implement QueryTensor() to encapsulate this // Verify the data string result = QueryString("SELECT data FROM Tensors"); const double* val = reinterpret_cast(result.data()); -- GitLab From 090d21c18f16303e740136e8a4e0f62c63df4e10 Mon Sep 17 00:00:00 2001 From: wangsiyu Date: Thu, 3 May 2018 18:31:29 +0800 Subject: [PATCH 020/755] fix bug of declaring regularization loss mutiple times when reusing partitioned variables in tf.layers --- tensorflow/python/layers/base.py | 13 ++++++++++++- tensorflow/python/layers/base_test.py | 15 +++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 64db49c900..c050e6be04 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -233,7 +233,8 @@ class Layer(base_layer.Layer): getter=vs.get_variable) if regularizer: - if context.executing_eagerly() or variable not in existing_variables: + if context.executing_eagerly() or _should_add_regularizer( + variable, existing_variables): self._handle_weight_regularization(name, variable, regularizer) if init_graph is not None: @@ -354,3 +355,13 @@ def _add_elements_to_collection(elements, collection_list): if element not in collection_set: collection.append(element) +def _should_add_regularizer(variable, existing_variable_set): + result = True + if isinstance(variable, tf_variables.PartitionedVariable): + for var in variable._get_variable_list(): + if var in existing_variable_set: + result = False + break + else: + result = variable not in existing_variable_set + return result diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index f08b552840..361e3de7aa 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -30,6 +30,7 @@ from tensorflow.python.layers import core as core_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope @@ -95,6 +96,20 @@ class BaseLayerTest(test.TestCase): regularizer=regularizer) self.assertEqual(len(layer.losses), 1) + def testReusePartitionedVaraiblesAndRegularizers(self): + regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 + partitioner = partitioned_variables.fixed_size_partitioner(3) + for i in xrange(2): + with variable_scope.variable_scope(variable_scope.get_variable_scope(), + partitioner=partitioner, + reuse=False if i == 0 else True): + layer = base_layers.Layer(name='my_layer') + variable = layer.add_variable( + 'reg_part_var', [4, 4], + initializer=init_ops.zeros_initializer(), + regularizer=regularizer) + self.assertEqual(len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 3) + def testNoEagerActivityRegularizer(self): with context.eager_mode(): with self.assertRaisesRegexp(ValueError, 'activity_regularizer'): -- GitLab From a2d35bddc7a1ab58b859ef396501472d7986ff0f Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Thu, 3 May 2018 07:50:13 -0700 Subject: [PATCH 021/755] Fix build dependency; add missing OpKernel; fix some formatting issues --- .../tensorrt/custom_plugin_examples/BUILD | 105 ++++++++---------- .../tensorrt/custom_plugin_examples/inc_op.py | 5 +- .../inc_op_kernel.cu.cc | 42 +++++++ .../custom_plugin_examples/inc_op_kernel.h | 2 +- .../{inc_op_plugin.cu.cc => inc_op_plugin.cc} | 13 ++- .../custom_plugin_examples/inc_op_plugin.h | 18 +-- .../custom_plugin_examples/plugin_test.py | 10 +- .../contrib/tensorrt/kernels/trt_engine_op.cc | 2 +- tensorflow/contrib/tensorrt/log/trt_logger.h | 2 +- .../tensorrt/plugin/trt_plugin_factory.cc | 6 + .../tensorrt/plugin/trt_plugin_factory.h | 12 +- .../tensorrt/plugin/trt_plugins_test.cc | 47 +++++--- 12 files changed, 162 insertions(+), 102 deletions(-) rename tensorflow/contrib/tensorrt/custom_plugin_examples/{inc_op_plugin.cu.cc => inc_op_plugin.cc} (90%) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index 3b1a7fb6f3..a45d4423bb 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -8,74 +8,39 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load( "//tensorflow:tensorflow.bzl", + "tf_copts", "tf_custom_op_library", - "tf_cuda_library", + "tf_custom_op_library_additional_deps", "tf_gen_op_libs", "tf_gen_op_wrapper_py", - "tf_py_wrap_cc", - "tf_copts", - "tf_py_test", + "tf_kernel_library", ) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load( "@local_config_tensorrt//:build_defs.bzl", "if_tensorrt", ) -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow:tensorflow.bzl", "tf_kernel_library") - -tf_kernel_library( - name = "_inc_op_plugin_kernel", - gpu_srcs = [ - "inc_op_kernel.cu.cc", - "inc_op_kernel.h", - "inc_op_plugin.cu.cc", - "inc_op_plugin.h", - ], - deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), -) tf_gen_op_libs( op_lib_names = [ "inc_op", ], - deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), ) tf_gen_op_wrapper_py( name = "inc_op", - gen_locally = True, deps = [ ":inc_op_op_lib", ], ) -tf_py_wrap_cc( - name = "plugin_wrap", - srcs = [ - "plugin_wrap.i", - ], - copts = tf_copts(), - deps = [ - ":_inc_op_plugin_kernel", - "//tensorflow/core:framework_lite", - "//util/python:python_headers", - ], -) - tf_custom_op_library( name = "_inc_op.so", srcs = ["ops/inc_op.cc"], deps = [ "//tensorflow/core:lib_proto_parsing", - "//tensorflow/contrib/tensorrt:trt_plugins", ], ) @@ -85,6 +50,10 @@ tf_custom_op_py_library( dso = [ ":_inc_op.so", ], + kernels = [ + ":inc_op_op_lib", + ":inc_op_plugin_kernel", + ], srcs_version = "PY2AND3", deps = [ "//tensorflow/python:framework_for_generated_wrappers", @@ -101,30 +70,54 @@ py_library( ], ) -tf_py_test( - name = "plugin_test", - size = "small", - srcs = [ - "plugin_test.py", +tf_kernel_library( + name = "inc_op_plugin_kernel", + srcs = ["inc_op_plugin.cc"], + hdrs = [ + "inc_op_plugin.h", ], - additional_deps = [ - ":init_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/contrib/tensorrt:init_py", - "//tensorflow/python:platform", - "//tensorflow/python:client_testlib", - "//tensorflow/python:tf_optimizer", + gpu_srcs = [ + "inc_op_kernel.h", + "inc_op_kernel.cu.cc" + ], + deps = [ + "//tensorflow/contrib/tensorrt:trt_plugins", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]) + tf_custom_op_library_additional_deps(), +) + +tf_py_wrap_cc( + name = "plugin_wrap", + srcs = ["plugin_wrap.i"], + copts = tf_copts(), + deps = [ + ":inc_op_plugin_kernel", + "//tensorflow/core:framework_lite", + "//util/python:python_headers", ], ) py_library( name = "init_py", - srcs = [ - "__init__.py", - ], + srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ ":inc_op_py", ":plugin_wrap", ], ) + +tf_py_test( + name = "plugin_test", + size = "small", + srcs = ["plugin_test.py"], + additional_deps = [ + ":init_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/contrib/tensorrt:init_py", + "//tensorflow/python:platform", + "//tensorflow/python:client_testlib", + "//tensorflow/python:tf_optimizer", + ], +) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py index ef8e26fbde..47fd55e2f6 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op.py @@ -18,13 +18,14 @@ from __future__ import division from __future__ import print_function import platform -import os if platform.system() != "Windows": + # pylint: disable=g-import-not-at-top from tensorflow.contrib.util import loader from tensorflow.python.platform import resource_loader + # pylint: enable=g-import-not-at-top _inc_op = loader.load_op_library( - os.path.join(os.path.dirname(os.path.realpath(__file__)),"_inc_op.so")) + resource_loader.get_path_to_datafile("_inc_op.so")) else: raise RuntimeError("Windows not supported") diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index 38e1e01d95..ee9fbe0ea1 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -15,8 +15,14 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/stream_executor.h" + #if GOOGLE_CUDA #if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" namespace tensorflow { namespace tensorrt { @@ -35,6 +41,42 @@ void IncrementKernel(const float* d_input, float inc, float* d_output, d_output, count); } +// Note: this kernel definition is not needed in the plugin_test rule, but it is +// required for correctness of the TF program, i.e. if not using plugin or when +// run with trt optimization pass, the test should work. +class IncPluginTRT : public OpKernel { + public: + explicit IncPluginTRT(OpKernelConstruction* context) : OpKernel(context) { + std::vector inc_list; + OP_REQUIRES_OK(context, context->GetAttr("inc", &inc_list)); + OP_REQUIRES(context, inc_list.size() == 1, + errors::InvalidArgument( + "The increment list should contain single element.")); + inc_ = inc_list[0]; + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_tensor = context->input(0); + const TensorShape& input_shape = input_tensor.shape(); + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input_shape, &output_tensor)); + const cudaStream_t* stream = CHECK_NOTNULL( + reinterpret_cast(context->op_device_context() + ->stream() + ->implementation() + ->CudaStreamMemberHack())); + IncrementKernel(input_tensor.flat().data(), inc_, + output_tensor->flat().data(), + input_shape.num_elements(), *stream); + } + + private: + float inc_; +}; + +REGISTER_KERNEL_BUILDER(Name("IncPluginTRT").Device(DEVICE_GPU), IncPluginTRT); + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h index 13156dad8f..1d0ec0b6b0 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h @@ -18,11 +18,11 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT +#include "cuda/include/cuda_runtime_api.h" namespace tensorflow { namespace tensorrt { -__global__ void VecInc(float* vec, float inc, float* dest, int n); void IncrementKernel(const float* d_input, float inc, float* d_output, int count, cudaStream_t stream); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc similarity index 90% rename from tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc rename to tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc index 508ced587b..489bc15def 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" -#include #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h" +#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #if GOOGLE_CUDA @@ -24,7 +23,7 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -const string IncOpPlugin::plugin_name_ = "IncPluginTRT"; +const char* kPluginName = "IncPluginTRT"; IncOpPlugin* CreateIncPlugin() { return new IncOpPlugin(); } @@ -33,14 +32,16 @@ IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { } bool RegisterIncOpPlugin() { - if (PluginFactoryTensorRT::GetInstance()->IsPlugin(IncOpPlugin::plugin_name_)) + if (PluginFactoryTensorRT::GetInstance()->IsPlugin(kPluginName)) return false; return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( - IncOpPlugin::plugin_name_, CreateIncPluginDeserialize, CreateIncPlugin); + kPluginName, CreateIncPluginDeserialize, CreateIncPlugin); } +IncOpPlugin::IncOpPlugin() : plugin_name_(kPluginName) {} + IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length) - : PluginTensorRT(serialized_data, length) { + : PluginTensorRT(serialized_data, length), plugin_name_(kPluginName) { // account for the consumed pointer. size_t consumed_data = PluginTensorRT::getSerializationSize(); assert(length - consumed_data >= sizeof(float)); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index 87404a755c..0676abe768 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -29,13 +29,17 @@ namespace tensorrt { class IncOpPlugin : public PluginTensorRT { public: - static const string plugin_name_; - IncOpPlugin() {}; + IncOpPlugin(); + IncOpPlugin(const void* serialized_data, size_t length); + const string& GetPluginName() const override { return plugin_name_; }; + bool Finalize() override { return true; }; + bool SetAttribute(const string& key, const void* ptr, const size_t size) override; + bool GetAttribute(const string& key, const void** ptr, size_t* size) const override; @@ -71,14 +75,11 @@ class IncOpPlugin : public PluginTensorRT { } void serialize(void* buffer) override { - // serializa parent stuff - // OpName + // Serialize parent data. PluginTensorRT::serialize(buffer); - - // incremented buffer after parent serialization; + // Incremented buffer after parent serialization. buffer = static_cast(buffer) + PluginTensorRT::getSerializationSize(); - std::memcpy(buffer, &inc_, sizeof(float)); buffer = static_cast(buffer) + sizeof(float); } @@ -86,6 +87,9 @@ class IncOpPlugin : public PluginTensorRT { protected: float inc_; nvinfer1::Dims dim_; + + private: + const string plugin_name_; }; IncOpPlugin* CreateIncPlugin(); diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py index 9f773c66a9..d1815fdf33 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py @@ -39,6 +39,7 @@ import numpy # the python api handles registration to the plugin factory from tensorflow.contrib.tensorrt import custom_plugin_examples + def get_plugin_graph_def(): """Create a simple graph and return its graph_def.""" g = ops.Graph() @@ -49,15 +50,16 @@ def get_plugin_graph_def(): v = nn_ops.max_pool( relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") - # insert custom_op in the graph + # insert custom_op in the graph v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test") - v = v*2.0 + v = v * 2.0 v = nn.relu(v) v = nn.relu(v) array_ops.squeeze(v, name="output") return g.as_graph_def() + def run_graph(gdef, dumm_inp): """Run given graphdef once.""" gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50) @@ -74,6 +76,7 @@ def run_graph(gdef, dumm_inp): val = sess.run(out, {inp: dumm_inp}) return val + if "__main__" in __name__: inp_dims = (5, 24, 24, 2) dummy_input = numpy.ones(inp_dims).astype(numpy.float32) @@ -88,8 +91,7 @@ if "__main__" in __name__: max_batch_size=inp_dims[0], max_workspace_size_bytes=1 << 25, precision_mode="FP32", - minimum_segment_size=2 - ) + minimum_segment_size=2) o2 = run_graph(trt_graph, dummy_input) if o2.reshape([-1])[0] == 35: print("pass") diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index c39bc12f73..71453631e2 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h index 3495dc6318..96ccacb791 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.h +++ b/tensorflow/contrib/tensorrt/log/trt_logger.h @@ -28,7 +28,7 @@ namespace tensorrt { // Logger for GIE info/warning/errors class Logger : public nvinfer1::ILogger { public: - Logger(string name = "DefaultLogger") : name_(name) {}; + Logger(string name = "DefaultLogger") : name_(name) {} void log(nvinfer1::ILogger::Severity severity, const char* msg) override; private: diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index 736a1321fe..b608e602a7 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -21,6 +21,12 @@ limitations under the License. namespace tensorflow { namespace tensorrt { +PluginFactoryTensorRT* PluginFactoryTensorRT::GetInstance() { + static PluginFactoryTensorRT* factory_instance = + new PluginFactoryTensorRT(); + return factory_instance; +} + PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, const void* serial_data, size_t serial_length) { diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 4e4a3af4ca..a088ffb842 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -31,19 +31,15 @@ namespace tensorrt { class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { public: - // deserialization method + static PluginFactoryTensorRT* GetInstance(); + + // Deserialization method PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data, size_t serial_length) override; - // plugin construction, PluginFactoryTensorRT owns the plugin; + // Plugin construction, PluginFactoryTensorRT owns the plugin. PluginTensorRT* CreatePlugin(const string& op_name); - static PluginFactoryTensorRT* GetInstance() { - static PluginFactoryTensorRT* factory_instance = - new PluginFactoryTensorRT(); - return factory_instance; - } - bool RegisterPlugin(const string& op_name, PluginDeserializeFunc deserialize_func, PluginConstructFunc construct_func); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc index b834c5511f..ae5a3e8742 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/test.h" @@ -20,8 +22,6 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" -#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorrt/include/NvInfer.h" namespace tensorflow { @@ -30,34 +30,49 @@ namespace test { class StubPlugin : public PluginTensorRT { public: - static const string plugin_name_; - StubPlugin() {}; + static const char* kPluginName; + + StubPlugin() : plugin_name_(kPluginName) {} + StubPlugin(const void* serialized_data, size_t length) - : PluginTensorRT(serialized_data, length) {}; - const string& GetPluginName() override { return plugin_name_; }; - virtual bool Finalize() { return true; }; + : PluginTensorRT(serialized_data, length) {} + + const string& GetPluginName() override { return plugin_name_; } + + virtual bool Finalize() { return true; } + virtual bool SetAttribute(const string& key, const void* ptr, const size_t size) { return true; - }; + } + virtual bool GetAttribute(const string& key, const void* ptr, size_t& size) { return true; - }; + } + int getNbOutputs() const override { return 1; } + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override { return inputs[0]; } + int initialize() override { return 0; } + void terminate() override {} + size_t getWorkspaceSize(int maxBatchSize) const override { return 0; } + int enqueue(int batch_size, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override { return 0; } + + private: + const string plugin_name_; }; -const string StubPlugin::plugin_name_ = "StubPlugin"; +const char* StubPlugin::kPluginName = "StubPlugin"; StubPlugin* CreateStubPlugin() { return new StubPlugin(); } @@ -70,32 +85,32 @@ class PluginTest : public ::testing::Test { public: bool RegisterStubPlugin() { if (PluginFactoryTensorRT::GetInstance()->IsPlugin( - StubPlugin::plugin_name_)) { + StubPlugin::kPluginName)) { return true; } return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( - StubPlugin::plugin_name_, CreateStubPluginDeserialize, + StubPlugin::kPluginName, CreateStubPluginDeserialize, CreateStubPlugin); } }; TEST_F(PluginTest, Registration) { EXPECT_FALSE( - PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); EXPECT_TRUE(RegisterStubPlugin()); ASSERT_TRUE( - PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); } TEST_F(PluginTest, CreationDeletion) { EXPECT_TRUE(RegisterStubPlugin()); ASSERT_TRUE( - PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::plugin_name_)); + PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); PluginFactoryTensorRT::GetInstance()->DestroyPlugins(); ASSERT_TRUE(PluginFactoryTensorRT::GetInstance()->CreatePlugin( - StubPlugin::plugin_name_)); + StubPlugin::kPluginName)); ASSERT_EQ(1, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins()); PluginFactoryTensorRT::GetInstance()->DestroyPlugins(); ASSERT_EQ(0, PluginFactoryTensorRT::GetInstance()->CountOwnedPlugins()); -- GitLab From 03de4a4a6cbfab49c2921d0cac5ccac31c0815f8 Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Thu, 3 May 2018 08:20:51 -0700 Subject: [PATCH 022/755] Move/rename the plugin factory test file; delete duplicate test file; fix minor formatting issues. --- tensorflow/contrib/tensorrt/BUILD | 10 ++- .../tensorrt/custom_plugin_examples/BUILD | 6 +- .../custom_plugin_examples/plugin_test.py | 5 -- .../contrib/tensorrt/plugin/trt_plugin.h | 6 +- .../tensorrt/plugin/trt_plugin_factory.h | 9 +- ...ins_test.cc => trt_plugin_factory_test.cc} | 6 +- .../tensorrt/plugin/trt_plugin_utils.h | 1 + tensorflow/contrib/tensorrt/plugin_test.py | 88 ------------------- 8 files changed, 26 insertions(+), 105 deletions(-) rename tensorflow/contrib/tensorrt/plugin/{trt_plugins_test.cc => trt_plugin_factory_test.cc} (96%) delete mode 100644 tensorflow/contrib/tensorrt/plugin_test.py diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 5fda11eccb..79e525edae 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -282,7 +282,7 @@ tf_cc_test( ], ) -# Library for the plugin factory +# Library for the plugin factory tf_cuda_library( name = "trt_plugins", srcs = [ @@ -304,9 +304,13 @@ tf_cuda_library( ) tf_cuda_cc_test( - name = "trt_plugins_test", + name = "trt_plugin_factory_test", size = "small", - srcs = ["plugin/trt_plugins_test.cc"], + srcs = ["plugin/trt_plugin_factory_test.cc"], + tags = [ + "manual", + "notap", + ], deps = [ ":trt_plugins", "//tensorflow/core:test", diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index a45d4423bb..c68e69457d 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -78,7 +78,7 @@ tf_kernel_library( ], gpu_srcs = [ "inc_op_kernel.h", - "inc_op_kernel.cu.cc" + "inc_op_kernel.cu.cc", ], deps = [ "//tensorflow/contrib/tensorrt:trt_plugins", @@ -120,4 +120,8 @@ tf_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:tf_optimizer", ], + tags = [ + "manual", + "notap", + ], ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py index d1815fdf33..cb40e08493 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py @@ -18,11 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# normally we should do import tensorflow as tf and then -# tf.placeholder, tf.constant, tf.nn.conv2d etc but -# it looks like internal builds don't like it so -# importing every module individually - from tensorflow.contrib import tensorrt from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h index dca377c2d2..d80ec44372 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include + #include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA @@ -35,9 +36,11 @@ namespace tensorrt { // PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT class PluginTensorRT : public nvinfer1::IPlugin { public: - PluginTensorRT() {}; + PluginTensorRT() {} PluginTensorRT(const void* serialized_data, size_t length); + virtual const string& GetPluginName() const = 0; + virtual bool Finalize() = 0; virtual bool SetAttribute(const string& key, const void* ptr, @@ -53,6 +56,7 @@ class PluginTensorRT : public nvinfer1::IPlugin { const size_t size); virtual size_t getSerializationSize() override; + virtual void serialize(void* buffer) override; protected: diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index a088ffb842..6d2992bbbb 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -19,8 +19,9 @@ limitations under the License. #include #include #include -#include "trt_plugin.h" -#include "trt_plugin_utils.h" + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -54,12 +55,12 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { protected: std::unordered_map > + std::pair> plugin_registry_; // TODO(jie): Owned plugin should be associated with different sessions; // should really hand ownership of plugins to resource management; - std::vector > owned_plugins_; + std::vector> owned_plugins_; std::mutex instance_m_; }; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc similarity index 96% rename from tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc rename to tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc index ae5a3e8742..c5b0e75eb1 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc @@ -81,7 +81,7 @@ StubPlugin* CreateStubPluginDeserialize(const void* serialized_data, return new StubPlugin(serialized_data, length); } -class PluginTest : public ::testing::Test { +class TrtPluginFactoryTest : public ::testing::Test { public: bool RegisterStubPlugin() { if (PluginFactoryTensorRT::GetInstance()->IsPlugin( @@ -94,7 +94,7 @@ class PluginTest : public ::testing::Test { } }; -TEST_F(PluginTest, Registration) { +TEST_F(TrtPluginFactoryTest, Registration) { EXPECT_FALSE( PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); EXPECT_TRUE(RegisterStubPlugin()); @@ -103,7 +103,7 @@ TEST_F(PluginTest, Registration) { PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); } -TEST_F(PluginTest, CreationDeletion) { +TEST_F(TrtPluginFactoryTest, CreationDeletion) { EXPECT_TRUE(RegisterStubPlugin()); ASSERT_TRUE( PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h index a94c67bba0..4ff6fbedb4 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS #include + #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/contrib/tensorrt/plugin_test.py b/tensorflow/contrib/tensorrt/plugin_test.py deleted file mode 100644 index 7c3e765bff..0000000000 --- a/tensorflow/contrib/tensorrt/plugin_test.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Script to show usage of TensorRT custom op & plugin.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib import tensorrt -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import importer -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops -import numpy as np - -# import custom_op as plugin op -# the python api handles registration to the plugin factory -from tensorflow.contrib.tensorrt import custom_plugin_examples - -def get_plugin_graph_def(): - """Create a simple graph and return its graph_def.""" - g = ops.Graph() - with g.as_default(): - a = array_ops.placeholder( - dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input") - relu = nn.relu(a, "relu") - v = nn_ops.max_pool( - relu, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool") - - # insert custom_op in the graph - v = custom_plugin_examples.inc_op(v, inc=[16.5], name="plugin_test") - - v = v*2.0 - v = nn.relu(v) - v = nn.relu(v) - array_ops.squeeze(v, name="output") - return g.as_graph_def() - -def run_graph(gdef, dumm_inp): - """Run given graphdef once.""" - gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - ops.reset_default_graph() - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) - inp = inp.outputs[0] - out = out.outputs[0] - - with session.Session( - config=config_pb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: - val = sess.run(out, {inp: dumm_inp}) - return val - -if "__main__" in __name__: - inp_dims = (5, 24, 24, 2) - dummy_input = np.ones(inp_dims).astype(np.float32) - orig_graph = get_plugin_graph_def() # graph with plugin node - - # trigger conversion. - # plugin nodes have been registered during import, converter will be able to - # create corresponding plugin layer during conversion. - trt_graph = tensorrt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="FP32", - minimum_segment_size=2 - ) - o2 = run_graph(trt_graph, dummy_input) - print (o2) -- GitLab From eb88fd1ef5505e3f8617cc7105052fbce0e4af9e Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Thu, 3 May 2018 11:17:10 -0700 Subject: [PATCH 023/755] Add a macro for registering the plugin so we don't need to depend on swig; remove the swig file; fix build dependencies; fix tf_custom_op_library by adding GOOGLE_TENSORRT macro when gpu_srcs is not empty. --- .../tensorrt/custom_plugin_examples/BUILD | 69 +++++++++---------- .../custom_plugin_examples/__init__.py | 2 - .../inc_op_kernel.cu.cc | 1 + .../custom_plugin_examples/inc_op_plugin.cc | 7 +- .../custom_plugin_examples/inc_op_plugin.h | 5 +- .../custom_plugin_examples/plugin_wrap.i | 31 --------- .../tensorrt/plugin/trt_plugin_factory.h | 25 +++++++ tensorflow/tensorflow.bzl | 2 +- 8 files changed, 61 insertions(+), 81 deletions(-) delete mode 100644 tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index c68e69457d..e623b54781 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -24,24 +24,48 @@ load( ) tf_gen_op_libs( - op_lib_names = [ - "inc_op", - ], + op_lib_names = ["inc_op"], ) tf_gen_op_wrapper_py( name = "inc_op", - deps = [ - ":inc_op_op_lib", - ], + deps = [":inc_op_op_lib"], ) tf_custom_op_library( name = "_inc_op.so", - srcs = ["ops/inc_op.cc"], + srcs = [ + "inc_op_kernel.h", + "inc_op_plugin.cc", + "inc_op_plugin.h", + "ops/inc_op.cc", + ], + gpu_srcs = [ + "inc_op_kernel.h", + "inc_op_kernel.cu.cc", + ], deps = [ - "//tensorflow/core:lib_proto_parsing", + "//tensorflow/contrib/tensorrt:trt_plugins", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + +tf_kernel_library( + name = "inc_op_plugin_kernel", + srcs = ["inc_op_plugin.cc"], + hdrs = [ + "inc_op_plugin.h", + ], + gpu_srcs = [ + "inc_op_kernel.h", + "inc_op_kernel.cu.cc", ], + deps = [ + "//tensorflow/contrib/tensorrt:trt_plugins", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]) + tf_custom_op_library_additional_deps(), ) tf_custom_op_py_library( @@ -70,41 +94,12 @@ py_library( ], ) -tf_kernel_library( - name = "inc_op_plugin_kernel", - srcs = ["inc_op_plugin.cc"], - hdrs = [ - "inc_op_plugin.h", - ], - gpu_srcs = [ - "inc_op_kernel.h", - "inc_op_kernel.cu.cc", - ], - deps = [ - "//tensorflow/contrib/tensorrt:trt_plugins", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]) + tf_custom_op_library_additional_deps(), -) - -tf_py_wrap_cc( - name = "plugin_wrap", - srcs = ["plugin_wrap.i"], - copts = tf_copts(), - deps = [ - ":inc_op_plugin_kernel", - "//tensorflow/core:framework_lite", - "//util/python:python_headers", - ], -) - py_library( name = "init_py", srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ ":inc_op_py", - ":plugin_wrap", ], ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py index e4cd0ae8a0..e06904ab56 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/__init__.py @@ -19,8 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.tensorrt.custom_plugin_examples.ops import gen_inc_op -from tensorflow.contrib.tensorrt.custom_plugin_examples.plugin_wrap import inc_op_register from tensorflow.contrib.tensorrt.custom_plugin_examples import inc_op as import_inc_op_so inc_op = gen_inc_op.inc_plugin_trt -inc_op_register() diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc index ee9fbe0ea1..abbc0c5680 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.cu.cc @@ -24,6 +24,7 @@ limitations under the License. #if GOOGLE_TENSORRT #include "cuda/include/cuda_runtime_api.h" + namespace tensorflow { namespace tensorrt { diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc index 489bc15def..d56aedc6d4 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.cc @@ -31,12 +31,7 @@ IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) { return new IncOpPlugin(buffer, length); } -bool RegisterIncOpPlugin() { - if (PluginFactoryTensorRT::GetInstance()->IsPlugin(kPluginName)) - return false; - return PluginFactoryTensorRT::GetInstance()->RegisterPlugin( - kPluginName, CreateIncPluginDeserialize, CreateIncPlugin); -} +REGISTER_TRT_PLUGIN(kPluginName, CreateIncPluginDeserialize, CreateIncPlugin); IncOpPlugin::IncOpPlugin() : plugin_name_(kPluginName) {} diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h index 0676abe768..60153546d2 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h @@ -18,6 +18,7 @@ limitations under the License. #include #include + #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #if GOOGLE_CUDA @@ -92,10 +93,6 @@ class IncOpPlugin : public PluginTensorRT { const string plugin_name_; }; -IncOpPlugin* CreateIncPlugin(); -IncOpPlugin* CreateIncPluginDeserialize(const void*, size_t); -bool RegisterIncOpPlugin(); - } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i deleted file mode 100644 index 9882daa842..0000000000 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_wrap.i +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/* Wrap inc_op_plugin */ -%module inc_op_plugin -%{ -#include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h" -extern bool tensorflow::tensorrt::RegisterIncOpPlugin(); -%} - -%{ -bool inc_op_register() { - return tensorflow::tensorrt::RegisterIncOpPlugin(); -} -%} - -extern bool tensorflow::tensorrt::RegisterIncOpPlugin(); - -bool inc_op_register(); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 6d2992bbbb..54fbca5930 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -64,6 +66,29 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { std::mutex instance_m_; }; +class TrtPluginRegistrar { + public: + TrtPluginRegistrar(const string& name, + PluginDeserializeFunc deserialize_func, + PluginConstructFunc construct_func) { + auto factory = PluginFactoryTensorRT::GetInstance(); + QCHECK(factory->RegisterPlugin(name, deserialize_func, construct_func)) + << "Failed to register plugin: " << name; + } +}; + +#define REGISTER_TRT_PLUGIN(name, deserialize_func, construct_func) \ + REGISTER_TRT_PLUGIN_UNIQ_HELPER( \ + __COUNTER__, name, deserialize_func, construct_func) +#define REGISTER_TRT_PLUGIN_UNIQ_HELPER( \ + ctr, name, deserialize_func, construct_func) \ + REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) +#define REGISTER_TRT_PLUGIN_UNIQ(ctr, name, deserialize_func, construct_func) \ + static ::tensorflow::tensorrt::TrtPluginRegistrar \ + trt_plugin_registrar##ctr TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::tensorrt::TrtPluginRegistrar( \ + name, deserialize_func, construct_func) + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index e5cc886b32..c27f894365 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1309,7 +1309,7 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]): native.cc_library( name=basename + "_gpu", srcs=gpu_srcs, - copts=_cuda_copts(), + copts=_cuda_copts() + if_tensorrt(["-DGOOGLE_TENSORRT=1"]), deps=deps + if_cuda(cuda_deps)) cuda_deps.extend([":" + basename + "_gpu"]) -- GitLab From e629595e8f629f2de7db225463136b0e331bd71c Mon Sep 17 00:00:00 2001 From: gracehoney <31743510+aaroey@users.noreply.github.com> Date: Thu, 3 May 2018 15:00:57 -0700 Subject: [PATCH 024/755] Simplify build dependencies; fix python import order; fix multiple singleton issues by inlining the singleton method. --- tensorflow/contrib/tensorrt/BUILD | 2 -- .../contrib/tensorrt/custom_plugin_examples/BUILD | 12 ++---------- .../tensorrt/custom_plugin_examples/plugin_test.py | 9 +++------ .../contrib/tensorrt/plugin/trt_plugin_factory.cc | 6 ------ .../contrib/tensorrt/plugin/trt_plugin_factory.h | 8 +++++++- 5 files changed, 12 insertions(+), 25 deletions(-) diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 79e525edae..5b56feed0f 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -259,7 +259,6 @@ cc_library( "segment/segment.h", "segment/union_find.h", ], - linkstatic = 1, deps = [ "//tensorflow/core:graph", "//tensorflow/core:lib_proto_parsing", @@ -295,7 +294,6 @@ tf_cuda_library( "plugin/trt_plugin_factory.h", "plugin/trt_plugin_utils.h", ], - linkstatic = 1, deps = [ "//tensorflow/core:platform_base", ] + if_tensorrt([ diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD index e623b54781..6f81ac2b44 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/BUILD @@ -85,21 +85,13 @@ tf_custom_op_py_library( ], ) -py_library( - name = "inc_op_py", - srcs_version = "PY2AND3", - deps = [ - ":inc_op", - ":inc_op_loader", - ], -) - py_library( name = "init_py", srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ - ":inc_op_py", + ":inc_op", + ":inc_op_loader", ], ) diff --git a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py index cb40e08493..aedfb16211 100644 --- a/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py +++ b/tensorflow/contrib/tensorrt/custom_plugin_examples/plugin_test.py @@ -18,7 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy + from tensorflow.contrib import tensorrt +from tensorflow.contrib.tensorrt import custom_plugin_examples from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import dtypes @@ -27,12 +30,6 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops -from tensorflow.python.framework import errors -import numpy - -# import custom_op as plugin op -# the python api handles registration to the plugin factory -from tensorflow.contrib.tensorrt import custom_plugin_examples def get_plugin_graph_def(): diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc index b608e602a7..736a1321fe 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.cc @@ -21,12 +21,6 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -PluginFactoryTensorRT* PluginFactoryTensorRT::GetInstance() { - static PluginFactoryTensorRT* factory_instance = - new PluginFactoryTensorRT(); - return factory_instance; -} - PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, const void* serial_data, size_t serial_length) { diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index 54fbca5930..0eee705fb9 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -34,7 +34,13 @@ namespace tensorrt { class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { public: - static PluginFactoryTensorRT* GetInstance(); + // TODO(aaroey): this static method has to be inlined to make the singleton a + // unique global symbol. Find a way to fix it. + static PluginFactoryTensorRT* GetInstance() { + static PluginFactoryTensorRT* factory_instance = + new PluginFactoryTensorRT(); + return factory_instance; + } // Deserialization method PluginTensorRT* createPlugin(const char* layer_name, const void* serial_data, -- GitLab From fe9b2637cfe39cf11eb3d0494948a733b7fc1d7d Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Thu, 29 Mar 2018 05:28:16 +0800 Subject: [PATCH 025/755] Parse op definition and generate a Java Op class. --- tensorflow/java/BUILD | 4 + tensorflow/java/src/gen/cc/java_defs.h | 76 ++-- tensorflow/java/src/gen/cc/op_gen_main.cc | 22 +- tensorflow/java/src/gen/cc/op_generator.cc | 406 +++++++++++++++-- tensorflow/java/src/gen/cc/op_generator.h | 42 +- tensorflow/java/src/gen/cc/op_parser.cc | 417 ++++++++++++++++++ tensorflow/java/src/gen/cc/op_parser.h | 137 ++++++ tensorflow/java/src/gen/cc/source_writer.cc | 127 +++--- tensorflow/java/src/gen/cc/source_writer.h | 55 ++- .../java/src/gen/cc/source_writer_test.cc | 82 ++-- tensorflow/java/src/gen/gen_ops.bzl | 29 +- .../src/gen/resources/license.snippet.java | 14 + 12 files changed, 1201 insertions(+), 210 deletions(-) create mode 100644 tensorflow/java/src/gen/cc/op_parser.cc create mode 100644 tensorflow/java/src/gen/cc/op_parser.h create mode 100644 tensorflow/java/src/gen/resources/license.snippet.java diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index ab7d698a45..635a4e807d 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -70,6 +70,7 @@ filegroup( tf_java_op_gen_srcjar( name = "java_op_gen_sources", + api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], gen_base_package = "org.tensorflow.op", gen_tool = "java_op_gen_tool", ops_libs = [ @@ -111,11 +112,13 @@ cc_library( name = "java_op_gen_lib", srcs = [ "src/gen/cc/op_generator.cc", + "src/gen/cc/op_parser.cc", "src/gen/cc/source_writer.cc", ], hdrs = [ "src/gen/cc/java_defs.h", "src/gen/cc/op_generator.h", + "src/gen/cc/op_parser.h", "src/gen/cc/source_writer.h", ], copts = tf_copts(), @@ -124,6 +127,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:op_gen_lib", ], ) diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index 59f8beaee7..2065477f58 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -18,12 +18,15 @@ limitations under the License. #include #include +#include +#include namespace tensorflow { namespace java { // An enumeration of different modifiers commonly used in Java enum Modifier { + PACKAGE = 0, PUBLIC = (1 << 0), PROTECTED = (1 << 1), PRIVATE = (1 << 2), @@ -72,6 +75,12 @@ class Type { // Reflection API does return Type(Type::PRIMITIVE, "void"); } + static Type Generic(const string& name) { + return Type(Type::GENERIC, name); + } + static Type Wildcard() { + return Type(Type::GENERIC, ""); + } static Type Class(const string& name, const string& package = "") { return Type(Type::CLASS, name, package); } @@ -81,9 +90,6 @@ class Type { static Type Enum(const string& name, const string& package = "") { return Type(Type::ENUM, name, package); } - static Type Generic(const string& name = "") { - return Type(Type::GENERIC, name); - } static Type ClassOf(const Type& type) { return Class("Class").add_parameter(type); } @@ -96,11 +102,10 @@ class Type { const Kind& kind() const { return kind_; } const string& name() const { return name_; } const string& package() const { return package_; } - const string& description() const { return description_; } - Type& description(const string& description) { - description_ = description; - return *this; + const string full_name() const { + return package_.empty() ? name_ : package_ + "." + name_; } + bool unknown() const { return name_.empty(); } // only wildcards has no name const std::list& parameters() const { return parameters_; } Type& add_parameter(const Type& parameter) { parameters_.push_back(parameter); @@ -120,14 +125,6 @@ class Type { } return *this; } - // Returns true if "type" is of a known collection type (only a few for now) - bool IsCollection() const { - return name_ == "List" || name_ == "Iterable"; - } - // Returns true if this instance is a wildcard () - bool IsWildcard() const { - return kind_ == GENERIC && name_.empty(); - } protected: Type(Kind kind, const string& name, const string& package = "") @@ -137,7 +134,6 @@ class Type { Kind kind_; string name_; string package_; - string description_; std::list parameters_; std::list annotations_; std::list supertypes_; @@ -180,16 +176,11 @@ class Variable { const string& name() const { return name_; } const Type& type() const { return type_; } bool variadic() const { return variadic_; } - const string& description() const { return description_; } - Variable& description(const string& description) { - description_ = description; - return *this; - } + private: string name_; Type type_; bool variadic_; - string description_; Variable(const string& name, const Type& type, bool variadic) : name_(name), type_(type), variadic_(variadic) {} @@ -210,16 +201,6 @@ class Method { bool constructor() const { return constructor_; } const string& name() const { return name_; } const Type& return_type() const { return return_type_; } - const string& description() const { return description_; } - Method& description(const string& description) { - description_ = description; - return *this; - } - const string& return_description() const { return return_description_; } - Method& return_description(const string& description) { - return_description_ = description; - return *this; - } const std::list& arguments() const { return arguments_; } Method& add_argument(const Variable& var) { arguments_.push_back(var); @@ -235,8 +216,6 @@ class Method { string name_; Type return_type_; bool constructor_; - string description_; - string return_description_; std::list arguments_; std::list annotations_; @@ -244,6 +223,35 @@ class Method { : name_(name), return_type_(return_type), constructor_(constructor) {} }; +// A definition of a documentation bloc for a Java element (JavaDoc) +class Javadoc { + public: + static Javadoc Create(const string& brief = "") { + return Javadoc(brief); + } + const string& brief() const { return brief_; } + const string& details() const { return description_; } + Javadoc& details(const string description) { + description_ = description; + return *this; + } + const std::list> tags() const { return tags_; } + Javadoc& add_tag(const string& tag, const string& text) { + tags_.push_back(std::make_pair(tag, text)); + return *this; + } + Javadoc& add_param_tag(const string& name, const string& text) { + return add_tag("param", name + " " + text); + } + + private: + string brief_; + string description_; + std::list> tags_; + + explicit Javadoc(const string& brief) : brief_(brief) {} +}; + } // namespace java } // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index bea99f3d7f..015200023f 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -48,8 +48,11 @@ const char kUsageHeader[] = "through\n" "the 'org.tensorflow.op.Ops' API as a group until the generated classes " "are compiled using an appropriate annotation processor.\n\n" - "Finally, the '--base_package' overrides the default parent package " - "under which the generated subpackage and classes are to be located.\n\n"; + "The '--base_package' overrides the default parent package under which " + "the generated subpackage and classes are to be located.\n\n" + "Finally, a list of directories of API proto definitions can be provided " + "to override default values found in the ops definitions, ordered by\n" + "priority (the last having precedence over the first).\n\n"; } // namespace java } // namespace tensorflow @@ -60,7 +63,7 @@ int main(int argc, char* argv[]) { tensorflow::string base_package = "org.tensorflow.op"; std::vector flag_list = { tensorflow::Flag("output_dir", &output_dir, - "Root directory into which output files are generated"), + "Root directory into which output files are generated"), tensorflow::Flag( "lib_name", &lib_name, "A name, in snake_case, used to classify this set of operations"), @@ -72,12 +75,15 @@ int main(int argc, char* argv[]) { bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); tensorflow::port::InitMain(usage.c_str(), &argc, &argv); QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage; - - tensorflow::java::OpGenerator generator; + std::vector api_dirs; + if (argc > 1) { + api_dirs = tensorflow::str_util::Split(argv[1], ",", + tensorflow::str_util::SkipEmpty()); + } + tensorflow::java::OpGenerator generator(base_package, output_dir, api_dirs); tensorflow::OpList ops; - tensorflow::OpRegistry::Global()->Export(true, &ops); - tensorflow::Status status = - generator.Run(ops, lib_name, base_package, output_dir); + tensorflow::OpRegistry::Global()->Export(false, &ops); + tensorflow::Status status = generator.Run(ops, lib_name); TF_QCHECK_OK(status); return 0; diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index def06baf2d..c9b57f5706 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -14,53 +14,409 @@ limitations under the License. ==============================================================================*/ #include +#include +#include +#include +#include +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/java/src/gen/cc/java_defs.h" +#include "tensorflow/java/src/gen/cc/source_writer.h" +#include "tensorflow/java/src/gen/cc/op_parser.h" #include "tensorflow/java/src/gen/cc/op_generator.h" namespace tensorflow { namespace java { namespace { -string CamelCase(const string& str, char delimiter, bool upper) { - string result; - bool cap = upper; - for (string::const_iterator it = str.begin(); it != str.end(); ++it) { - const char c = *it; - if (c == delimiter) { - cap = true; - } else if (cap) { - result += toupper(c); - cap = false; +const char* kLicenseSnippet = + "tensorflow/java/src/gen/resources/license.snippet.java"; + +const std::map kPrimitiveAttrTypes = { + { "Boolean", Type::Boolean() }, + { "Byte", Type::Byte() }, + { "Character", Type::Byte() }, + { "Float", Type::Float() }, + { "Integer", Type::Long() }, + { "Long", Type::Long() }, + { "Short", Type::Long() }, + { "Double", Type::Float() }, +}; + +enum RenderMode { + DEFAULT, + SINGLE_OUTPUT, + SINGLE_LIST_OUTPUT +}; + +void CollectOpDependencies(const OpSpec& op, RenderMode mode, + std::list* out) { + out->push_back(Type::Class("Operation", "org.tensorflow")); + out->push_back(Type::Class("OperationBuilder", "org.tensorflow")); + out->push_back(Type::Class("Scope", "org.tensorflow.op")); + if (mode == SINGLE_OUTPUT) { + out->push_back(Type::Class("Output", "org.tensorflow")); + } else if (mode == SINGLE_LIST_OUTPUT) { + out->push_back(Type::Interface("Iterator", "java.util")); + } + // Don't pay attention to duplicate types in the dependency list, they will + // be filtered out by the SourceWriter. + for (const OpSpec::Operand& input : op.inputs()) { + out->push_back(input.var().type()); + if (input.iterable()) { + out->push_back(Type::Class("Operands", "org.tensorflow.op")); + } + } + for (const OpSpec::Operand& output : op.outputs()) { + out->push_back(output.var().type()); + if (output.iterable()) { + out->push_back(Type::Class("Arrays", "java.util")); + } + } + for (const OpSpec::Operand& attribute : op.attributes()) { + out->push_back(attribute.var().type()); + if (attribute.var().type().name() == "Class") { + out->push_back(Type::Enum("DataType", "org.tensorflow")); + } + } + for (const OpSpec::Operand& option : op.options()) { + out->push_back(option.var().type()); + } +} + +void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional, + SourceWriter* writer) { + string var = optional ? "opts." + attr.var().name() : attr.var().name(); + if (attr.iterable()) { + const Type& type = attr.data_type(); + std::map::const_iterator it = + kPrimitiveAttrTypes.find(type.name()); + if (it != kPrimitiveAttrTypes.end()) { + string array = attr.var().name() + "Array"; + writer->AppendType(it->second) + .Append("[] " + array + " = new ") + .AppendType(it->second) + .Append("[" + var + ".size()];") + .EndLine(); + writer->BeginBlock("for (int i = 0; i < " + array + ".length; ++i)") + .Append(array + "[i] = " + var + ".get(i);") + .EndLine() + .EndBlock() + .Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + array) + .Append(");") + .EndLine(); } else { - result += c; + writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + var) + .Append(".toArray(new ") + .AppendType(type) + .Append("[" + var + ".size()]));") + .EndLine(); } + } else { + Type type = attr.var().type(); + writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", "); + if (type.name() == "Class") { + writer->Append("DataType.fromClass(" + attr.var().name() + "));"); + } else { + writer->Append(var + ");"); + } + writer->EndLine(); } - return result; } -} // namespace +void RenderFactoryMethod(const OpSpec& op, const Type& op_class, + SourceWriter* writer) { + Method factory = Method::Create("create", op_class); + Javadoc factory_doc = Javadoc::Create( + "Factory method to create a class to wrap a new " + op_class.name() + + " operation to the graph."); + Variable scope = + Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op")); + factory.add_argument(scope); + factory_doc.add_param_tag(scope.name(), "Current graph scope"); + for (const OpSpec::Operand& input : op.inputs()) { + factory.add_argument(input.var()); + factory_doc.add_param_tag(input.var().name(), input.description()); + } + for (const OpSpec::Operand& attribute : op.attributes()) { + factory.add_argument(attribute.var()); + factory_doc.add_param_tag(attribute.var().name(), attribute.description()); + } + if (!op.options().empty()) { + factory.add_argument(Variable::Varargs("options", Type::Class("Options"))); + factory_doc.add_param_tag("options", "carries optional attributes values"); + } + factory_doc.add_tag("return", "a new instance of " + op_class.name()); + writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc); + writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" + + op.graph_name() + "\", scope.makeOpName(\"" + + op_class.name() + "\"));"); + writer->EndLine(); -OpGenerator::OpGenerator() : env(Env::Default()) {} + for (const OpSpec::Operand& input : op.inputs()) { + if (input.iterable()) { + writer->Append("opBuilder.addInputList(Operands.asOutputs(" + + input.var().name() + "));"); + writer->EndLine(); + } else { + writer->Append("opBuilder.addInput(" + input.var().name() + + ".asOutput());"); + writer->EndLine(); + } + } + for (const OpSpec::Operand& attribute : op.attributes()) { + WriteSetAttrDirective(attribute, false, writer); + } + if (!op.options().empty()) { + writer->BeginBlock("if (options != null)") + .BeginBlock("for (Options opts : options)"); + for (const OpSpec::Operand& option : op.options()) { + writer->BeginBlock("if (opts." + option.var().name() + " != null)"); + WriteSetAttrDirective(option, true, writer); + writer->EndBlock(); + } + writer->EndBlock().EndBlock(); + } + writer->Append("return new ") + .AppendType(op_class) + .Append("(opBuilder.build());") + .EndLine(); + writer->EndMethod(); +} -OpGenerator::~OpGenerator() {} +void RenderConstructor(const OpSpec& op, const Type& op_class, + SourceWriter* writer) { + Method constructor = Method::ConstructorFor(op_class) + .add_argument( + Variable::Create("operation", + Type::Class("Operation", "org.tensorflow"))); + for (const OpSpec::Operand& output : op.outputs()) { + if (output.iterable() && !output.data_type().unknown()) { + constructor.add_annotation( + Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); + break; + } + } + writer->BeginMethod(constructor, PRIVATE) + .Append("super(operation);") + .EndLine(); + if (op.outputs().size() > 0) { + writer->Append("int outputIdx = 0;") + .EndLine(); + for (const OpSpec::Operand& output : op.outputs()) { + if (output.iterable()) { + string var_length = output.var().name() + "Length"; + writer->Append("int " + var_length) + .Append(" = operation.outputListLength(\"" + output.graph_name() + + "\");") + .EndLine() + .Append(output.var().name() + " = Arrays.asList("); + if (!output.data_type().unknown()) { + writer->Append("(") + .AppendType(output.var().type().parameters().front()) + .Append("[])"); + } + writer->Append("operation.outputList(outputIdx, " + var_length + "));") + .EndLine() + .Append("outputIdx += " + var_length + ";") + .EndLine(); + } else { + writer->Append(output.var().name() + + " = operation.output(outputIdx++);") + .EndLine(); + } + } + } + writer->EndMethod(); +} -Status OpGenerator::Run(const OpList& ops, const string& lib_name, - const string& base_package, const string& output_dir) { - const string package = - base_package + '.' + str_util::StringReplace(lib_name, "_", "", true); - const string package_path = - output_dir + '/' + str_util::StringReplace(package, ".", "/", true); - const string group = CamelCase(lib_name, '_', false); +void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) { + for (const OpSpec::Operand& option : op.options()) { + Method setter = Method::Create(option.var().name(), Type::Class("Options")) + .add_argument(option.var()); + Javadoc setter_doc = Javadoc::Create() + .add_param_tag(option.var().name(), option.description()); + writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc) + .Append("return new Options()." + option.var().name() + "(" + + option.var().name() + ");") + .EndLine() + .EndMethod(); + } + for (const OpSpec::Operand& output : op.outputs()) { + Method getter = Method::Create(output.var().name(), output.var().type()); + Javadoc getter_doc = Javadoc::Create(output.description()); + writer->BeginMethod(getter, PUBLIC, &getter_doc) + .Append("return " + output.var().name() + ";") + .EndLine() + .EndMethod(); + } +} + +void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, + SourceWriter* writer) { + OpSpec::Operand output = op.outputs().front(); + + if (mode == SINGLE_OUTPUT) { + bool cast2obj = output.data_type().unknown(); + Type return_type = Type::Class("Output", "org.tensorflow") + .add_parameter(cast2obj ? Type::Class("Object") : output.data_type()); + Method as_output = Method::Create("asOutput", return_type) + .add_annotation(Annotation::Create("Override")); + if (cast2obj) { + as_output.add_annotation( + Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); + } + writer->BeginMethod(as_output, PUBLIC); + if (cast2obj) { + writer->Append("return (").AppendType(return_type).Append(") "); + } else { + writer->Append("return "); + } + writer->Append(output.var().name() + ";") + .EndLine() + .EndMethod(); + + } else if (mode == SINGLE_LIST_OUTPUT) { + Type operand = Type::Interface("Operand", "org.tensorflow"); + if (output.data_type().unknown()) { + operand.add_parameter(Type::Class("Object")); + } else { + operand.add_parameter(output.data_type()); + } + Type return_type = Type::Interface("Iterator", "java.util") + .add_parameter(operand); + Method iterator = Method::Create("iterator", return_type) + .add_annotation(Annotation::Create("Override")) + .add_annotation(Annotation::Create("SuppressWarnings") + .attributes("{\"rawtypes\", \"unchecked\"}")); + // cast the output list using a raw List + writer->BeginMethod(iterator, PUBLIC) + .Append("return (" + return_type.name() + ") ") + .Append(output.var().name() + ".iterator();") + .EndLine() + .EndMethod(); + } +} + +void RenderOptionsClass(const OpSpec& op, SourceWriter* writer) { + Type options_class = Type::Class("Options"); + Javadoc options_doc = Javadoc::Create( + "Class holding optional attributes of this operation"); + writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc); + for (const OpSpec::Operand& option : op.options()) { + Method setter = Method::Create(option.var().name(), options_class) + .add_argument(option.var()); + Javadoc setter_doc = Javadoc::Create() + .add_param_tag(option.var().name(), option.description()); + writer->BeginMethod(setter, PUBLIC, &setter_doc) + .Append("this." + option.var().name() + " = " + option.var().name() + + ";") + .EndLine() + .Append("return this;") + .EndLine() + .EndMethod(); + } + writer->EndLine(); + for (const OpSpec::Operand& option : op.options()) { + writer->WriteField(option.var(), PRIVATE); + } + Method constructor = Method::ConstructorFor(options_class); + writer->BeginMethod(constructor, PRIVATE).EndMethod(); + writer->EndType(); +} - if (!env->FileExists(package_path).ok()) { - TF_CHECK_OK(env->RecursivelyCreateDir(package_path)); +void RenderEndpoint(const OpSpec& op, const OpSpec::Endpoint& endpoint, + SourceWriter* writer) { + RenderMode mode = DEFAULT; + if (op.outputs().size() == 1) { + mode = op.outputs().front().iterable() ? SINGLE_LIST_OUTPUT : SINGLE_OUTPUT; + } + std::list dependencies; + CollectOpDependencies(op, mode, &dependencies); + const Type& op_class = endpoint.type(); + writer->WriteFromFile(kLicenseSnippet) + .EndLine() + .Append("// This file is machine generated, DO NOT EDIT!") + .EndLine() + .EndLine() + .BeginType(op_class, PUBLIC|FINAL, &dependencies, &endpoint.javadoc()); + if (!op.options().empty()) { + RenderOptionsClass(op, writer); } + RenderFactoryMethod(op, op_class, writer); + RenderGettersAndSetters(op, writer); + if (mode != DEFAULT) { + RenderInterfaceImpl(op, mode, writer); + } + writer->EndLine(); + for (const OpSpec::Operand& output : op.outputs()) { + writer->WriteField(output.var(), PRIVATE); + } + RenderConstructor(op, op_class, writer); + writer->EndType(); +} + +} // namespace + +OpGenerator::OpGenerator(const string& base_package, const string& output_dir, + const std::vector& api_dirs, Env* env) + : base_package_(base_package), output_dir_(output_dir), api_dirs_(api_dirs), + env_(env) { +} +Status OpGenerator::Run(const OpList& op_list, const string& lib_name) { LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations"; - // TODO(karllessard) generate wrappers from list of ops + ApiDefMap api_map(op_list); + if (!api_dirs_.empty()) { + // Only load api files that correspond to the requested "op_list" + for (const auto& op : op_list.op()) { + for (const auto& api_def_dir : api_dirs_) { + const std::string api_def_file_pattern = + io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt"); + if (env_->FileExists(api_def_file_pattern).ok()) { + TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern)); + } + } + } + } + api_map.UpdateDocs(); + for (const auto& op_def : op_list.op()) { + const ApiDef* api_def = api_map.GetApiDef(op_def.name()); + if (api_def->visibility() != ApiDef::SKIP) { + Status status = GenerateOp(op_def, *api_def, lib_name); + if (status != Status::OK()) { + LOG(ERROR) << "Fail to generate Java wrapper for operation \"" + << op_def.name() << "\""; + } + } + } + return Status::OK(); +} + +Status OpGenerator::GenerateOp(const OpDef& op_def, const ApiDef& api_def, + const string& lib_name) { + std::unique_ptr op; + OpParser op_parser(op_def, api_def, lib_name, base_package_); + op_parser.Parse(&op); + for (const OpSpec::Endpoint& endpoint : op->endpoints()) { + string package_path = io::JoinPath(output_dir_, + str_util::StringReplace(endpoint.type().package(), ".", "/", true)); + if (!env_->FileExists(package_path).ok()) { + TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(package_path)); + } + string file_path = + io::JoinPath(package_path, endpoint.type().name() + ".java"); + std::unique_ptr file; + TF_CHECK_OK(env_->NewWritableFile(file_path, &file)); + SourceFileWriter writer(file.get()); + RenderEndpoint(*op, endpoint, &writer); + } return Status::OK(); } diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index 4b55ed3ed9..19d8db95fb 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -17,34 +17,42 @@ limitations under the License. #define TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_ #include +#include -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/lib/core/status.h" namespace tensorflow { namespace java { -/// \brief A generator of Java operation wrappers. -/// -/// Such generator is normally ran only once per executable, outputting -/// wrappers for the all registered operations it has been compiled with. -/// Nonetheless, it is designed to support multiple runs, giving a different -/// list of operations on each cycle. +// A generator of Java operation wrappers. +// +// Such generator is normally ran only once per executable, outputting +// wrappers for the all registered operations it has been compiled with. +// Nonetheless, it is designed to support multiple runs, giving a different +// list of operations on each cycle. class OpGenerator { public: - OpGenerator(); - virtual ~OpGenerator(); + OpGenerator(const string& base_package, const string& output_dir, + const std::vector& api_dirs, Env* env = Env::Default()); + virtual ~OpGenerator() = default; - /// \brief Generates wrappers for the given list of 'ops'. - /// - /// Output files are generated in //, - /// where 'lib_package' is derived from 'lib_name'. - Status Run(const OpList& ops, const string& lib_name, - const string& base_package, const string& output_dir); + // Generates wrappers for the given list of 'ops'. + // + // Output files are generated in //, + // where 'lib_package' is derived from 'lib_name'. + Status Run(const OpList& op_list, const string& lib_name); private: - Env* env; + string base_package_; + string output_dir_; + std::vector api_dirs_; + Env* env_; + + Status GenerateOp(const OpDef& op_def, const ApiDef& api_def, + const string& lib_name); }; } // namespace java diff --git a/tensorflow/java/src/gen/cc/op_parser.cc b/tensorflow/java/src/gen/cc/op_parser.cc new file mode 100644 index 0000000000..0541e343d8 --- /dev/null +++ b/tensorflow/java/src/gen/cc/op_parser.cc @@ -0,0 +1,417 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/java/src/gen/cc/op_parser.h" + +namespace tensorflow { +namespace java { +namespace { + +string SnakeToCamelCase(const string& str, bool upper = false) { + string result; + bool cap = upper; + for (string::const_iterator it = str.begin(); it != str.end(); ++it) { + const char c = *it; + if (c == '_') { + cap = true; + } else if (cap) { + result += toupper(c); + cap = false; + } else { + result += c; + } + } + return result; +} + +bool IsRealNumber(DataType type) { + for (DataType dt : RealNumberTypes()) { + if (type == dt) { + return true; + } + } + return false; +} + +bool IsRealNumbers(const AttrValue& values) { + if (values.has_list()) { + for (int i = 0; i < values.list().type_size(); ++i) { + if (!IsRealNumber(values.list().type(i))) { + return false; + } + } + return true; + } + return IsRealNumber(values.type()); +} + +string ParseDocumentation(const string& text) { + std::stringstream javadoc_text; + string::const_iterator c_iter = text.cbegin(); + bool code = false; + bool emphasis = false; + bool list = false; + while (c_iter != text.cend()) { + char c = *c_iter++; + int count = 1; + switch (c) { + case '\n': + if (!code) { + // consumes all subsequent newlines, if there are more than one, + // then there are two choices: + // - if the next line starts with an asterisk, we are enumerating + // a list of items + // - otherwise, we are starting a new paragraph + for (; c_iter != text.cend() && *c_iter == '\n'; ++count, ++c_iter) {} + if (c_iter != text.cend()) { + if (count > 1) { + if (*c_iter != '*' && list) { + javadoc_text << "\n\n"; + list = false; + } else if (*c_iter == '*' && !list) { + javadoc_text << "\n
    \n
  • "; + list = true; + c_iter++; + } else { + javadoc_text << "\n

    \n"; + } + } else if (list && *c_iter == '*') { + javadoc_text << "

  • \n
  • "; + c_iter++; + } else { + javadoc_text << '\n'; + } + } + } + break; + case '`': + // consumes all subsequent backquotes, those are use enclose code. + // if there are more than 3, we are dealing with a pre-formatted block, + // otherwise it is a single-line code snippet + for (; c_iter != text.cend() && *c_iter == '`'; ++count, ++c_iter) {} + if (count >= 3) { + javadoc_text << (code ? "\n}" : "
    {@code\n");
    +      } else {
    +        javadoc_text << (code ? "}" : "{@code ");
    +      }
    +      code = !code;
    +      break;
    +    case '*':
    +      if (!code) {
    +        // consumes all subsequent asterisks, if there are more than one, then
    +        // we put the text in bold, otherwise in italic
    +        for (; c_iter != text.cend() && *c_iter == '*'; ++count, ++c_iter) {}
    +        if (count > 1) {
    +          javadoc_text << (emphasis ? "" : "");
    +        } else {
    +          javadoc_text << (emphasis ? "" : "");
    +        }
    +        emphasis = !emphasis;
    +      } else {
    +        javadoc_text << '*';
    +      }
    +      break;
    +    default:
    +      javadoc_text << c;
    +      break;
    +    }
    +  }
    +  return javadoc_text.str();
    +}
    +
    +}  // namespace
    +
    +OpParser::OpParser(const OpDef& op_def, const ApiDef& api_def,
    +    const string& lib_name, const string& base_package)
    +  : op_def_(op_def), op_api_(api_def), lib_name_(lib_name),
    +    base_package_(base_package) {
    +}
    +
    +void OpParser::Parse(std::unique_ptr* op_ptr) {
    +  visited_attrs_.clear();
    +  next_generic_ = 'T';
    +  op_ptr->reset(new OpSpec(op_api_.graph_op_name()));
    +  for (const string& next_input_name : op_api_.arg_order()) {
    +    for (int i = 0; i < op_def_.input_arg().size(); ++i) {
    +      if (op_def_.input_arg(i).name() == next_input_name) {
    +        ParseInput(op_def_.input_arg(i), op_api_.in_arg(i), op_ptr->get());
    +        break;
    +      }
    +    }
    +  }
    +  for (int i = 0; i < op_def_.attr().size(); ++i) {
    +    ParseAttribute(op_def_.attr(i), op_api_.attr(i), op_ptr->get());
    +  }
    +  for (int i = 0; i < op_def_.output_arg().size(); ++i) {
    +    ParseOutput(op_def_.output_arg(i), op_api_.out_arg(i), op_ptr->get());
    +  }
    +  BuildEndpoints(op_ptr->get());
    +}
    +
    +void OpParser::BuildEndpoints(OpSpec* op) {
    +  Javadoc op_doc = Javadoc::Create(ParseDocumentation(op_api_.summary()))
    +    .details(ParseDocumentation(op_api_.description()));
    +  std::vector op_supertypes;
    +  op_supertypes.push_back(Type::Class("PrimitiveOp", "org.tensorflow.op"));
    +  std::map op_generics;
    +  for (const OpSpec::Operand& output : op->outputs()) {
    +    // declare generic output parameters at the Op class level
    +    const Type& data_type = output.data_type();
    +    if (data_type.kind() == Type::GENERIC && !data_type.unknown()
    +        && op_generics.find(data_type.name()) == op_generics.end()) {
    +      op_generics.insert(std::make_pair(data_type.name(), &data_type));
    +      op_doc.add_param_tag("<" + data_type.name() + ">",
    +          "data type of output '" + output.var().name() + "'");
    +    }
    +    // implement the Op as an (iteration of) Operand if it has only one output
    +    if (op->outputs().size() == 1) {
    +      Type operand_inf(Type::Interface("Operand", "org.tensorflow"));
    +      operand_inf.add_parameter(data_type.unknown() ?
    +          Type::Class("Object") : data_type);
    +      op_supertypes.push_back(output.iterable() ?
    +          Type::IterableOf(operand_inf) : operand_inf);
    +    }
    +  }
    +  for (const auto& endpoint_def : op_api_.endpoint()) {
    +    std::vector name_tokens = str_util::Split(endpoint_def.name(), ".");
    +    // if the endpoint specifies a package, use it, otherwise derive it from the
    +    // op library name.
    +    string name;
    +    string package;
    +    if (name_tokens.size() > 1) {
    +      package = str_util::Lowercase(name_tokens.at(0));
    +      name = name_tokens.at(1);
    +    } else {
    +      package = str_util::StringReplace(lib_name_, "_", "", true);
    +      name = name_tokens.at(0);
    +    }
    +    Type endpoint(Type::Class(name, base_package_ + "." + package));
    +    Javadoc endpoint_doc(op_doc);
    +    for (const auto& parameter : op_generics) {
    +      endpoint.add_parameter(*parameter.second);
    +    }
    +    for (const Type& supertype : op_supertypes) {
    +      endpoint.add_supertype(supertype);
    +    }
    +    if (endpoint_def.deprecation_version() > 0) {
    +      string explanation;
    +      if (op_api_.endpoint(0).deprecation_version() == 0) {
    +        explanation = ", use {@link "
    +            + op->endpoints().at(0).type().full_name()
    +            + "} instead";
    +      } else {
    +        explanation = op_def_.deprecation().explanation();
    +      }
    +      endpoint_doc.add_tag("deprecated", explanation);
    +      endpoint.add_annotation(Annotation::Create("Deprecated"));
    +    }
    +    // only visible ops should be annotated for exposure in the Ops Graph API
    +    if (op_api_.visibility() != ApiDef::HIDDEN) {
    +      string group_name = SnakeToCamelCase(lib_name_);
    +      endpoint.add_annotation(
    +          Annotation::Create("Operator", "org.tensorflow.op.annotation")
    +            .attributes("group = \"" + group_name + "\""));
    +    }
    +    op->add_endpoint(endpoint, endpoint_doc);
    +  }
    +}
    +
    +void OpParser::ParseInput(const OpDef_ArgDef& input_def,
    +    const ApiDef::Arg& input_api, OpSpec* op) {
    +  bool iterable = false;
    +  Type data_type = DataTypeOf(input_def, &iterable);
    +  Type type = Type::Interface("Operand", "org.tensorflow")
    +    .add_parameter(data_type);
    +  if (iterable) {
    +    type = Type::IterableOf(type);
    +  }
    +  op->add_input(OpSpec::Operand(input_api.name(),
    +      Variable::Create(SnakeToCamelCase(input_api.rename_to()), type),
    +      data_type,
    +      ParseDocumentation(input_api.description()),
    +      iterable));
    +}
    +
    +void OpParser::ParseOutput(const OpDef_ArgDef& output_def,
    +    const ApiDef::Arg& output_api, OpSpec* op) {
    +  bool iterable = false;
    +  Type data_type = DataTypeOf(output_def, &iterable);
    +  Type type = Type::Class("Output", "org.tensorflow")
    +    .add_parameter(data_type);
    +  if (iterable) {
    +    type = Type::ListOf(type);
    +  }
    +  op->add_output(OpSpec::Operand(output_api.name(),
    +      Variable::Create(SnakeToCamelCase(output_api.rename_to()), type),
    +      data_type,
    +      ParseDocumentation(output_api.description()),
    +      iterable));
    +}
    +
    +void OpParser::ParseAttribute(const OpDef_AttrDef& attr_def,
    +    const ApiDef::Attr& attr_api, OpSpec* op) {
    +  // do not parse attributes already visited, they have probably been inferred
    +  // before as an input argument type
    +  if (visited_attrs_.find(attr_def.name()) != visited_attrs_.cend()) {
    +    return;
    +  }
    +  bool iterable = false;
    +  Type data_type = DataTypeOf(attr_def, &iterable);
    +  // generic attributes should be passed as an explicit type
    +  bool explicit_type = data_type.kind() == Type::GENERIC && !iterable;
    +  Type type = explicit_type ?
    +      Type::Class("Class").add_parameter(data_type) : data_type;
    +  if (iterable) {
    +    type = Type::ListOf(data_type);
    +  }
    +  OpSpec::Operand attr(attr_api.name(),
    +      Variable::Create(SnakeToCamelCase(attr_api.rename_to()), type),
    +      data_type,
    +      ParseDocumentation(attr_api.description()),
    +      iterable);
    +  // attributes with a default value are optional
    +  if (attr_api.has_default_value() && !explicit_type) {
    +    op->add_option(attr);
    +  } else {
    +    op->add_attribute(attr);
    +  }
    +  visited_attrs_.insert(std::make_pair(attr_api.name(), data_type));
    +}
    +
    +Type OpParser::DataTypeOf(const OpDef_ArgDef& arg, bool* iterable_out) {
    +  if (!arg.number_attr().empty()) {
    +    visited_attrs_.insert(std::make_pair(arg.number_attr(), Type::Int()));
    +    *iterable_out = true;
    +  }
    +  if (arg.type() != DataType::DT_INVALID) {
    +    // resolve type from DataType
    +    switch (arg.type()) {
    +      case DataType::DT_BOOL:
    +        return Type::Class("Boolean");
    +
    +      case DataType::DT_STRING:
    +        return Type::Class("String");
    +
    +      case DataType::DT_FLOAT:
    +        return Type::Class("Float");
    +
    +      case DataType::DT_DOUBLE:
    +        return Type::Class("Double");
    +
    +      case DataType::DT_UINT8:
    +        return Type::Class("UInt8", "org.tensorflow.types");
    +
    +      case DataType::DT_INT32:
    +        return Type::Class("Integer");
    +
    +      case DataType::DT_INT64:
    +        return Type::Class("Long");
    +
    +      case DataType::DT_RESOURCE:
    +        // TODO(karllessard) create a Resource utility class that could be
    +        // used to store a resource and its type (passed in a second argument).
    +        // For now, we need to force a wildcard and we will unfortunately lose
    +        // track of the resource type.
    +        return Type::Wildcard();
    +
    +      default:
    +        break;
    +    }
    +  } else {
    +    // resolve type from type attribute
    +    string attr_name = arg.type_attr();
    +    if (attr_name.empty()) {
    +      attr_name = arg.type_list_attr();
    +      if (!attr_name.empty()) {
    +        *iterable_out = true;
    +        Type type = Type::Wildcard();
    +        visited_attrs_.insert(std::make_pair(attr_name, type));
    +        return type;
    +      }
    +    }
    +    for (const auto& attr : op_def_.attr()) {
    +      if (attr.name() == attr_name) {
    +        Type type = DataTypeOf(attr, iterable_out);
    +        visited_attrs_.insert(std::make_pair(attr_name, type));
    +        return type;
    +      }
    +    }
    +  }
    +  LOG(WARNING) << "Data type for arg \"" << arg.name() << "\" is unknown";
    +  return Type::Wildcard();
    +}
    +
    +Type OpParser::DataTypeOf(const OpDef_AttrDef& attr, bool* iterable_out) {
    +  std::map::const_iterator it = visited_attrs_.find(attr.name());
    +  if (it != visited_attrs_.cend()) {
    +    return it->second;
    +  }
    +  string attr_type = attr.type();
    +  if (attr.type().compare(0, 5, "list(") == 0) {
    +    attr_type = attr_type.substr(5, attr.type().find_last_of(')') - 5);
    +    *iterable_out = true;
    +  }
    +  if (attr_type == "type") {
    +    if (*iterable_out) {
    +      return Type::Enum("DataType", "org.tensorflow");
    +    }
    +    return GetNextGenericTensorType(attr.allowed_values());
    +  }
    +  if (attr_type == "string") {
    +    return Type::Class("String");
    +  }
    +  if (attr_type == "int") {
    +    return Type::Class("Integer");
    +  }
    +  if (attr_type == "float") {
    +    return Type::Class("Float");
    +  }
    +  if (attr_type == "bool") {
    +    return Type::Class("Boolean");
    +  }
    +  if (attr_type == "shape") {
    +    return Type::Class("Shape", "org.tensorflow");
    +  }
    +  if (attr_type == "tensor") {
    +    return Type::Class("Tensor", "org.tensorflow")
    +      .add_parameter(Type::Wildcard());
    +  }
    +  LOG(WARNING) << "Data type for attribute \"" << attr_type << "\" is unknown";
    +  return *iterable_out ? Type::Wildcard() : Type::Class("Object");
    +}
    +
    +Type OpParser::GetNextGenericTensorType(const AttrValue& allowed_values)  {
    +  Type generic = Type::Generic(string(1, next_generic_));
    +  next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1;
    +
    +  // when only real numbers are allowed, enforce that restriction in the Java by
    +  // extending the generic from java.lang.Number
    +  if (IsRealNumbers(allowed_values)) {
    +    generic.add_supertype(Type::Class("Number"));
    +  }
    +  return generic;
    +}
    +
    +}  // namespace java
    +}  // namespace tensorflow
    diff --git a/tensorflow/java/src/gen/cc/op_parser.h b/tensorflow/java/src/gen/cc/op_parser.h
    new file mode 100644
    index 0000000000..42855127cc
    --- /dev/null
    +++ b/tensorflow/java/src/gen/cc/op_parser.h
    @@ -0,0 +1,137 @@
    +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
    +
    +Licensed under the Apache License, Version 2.0 (the "License");
    +you may not use this file except in compliance with the License.
    +You may obtain a copy of the License at
    +
    +    http://www.apache.org/licenses/LICENSE-2.0
    +
    +Unless required by applicable law or agreed to in writing, software
    +distributed under the License is distributed on an "AS IS" BASIS,
    +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +See the License for the specific language governing permissions and
    +limitations under the License.
    +==============================================================================*/
    +
    +#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
    +#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
    +
    +#include 
    +#include 
    +#include 
    +#include 
    +
    +#include "tensorflow/core/framework/op_def.pb.h"
    +#include "tensorflow/core/framework/api_def.pb.h"
    +#include "tensorflow/java/src/gen/cc/java_defs.h"
    +
    +namespace tensorflow {
    +namespace java {
    +
    +// Specification of a TensorFlow operation to generate.
    +//
    +// This is the result of an operation definition parsing, see OpParser::Parse().
    +class OpSpec {
    + public:
    +  class Endpoint {
    +   public:
    +    Endpoint(const Type& type, const Javadoc& javadoc)
    +      : type_(type), javadoc_(javadoc) {}
    +    const Type& type() const { return type_; }
    +    const Javadoc& javadoc() const { return javadoc_; }
    +
    +   private:
    +    Type type_;
    +    Javadoc javadoc_;
    +  };
    +
    +  class Operand {
    +   public:
    +    Operand(const string& graph_name, const Variable& var,
    +        const Type& data_type, const string& description, bool iterable)
    +     : graph_name_(graph_name), var_(var), data_type_(data_type),
    +       description_(description), iterable_(iterable) {}
    +    const string& graph_name() const { return graph_name_; }
    +    const Variable& var() const { return var_; }
    +    Variable* var_ptr() { return &var_; }
    +    const Type& data_type() const { return data_type_; }
    +    const string& description() const { return description_; }
    +    bool iterable() const { return iterable_; }
    +
    +   private:
    +    string graph_name_;
    +    Variable var_;
    +    Type data_type_;
    +    string description_;
    +    bool iterable_;
    +  };
    +
    +  explicit OpSpec(const string& graph_name) : graph_name_(graph_name) {}
    +  const string& graph_name() const { return graph_name_; }
    +  const std::vector endpoints() const { return endpoints_; }
    +  void add_endpoint(const Type& type, const Javadoc& javadoc) {
    +    endpoints_.push_back(Endpoint(type, javadoc));
    +  }
    +  const std::vector& inputs() const { return inputs_; }
    +  void add_input(const Operand& input) {
    +    inputs_.push_back(input);
    +  }
    +  const std::vector& outputs() const { return outputs_; }
    +  void add_output(const Operand& output) {
    +    outputs_.push_back(output);
    +  }
    +  const std::vector& attributes() const { return attributes_; }
    +  void add_attribute(const Operand& attribute) {
    +    attributes_.push_back(attribute);
    +  }
    +  const std::vector& options() const { return options_; }
    +  void add_option(const Operand& option) {
    +    options_.push_back(option);
    +  }
    +
    + private:
    +  string graph_name_;
    +  std::vector endpoints_;
    +  std::vector inputs_;
    +  std::vector outputs_;
    +  std::vector attributes_;
    +  std::vector options_;
    +};
    +
    +// A parser of ops proto definitions.
    +//
    +// This object parses the definition and the api of an TensorFlow operation to
    +// produce a specification that can be used for Java source code rendering.
    +class OpParser {
    + public:
    +  OpParser(const OpDef& op_def, const ApiDef& api_def, const string& lib_name,
    +      const string& base_package);
    +  virtual ~OpParser() = default;
    +
    +  // Produces an operation specification from its proto definitions.
    +  void Parse(std::unique_ptr* op_ptr);
    +
    + private:
    +  OpDef op_def_;
    +  ApiDef op_api_;
    +  string lib_name_;
    +  string base_package_;
    +  std::map visited_attrs_;
    +  char next_generic_ = 0;
    +
    +  void BuildEndpoints(OpSpec* op);
    +  void ParseInput(const OpDef_ArgDef& input_def,
    +      const ApiDef::Arg& input_api, OpSpec* op);
    +  void ParseOutput(const OpDef_ArgDef& output_def,
    +      const ApiDef::Arg& output_api, OpSpec* op);
    +  void ParseAttribute(const OpDef_AttrDef& attr_def,
    +      const ApiDef::Attr& attr_api, OpSpec* op);
    +  Type DataTypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
    +  Type DataTypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out);
    +  Type GetNextGenericTensorType(const AttrValue& allowed_values);
    +};
    +
    +}  // namespace java
    +}  // namespace tensorflow
    +
    +#endif  // TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
    diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc
    index a02f75ad6e..b1de5af6ba 100644
    --- a/tensorflow/java/src/gen/cc/source_writer.cc
    +++ b/tensorflow/java/src/gen/cc/source_writer.cc
    @@ -15,7 +15,7 @@ limitations under the License.
     
     #include 
     #include 
    -#include 
    +#include 
     
     #include "tensorflow/java/src/gen/cc/source_writer.h"
     
    @@ -83,20 +83,20 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) {
     }
     
     SourceWriter& SourceWriter::AppendType(const Type& type) {
    -  if (type.kind() == Type::Kind::GENERIC && type.name().empty()) {
    +  if (type.unknown()) {
         Append("?");
       } else {
         Append(type.name());
    -  }
    -  if (!type.parameters().empty()) {
    -    Append("<");
    -    for (const Type& t : type.parameters()) {
    -      if (&t != &type.parameters().front()) {
    -        Append(", ");
    +    if (!type.parameters().empty()) {
    +      Append("<");
    +      for (const Type& t : type.parameters()) {
    +        if (&t != &type.parameters().front()) {
    +          Append(", ");
    +        }
    +        AppendType(t);
           }
    -      AppendType(t);
    +      Append(">");
         }
    -    Append(">");
       }
       return *this;
     }
    @@ -107,7 +107,21 @@ SourceWriter& SourceWriter::EndLine() {
       return *this;
     }
     
    -SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) {
    +SourceWriter& SourceWriter::BeginBlock(const string& expression) {
    +  if (!expression.empty()) {
    +    Append(expression + " {");
    +  } else {
    +    Append(newline_ ? "{" : " {");
    +  }
    +  return EndLine().Indent(2);
    +}
    +
    +SourceWriter& SourceWriter::EndBlock() {
    +  return Indent(-2).Append("}").EndLine();
    +}
    +
    +SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers,
    +    const Javadoc* javadoc) {
       GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
       if (!method.constructor()) {
         generic_namespace->Visit(method.return_type());
    @@ -116,8 +130,9 @@ SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) {
         generic_namespace->Visit(v.type());
       }
       EndLine();
    -  WriteDoc(method.description(), method.return_description(),
    -      &method.arguments());
    +  if (javadoc != nullptr) {
    +    WriteJavadoc(*javadoc);
    +  }
       if (!method.annotations().empty()) {
         WriteAnnotations(method.annotations());
       }
    @@ -145,29 +160,35 @@ SourceWriter& SourceWriter::EndMethod() {
       return *this;
     }
     
    -SourceWriter& SourceWriter::BeginType(const Type& type,
    -    const std::list* dependencies, int modifiers) {
    +SourceWriter& SourceWriter::BeginType(const Type& type, int modifiers,
    +    const std::list* extra_dependencies, const Javadoc* javadoc) {
       if (!type.package().empty()) {
         Append("package ").Append(type.package()).Append(";").EndLine();
       }
    -  if (dependencies != nullptr && !dependencies->empty()) {
    -    TypeImporter type_importer(type.package());
    -    for (const Type& t : *dependencies) {
    +  TypeImporter type_importer(type.package());
    +  type_importer.Visit(type);
    +  if (extra_dependencies != nullptr) {
    +    for (const Type& t : *extra_dependencies) {
           type_importer.Visit(t);
         }
    +  }
    +  if (!type_importer.imports().empty()) {
         EndLine();
         for (const string& s : type_importer.imports()) {
           Append("import ").Append(s).Append(";").EndLine();
         }
       }
    -  return BeginInnerType(type, modifiers);
    +  return BeginInnerType(type, modifiers, javadoc);
     }
     
    -SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers) {
    +SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers,
    +    const Javadoc* javadoc) {
       GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
       generic_namespace->Visit(type);
       EndLine();
    -  WriteDoc(type.description());
    +  if (javadoc != nullptr) {
    +    WriteJavadoc(*javadoc);
    +  }
       if (!type.annotations().empty()) {
         WriteAnnotations(type.annotations());
       }
    @@ -200,14 +221,15 @@ SourceWriter& SourceWriter::EndType() {
       return *this;
     }
     
    -SourceWriter& SourceWriter::WriteFields(const std::list& fields,
    -    int modifiers) {
    -  EndLine();
    -  for (const Variable& v : fields) {
    -    WriteModifiers(modifiers);
    -    AppendType(v.type()).Append(" ").Append(v.name()).Append(";");
    -    EndLine();
    +SourceWriter& SourceWriter::WriteField(const Variable& field, int modifiers,
    +    const Javadoc* javadoc) {
    +  // If present, write field javadoc only as one brief line
    +  if (javadoc != nullptr && !javadoc->brief().empty()) {
    +    Append("/** ").Append(javadoc->brief()).Append(" */").EndLine();
       }
    +  WriteModifiers(modifiers);
    +  AppendType(field.type()).Append(" ").Append(field.name()).Append(";");
    +  EndLine();
       return *this;
     }
     
    @@ -228,39 +250,33 @@ SourceWriter& SourceWriter::WriteModifiers(int modifiers) {
       return *this;
     }
     
    -SourceWriter& SourceWriter::WriteDoc(const string& description,
    -    const string& return_description, const std::list* parameters) {
    -  if (description.empty() && return_description.empty()
    -      && (parameters == nullptr || parameters->empty())) {
    -    return *this;  // no doc to write
    -  }
    +SourceWriter& SourceWriter::WriteJavadoc(const Javadoc& javadoc) {
    +  Append("/**").Prefix(" * ").EndLine();
       bool do_line_break = false;
    -  Append("/**").EndLine().Prefix(" * ");
    -  if (!description.empty()) {
    -    Write(description).EndLine();
    +  if (!javadoc.brief().empty()) {
    +    Write(javadoc.brief()).EndLine();
         do_line_break = true;
       }
    -  if (parameters != nullptr && !parameters->empty()) {
    +  if (!javadoc.details().empty()) {
         if (do_line_break) {
    -      EndLine();
    -      do_line_break = false;
    -    }
    -    for (const Variable& v : *parameters) {
    -      Append("@param ").Append(v.name());
    -      if (!v.description().empty()) {
    -        Append(" ").Write(v.description());
    -      }
    -      EndLine();
    +      Append("

    ").EndLine(); } + Write(javadoc.details()).EndLine(); + do_line_break = true; } - if (!return_description.empty()) { + if (!javadoc.tags().empty()) { if (do_line_break) { EndLine(); - do_line_break = false; } - Append("@return ").Write(return_description).EndLine(); + for (const auto& p : javadoc.tags()) { + Append("@" + p.first); + if (!p.second.empty()) { + Append(" ").Write(p.second); + } + EndLine(); + } } - return Prefix("").Append(" **/").EndLine(); + return Prefix("").Append(" */").EndLine(); } SourceWriter& SourceWriter::WriteAnnotations( @@ -311,20 +327,19 @@ void SourceWriter::PopGenericNamespace() { void SourceWriter::TypeVisitor::Visit(const Type& type) { DoVisit(type); for (const Type& t : type.parameters()) { - DoVisit(t); + Visit(t); } for (const Annotation& t : type.annotations()) { DoVisit(t); } for (const Type& t : type.supertypes()) { - DoVisit(t); + Visit(t); } } void SourceWriter::GenericNamespace::DoVisit(const Type& type) { // ignore non-generic parameters, wildcards and generics already declared - if (type.kind() == Type::GENERIC - && !type.IsWildcard() + if (type.kind() == Type::GENERIC && !type.unknown() && generic_names_.find(type.name()) == generic_names_.end()) { declared_types_.push_back(&type); generic_names_.insert(type.name()); @@ -333,7 +348,7 @@ void SourceWriter::GenericNamespace::DoVisit(const Type& type) { void SourceWriter::TypeImporter::DoVisit(const Type& type) { if (!type.package().empty() && type.package() != current_package_) { - imports_.insert(type.package() + '.' + type.name()); + imports_.insert(type.full_name()); } } diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h index f011acd30a..1f0febe9a3 100644 --- a/tensorflow/java/src/gen/cc/source_writer.h +++ b/tensorflow/java/src/gen/cc/source_writer.h @@ -93,25 +93,22 @@ class SourceWriter { // This method appends a new opening brace to the current data and indent the // next lines according to Google Java Style Guide. The block can optionally // be preceded by an expression (e.g. Append("if(true)").BeginBlock();) - SourceWriter& BeginBlock() { - return Append(newline_ ? "{" : " {").EndLine().Indent(2); - } + SourceWriter& BeginBlock(const string& expr = ""); // Ends the current block of source code. // // This method appends a new closing brace to the current data and outdent the // next lines back to the margin used before BeginBlock() was invoked. - SourceWriter& EndBlock() { - return Indent(-2).Append("}").EndLine(); - } + SourceWriter& EndBlock(); // Begins to write a method. // // This method outputs the signature of the Java method from the data passed - // in the 'method' parameter and starts a new block. Additionnal modifiers can - // also be passed in parameter to define the accesses and the scope of this - // method. - SourceWriter& BeginMethod(const Method& method, int modifiers = 0); + // in the 'method' parameter and starts a new block. Modifiers are also passed + // in parameter to define the access scope of this method and, optionally, + // a Javadoc. + SourceWriter& BeginMethod(const Method& method, int modifiers, + const Javadoc* javadoc = nullptr); // Ends the current method. // @@ -122,22 +119,24 @@ class SourceWriter { // Begins to write the main type of a source file. // // This method outputs the declaration of the Java type from the data passed - // in the 'type' parameter and starts a new block. Additionnal modifiers can - // also be passed in parameter to define the accesses and the scope of this - // type. + // in the 'type' parameter and starts a new block. Modifiers are also passed + // in parameter to define the access scope of this type and, optionally, + // a Javadoc. // - // If not null, all types found in the 'dependencies' list will be imported - // before declaring the new type. - SourceWriter& BeginType(const Type& clazz, - const std::list* dependencies, int modifiers = 0); + // If not null, all types found in the 'extra_dependencies' list will be + // imported before declaring the new type. + SourceWriter& BeginType(const Type& clazz, int modifiers, + const std::list* extra_dependencies = nullptr, + const Javadoc* javadoc = nullptr); // Begins to write a new inner type. // // This method outputs the declaration of the Java type from the data passed - // in the 'type' parameter and starts a new block. Additionnal modifiers can - // also be passed in parameter to define the accesses and the scope of this - // type. - SourceWriter& BeginInnerType(const Type& type, int modifiers = 0); + // in the 'type' parameter and starts a new block. Modifiers are also passed + // in parameter to define the accesses and the scope of this type and, + // optionally, a Javadoc. + SourceWriter& BeginInnerType(const Type& type, int modifiers, + const Javadoc* javadoc = nullptr); // Ends the current type. // @@ -145,13 +144,13 @@ class SourceWriter { // BeginType() or BeginInnerType() prior to this. SourceWriter& EndType(); - // Writes a list of variables as fields of a type. + // Writes a variable as fields of a type. // // This method must be called within the definition of a type (see BeginType() - // or BeginInnerType()). Additional modifiers can also be passed in parameter - // to define the accesses and the scope of those fields. - SourceWriter& WriteFields(const std::list& fields, - int modifiers = 0); + // or BeginInnerType()). Modifiers are also be passed in parameter to define + // the accesses and the scope of this field and, optionally, a Javadoc. + SourceWriter& WriteField(const Variable& field, int modifiers, + const Javadoc* javadoc = nullptr); protected: virtual void DoAppend(const StringPiece& str) = 0; @@ -207,9 +206,7 @@ class SourceWriter { std::stack generic_namespaces_; SourceWriter& WriteModifiers(int modifiers); - SourceWriter& WriteDoc(const string& description, - const string& return_description = "", - const std::list* parameters = nullptr); + SourceWriter& WriteJavadoc(const Javadoc& javadoc); SourceWriter& WriteAnnotations(const std::list& annotations); SourceWriter& WriteGenerics(const std::list& generics); GenericNamespace* PushGenericNamespace(int modifiers); diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc index 4bce2fea70..8bd42d9d0e 100644 --- a/tensorflow/java/src/gen/cc/source_writer_test.cc +++ b/tensorflow/java/src/gen/cc/source_writer_test.cc @@ -250,7 +250,7 @@ TEST(StreamTest, Types) { .AppendType(generic).Append(", ") .AppendType(Type::ListOf(generic)).Append(", ") .AppendType(Type::ListOf(Type::IterableOf(generic))).Append(", ") - .AppendType(Type::ListOf(Type::Generic())); + .AppendType(Type::ListOf(Type::Wildcard())); const char* expected = "int, String, T, List, List>, List"; @@ -282,7 +282,7 @@ TEST(WriteType, SimpleClass) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -300,7 +300,7 @@ TEST(WriteType, SimpleClassWithDependencies) { deps.push_back(Type::Class("SamePackageType", "org.tensorflow")); deps.push_back(Type::Class("NoPackageType")); - writer.BeginType(clazz, &deps, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC, &deps).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -313,18 +313,21 @@ TEST(WriteType, SimpleClassWithDependencies) { TEST(WriteType, AnnotatedAndDocumentedClass) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); - clazz.description("This class has a\n

    \nmultiline description."); + Javadoc clazz_doc; + clazz_doc.brief("Javadoc test") + .details("This is a\nmultiline description."); clazz.add_annotation(Annotation::Create("Bean")); clazz.add_annotation(Annotation::Create("SuppressWarnings") .attributes("\"rawtypes\"")); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC, nullptr, &clazz_doc).EndType(); const char* expected = "package org.tensorflow;\n\n" "/**\n" - " * This class has a\n" + " * Javadoc test\n" " *

    \n" + " * This is a\n" " * multiline description.\n" " **/\n" "@Bean\n" @@ -339,7 +342,7 @@ TEST(WriteType, ParameterizedClass) { clazz.add_parameter(Type::Generic("T")); clazz.add_parameter(Type::Generic("U").add_supertype(Type::Class("Number"))); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -358,7 +361,7 @@ TEST(WriteType, ParameterizedClassAndSupertypes) { clazz.add_supertype(Type::Interface("Runnable")); clazz.add_supertype(Type::Class("SuperTest").add_parameter(type_t)); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -372,24 +375,24 @@ TEST(WriteType, ParameterizedClassFields) { Type clazz = Type::Class("Test", "org.tensorflow"); Type type_t = Type::Generic("T").add_supertype(Type::Class("Number")); clazz.add_parameter(type_t); - std::list static_fields; - static_fields.push_back(Variable::Create("field1", Type::Class("String"))); - std::list member_fields; - member_fields.push_back(Variable::Create("field2", Type::Class("String"))); - member_fields.push_back(Variable::Create("field3", type_t)); - - writer.BeginType(clazz, nullptr, PUBLIC) - .WriteFields(static_fields, STATIC | PUBLIC | FINAL) - .WriteFields(member_fields, PRIVATE) + Variable field1 = Variable::Create("field1", Type::Class("String")); + Variable field2 = Variable::Create("field2", Type::Class("String")); + Variable field3 = Variable::Create("field3", type_t); + Javadoc field3_doc; + field3_doc.brief("This variable is documented"); + + writer.BeginType(clazz, PUBLIC) + .WriteField(field1, STATIC | PUBLIC | FINAL) + .WriteField(field2, PRIVATE) + .WriteField(field3, PRIVATE, &field3_doc) .EndType(); const char* expected = "package org.tensorflow;\n\n" "public class Test {\n" - " \n" " public static final String field1;\n" - " \n" " private String field2;\n" + " /** This variable is documented */\n" " private T field3;\n" "}\n"; ASSERT_STREQ(expected, writer.str().data()); @@ -400,7 +403,7 @@ TEST(WriteType, SimpleInnerClass) { Type clazz = Type::Class("Test", "org.tensorflow"); Type inner_class = Type::Class("InnerTest"); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginInnerType(inner_class, PUBLIC) .EndType() .EndType(); @@ -423,7 +426,7 @@ TEST(WriteType, StaticParameterizedInnerClass) { Type inner_class = Type::Class("InnerTest"); inner_class.add_parameter(type_t); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginInnerType(inner_class, PUBLIC | STATIC) .EndType() .EndType(); @@ -443,7 +446,7 @@ TEST(WriteMethod, SimpleMethod) { Type clazz = Type::Class("Test", "org.tensorflow"); Method method = Method::Create("doNothing", Type::Void()); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginMethod(method, PUBLIC).EndMethod() .EndType(); @@ -461,13 +464,15 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); Method method = Method::Create("doNothing", Type::Void()); - method.description("This method has a\n

    \nmultiline description."); + Javadoc method_doc; + method_doc.brief("Javadoc test") + .details("This method has a\nmultiline description."); method.add_annotation(Annotation::Create("Override")); method.add_annotation(Annotation::Create("SuppressWarnings") .attributes("\"rawtypes\"")); - writer.BeginType(clazz, nullptr, PUBLIC) - .BeginMethod(method, PUBLIC).EndMethod() + writer.BeginType(clazz, PUBLIC) + .BeginMethod(method, PUBLIC, &method_doc).EndMethod() .EndType(); const char* expected = @@ -475,8 +480,9 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { "public class Test {\n" " \n" " /**\n" - " * This method has a\n" + " * Javadoc test\n" " *

    \n" + " * This method has a\n" " * multiline description.\n" " **/\n" " @Override\n" @@ -490,16 +496,18 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { TEST(WriteMethod, DocumentedMethodWithArguments) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); + Variable reverse = Variable::Create("reverse", Type::Boolean()); Method method = Method::Create("boolToInt", Type::Int()); - method.description("Converts a boolean to an int"); - method.return_description("int value for this boolean"); method.add_argument(Variable::Create("b", Type::Boolean())); - Variable reverse = Variable::Create("reverse", Type::Boolean()); - reverse.description("if true, value is reversed"); method.add_argument(reverse); - - writer.BeginType(clazz, nullptr, PUBLIC) - .BeginMethod(method, PUBLIC) + Javadoc method_doc; + method_doc.brief("Converts a boolean to an int") + .details("This method will convert\na boolean to an int") + .add_param_tag(reverse.name(), "if true, value is reversed") + .add_tag("return", "int value for this boolean"); + + writer.BeginType(clazz, PUBLIC) + .BeginMethod(method, PUBLIC, &method_doc) .Append("if (b && !reverse)") .BeginBlock() .Append("return 1;").EndLine() @@ -514,8 +522,10 @@ TEST(WriteMethod, DocumentedMethodWithArguments) { " \n" " /**\n" " * Converts a boolean to an int\n" + " *

    \n" + " * This method will convert\n" + " * a boolean to an int\n" " * \n" - " * @param b\n" " * @param reverse if true, value is reversed\n" " * @return int value for this boolean\n" " **/\n" @@ -536,7 +546,7 @@ TEST(WriteMethod, ParameterizedMethod) { clazz.add_parameter(type_t); Method method = Method::Create("doNothing", type_t); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginMethod(method, PUBLIC) .Append("return null;").EndLine() .EndMethod() @@ -560,7 +570,7 @@ TEST(WriteMethod, StaticParameterizedMethod) { clazz.add_parameter(type_t); Method method = Method::Create("doNothing", type_t); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginMethod(method, PUBLIC | STATIC) .Append("return null;").EndLine() .EndMethod() diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl index a6650fc4ea..1e7899cf7a 100644 --- a/tensorflow/java/src/gen/gen_ops.bzl +++ b/tensorflow/java/src/gen/gen_ops.bzl @@ -1,9 +1,11 @@ # -*- Python -*- -load("//tensorflow:tensorflow.bzl", - "tf_binary_additional_srcs", - "tf_cc_binary", - "tf_copts") +load( + "//tensorflow:tensorflow.bzl", + "tf_binary_additional_srcs", + "tf_cc_binary", + "tf_copts", +) # Given a list of "ops_libs" (a list of files in the core/ops directory # without their .cc extensions), generate Java wrapper code for all operations @@ -27,16 +29,31 @@ def tf_java_op_gen_srcjar(name, ops_libs_pkg="//tensorflow/core", out_dir="ops/", out_src_dir="src/main/java/", + api_def_srcs=[], visibility=["//tensorflow/java:__pkg__"]): gen_tools = [] gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files + srcs = api_def_srcs[:] # Construct an op generator binary for each ops library. for ops_lib in ops_libs: gen_lib = ops_lib[:ops_lib.rfind("_")] out_gen_tool = out_dir + ops_lib + "_gen_tool" + if not api_def_srcs: + api_def_args_str = "," + else: + api_def_args = [] + for api_def_src in api_def_srcs: + # Add directory of the first ApiDef source to args. + # We are assuming all ApiDefs in a single api_def_src are in the + # same directory. + api_def_args.append( + " $$(dirname $$(echo $(locations " + api_def_src + + ") | cut -d\" \" -f1))") + api_def_args_str = ",".join(api_def_args) + tf_cc_binary( name=out_gen_tool, copts=tf_copts(), @@ -48,7 +65,8 @@ def tf_java_op_gen_srcjar(name, gen_cmds += ["$(location :" + out_gen_tool + ")" + " --output_dir=$(@D)/" + out_src_dir + " --lib_name=" + gen_lib + - " --base_package=" + gen_base_package] + " --base_package=" + gen_base_package + + " " + api_def_args_str] # Generate a source archive containing generated code for these ops. gen_srcjar = out_dir + name + ".srcjar" @@ -57,6 +75,7 @@ def tf_java_op_gen_srcjar(name, gen_tools += tf_binary_additional_srcs() native.genrule( name=name, + srcs=srcs, outs=[gen_srcjar], tools=gen_tools, cmd="&&".join(gen_cmds)) diff --git a/tensorflow/java/src/gen/resources/license.snippet.java b/tensorflow/java/src/gen/resources/license.snippet.java new file mode 100644 index 0000000000..90285ec669 --- /dev/null +++ b/tensorflow/java/src/gen/resources/license.snippet.java @@ -0,0 +1,14 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ -- GitLab From 7e80197f020895fea41eda36b08135b747a9a4f1 Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Fri, 6 Apr 2018 08:56:54 -0400 Subject: [PATCH 026/755] Improve Javadoc and include first code review --- tensorflow/java/BUILD | 23 +- tensorflow/java/src/gen/cc/java_defs.h | 12 +- tensorflow/java/src/gen/cc/op_gen_main.cc | 48 +- tensorflow/java/src/gen/cc/op_generator.cc | 224 ++++++---- tensorflow/java/src/gen/cc/op_generator.h | 25 +- tensorflow/java/src/gen/cc/op_parser.cc | 417 ------------------ tensorflow/java/src/gen/cc/op_parser.h | 137 ------ tensorflow/java/src/gen/cc/op_specs.cc | 390 ++++++++++++++++ tensorflow/java/src/gen/cc/op_specs.h | 152 +++++++ tensorflow/java/src/gen/cc/source_writer.cc | 2 +- tensorflow/java/src/gen/cc/source_writer.h | 2 +- .../java/src/gen/cc/source_writer_test.cc | 20 +- tensorflow/java/src/gen/gen_ops.bzl | 68 +-- ...ense.snippet.java => license.java.snippet} | 0 14 files changed, 760 insertions(+), 760 deletions(-) delete mode 100644 tensorflow/java/src/gen/cc/op_parser.cc delete mode 100644 tensorflow/java/src/gen/cc/op_parser.h create mode 100644 tensorflow/java/src/gen/cc/op_specs.cc create mode 100644 tensorflow/java/src/gen/cc/op_specs.h rename tensorflow/java/src/gen/resources/{license.snippet.java => license.java.snippet} (100%) diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 635a4e807d..17566e1a9c 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -68,9 +68,13 @@ filegroup( ], ) +# Build the gen tool as a library, as it will be linked to a core/ops binary +# files before making it an executable. tf_java_op_gen_srcjar( name = "java_op_gen_sources", - api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], + api_def_srcs = [ + "//tensorflow/core/api_def:base_api_def", + ], gen_base_package = "org.tensorflow.op", gen_tool = "java_op_gen_tool", ops_libs = [ @@ -95,30 +99,17 @@ tf_java_op_gen_srcjar( ], ) -# Build the gen tool as a library, as it will be linked to a core/ops binary -# file before making it an executable. See tf_java_op_gen_srcjar(). -cc_library( - name = "java_op_gen_tool", - srcs = [ - "src/gen/cc/op_gen_main.cc", - ], - copts = tf_copts(), - deps = [ - ":java_op_gen_lib", - ], -) - cc_library( name = "java_op_gen_lib", srcs = [ "src/gen/cc/op_generator.cc", - "src/gen/cc/op_parser.cc", + "src/gen/cc/op_specs.cc", "src/gen/cc/source_writer.cc", ], hdrs = [ "src/gen/cc/java_defs.h", "src/gen/cc/op_generator.h", - "src/gen/cc/op_parser.h", + "src/gen/cc/op_specs.h", "src/gen/cc/source_writer.h", ], copts = tf_copts(), diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index 2065477f58..81ac67eb2f 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -230,12 +230,12 @@ class Javadoc { return Javadoc(brief); } const string& brief() const { return brief_; } - const string& details() const { return description_; } - Javadoc& details(const string description) { - description_ = description; + const string& details() const { return details_; } + Javadoc& details(const string& details) { + details_ = details; return *this; } - const std::list> tags() const { return tags_; } + const std::list>& tags() const { return tags_; } Javadoc& add_tag(const string& tag, const string& text) { tags_.push_back(std::make_pair(tag, text)); return *this; @@ -246,7 +246,7 @@ class Javadoc { private: string brief_; - string description_; + string details_; std::list> tags_; explicit Javadoc(const string& brief) : brief_(brief) {} diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index 015200023f..458141b877 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,55 +36,41 @@ const char kUsageHeader[] = "Operation wrappers are generated under the path specified by the " "'--output_dir' argument. This path can be absolute or relative to the\n" "current working directory and will be created if it does not exists.\n\n" - "The '--lib_name' argument is used to classify the set of operations. If " - "the chosen name contains more than one word, it must be provided in \n" - "snake_case. This value is declined into other meaningful names, such as " - "the group and package of the generated operations. For example,\n" - "'--lib_name=my_lib' generates the operations under the " - "'org.tensorflow.op.mylib' package and add them to the 'myLib()' operator\n" - "group.\n\n" - "Note that the operator group assigned to the generated wrappers is just " - "an annotation tag at this stage. Operations will not be available " - "through\n" - "the 'org.tensorflow.op.Ops' API as a group until the generated classes " - "are compiled using an appropriate annotation processor.\n\n" + "Note that the operations will not be available through the " + "'org.tensorflow.op.Ops' API until the generated classes are compiled\n" + "using an appropriate annotation processor.\n\n" "The '--base_package' overrides the default parent package under which " "the generated subpackage and classes are to be located.\n\n" - "Finally, a list of directories of API proto definitions can be provided " - "to override default values found in the ops definitions, ordered by\n" - "priority (the last having precedence over the first).\n\n"; + "Finally, the `--api_dirs` argument takes a list of comma-seperated " + "directories of API definitions can be provided to override default\n" + "values found in the ops definitions. Directories are ordered by priority " + "(the last having precedence over the first).\n\n"; } // namespace java } // namespace tensorflow int main(int argc, char* argv[]) { - tensorflow::string lib_name; tensorflow::string output_dir; tensorflow::string base_package = "org.tensorflow.op"; + tensorflow::string api_dirs_str; std::vector flag_list = { tensorflow::Flag("output_dir", &output_dir, "Root directory into which output files are generated"), - tensorflow::Flag( - "lib_name", &lib_name, - "A name, in snake_case, used to classify this set of operations"), - tensorflow::Flag( - "base_package", &base_package, - "Package parent to the generated subpackage and classes")}; + tensorflow::Flag("base_package", &base_package, + "Package parent to the generated subpackage and classes"), + tensorflow::Flag("api_dirs", &api_dirs_str, + "List of directories that contains the ops api definitions")}; tensorflow::string usage = tensorflow::java::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); tensorflow::port::InitMain(usage.c_str(), &argc, &argv); - QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage; - std::vector api_dirs; - if (argc > 1) { - api_dirs = tensorflow::str_util::Split(argv[1], ",", - tensorflow::str_util::SkipEmpty()); - } + QCHECK(parsed_flags_ok && !output_dir.empty()) << usage; + std::vector api_dirs = tensorflow::str_util::Split( + api_dirs_str, ",", tensorflow::str_util::SkipEmpty()); tensorflow::java::OpGenerator generator(base_package, output_dir, api_dirs); tensorflow::OpList ops; tensorflow::OpRegistry::Global()->Export(false, &ops); - tensorflow::Status status = generator.Run(ops, lib_name); - TF_QCHECK_OK(status); + TF_CHECK_OK(generator.Run(ops)); return 0; } diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index c9b57f5706..c32ad3b109 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -27,15 +28,15 @@ limitations under the License. #include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/java/src/gen/cc/java_defs.h" #include "tensorflow/java/src/gen/cc/source_writer.h" -#include "tensorflow/java/src/gen/cc/op_parser.h" #include "tensorflow/java/src/gen/cc/op_generator.h" +#include "tensorflow/java/src/gen/cc/op_specs.h" namespace tensorflow { namespace java { namespace { const char* kLicenseSnippet = - "tensorflow/java/src/gen/resources/license.snippet.java"; + "tensorflow/java/src/gen/resources/license.java.snippet"; const std::map kPrimitiveAttrTypes = { { "Boolean", Type::Boolean() }, @@ -66,34 +67,34 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode, } // Don't pay attention to duplicate types in the dependency list, they will // be filtered out by the SourceWriter. - for (const OpSpec::Operand& input : op.inputs()) { + for (const ArgumentSpec& input : op.inputs()) { out->push_back(input.var().type()); if (input.iterable()) { out->push_back(Type::Class("Operands", "org.tensorflow.op")); } } - for (const OpSpec::Operand& output : op.outputs()) { + for (const ArgumentSpec& output : op.outputs()) { out->push_back(output.var().type()); if (output.iterable()) { out->push_back(Type::Class("Arrays", "java.util")); } } - for (const OpSpec::Operand& attribute : op.attributes()) { + for (const AttributeSpec& attribute : op.attributes()) { out->push_back(attribute.var().type()); if (attribute.var().type().name() == "Class") { out->push_back(Type::Enum("DataType", "org.tensorflow")); } } - for (const OpSpec::Operand& option : op.options()) { - out->push_back(option.var().type()); + for (const AttributeSpec& optional_attribute : op.optional_attributes()) { + out->push_back(optional_attribute.var().type()); } } -void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional, +void WriteSetAttrDirective(const AttributeSpec& attr, bool optional, SourceWriter* writer) { string var = optional ? "opts." + attr.var().name() : attr.var().name(); if (attr.iterable()) { - const Type& type = attr.data_type(); + const Type& type = attr.type(); std::map::const_iterator it = kPrimitiveAttrTypes.find(type.name()); if (it != kPrimitiveAttrTypes.end()) { @@ -107,11 +108,11 @@ void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional, .Append(array + "[i] = " + var + ".get(i);") .EndLine() .EndBlock() - .Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + array) + .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + array) .Append(");") .EndLine(); } else { - writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + var) + writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + var) .Append(".toArray(new ") .AppendType(type) .Append("[" + var + ".size()]));") @@ -119,7 +120,7 @@ void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional, } } else { Type type = attr.var().type(); - writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", "); + writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", "); if (type.name() == "Class") { writer->Append("DataType.fromClass(" + attr.var().name() + "));"); } else { @@ -139,26 +140,26 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class, Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op")); factory.add_argument(scope); factory_doc.add_param_tag(scope.name(), "Current graph scope"); - for (const OpSpec::Operand& input : op.inputs()) { + for (const ArgumentSpec& input : op.inputs()) { factory.add_argument(input.var()); factory_doc.add_param_tag(input.var().name(), input.description()); } - for (const OpSpec::Operand& attribute : op.attributes()) { + for (const AttributeSpec& attribute : op.attributes()) { factory.add_argument(attribute.var()); factory_doc.add_param_tag(attribute.var().name(), attribute.description()); } - if (!op.options().empty()) { + if (!op.optional_attributes().empty()) { factory.add_argument(Variable::Varargs("options", Type::Class("Options"))); factory_doc.add_param_tag("options", "carries optional attributes values"); } factory_doc.add_tag("return", "a new instance of " + op_class.name()); writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc); writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" - + op.graph_name() + "\", scope.makeOpName(\"" + + op.graph_op_name() + "\", scope.makeOpName(\"" + op_class.name() + "\"));"); writer->EndLine(); - for (const OpSpec::Operand& input : op.inputs()) { + for (const ArgumentSpec& input : op.inputs()) { if (input.iterable()) { writer->Append("opBuilder.addInputList(Operands.asOutputs(" + input.var().name() + "));"); @@ -169,15 +170,15 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class, writer->EndLine(); } } - for (const OpSpec::Operand& attribute : op.attributes()) { + for (const AttributeSpec& attribute : op.attributes()) { WriteSetAttrDirective(attribute, false, writer); } - if (!op.options().empty()) { + if (!op.optional_attributes().empty()) { writer->BeginBlock("if (options != null)") .BeginBlock("for (Options opts : options)"); - for (const OpSpec::Operand& option : op.options()) { - writer->BeginBlock("if (opts." + option.var().name() + " != null)"); - WriteSetAttrDirective(option, true, writer); + for (const AttributeSpec& attribute : op.optional_attributes()) { + writer->BeginBlock("if (opts." + attribute.var().name() + " != null)"); + WriteSetAttrDirective(attribute, true, writer); writer->EndBlock(); } writer->EndBlock().EndBlock(); @@ -195,8 +196,8 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, .add_argument( Variable::Create("operation", Type::Class("Operation", "org.tensorflow"))); - for (const OpSpec::Operand& output : op.outputs()) { - if (output.iterable() && !output.data_type().unknown()) { + for (const ArgumentSpec& output : op.outputs()) { + if (output.iterable() && !output.type().unknown()) { constructor.add_annotation( Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); break; @@ -208,15 +209,15 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, if (op.outputs().size() > 0) { writer->Append("int outputIdx = 0;") .EndLine(); - for (const OpSpec::Operand& output : op.outputs()) { + for (const ArgumentSpec& output : op.outputs()) { if (output.iterable()) { string var_length = output.var().name() + "Length"; writer->Append("int " + var_length) - .Append(" = operation.outputListLength(\"" + output.graph_name() + .Append(" = operation.outputListLength(\"" + output.op_def_name() + "\");") .EndLine() .Append(output.var().name() + " = Arrays.asList("); - if (!output.data_type().unknown()) { + if (!output.type().unknown()) { writer->Append("(") .AppendType(output.var().type().parameters().front()) .Append("[])"); @@ -236,18 +237,19 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, } void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) { - for (const OpSpec::Operand& option : op.options()) { - Method setter = Method::Create(option.var().name(), Type::Class("Options")) - .add_argument(option.var()); + for (const AttributeSpec& attribute : op.optional_attributes()) { + Method setter = + Method::Create(attribute.var().name(), Type::Class("Options")) + .add_argument(attribute.var()); Javadoc setter_doc = Javadoc::Create() - .add_param_tag(option.var().name(), option.description()); + .add_param_tag(attribute.var().name(), attribute.description()); writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc) - .Append("return new Options()." + option.var().name() + "(" - + option.var().name() + ");") + .Append("return new Options()." + attribute.var().name() + "(" + + attribute.var().name() + ");") .EndLine() .EndMethod(); } - for (const OpSpec::Operand& output : op.outputs()) { + for (const ArgumentSpec& output : op.outputs()) { Method getter = Method::Create(output.var().name(), output.var().type()); Javadoc getter_doc = Javadoc::Create(output.description()); writer->BeginMethod(getter, PUBLIC, &getter_doc) @@ -259,12 +261,12 @@ void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) { void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, SourceWriter* writer) { - OpSpec::Operand output = op.outputs().front(); + ArgumentSpec output = op.outputs().front(); if (mode == SINGLE_OUTPUT) { - bool cast2obj = output.data_type().unknown(); + bool cast2obj = output.type().unknown(); Type return_type = Type::Class("Output", "org.tensorflow") - .add_parameter(cast2obj ? Type::Class("Object") : output.data_type()); + .add_parameter(cast2obj ? Type::Class("Object") : output.type()); Method as_output = Method::Create("asOutput", return_type) .add_annotation(Annotation::Create("Override")); if (cast2obj) { @@ -283,10 +285,10 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, } else if (mode == SINGLE_LIST_OUTPUT) { Type operand = Type::Interface("Operand", "org.tensorflow"); - if (output.data_type().unknown()) { + if (output.type().unknown()) { operand.add_parameter(Type::Class("Object")); } else { - operand.add_parameter(output.data_type()); + operand.add_parameter(output.type()); } Type return_type = Type::Interface("Iterator", "java.util") .add_parameter(operand); @@ -308,57 +310,119 @@ void RenderOptionsClass(const OpSpec& op, SourceWriter* writer) { Javadoc options_doc = Javadoc::Create( "Class holding optional attributes of this operation"); writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc); - for (const OpSpec::Operand& option : op.options()) { - Method setter = Method::Create(option.var().name(), options_class) - .add_argument(option.var()); + for (const AttributeSpec& attribute : op.optional_attributes()) { + Method setter = Method::Create(attribute.var().name(), options_class) + .add_argument(attribute.var()); Javadoc setter_doc = Javadoc::Create() - .add_param_tag(option.var().name(), option.description()); + .add_param_tag(attribute.var().name(), attribute.description()); writer->BeginMethod(setter, PUBLIC, &setter_doc) - .Append("this." + option.var().name() + " = " + option.var().name() - + ";") + .Append("this." + attribute.var().name() + " = " + + attribute.var().name() + ";") .EndLine() .Append("return this;") .EndLine() .EndMethod(); } writer->EndLine(); - for (const OpSpec::Operand& option : op.options()) { - writer->WriteField(option.var(), PRIVATE); + for (const AttributeSpec& optional_attribute : op.optional_attributes()) { + writer->WriteField(optional_attribute.var(), PRIVATE); } Method constructor = Method::ConstructorFor(options_class); writer->BeginMethod(constructor, PRIVATE).EndMethod(); writer->EndType(); } -void RenderEndpoint(const OpSpec& op, const OpSpec::Endpoint& endpoint, - SourceWriter* writer) { +inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) { + return Type::Class(endpoint.name(), + base_package + "." + str_util::Lowercase(endpoint.package())); +} + +void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, + const string& base_package, const string& output_dir, Env* env) { + Type op_class(ClassOf(endpoint, base_package) + .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op"))); + Javadoc op_javadoc(endpoint.javadoc()); + + // implement Operand (or Iterable) if the op has only one output RenderMode mode = DEFAULT; if (op.outputs().size() == 1) { - mode = op.outputs().front().iterable() ? SINGLE_LIST_OUTPUT : SINGLE_OUTPUT; + const ArgumentSpec& output = op.outputs().front(); + Type operand_type(output.type().unknown() ? + Type::Class("Object") : output.type()); + Type operand_inf(Type::Interface("Operand", "org.tensorflow") + .add_parameter(operand_type)); + if (output.iterable()) { + mode = SINGLE_LIST_OUTPUT; + op_class.add_supertype(Type::IterableOf(operand_inf)); + } else { + mode = SINGLE_OUTPUT; + op_class.add_supertype(operand_inf); + } + } + // declare all outputs generics at the op class level + std::set generics; + for (const ArgumentSpec& output : op.outputs()) { + if (output.type().kind() == Type::GENERIC && !output.type().unknown() + && generics.find(output.type().name()) == generics.end()) { + op_class.add_parameter(output.type()); + op_javadoc.add_param_tag("<" + output.type().name() + ">", + "data type of output {@code " + output.var().name() + "}"); + generics.insert(output.type().name()); + } + } + // handle endpoint deprecation + if (endpoint.deprecated()) { + op_class.add_annotation(Annotation::Create("Deprecated")); + string explanation; + if (!op.endpoints().front().deprecated()) { + explanation = "use {@link " + + ClassOf(op.endpoints().front(), base_package).full_name() + + "} instead"; + } else { + explanation = op.deprecation_explanation(); + } + op_javadoc.add_tag("deprecated", explanation); } + // expose the op in the Ops Graph API only if it is visible + if (!op.hidden()) { + op_class.add_annotation( + Annotation::Create("Operator", "org.tensorflow.op.annotation") + .attributes("group = \"" + endpoint.package() + "\"")); + } + // create op class file + string op_dir = io::JoinPath(output_dir, + str_util::StringReplace(op_class.package(), ".", "/", true)); + if (!env->FileExists(op_dir).ok()) { + TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir)); + } + std::unique_ptr op_file; + TF_CHECK_OK(env->NewWritableFile( + io::JoinPath(op_dir, op_class.name() + ".java"), &op_file)); + + // render endpoint source code + SourceFileWriter writer(op_file.get()); std::list dependencies; CollectOpDependencies(op, mode, &dependencies); - const Type& op_class = endpoint.type(); - writer->WriteFromFile(kLicenseSnippet) + writer.WriteFromFile(kLicenseSnippet) .EndLine() .Append("// This file is machine generated, DO NOT EDIT!") .EndLine() .EndLine() - .BeginType(op_class, PUBLIC|FINAL, &dependencies, &endpoint.javadoc()); - if (!op.options().empty()) { - RenderOptionsClass(op, writer); + .BeginType(op_class, PUBLIC|FINAL, &dependencies, &op_javadoc); + if (!op.optional_attributes().empty()) { + RenderOptionsClass(op, &writer); } - RenderFactoryMethod(op, op_class, writer); - RenderGettersAndSetters(op, writer); + RenderFactoryMethod(op, op_class, &writer); + RenderGettersAndSetters(op, &writer); if (mode != DEFAULT) { - RenderInterfaceImpl(op, mode, writer); + RenderInterfaceImpl(op, mode, &writer); } - writer->EndLine(); - for (const OpSpec::Operand& output : op.outputs()) { - writer->WriteField(output.var(), PRIVATE); + writer.EndLine(); + for (const ArgumentSpec& output : op.outputs()) { + writer.WriteField(output.var(), PRIVATE); } - RenderConstructor(op, op_class, writer); - writer->EndType(); + RenderConstructor(op, op_class, &writer); + writer.EndType(); } } // namespace @@ -369,8 +433,7 @@ OpGenerator::OpGenerator(const string& base_package, const string& output_dir, env_(env) { } -Status OpGenerator::Run(const OpList& op_list, const string& lib_name) { - LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations"; +Status OpGenerator::Run(const OpList& op_list) { ApiDefMap api_map(op_list); if (!api_dirs_.empty()) { // Only load api files that correspond to the requested "op_list" @@ -388,37 +451,14 @@ Status OpGenerator::Run(const OpList& op_list, const string& lib_name) { for (const auto& op_def : op_list.op()) { const ApiDef* api_def = api_map.GetApiDef(op_def.name()); if (api_def->visibility() != ApiDef::SKIP) { - Status status = GenerateOp(op_def, *api_def, lib_name); - if (status != Status::OK()) { - LOG(ERROR) << "Fail to generate Java wrapper for operation \"" - << op_def.name() << "\""; + OpSpec op(OpSpec::Create(op_def, *api_def)); + for (const EndpointSpec& endpoint : op.endpoints()) { + GenerateOp(op, endpoint, base_package_, output_dir_, env_); } } } return Status::OK(); } -Status OpGenerator::GenerateOp(const OpDef& op_def, const ApiDef& api_def, - const string& lib_name) { - std::unique_ptr op; - OpParser op_parser(op_def, api_def, lib_name, base_package_); - op_parser.Parse(&op); - for (const OpSpec::Endpoint& endpoint : op->endpoints()) { - string package_path = io::JoinPath(output_dir_, - str_util::StringReplace(endpoint.type().package(), ".", "/", true)); - if (!env_->FileExists(package_path).ok()) { - TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(package_path)); - } - string file_path = - io::JoinPath(package_path, endpoint.type().name() + ".java"); - std::unique_ptr file; - TF_CHECK_OK(env_->NewWritableFile(file_path, &file)); - - SourceFileWriter writer(file.get()); - RenderEndpoint(*op, endpoint, &writer); - } - return Status::OK(); -} - } // namespace java } // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index 19d8db95fb..06b08e852a 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,36 +23,33 @@ limitations under the License. #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/java/src/gen/cc/op_specs.h" namespace tensorflow { namespace java { // A generator of Java operation wrappers. // -// Such generator is normally ran only once per executable, outputting -// wrappers for the all registered operations it has been compiled with. -// Nonetheless, it is designed to support multiple runs, giving a different -// list of operations on each cycle. +// This generator takes a list of ops definitions in input and outputs +// a Java Op wrapper for each of them in the provided directory. The same +// generator instance can be invoked multiple times with a different list of +// ops definitions. class OpGenerator { public: OpGenerator(const string& base_package, const string& output_dir, const std::vector& api_dirs, Env* env = Env::Default()); - virtual ~OpGenerator() = default; // Generates wrappers for the given list of 'ops'. // // Output files are generated in //, - // where 'lib_package' is derived from 'lib_name'. - Status Run(const OpList& op_list, const string& lib_name); + // where 'lib_package' is derived from ops endpoints. + Status Run(const OpList& op_list); private: - string base_package_; - string output_dir_; - std::vector api_dirs_; + const string base_package_; + const string output_dir_; + const std::vector api_dirs_; Env* env_; - - Status GenerateOp(const OpDef& op_def, const ApiDef& api_def, - const string& lib_name); }; } // namespace java diff --git a/tensorflow/java/src/gen/cc/op_parser.cc b/tensorflow/java/src/gen/cc/op_parser.cc deleted file mode 100644 index 0541e343d8..0000000000 --- a/tensorflow/java/src/gen/cc/op_parser.cc +++ /dev/null @@ -1,417 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/java/src/gen/cc/op_parser.h" - -namespace tensorflow { -namespace java { -namespace { - -string SnakeToCamelCase(const string& str, bool upper = false) { - string result; - bool cap = upper; - for (string::const_iterator it = str.begin(); it != str.end(); ++it) { - const char c = *it; - if (c == '_') { - cap = true; - } else if (cap) { - result += toupper(c); - cap = false; - } else { - result += c; - } - } - return result; -} - -bool IsRealNumber(DataType type) { - for (DataType dt : RealNumberTypes()) { - if (type == dt) { - return true; - } - } - return false; -} - -bool IsRealNumbers(const AttrValue& values) { - if (values.has_list()) { - for (int i = 0; i < values.list().type_size(); ++i) { - if (!IsRealNumber(values.list().type(i))) { - return false; - } - } - return true; - } - return IsRealNumber(values.type()); -} - -string ParseDocumentation(const string& text) { - std::stringstream javadoc_text; - string::const_iterator c_iter = text.cbegin(); - bool code = false; - bool emphasis = false; - bool list = false; - while (c_iter != text.cend()) { - char c = *c_iter++; - int count = 1; - switch (c) { - case '\n': - if (!code) { - // consumes all subsequent newlines, if there are more than one, - // then there are two choices: - // - if the next line starts with an asterisk, we are enumerating - // a list of items - // - otherwise, we are starting a new paragraph - for (; c_iter != text.cend() && *c_iter == '\n'; ++count, ++c_iter) {} - if (c_iter != text.cend()) { - if (count > 1) { - if (*c_iter != '*' && list) { - javadoc_text << "

  • \n
\n"; - list = false; - } else if (*c_iter == '*' && !list) { - javadoc_text << "\n
    \n
  • "; - list = true; - c_iter++; - } else { - javadoc_text << "\n

    \n"; - } - } else if (list && *c_iter == '*') { - javadoc_text << "

  • \n
  • "; - c_iter++; - } else { - javadoc_text << '\n'; - } - } - } - break; - case '`': - // consumes all subsequent backquotes, those are use enclose code. - // if there are more than 3, we are dealing with a pre-formatted block, - // otherwise it is a single-line code snippet - for (; c_iter != text.cend() && *c_iter == '`'; ++count, ++c_iter) {} - if (count >= 3) { - javadoc_text << (code ? "\n}" : "
    {@code\n");
    -      } else {
    -        javadoc_text << (code ? "}" : "{@code ");
    -      }
    -      code = !code;
    -      break;
    -    case '*':
    -      if (!code) {
    -        // consumes all subsequent asterisks, if there are more than one, then
    -        // we put the text in bold, otherwise in italic
    -        for (; c_iter != text.cend() && *c_iter == '*'; ++count, ++c_iter) {}
    -        if (count > 1) {
    -          javadoc_text << (emphasis ? "" : "");
    -        } else {
    -          javadoc_text << (emphasis ? "" : "");
    -        }
    -        emphasis = !emphasis;
    -      } else {
    -        javadoc_text << '*';
    -      }
    -      break;
    -    default:
    -      javadoc_text << c;
    -      break;
    -    }
    -  }
    -  return javadoc_text.str();
    -}
    -
    -}  // namespace
    -
    -OpParser::OpParser(const OpDef& op_def, const ApiDef& api_def,
    -    const string& lib_name, const string& base_package)
    -  : op_def_(op_def), op_api_(api_def), lib_name_(lib_name),
    -    base_package_(base_package) {
    -}
    -
    -void OpParser::Parse(std::unique_ptr* op_ptr) {
    -  visited_attrs_.clear();
    -  next_generic_ = 'T';
    -  op_ptr->reset(new OpSpec(op_api_.graph_op_name()));
    -  for (const string& next_input_name : op_api_.arg_order()) {
    -    for (int i = 0; i < op_def_.input_arg().size(); ++i) {
    -      if (op_def_.input_arg(i).name() == next_input_name) {
    -        ParseInput(op_def_.input_arg(i), op_api_.in_arg(i), op_ptr->get());
    -        break;
    -      }
    -    }
    -  }
    -  for (int i = 0; i < op_def_.attr().size(); ++i) {
    -    ParseAttribute(op_def_.attr(i), op_api_.attr(i), op_ptr->get());
    -  }
    -  for (int i = 0; i < op_def_.output_arg().size(); ++i) {
    -    ParseOutput(op_def_.output_arg(i), op_api_.out_arg(i), op_ptr->get());
    -  }
    -  BuildEndpoints(op_ptr->get());
    -}
    -
    -void OpParser::BuildEndpoints(OpSpec* op) {
    -  Javadoc op_doc = Javadoc::Create(ParseDocumentation(op_api_.summary()))
    -    .details(ParseDocumentation(op_api_.description()));
    -  std::vector op_supertypes;
    -  op_supertypes.push_back(Type::Class("PrimitiveOp", "org.tensorflow.op"));
    -  std::map op_generics;
    -  for (const OpSpec::Operand& output : op->outputs()) {
    -    // declare generic output parameters at the Op class level
    -    const Type& data_type = output.data_type();
    -    if (data_type.kind() == Type::GENERIC && !data_type.unknown()
    -        && op_generics.find(data_type.name()) == op_generics.end()) {
    -      op_generics.insert(std::make_pair(data_type.name(), &data_type));
    -      op_doc.add_param_tag("<" + data_type.name() + ">",
    -          "data type of output '" + output.var().name() + "'");
    -    }
    -    // implement the Op as an (iteration of) Operand if it has only one output
    -    if (op->outputs().size() == 1) {
    -      Type operand_inf(Type::Interface("Operand", "org.tensorflow"));
    -      operand_inf.add_parameter(data_type.unknown() ?
    -          Type::Class("Object") : data_type);
    -      op_supertypes.push_back(output.iterable() ?
    -          Type::IterableOf(operand_inf) : operand_inf);
    -    }
    -  }
    -  for (const auto& endpoint_def : op_api_.endpoint()) {
    -    std::vector name_tokens = str_util::Split(endpoint_def.name(), ".");
    -    // if the endpoint specifies a package, use it, otherwise derive it from the
    -    // op library name.
    -    string name;
    -    string package;
    -    if (name_tokens.size() > 1) {
    -      package = str_util::Lowercase(name_tokens.at(0));
    -      name = name_tokens.at(1);
    -    } else {
    -      package = str_util::StringReplace(lib_name_, "_", "", true);
    -      name = name_tokens.at(0);
    -    }
    -    Type endpoint(Type::Class(name, base_package_ + "." + package));
    -    Javadoc endpoint_doc(op_doc);
    -    for (const auto& parameter : op_generics) {
    -      endpoint.add_parameter(*parameter.second);
    -    }
    -    for (const Type& supertype : op_supertypes) {
    -      endpoint.add_supertype(supertype);
    -    }
    -    if (endpoint_def.deprecation_version() > 0) {
    -      string explanation;
    -      if (op_api_.endpoint(0).deprecation_version() == 0) {
    -        explanation = ", use {@link "
    -            + op->endpoints().at(0).type().full_name()
    -            + "} instead";
    -      } else {
    -        explanation = op_def_.deprecation().explanation();
    -      }
    -      endpoint_doc.add_tag("deprecated", explanation);
    -      endpoint.add_annotation(Annotation::Create("Deprecated"));
    -    }
    -    // only visible ops should be annotated for exposure in the Ops Graph API
    -    if (op_api_.visibility() != ApiDef::HIDDEN) {
    -      string group_name = SnakeToCamelCase(lib_name_);
    -      endpoint.add_annotation(
    -          Annotation::Create("Operator", "org.tensorflow.op.annotation")
    -            .attributes("group = \"" + group_name + "\""));
    -    }
    -    op->add_endpoint(endpoint, endpoint_doc);
    -  }
    -}
    -
    -void OpParser::ParseInput(const OpDef_ArgDef& input_def,
    -    const ApiDef::Arg& input_api, OpSpec* op) {
    -  bool iterable = false;
    -  Type data_type = DataTypeOf(input_def, &iterable);
    -  Type type = Type::Interface("Operand", "org.tensorflow")
    -    .add_parameter(data_type);
    -  if (iterable) {
    -    type = Type::IterableOf(type);
    -  }
    -  op->add_input(OpSpec::Operand(input_api.name(),
    -      Variable::Create(SnakeToCamelCase(input_api.rename_to()), type),
    -      data_type,
    -      ParseDocumentation(input_api.description()),
    -      iterable));
    -}
    -
    -void OpParser::ParseOutput(const OpDef_ArgDef& output_def,
    -    const ApiDef::Arg& output_api, OpSpec* op) {
    -  bool iterable = false;
    -  Type data_type = DataTypeOf(output_def, &iterable);
    -  Type type = Type::Class("Output", "org.tensorflow")
    -    .add_parameter(data_type);
    -  if (iterable) {
    -    type = Type::ListOf(type);
    -  }
    -  op->add_output(OpSpec::Operand(output_api.name(),
    -      Variable::Create(SnakeToCamelCase(output_api.rename_to()), type),
    -      data_type,
    -      ParseDocumentation(output_api.description()),
    -      iterable));
    -}
    -
    -void OpParser::ParseAttribute(const OpDef_AttrDef& attr_def,
    -    const ApiDef::Attr& attr_api, OpSpec* op) {
    -  // do not parse attributes already visited, they have probably been inferred
    -  // before as an input argument type
    -  if (visited_attrs_.find(attr_def.name()) != visited_attrs_.cend()) {
    -    return;
    -  }
    -  bool iterable = false;
    -  Type data_type = DataTypeOf(attr_def, &iterable);
    -  // generic attributes should be passed as an explicit type
    -  bool explicit_type = data_type.kind() == Type::GENERIC && !iterable;
    -  Type type = explicit_type ?
    -      Type::Class("Class").add_parameter(data_type) : data_type;
    -  if (iterable) {
    -    type = Type::ListOf(data_type);
    -  }
    -  OpSpec::Operand attr(attr_api.name(),
    -      Variable::Create(SnakeToCamelCase(attr_api.rename_to()), type),
    -      data_type,
    -      ParseDocumentation(attr_api.description()),
    -      iterable);
    -  // attributes with a default value are optional
    -  if (attr_api.has_default_value() && !explicit_type) {
    -    op->add_option(attr);
    -  } else {
    -    op->add_attribute(attr);
    -  }
    -  visited_attrs_.insert(std::make_pair(attr_api.name(), data_type));
    -}
    -
    -Type OpParser::DataTypeOf(const OpDef_ArgDef& arg, bool* iterable_out) {
    -  if (!arg.number_attr().empty()) {
    -    visited_attrs_.insert(std::make_pair(arg.number_attr(), Type::Int()));
    -    *iterable_out = true;
    -  }
    -  if (arg.type() != DataType::DT_INVALID) {
    -    // resolve type from DataType
    -    switch (arg.type()) {
    -      case DataType::DT_BOOL:
    -        return Type::Class("Boolean");
    -
    -      case DataType::DT_STRING:
    -        return Type::Class("String");
    -
    -      case DataType::DT_FLOAT:
    -        return Type::Class("Float");
    -
    -      case DataType::DT_DOUBLE:
    -        return Type::Class("Double");
    -
    -      case DataType::DT_UINT8:
    -        return Type::Class("UInt8", "org.tensorflow.types");
    -
    -      case DataType::DT_INT32:
    -        return Type::Class("Integer");
    -
    -      case DataType::DT_INT64:
    -        return Type::Class("Long");
    -
    -      case DataType::DT_RESOURCE:
    -        // TODO(karllessard) create a Resource utility class that could be
    -        // used to store a resource and its type (passed in a second argument).
    -        // For now, we need to force a wildcard and we will unfortunately lose
    -        // track of the resource type.
    -        return Type::Wildcard();
    -
    -      default:
    -        break;
    -    }
    -  } else {
    -    // resolve type from type attribute
    -    string attr_name = arg.type_attr();
    -    if (attr_name.empty()) {
    -      attr_name = arg.type_list_attr();
    -      if (!attr_name.empty()) {
    -        *iterable_out = true;
    -        Type type = Type::Wildcard();
    -        visited_attrs_.insert(std::make_pair(attr_name, type));
    -        return type;
    -      }
    -    }
    -    for (const auto& attr : op_def_.attr()) {
    -      if (attr.name() == attr_name) {
    -        Type type = DataTypeOf(attr, iterable_out);
    -        visited_attrs_.insert(std::make_pair(attr_name, type));
    -        return type;
    -      }
    -    }
    -  }
    -  LOG(WARNING) << "Data type for arg \"" << arg.name() << "\" is unknown";
    -  return Type::Wildcard();
    -}
    -
    -Type OpParser::DataTypeOf(const OpDef_AttrDef& attr, bool* iterable_out) {
    -  std::map::const_iterator it = visited_attrs_.find(attr.name());
    -  if (it != visited_attrs_.cend()) {
    -    return it->second;
    -  }
    -  string attr_type = attr.type();
    -  if (attr.type().compare(0, 5, "list(") == 0) {
    -    attr_type = attr_type.substr(5, attr.type().find_last_of(')') - 5);
    -    *iterable_out = true;
    -  }
    -  if (attr_type == "type") {
    -    if (*iterable_out) {
    -      return Type::Enum("DataType", "org.tensorflow");
    -    }
    -    return GetNextGenericTensorType(attr.allowed_values());
    -  }
    -  if (attr_type == "string") {
    -    return Type::Class("String");
    -  }
    -  if (attr_type == "int") {
    -    return Type::Class("Integer");
    -  }
    -  if (attr_type == "float") {
    -    return Type::Class("Float");
    -  }
    -  if (attr_type == "bool") {
    -    return Type::Class("Boolean");
    -  }
    -  if (attr_type == "shape") {
    -    return Type::Class("Shape", "org.tensorflow");
    -  }
    -  if (attr_type == "tensor") {
    -    return Type::Class("Tensor", "org.tensorflow")
    -      .add_parameter(Type::Wildcard());
    -  }
    -  LOG(WARNING) << "Data type for attribute \"" << attr_type << "\" is unknown";
    -  return *iterable_out ? Type::Wildcard() : Type::Class("Object");
    -}
    -
    -Type OpParser::GetNextGenericTensorType(const AttrValue& allowed_values)  {
    -  Type generic = Type::Generic(string(1, next_generic_));
    -  next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1;
    -
    -  // when only real numbers are allowed, enforce that restriction in the Java by
    -  // extending the generic from java.lang.Number
    -  if (IsRealNumbers(allowed_values)) {
    -    generic.add_supertype(Type::Class("Number"));
    -  }
    -  return generic;
    -}
    -
    -}  // namespace java
    -}  // namespace tensorflow
    diff --git a/tensorflow/java/src/gen/cc/op_parser.h b/tensorflow/java/src/gen/cc/op_parser.h
    deleted file mode 100644
    index 42855127cc..0000000000
    --- a/tensorflow/java/src/gen/cc/op_parser.h
    +++ /dev/null
    @@ -1,137 +0,0 @@
    -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
    -
    -Licensed under the Apache License, Version 2.0 (the "License");
    -you may not use this file except in compliance with the License.
    -You may obtain a copy of the License at
    -
    -    http://www.apache.org/licenses/LICENSE-2.0
    -
    -Unless required by applicable law or agreed to in writing, software
    -distributed under the License is distributed on an "AS IS" BASIS,
    -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    -See the License for the specific language governing permissions and
    -limitations under the License.
    -==============================================================================*/
    -
    -#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
    -#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
    -
    -#include 
    -#include 
    -#include 
    -#include 
    -
    -#include "tensorflow/core/framework/op_def.pb.h"
    -#include "tensorflow/core/framework/api_def.pb.h"
    -#include "tensorflow/java/src/gen/cc/java_defs.h"
    -
    -namespace tensorflow {
    -namespace java {
    -
    -// Specification of a TensorFlow operation to generate.
    -//
    -// This is the result of an operation definition parsing, see OpParser::Parse().
    -class OpSpec {
    - public:
    -  class Endpoint {
    -   public:
    -    Endpoint(const Type& type, const Javadoc& javadoc)
    -      : type_(type), javadoc_(javadoc) {}
    -    const Type& type() const { return type_; }
    -    const Javadoc& javadoc() const { return javadoc_; }
    -
    -   private:
    -    Type type_;
    -    Javadoc javadoc_;
    -  };
    -
    -  class Operand {
    -   public:
    -    Operand(const string& graph_name, const Variable& var,
    -        const Type& data_type, const string& description, bool iterable)
    -     : graph_name_(graph_name), var_(var), data_type_(data_type),
    -       description_(description), iterable_(iterable) {}
    -    const string& graph_name() const { return graph_name_; }
    -    const Variable& var() const { return var_; }
    -    Variable* var_ptr() { return &var_; }
    -    const Type& data_type() const { return data_type_; }
    -    const string& description() const { return description_; }
    -    bool iterable() const { return iterable_; }
    -
    -   private:
    -    string graph_name_;
    -    Variable var_;
    -    Type data_type_;
    -    string description_;
    -    bool iterable_;
    -  };
    -
    -  explicit OpSpec(const string& graph_name) : graph_name_(graph_name) {}
    -  const string& graph_name() const { return graph_name_; }
    -  const std::vector endpoints() const { return endpoints_; }
    -  void add_endpoint(const Type& type, const Javadoc& javadoc) {
    -    endpoints_.push_back(Endpoint(type, javadoc));
    -  }
    -  const std::vector& inputs() const { return inputs_; }
    -  void add_input(const Operand& input) {
    -    inputs_.push_back(input);
    -  }
    -  const std::vector& outputs() const { return outputs_; }
    -  void add_output(const Operand& output) {
    -    outputs_.push_back(output);
    -  }
    -  const std::vector& attributes() const { return attributes_; }
    -  void add_attribute(const Operand& attribute) {
    -    attributes_.push_back(attribute);
    -  }
    -  const std::vector& options() const { return options_; }
    -  void add_option(const Operand& option) {
    -    options_.push_back(option);
    -  }
    -
    - private:
    -  string graph_name_;
    -  std::vector endpoints_;
    -  std::vector inputs_;
    -  std::vector outputs_;
    -  std::vector attributes_;
    -  std::vector options_;
    -};
    -
    -// A parser of ops proto definitions.
    -//
    -// This object parses the definition and the api of an TensorFlow operation to
    -// produce a specification that can be used for Java source code rendering.
    -class OpParser {
    - public:
    -  OpParser(const OpDef& op_def, const ApiDef& api_def, const string& lib_name,
    -      const string& base_package);
    -  virtual ~OpParser() = default;
    -
    -  // Produces an operation specification from its proto definitions.
    -  void Parse(std::unique_ptr* op_ptr);
    -
    - private:
    -  OpDef op_def_;
    -  ApiDef op_api_;
    -  string lib_name_;
    -  string base_package_;
    -  std::map visited_attrs_;
    -  char next_generic_ = 0;
    -
    -  void BuildEndpoints(OpSpec* op);
    -  void ParseInput(const OpDef_ArgDef& input_def,
    -      const ApiDef::Arg& input_api, OpSpec* op);
    -  void ParseOutput(const OpDef_ArgDef& output_def,
    -      const ApiDef::Arg& output_api, OpSpec* op);
    -  void ParseAttribute(const OpDef_AttrDef& attr_def,
    -      const ApiDef::Attr& attr_api, OpSpec* op);
    -  Type DataTypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
    -  Type DataTypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out);
    -  Type GetNextGenericTensorType(const AttrValue& allowed_values);
    -};
    -
    -}  // namespace java
    -}  // namespace tensorflow
    -
    -#endif  // TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
    diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc
    new file mode 100644
    index 0000000000..a727f7ae90
    --- /dev/null
    +++ b/tensorflow/java/src/gen/cc/op_specs.cc
    @@ -0,0 +1,390 @@
    +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
    +
    +Licensed under the Apache License, Version 2.0 (the "License");
    +you may not use this file except in compliance with the License.
    +You may obtain a copy of the License at
    +
    +    http://www.apache.org/licenses/LICENSE-2.0
    +
    +Unless required by applicable law or agreed to in writing, software
    +distributed under the License is distributed on an "AS IS" BASIS,
    +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +See the License for the specific language governing permissions and
    +limitations under the License.
    +==============================================================================*/
    +
    +#include 
    +#include 
    +#include 
    +#include 
    +
    +#include "re2/re2.h"
    +#include "tensorflow/core/framework/op.h"
    +#include "tensorflow/core/framework/types.h"
    +#include "tensorflow/core/lib/strings/str_util.h"
    +#include "tensorflow/core/platform/logging.h"
    +#include "tensorflow/java/src/gen/cc/op_specs.h"
    +
    +namespace tensorflow {
    +namespace java {
    +namespace {
    +
    +inline bool IsRealNumbers(const AttrValue& values) {
    +  if (!values.has_list()) {
    +    return RealNumberTypes().Contains(values.type());
    +  }
    +  for (int i = 0; i < values.list().type_size(); ++i) {
    +    if (!RealNumberTypes().Contains(values.list().type(i))) {
    +      return false;
    +    }
    +  }
    +  return true;
    +}
    +
    +class TypeResolver {
    + public:
    +  explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {}
    +
    +  Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
    +  Type TypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out);
    +  bool IsAttributeVisited(const string& attr_name) {
    +    return visited_attrs_.find(attr_name) != visited_attrs_.cend();
    +  }
    + private:
    +  const OpDef op_def_;
    +  std::map visited_attrs_;
    +  char next_generic_ = 'T';
    +};
    +
    +Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def,
    +    bool* iterable_out) {
    +  *iterable_out = false;
    +  if (!arg_def.number_attr().empty()) {
    +    // when number_attr is set, argument has to be a list of tensors
    +    *iterable_out = true;
    +    visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int()));
    +  }
    +  Type type = Type::Wildcard();
    +  if (arg_def.type() != DataType::DT_INVALID) {
    +    // resolve type from DataType
    +    switch (arg_def.type()) {
    +      case DataType::DT_BOOL:
    +        type = Type::Class("Boolean");
    +        break;
    +      case DataType::DT_STRING:
    +        type = Type::Class("String");
    +        break;
    +      case DataType::DT_FLOAT:
    +        type = Type::Class("Float");
    +        break;
    +      case DataType::DT_DOUBLE:
    +        type = Type::Class("Double");
    +        break;
    +      case DataType::DT_UINT8:
    +        type = Type::Class("UInt8", "org.tensorflow.types");
    +        break;
    +      case DataType::DT_INT32:
    +        type = Type::Class("Integer");
    +        break;
    +      case DataType::DT_INT64:
    +        type = Type::Class("Long");
    +        break;
    +      case DataType::DT_RESOURCE:
    +        // TODO(karllessard) create a Resource utility class that could be
    +        // used to store a resource and its type (passed in a second argument).
    +        // For now, we need to force a wildcard and we will unfortunately lose
    +        // track of the resource type.
    +        break;
    +      default:
    +        // Any other datatypes does not have a equivalent in Java and must
    +        // remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
    +        break;
    +    }
    +  } else if (!arg_def.type_attr().empty()) {
    +    // resolve type from attribute (if already visited, retrieve its type)
    +    if (IsAttributeVisited(arg_def.type_attr())) {
    +      type = visited_attrs_.at(arg_def.type_attr());
    +    } else {
    +      for (const auto& attr_def : op_def_.attr()) {
    +        if (attr_def.name() == arg_def.type_attr()) {
    +          type = TypeOf(attr_def, iterable_out);
    +          break;
    +        }
    +      }
    +    }
    +  } else if (!arg_def.type_list_attr().empty()) {
    +    // type is a list of tensors that can be of different data types, so leave
    +    // it as a list of wildcards
    +    *iterable_out = true;
    +    visited_attrs_.insert(std::make_pair(arg_def.type_list_attr(), type));
    +
    +  } else {
    +    LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name()
    +        << "\" in operation \"" << op_def_.name() << "\"";
    +  }
    +  return type;
    +}
    +
    +Type TypeResolver::TypeOf(const OpDef_AttrDef& attr_def,
    +    bool* iterable_out) {
    +  *iterable_out = false;
    +  StringPiece attr_type = attr_def.type();
    +  if (str_util::ConsumePrefix(&attr_type, "list(")) {
    +    attr_type.remove_suffix(1);  // remove closing brace
    +    *iterable_out = true;
    +  }
    +  Type type = *iterable_out ? Type::Wildcard() : Type::Class("Object");
    +  if (attr_type == "type") {
    +    if (*iterable_out) {
    +      type = Type::Enum("DataType", "org.tensorflow");
    +    } else {
    +      type = Type::Generic(string(1, next_generic_));
    +      next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1;
    +      if (IsRealNumbers(attr_def.allowed_values())) {
    +        // enforce real numbers datasets by extending java.lang.Number
    +        type.add_supertype(Type::Class("Number"));
    +      }
    +    }
    +  } else if (attr_type == "string") {
    +    type = Type::Class("String");
    +
    +  } else if (attr_type == "int") {
    +    type = Type::Class("Integer");
    +
    +  } else if (attr_type == "float") {
    +    type = Type::Class("Float");
    +
    +  } else if (attr_type == "bool") {
    +    type = Type::Class("Boolean");
    +
    +  } else if (attr_type == "shape") {
    +    type = Type::Class("Shape", "org.tensorflow");
    +
    +  } else if (attr_type == "tensor") {
    +    type = Type::Class("Tensor", "org.tensorflow")
    +        .add_parameter(Type::Wildcard());
    +
    +  } else {
    +    LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
    +        << "\" in operation \"" << op_def_.name() << "\"";
    +  }
    +  visited_attrs_.insert(std::make_pair(attr_def.name(), type));
    +  return type;
    +}
    +
    +string SnakeToCamelCase(const string& str, bool upper = false) {
    +  string result;
    +  bool cap = upper;
    +  for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
    +    const char c = *it;
    +    if (c == '_') {
    +      cap = true;
    +    } else if (cap) {
    +      result += toupper(c);
    +      cap = false;
    +    } else {
    +      result += c;
    +    }
    +  }
    +  return result;
    +}
    +
    +bool FindAndCut(re2::StringPiece* input, const RE2& expr,
    +    re2::StringPiece* before_match, re2::StringPiece* ret_match = nullptr) {
    +  re2::StringPiece match;
    +  bool matches =
    +      expr.Match(*input, 0, input->size(), RE2::UNANCHORED, &match, 1);
    +  if (matches) {
    +    before_match->set(input->data(), match.begin() - input->begin());
    +    input->remove_prefix(match.end() - before_match->begin());
    +    if (ret_match != nullptr) {
    +      *ret_match = match;
    +    }
    +  } else {
    +    *before_match = *input;
    +    if (ret_match != nullptr) {
    +      ret_match->set(nullptr, 0);
    +    }
    +  }
    +  return matches;
    +}
    +
    +string ParseDocumentation(const string& mdtext) {
    +  std::stringstream javadoc_text;
    +  re2::StringPiece input(mdtext);
    +  re2::StringPiece text;
    +  bool in_list = false;
    +  do {
    +    re2::StringPiece markup;
    +    FindAndCut(&input,
    +        "\n+\\*[[:blank:]]+|\n{2,}|`{3,}|`{1,2}|\\*{1,2}\\b|\\[",
    +        &text, &markup);
    +    javadoc_text << text;
    +    if (markup.empty()) {
    +      break;  // we are done parsing
    +    }
    +    if (markup.starts_with("\n")) {
    +      javadoc_text << "\n";
    +      if (markup.contains("* ")) {
    +        javadoc_text << (in_list ? "
  • \n" : "
      \n") << "
    • \n"; + in_list = true; + } else if (markup.starts_with("\n\n")) { + if (in_list) { + javadoc_text << "
    • \n
    \n"; + in_list = false; + } else if (!input.starts_with("```")) { + javadoc_text << "

    \n"; + } + } + } else if (markup.starts_with("```") && text.empty()) { + re2::StringPiece language; + RE2::Consume(&input, "[\\w\\+]+", &language); + if (FindAndCut(&input, markup.ToString() + "\n*", &text)) { + javadoc_text << "

    \n{@code" << text << "}\n
    \n"; + } else { + javadoc_text << markup << language; + } + } else if (markup.starts_with("`")) { + if (FindAndCut(&input, markup, &text)) { + javadoc_text << "{@code " << text << "}"; + } else { + javadoc_text << markup; + } + } else if (markup == "**") { + if (FindAndCut(&input, "\\b\\*{2}", &text)) { + javadoc_text << "" << text << ""; + } else { + javadoc_text << markup; + } + } else if (markup == "*") { + if (FindAndCut(&input, "\\b\\*{1}", &text)) { + javadoc_text << "" << text << ""; + } else { + javadoc_text << markup; + } + } else if (markup == "[") { + string label; + string link; + if (RE2::Consume(&input, "([^\\[]+)\\]\\((http.+)\\)", &label, &link)) { + javadoc_text << "" << label << ""; + } else { + javadoc_text << markup; + } + } + } while (!input.empty()); + + return javadoc_text.str(); +} + +ArgumentSpec CreateInput(const OpDef_ArgDef& input_def, + const ApiDef::Arg& input_api_def, TypeResolver* type_resolver) { + bool iterable = false; + Type type = type_resolver->TypeOf(input_def, &iterable); + Type var_type = Type::Interface("Operand", "org.tensorflow") + .add_parameter(type); + if (iterable) { + var_type = Type::IterableOf(var_type); + } + return ArgumentSpec(input_api_def.name(), + Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type), + type, + ParseDocumentation(input_api_def.description()), + iterable); +} + +AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def, + const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) { + bool iterable = false; + Type type = type_resolver->TypeOf(attr_def, &iterable); + // type attributes must be passed explicitly in methods as a Class<> parameter + bool is_explicit = type.kind() == Type::GENERIC && !iterable; + Type var_type = is_explicit ? Type::Class("Class").add_parameter(type) : type; + if (iterable) { + var_type = Type::ListOf(type); + } + return AttributeSpec(attr_api_def.name(), + Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type), + type, + ParseDocumentation(attr_api_def.description()), + iterable, + attr_api_def.has_default_value() && !is_explicit); +} + +ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def, + const ApiDef::Arg& output_api, TypeResolver* type_resolver) { + bool iterable = false; + Type type = type_resolver->TypeOf(output_def, &iterable); + Type var_type = Type::Class("Output", "org.tensorflow") + .add_parameter(type); + if (iterable) { + var_type = Type::ListOf(var_type); + } + return ArgumentSpec(output_api.name(), + Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type), + type, + ParseDocumentation(output_api.description()), + iterable); +} + +EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def, + const ApiDef_Endpoint& endpoint_def) { + + std::vector name_tokens = str_util::Split(endpoint_def.name(), "."); + string package; + string name; + if (name_tokens.size() > 1) { + package = name_tokens.at(0); + name = name_tokens.at(1); + } else { + package = "core"; // generate unclassified ops in the 'core' package + name = name_tokens.at(0); + } + return EndpointSpec(package, + name, + Javadoc::Create(ParseDocumentation(api_def.summary())) + .details(ParseDocumentation(api_def.description())), + endpoint_def.deprecation_version() > 0); +} + +} // namespace + +OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) { + OpSpec op(api_def.graph_op_name(), + api_def.visibility() == ApiDef::HIDDEN, + op_def.deprecation().explanation()); + TypeResolver type_resolver(op_def); + for (const string& next_input_name : api_def.arg_order()) { + for (int i = 0; i < op_def.input_arg().size(); ++i) { + if (op_def.input_arg(i).name() == next_input_name) { + op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i), + &type_resolver)); + break; + } + } + } + for (int i = 0; i < op_def.attr().size(); ++i) { + // do not parse attributes already visited, they have probably been inferred + // before as an input argument type + if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) { + AttributeSpec attr = CreateAttribute(op_def.attr(i), api_def.attr(i), + &type_resolver); + // attributes with a default value are optional + if (attr.optional()) { + op.optional_attributes_.push_back(attr); + } else { + op.attributes_.push_back(attr); + } + } + } + for (int i = 0; i < op_def.output_arg().size(); ++i) { + op.outputs_.push_back(CreateOutput(op_def.output_arg(i), api_def.out_arg(i), + &type_resolver)); + } + for (const auto& endpoint_def : api_def.endpoint()) { + op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def)); + } + return op; +} + +} // namespace java +} // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h new file mode 100644 index 0000000000..55c2c3f307 --- /dev/null +++ b/tensorflow/java/src/gen/cc/op_specs.h @@ -0,0 +1,152 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_ +#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_ + +#include +#include + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/java/src/gen/cc/java_defs.h" + +namespace tensorflow { +namespace java { + +class EndpointSpec { + public: + // A specification for an operation endpoint + // + // package: package of this endpoint (from which also derives its package) + // name: name of this endpoint class + // javadoc: the endpoint class documentation + // deprecated: true if this endpoint is now deprecated + EndpointSpec(const string& package, const string& name, + const Javadoc& javadoc, bool deprecated) + : package_(package), name_(name), javadoc_(javadoc), + deprecated_(deprecated) {} + + const string& package() const { return package_; } + const string& name() const { return name_; } + const Javadoc& javadoc() const { return javadoc_; } + bool deprecated() const { return deprecated_; } + + private: + const string package_; + const string name_; + const Javadoc javadoc_; + const bool deprecated_; +}; + +class ArgumentSpec { + public: + // A specification for an operation argument + // + // op_def_name: argument name, as known by TensorFlow core + // var: a variable to represent this argument in Java + // type: the tensor type of this argument + // description: a description of this argument, in javadoc + // iterable: true if this argument is a list + ArgumentSpec(const string& op_def_name, const Variable& var, + const Type& type, const string& description, bool iterable) + : op_def_name_(op_def_name), var_(var), type_(type), + description_(description), iterable_(iterable) {} + virtual ~ArgumentSpec() = default; + + const string& op_def_name() const { return op_def_name_; } + const Variable& var() const { return var_; } + const Type& type() const { return type_; } + const string& description() const { return description_; } + bool iterable() const { return iterable_; } + + private: + const string op_def_name_; + const Variable var_; + const Type type_; + const string description_; + const bool iterable_; +}; + +class AttributeSpec : public ArgumentSpec { + public: + // A specification for an operation attribute + // + // op_def_name: attribute name, as known by TensorFlow core + // var: a variable to represent this attribute in Java + // type: the type of this attribute + // description: a description of this attribute, in javadoc + // iterable: true if this attribute is a list + // optional: true if this attribute does not require to be set explicitly + AttributeSpec(const string& op_def_name, const Variable& var, + const Type& type, const string& description, bool iterable, + bool optional) + : ArgumentSpec(op_def_name, var, type, description, iterable), + optional_(optional) {} + virtual ~AttributeSpec() = default; + + bool optional() const { return optional_; } + + private: + const bool optional_; +}; + +class OpSpec { + public: + // Parses an op definition and its API to produce a specification used for + // rendering its Java wrapper + // + // op_def: Op definition + // api_def: Op API definition + static OpSpec Create(const OpDef& op_def, const ApiDef& api_def); + + const string& graph_op_name() const { return graph_op_name_; } + bool hidden() const { return hidden_; } + const string& deprecation_explanation() const { + return deprecation_explanation_; + } + const std::vector endpoints() const { return endpoints_; } + const std::vector& inputs() const { return inputs_; } + const std::vector& outputs() const { return outputs_; } + const std::vector& attributes() const { return attributes_; } + const std::vector& optional_attributes() const { + return optional_attributes_; + } + + private: + // A specification for an operation + // + // graph_op_name: name of this op, as known by TensorFlow core engine + // hidden: true if this op should not be visible through the Graph Ops API + // deprecation_explanation: message to show if all endpoints are deprecated + explicit OpSpec(const string& graph_op_name, bool hidden, + const string& deprecation_explanation) + : graph_op_name_(graph_op_name), hidden_(hidden), + deprecation_explanation_(deprecation_explanation) {} + + const string graph_op_name_; + const bool hidden_; + const string deprecation_explanation_; + std::vector endpoints_; + std::vector inputs_; + std::vector outputs_; + std::vector attributes_; + std::vector optional_attributes_; +}; + +} // namespace java +} // namespace tensorflow + +#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_ diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc index b1de5af6ba..7e427787f9 100644 --- a/tensorflow/java/src/gen/cc/source_writer.cc +++ b/tensorflow/java/src/gen/cc/source_writer.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h index 1f0febe9a3..bcae33ccce 100644 --- a/tensorflow/java/src/gen/cc/source_writer.h +++ b/tensorflow/java/src/gen/cc/source_writer.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc index 8bd42d9d0e..875ad99ae2 100644 --- a/tensorflow/java/src/gen/cc/source_writer_test.cc +++ b/tensorflow/java/src/gen/cc/source_writer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -313,8 +313,7 @@ TEST(WriteType, SimpleClassWithDependencies) { TEST(WriteType, AnnotatedAndDocumentedClass) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); - Javadoc clazz_doc; - clazz_doc.brief("Javadoc test") + Javadoc clazz_doc = Javadoc::Create("Javadoc test") .details("This is a\nmultiline description."); clazz.add_annotation(Annotation::Create("Bean")); clazz.add_annotation(Annotation::Create("SuppressWarnings") @@ -329,7 +328,7 @@ TEST(WriteType, AnnotatedAndDocumentedClass) { " *

    \n" " * This is a\n" " * multiline description.\n" - " **/\n" + " */\n" "@Bean\n" "@SuppressWarnings(\"rawtypes\")\n" "public class Test {\n}\n"; @@ -378,8 +377,7 @@ TEST(WriteType, ParameterizedClassFields) { Variable field1 = Variable::Create("field1", Type::Class("String")); Variable field2 = Variable::Create("field2", Type::Class("String")); Variable field3 = Variable::Create("field3", type_t); - Javadoc field3_doc; - field3_doc.brief("This variable is documented"); + Javadoc field3_doc = Javadoc::Create("This variable is documented"); writer.BeginType(clazz, PUBLIC) .WriteField(field1, STATIC | PUBLIC | FINAL) @@ -464,8 +462,7 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); Method method = Method::Create("doNothing", Type::Void()); - Javadoc method_doc; - method_doc.brief("Javadoc test") + Javadoc method_doc = Javadoc::Create("Javadoc test") .details("This method has a\nmultiline description."); method.add_annotation(Annotation::Create("Override")); method.add_annotation(Annotation::Create("SuppressWarnings") @@ -484,7 +481,7 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { " *

    \n" " * This method has a\n" " * multiline description.\n" - " **/\n" + " */\n" " @Override\n" " @SuppressWarnings(\"rawtypes\")\n" " public void doNothing() {\n" @@ -500,8 +497,7 @@ TEST(WriteMethod, DocumentedMethodWithArguments) { Method method = Method::Create("boolToInt", Type::Int()); method.add_argument(Variable::Create("b", Type::Boolean())); method.add_argument(reverse); - Javadoc method_doc; - method_doc.brief("Converts a boolean to an int") + Javadoc method_doc = Javadoc::Create("Converts a boolean to an int") .details("This method will convert\na boolean to an int") .add_param_tag(reverse.name(), "if true, value is reversed") .add_tag("return", "int value for this boolean"); @@ -528,7 +524,7 @@ TEST(WriteMethod, DocumentedMethodWithArguments) { " * \n" " * @param reverse if true, value is reversed\n" " * @return int value for this boolean\n" - " **/\n" + " */\n" " public int boolToInt(boolean b, boolean reverse) {\n" " if (b && !reverse) {\n" " return 1;\n" diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl index 1e7899cf7a..7017b52649 100644 --- a/tensorflow/java/src/gen/gen_ops.bzl +++ b/tensorflow/java/src/gen/gen_ops.bzl @@ -32,50 +32,52 @@ def tf_java_op_gen_srcjar(name, api_def_srcs=[], visibility=["//tensorflow/java:__pkg__"]): - gen_tools = [] gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files srcs = api_def_srcs[:] - # Construct an op generator binary for each ops library. - for ops_lib in ops_libs: - gen_lib = ops_lib[:ops_lib.rfind("_")] - out_gen_tool = out_dir + ops_lib + "_gen_tool" + if not api_def_srcs: + api_def_args_str = "," + else: + api_def_args = [] + for api_def_src in api_def_srcs: + # Add directory of the first ApiDef source to args. + # We are assuming all ApiDefs in a single api_def_src are in the + # same directory. + api_def_args.append( + "$$(dirname $$(echo $(locations " + api_def_src + + ") | cut -d\" \" -f1))") + api_def_args_str = ",".join(api_def_args) - if not api_def_srcs: - api_def_args_str = "," - else: - api_def_args = [] - for api_def_src in api_def_srcs: - # Add directory of the first ApiDef source to args. - # We are assuming all ApiDefs in a single api_def_src are in the - # same directory. - api_def_args.append( - " $$(dirname $$(echo $(locations " + api_def_src + - ") | cut -d\" \" -f1))") - api_def_args_str = ",".join(api_def_args) + gen_tool_deps = [":java_op_gen_lib"] + for ops_lib in ops_libs: + gen_tool_deps.append(ops_libs_pkg + ":" + ops_lib + "_op_lib") - tf_cc_binary( - name=out_gen_tool, - copts=tf_copts(), - linkopts=["-lm"], - linkstatic=1, # Faster to link this one-time-use binary dynamically - deps=[gen_tool, ops_libs_pkg + ":" + ops_lib + "_op_lib"]) + tf_cc_binary( + name=gen_tool, + srcs=[ + "src/gen/cc/op_gen_main.cc", + ], + copts=tf_copts(), + linkopts=["-lm"], + linkstatic=1, # Faster to link this one-time-use binary dynamically + deps = gen_tool_deps) - gen_tools += [":" + out_gen_tool] - gen_cmds += ["$(location :" + out_gen_tool + ")" + - " --output_dir=$(@D)/" + out_src_dir + - " --lib_name=" + gen_lib + - " --base_package=" + gen_base_package + - " " + api_def_args_str] + gen_cmds += ["$(location :" + gen_tool + ")" + + " --output_dir=$(@D)/" + out_src_dir + + " --base_package=" + gen_base_package + + " --api_dirs=" + api_def_args_str] # Generate a source archive containing generated code for these ops. gen_srcjar = out_dir + name + ".srcjar" gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"] - gen_tools += ["@local_jdk//:jar"] + ["@local_jdk//:jdk"] - gen_tools += tf_binary_additional_srcs() + native.genrule( name=name, srcs=srcs, outs=[gen_srcjar], - tools=gen_tools, - cmd="&&".join(gen_cmds)) + tools=[ + "@local_jdk//:jar", + "@local_jdk//:jdk", + gen_tool + ] + tf_binary_additional_srcs(), + cmd=" && ".join(gen_cmds)) diff --git a/tensorflow/java/src/gen/resources/license.snippet.java b/tensorflow/java/src/gen/resources/license.java.snippet similarity index 100% rename from tensorflow/java/src/gen/resources/license.snippet.java rename to tensorflow/java/src/gen/resources/license.java.snippet -- GitLab From 6fee70dd4c82502fefa8259f0d8dbefcece58c60 Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Tue, 24 Apr 2018 09:09:01 -0400 Subject: [PATCH 027/755] Comments and little improvements to documentation parser --- tensorflow/java/src/gen/cc/op_specs.cc | 65 ++++++++++++++------------ 1 file changed, 36 insertions(+), 29 deletions(-) diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index a727f7ae90..3645fcf836 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -192,51 +192,51 @@ string SnakeToCamelCase(const string& str, bool upper = false) { bool FindAndCut(re2::StringPiece* input, const RE2& expr, re2::StringPiece* before_match, re2::StringPiece* ret_match = nullptr) { re2::StringPiece match; - bool matches = - expr.Match(*input, 0, input->size(), RE2::UNANCHORED, &match, 1); - if (matches) { - before_match->set(input->data(), match.begin() - input->begin()); - input->remove_prefix(match.end() - before_match->begin()); - if (ret_match != nullptr) { - *ret_match = match; - } - } else { - *before_match = *input; - if (ret_match != nullptr) { - ret_match->set(nullptr, 0); - } + if (!expr.Match(*input, 0, input->size(), RE2::UNANCHORED, &match, 1)) { + return false; } - return matches; + before_match->set(input->data(), match.begin() - input->begin()); + input->remove_prefix(match.end() - before_match->begin()); + if (ret_match != nullptr) { + *ret_match = match; + } + return true; } -string ParseDocumentation(const string& mdtext) { +string ParseDocumentation(re2::StringPiece input) { + // TODO(karllessard) This is a very minimalist utility method for converting + // markdown syntax, as found in ops descriptions, to Javadoc/html tags. Check + // for alternatives to increase the level of support for markups. std::stringstream javadoc_text; - re2::StringPiece input(mdtext); - re2::StringPiece text; bool in_list = false; - do { + const RE2 markup_expr( + "\n+\\*[[:blank:]]+|\n{2,}|`{3,}|`{1,2}|\\*{1,2}\\b|\\["); + while (true) { + re2::StringPiece text; re2::StringPiece markup; - FindAndCut(&input, - "\n+\\*[[:blank:]]+|\n{2,}|`{3,}|`{1,2}|\\*{1,2}\\b|\\[", - &text, &markup); - javadoc_text << text; - if (markup.empty()) { - break; // we are done parsing + if (!FindAndCut(&input, markup_expr, &text, &markup)) { + javadoc_text << input; + break; // end of loop } + javadoc_text << text; if (markup.starts_with("\n")) { javadoc_text << "\n"; if (markup.contains("* ")) { + // starts a list item javadoc_text << (in_list ? "\n" : "

      \n") << "
    • \n"; in_list = true; } else if (markup.starts_with("\n\n")) { if (in_list) { + // ends the current list javadoc_text << "
    • \n
    \n"; in_list = false; } else if (!input.starts_with("```")) { + // starts new paragraph (not required if a
     block follows)
               javadoc_text << "

    \n"; } } } else if (markup.starts_with("```") && text.empty()) { + // create a multiline code block re2::StringPiece language; RE2::Consume(&input, "[\\w\\+]+", &language); if (FindAndCut(&input, markup.ToString() + "\n*", &text)) { @@ -245,34 +245,41 @@ string ParseDocumentation(const string& mdtext) { javadoc_text << markup << language; } } else if (markup.starts_with("`")) { + // write inlined code if (FindAndCut(&input, markup, &text)) { javadoc_text << "{@code " << text << "}"; } else { javadoc_text << markup; } } else if (markup == "**") { + // emphase text (strong) if (FindAndCut(&input, "\\b\\*{2}", &text)) { - javadoc_text << "" << text << ""; + javadoc_text << "" << ParseDocumentation(text) << ""; } else { javadoc_text << markup; } } else if (markup == "*") { + // emphase text (light) if (FindAndCut(&input, "\\b\\*{1}", &text)) { - javadoc_text << "" << text << ""; + javadoc_text << "" << ParseDocumentation(text) << ""; } else { javadoc_text << markup; } } else if (markup == "[") { + // add an external link string label; string link; if (RE2::Consume(&input, "([^\\[]+)\\]\\((http.+)\\)", &label, &link)) { - javadoc_text << "" << label << ""; + javadoc_text << "" + << ParseDocumentation(label) + << ""; } else { javadoc_text << markup; } + } else { + javadoc_text << markup; } - } while (!input.empty()); - + } return javadoc_text.str(); } -- GitLab From 4bfedb4f2edc4bd71984d79145ab6b0293fe8096 Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Wed, 25 Apr 2018 00:39:37 -0400 Subject: [PATCH 028/755] Improve again javadoc readability and quality --- tensorflow/java/src/gen/cc/op_generator.cc | 7 +-- tensorflow/java/src/gen/cc/op_specs.cc | 57 ++++++++++++---------- 2 files changed, 35 insertions(+), 29 deletions(-) diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index c32ad3b109..00f84bc9cd 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -305,10 +305,11 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, } } -void RenderOptionsClass(const OpSpec& op, SourceWriter* writer) { +void RenderOptionsClass(const OpSpec& op, const Type& op_class, + SourceWriter* writer) { Type options_class = Type::Class("Options"); Javadoc options_doc = Javadoc::Create( - "Class holding optional attributes of this operation"); + "Optional attributes for {@link " + op_class.full_name() + "}"); writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc); for (const AttributeSpec& attribute : op.optional_attributes()) { Method setter = Method::Create(attribute.var().name(), options_class) @@ -410,7 +411,7 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, .EndLine() .BeginType(op_class, PUBLIC|FINAL, &dependencies, &op_javadoc); if (!op.optional_attributes().empty()) { - RenderOptionsClass(op, &writer); + RenderOptionsClass(op, op_class, &writer); } RenderFactoryMethod(op, op_class, &writer); RenderGettersAndSetters(op, &writer); diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index 3645fcf836..a0e7a180f2 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -204,13 +204,21 @@ bool FindAndCut(re2::StringPiece* input, const RE2& expr, } string ParseDocumentation(re2::StringPiece input) { + std::stringstream javadoc_text; + // TODO(karllessard) This is a very minimalist utility method for converting // markdown syntax, as found in ops descriptions, to Javadoc/html tags. Check // for alternatives to increase the level of support for markups. - std::stringstream javadoc_text; + std::vector markups_subexpr; + markups_subexpr.push_back("\n+\\*\\s+"); // lists + markups_subexpr.push_back("\n{2,}"); // paragraphs + markups_subexpr.push_back("`{3,}\\s*[^\\s\n]*\\s*\n"); // code blocks + markups_subexpr.push_back("`+"); // inlined code and code blocks + markups_subexpr.push_back("\\*{1,2}\\b"); // text emphasis + markups_subexpr.push_back("\\["); // hyperlinks + const RE2 markup_expr(str_util::Join(markups_subexpr, "|")); + bool in_list = false; - const RE2 markup_expr( - "\n+\\*[[:blank:]]+|\n{2,}|`{3,}|`{1,2}|\\*{1,2}\\b|\\["); while (true) { re2::StringPiece text; re2::StringPiece markup; @@ -221,52 +229,48 @@ string ParseDocumentation(re2::StringPiece input) { javadoc_text << text; if (markup.starts_with("\n")) { javadoc_text << "\n"; - if (markup.contains("* ")) { - // starts a list item + if (markup.contains("*")) { + // new list item javadoc_text << (in_list ? "\n" : "

      \n") << "
    • \n"; in_list = true; - } else if (markup.starts_with("\n\n")) { - if (in_list) { - // ends the current list - javadoc_text << "
    • \n
    \n"; - in_list = false; - } else if (!input.starts_with("```")) { - // starts new paragraph (not required if a
     block follows)
    -          javadoc_text << "

    \n"; - } + } else if (in_list) { + // end of list + javadoc_text << "\n

\n"; + in_list = false; + } else if (!input.starts_with("```")) { + // new paragraph (not required if a
 block follows)
+        javadoc_text << "

\n"; } - } else if (markup.starts_with("```") && text.empty()) { - // create a multiline code block - re2::StringPiece language; - RE2::Consume(&input, "[\\w\\+]+", &language); - if (FindAndCut(&input, markup.ToString() + "\n*", &text)) { - javadoc_text << "

\n{@code" << text << "}\n
\n"; + } else if (markup.starts_with("```")) { + // code blocks + if (FindAndCut(&input, "```\\s*\n*", &text)) { + javadoc_text << "
{@code\n" << text << "}
\n"; } else { - javadoc_text << markup << language; + javadoc_text << markup; } } else if (markup.starts_with("`")) { - // write inlined code + // inlined code if (FindAndCut(&input, markup, &text)) { javadoc_text << "{@code " << text << "}"; } else { javadoc_text << markup; } } else if (markup == "**") { - // emphase text (strong) + // text emphasis (strong) if (FindAndCut(&input, "\\b\\*{2}", &text)) { javadoc_text << "" << ParseDocumentation(text) << ""; } else { javadoc_text << markup; } } else if (markup == "*") { - // emphase text (light) + // text emphasis (normal) if (FindAndCut(&input, "\\b\\*{1}", &text)) { javadoc_text << "" << ParseDocumentation(text) << ""; } else { javadoc_text << markup; } - } else if (markup == "[") { - // add an external link + } else if (markup.starts_with("[")) { + // hyperlinks string label; string link; if (RE2::Consume(&input, "([^\\[]+)\\]\\((http.+)\\)", &label, &link)) { @@ -277,6 +281,7 @@ string ParseDocumentation(re2::StringPiece input) { javadoc_text << markup; } } else { + // safe fallback javadoc_text << markup; } } -- GitLab From eac1479f04181fb107c85af29a709eb369831972 Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Mon, 30 Apr 2018 07:38:48 -0400 Subject: [PATCH 029/755] Simplify and improve generics handling in generator --- tensorflow/java/build_defs.bzl | 1 + tensorflow/java/src/gen/cc/op_gen_main.cc | 4 +- tensorflow/java/src/gen/cc/op_generator.cc | 155 +++++++++------------ tensorflow/java/src/gen/cc/op_generator.h | 13 +- tensorflow/java/src/gen/cc/op_specs.cc | 81 ++++++----- tensorflow/java/src/gen/cc/op_specs.h | 16 ++- 6 files changed, 132 insertions(+), 138 deletions(-) diff --git a/tensorflow/java/build_defs.bzl b/tensorflow/java/build_defs.bzl index ab7f60d03d..e1916ca4d9 100644 --- a/tensorflow/java/build_defs.bzl +++ b/tensorflow/java/build_defs.bzl @@ -15,6 +15,7 @@ JAVA_VERSION_OPTS = [ XLINT_OPTS = [ "-Werror", "-Xlint:all", + "-Xlint:-processing", "-Xlint:-serial", "-Xlint:-try", "-Xlint:-classfile", # see b/32750402, go/javac-warnings#classfile diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index 458141b877..a508c96516 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -67,10 +67,10 @@ int main(int argc, char* argv[]) { QCHECK(parsed_flags_ok && !output_dir.empty()) << usage; std::vector api_dirs = tensorflow::str_util::Split( api_dirs_str, ",", tensorflow::str_util::SkipEmpty()); - tensorflow::java::OpGenerator generator(base_package, output_dir, api_dirs); + tensorflow::java::OpGenerator generator(api_dirs); tensorflow::OpList ops; tensorflow::OpRegistry::Global()->Export(false, &ops); - TF_CHECK_OK(generator.Run(ops)); + TF_CHECK_OK(generator.Run(ops, base_package, output_dir)); return 0; } diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 00f84bc9cd..2327a4daf1 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -38,23 +39,18 @@ namespace { const char* kLicenseSnippet = "tensorflow/java/src/gen/resources/license.java.snippet"; -const std::map kPrimitiveAttrTypes = { - { "Boolean", Type::Boolean() }, - { "Byte", Type::Byte() }, - { "Character", Type::Byte() }, - { "Float", Type::Float() }, - { "Integer", Type::Long() }, - { "Long", Type::Long() }, - { "Short", Type::Long() }, - { "Double", Type::Float() }, -}; - enum RenderMode { DEFAULT, SINGLE_OUTPUT, SINGLE_LIST_OUTPUT }; +inline void AddArgument(const Variable& var, const string& description, + Method* method_out, Javadoc* javadoc_out) { + method_out->add_argument(var); + javadoc_out->add_param_tag(var.name(), description); +} + void CollectOpDependencies(const OpSpec& op, RenderMode mode, std::list* out) { out->push_back(Type::Class("Operation", "org.tensorflow")); @@ -81,9 +77,7 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode, } for (const AttributeSpec& attribute : op.attributes()) { out->push_back(attribute.var().type()); - if (attribute.var().type().name() == "Class") { - out->push_back(Type::Enum("DataType", "org.tensorflow")); - } + out->push_back(attribute.jni_type()); } for (const AttributeSpec& optional_attribute : op.optional_attributes()) { out->push_back(optional_attribute.var().type()); @@ -92,45 +86,38 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode, void WriteSetAttrDirective(const AttributeSpec& attr, bool optional, SourceWriter* writer) { - string var = optional ? "opts." + attr.var().name() : attr.var().name(); + string var_name = optional ? "opts." + attr.var().name() : attr.var().name(); if (attr.iterable()) { - const Type& type = attr.type(); - std::map::const_iterator it = - kPrimitiveAttrTypes.find(type.name()); - if (it != kPrimitiveAttrTypes.end()) { - string array = attr.var().name() + "Array"; - writer->AppendType(it->second) - .Append("[] " + array + " = new ") - .AppendType(it->second) - .Append("[" + var + ".size()];") - .EndLine(); - writer->BeginBlock("for (int i = 0; i < " + array + ".length; ++i)") - .Append(array + "[i] = " + var + ".get(i);") - .EndLine() - .EndBlock() - .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + array) - .Append(");") - .EndLine(); + string array_name = attr.var().name() + "Array"; + writer->AppendType(attr.jni_type()) + .Append("[] " + array_name + " = new ") + .AppendType(attr.jni_type()) + .Append("[" + var_name + ".size()];") + .EndLine() + .BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)") + .Append(array_name + "[i] = "); + if (attr.type().kind() == Type::GENERIC) { + writer->Append("DataType.fromClass(" + var_name + ".get(i));"); } else { - writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + var) - .Append(".toArray(new ") - .AppendType(type) - .Append("[" + var + ".size()]));") - .EndLine(); + writer->Append(var_name + ".get(i);"); } + writer->EndLine() + .EndBlock() + .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ") + .Append(array_name + ");") + .EndLine(); } else { - Type type = attr.var().type(); writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", "); - if (type.name() == "Class") { - writer->Append("DataType.fromClass(" + attr.var().name() + "));"); + if (attr.var().type().name() == "Class") { + writer->Append("DataType.fromClass(" + var_name + "));"); } else { - writer->Append(var + ");"); + writer->Append(var_name + ");"); } writer->EndLine(); } } -void RenderFactoryMethod(const OpSpec& op, const Type& op_class, +void RenderFactoryMethods(const OpSpec& op, const Type& op_class, SourceWriter* writer) { Method factory = Method::Create("create", op_class); Javadoc factory_doc = Javadoc::Create( @@ -138,27 +125,24 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class, + " operation to the graph."); Variable scope = Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op")); - factory.add_argument(scope); - factory_doc.add_param_tag(scope.name(), "Current graph scope"); + AddArgument(scope, "current graph scope", &factory, &factory_doc); for (const ArgumentSpec& input : op.inputs()) { - factory.add_argument(input.var()); - factory_doc.add_param_tag(input.var().name(), input.description()); + AddArgument(input.var(), input.description(), &factory, &factory_doc); } - for (const AttributeSpec& attribute : op.attributes()) { - factory.add_argument(attribute.var()); - factory_doc.add_param_tag(attribute.var().name(), attribute.description()); + for (const AttributeSpec& attr : op.attributes()) { + AddArgument(attr.var(), attr.description(), &factory, &factory_doc); } if (!op.optional_attributes().empty()) { - factory.add_argument(Variable::Varargs("options", Type::Class("Options"))); - factory_doc.add_param_tag("options", "carries optional attributes values"); + AddArgument(Variable::Varargs("options", Type::Class("Options")), + "carries optional attributes values", &factory, &factory_doc); } factory_doc.add_tag("return", "a new instance of " + op_class.name()); + writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc); writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" + op.graph_op_name() + "\", scope.makeOpName(\"" + op_class.name() + "\"));"); writer->EndLine(); - for (const ArgumentSpec& input : op.inputs()) { if (input.iterable()) { writer->Append("opBuilder.addInputList(Operands.asOutputs(" @@ -192,10 +176,9 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class, void RenderConstructor(const OpSpec& op, const Type& op_class, SourceWriter* writer) { - Method constructor = Method::ConstructorFor(op_class) - .add_argument( - Variable::Create("operation", - Type::Class("Operation", "org.tensorflow"))); + Variable operation = + Variable::Create("operation", Type::Class("Operation", "org.tensorflow")); + Method constructor = Method::ConstructorFor(op_class).add_argument(operation); for (const ArgumentSpec& output : op.outputs()) { if (output.iterable() && !output.type().unknown()) { constructor.add_annotation( @@ -237,15 +220,14 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, } void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) { - for (const AttributeSpec& attribute : op.optional_attributes()) { + for (const AttributeSpec& attr : op.optional_attributes()) { Method setter = - Method::Create(attribute.var().name(), Type::Class("Options")) - .add_argument(attribute.var()); - Javadoc setter_doc = Javadoc::Create() - .add_param_tag(attribute.var().name(), attribute.description()); + Method::Create(attr.var().name(), Type::Class("Options")); + Javadoc setter_doc = Javadoc::Create(); + AddArgument(attr.var(), attr.description(), &setter, &setter_doc); writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc) - .Append("return new Options()." + attribute.var().name() + "(" - + attribute.var().name() + ");") + .Append("return new Options()." + attr.var().name() + "(" + + attr.var().name() + ");") .EndLine() .EndMethod(); } @@ -311,14 +293,12 @@ void RenderOptionsClass(const OpSpec& op, const Type& op_class, Javadoc options_doc = Javadoc::Create( "Optional attributes for {@link " + op_class.full_name() + "}"); writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc); - for (const AttributeSpec& attribute : op.optional_attributes()) { - Method setter = Method::Create(attribute.var().name(), options_class) - .add_argument(attribute.var()); - Javadoc setter_doc = Javadoc::Create() - .add_param_tag(attribute.var().name(), attribute.description()); + for (const AttributeSpec& attr : op.optional_attributes()) { + Method setter = Method::Create(attr.var().name(), options_class); + Javadoc setter_doc = Javadoc::Create(); + AddArgument(attr.var(), attr.description(), &setter, &setter_doc); writer->BeginMethod(setter, PUBLIC, &setter_doc) - .Append("this." + attribute.var().name() + " = " - + attribute.var().name() + ";") + .Append("this." + attr.var().name() + " = " + attr.var().name() + ";") .EndLine() .Append("return this;") .EndLine() @@ -339,12 +319,13 @@ inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) { } void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, - const string& base_package, const string& output_dir, Env* env) { + const string& base_package, const string& output_dir, Env* env, + const std::tm* timestamp) { Type op_class(ClassOf(endpoint, base_package) .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op"))); Javadoc op_javadoc(endpoint.javadoc()); - // implement Operand (or Iterable) if the op has only one output + // op interfaces RenderMode mode = DEFAULT; if (op.outputs().size() == 1) { const ArgumentSpec& output = op.outputs().front(); @@ -360,18 +341,22 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, op_class.add_supertype(operand_inf); } } - // declare all outputs generics at the op class level + // op generic parameters std::set generics; for (const ArgumentSpec& output : op.outputs()) { if (output.type().kind() == Type::GENERIC && !output.type().unknown() && generics.find(output.type().name()) == generics.end()) { op_class.add_parameter(output.type()); op_javadoc.add_param_tag("<" + output.type().name() + ">", - "data type of output {@code " + output.var().name() + "}"); + "data type for {@code " + output.var().name() + "()} output"); generics.insert(output.type().name()); } } - // handle endpoint deprecation + // op annotations + char date[20]; + strftime(date, sizeof date, "%FT%TZ", timestamp); + op_class.add_annotation(Annotation::Create("Generated", "javax.annotation") + .attributes(string("value = \"op_generator\", date = \"") + date + "\"")); if (endpoint.deprecated()) { op_class.add_annotation(Annotation::Create("Deprecated")); string explanation; @@ -384,8 +369,8 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, } op_javadoc.add_tag("deprecated", explanation); } - // expose the op in the Ops Graph API only if it is visible if (!op.hidden()) { + // expose the op in the Ops Graph API only if it is visible op_class.add_annotation( Annotation::Create("Operator", "org.tensorflow.op.annotation") .attributes("group = \"" + endpoint.package() + "\"")); @@ -405,15 +390,12 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, std::list dependencies; CollectOpDependencies(op, mode, &dependencies); writer.WriteFromFile(kLicenseSnippet) - .EndLine() - .Append("// This file is machine generated, DO NOT EDIT!") - .EndLine() .EndLine() .BeginType(op_class, PUBLIC|FINAL, &dependencies, &op_javadoc); if (!op.optional_attributes().empty()) { RenderOptionsClass(op, op_class, &writer); } - RenderFactoryMethod(op, op_class, &writer); + RenderFactoryMethods(op, op_class, &writer); RenderGettersAndSetters(op, &writer); if (mode != DEFAULT) { RenderInterfaceImpl(op, mode, &writer); @@ -428,13 +410,8 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, } // namespace -OpGenerator::OpGenerator(const string& base_package, const string& output_dir, - const std::vector& api_dirs, Env* env) - : base_package_(base_package), output_dir_(output_dir), api_dirs_(api_dirs), - env_(env) { -} - -Status OpGenerator::Run(const OpList& op_list) { +Status OpGenerator::Run(const OpList& op_list, const string& base_package, + const string& output_dir) { ApiDefMap api_map(op_list); if (!api_dirs_.empty()) { // Only load api files that correspond to the requested "op_list" @@ -449,12 +426,14 @@ Status OpGenerator::Run(const OpList& op_list) { } } api_map.UpdateDocs(); + time_t now; + time(&now); for (const auto& op_def : op_list.op()) { const ApiDef* api_def = api_map.GetApiDef(op_def.name()); if (api_def->visibility() != ApiDef::SKIP) { OpSpec op(OpSpec::Create(op_def, *api_def)); for (const EndpointSpec& endpoint : op.endpoints()) { - GenerateOp(op, endpoint, base_package_, output_dir_, env_); + GenerateOp(op, endpoint, base_package, output_dir, env_, gmtime(&now)); } } } diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index 06b08e852a..b789e11fa9 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -36,18 +36,17 @@ namespace java { // ops definitions. class OpGenerator { public: - OpGenerator(const string& base_package, const string& output_dir, - const std::vector& api_dirs, Env* env = Env::Default()); + explicit OpGenerator(const std::vector& api_dirs, + Env* env = Env::Default()) : api_dirs_(api_dirs), env_(env) {} // Generates wrappers for the given list of 'ops'. // - // Output files are generated in //, - // where 'lib_package' is derived from ops endpoints. - Status Run(const OpList& op_list); + // Output files are generated in //, + // where 'op_package' is derived from ops endpoints. + Status Run(const OpList& op_list, const string& base_package, + const string& output_dir); private: - const string base_package_; - const string output_dir_; const std::vector api_dirs_; Env* env_; }; diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index a0e7a180f2..dcc6388614 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -46,14 +46,30 @@ class TypeResolver { explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {} Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out); - Type TypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out); + std::pair TypeOf(const OpDef_AttrDef& attr_def, + bool *iterable_out); bool IsAttributeVisited(const string& attr_name) { return visited_attrs_.find(attr_name) != visited_attrs_.cend(); } + private: const OpDef op_def_; std::map visited_attrs_; - char next_generic_ = 'T'; + char next_generic_letter_ = 'T'; + + std::pair MakeTypePair(const Type& type, const Type& jni_type) { + return std::make_pair(type, jni_type); + } + std::pair MakeTypePair(const Type& type) { + return std::make_pair(type, type); + } + Type NextGeneric() { + char generic_letter = next_generic_letter_++; + if (next_generic_letter_ > 'Z') { + next_generic_letter_ = 'A'; + } + return Type::Generic(string(1, generic_letter)); + } }; Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, @@ -107,7 +123,7 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, } else { for (const auto& attr_def : op_def_.attr()) { if (attr_def.name() == arg_def.type_attr()) { - type = TypeOf(attr_def, iterable_out); + type = TypeOf(attr_def, iterable_out).first; break; } } @@ -125,51 +141,47 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, return type; } -Type TypeResolver::TypeOf(const OpDef_AttrDef& attr_def, +std::pair TypeResolver::TypeOf(const OpDef_AttrDef& attr_def, bool* iterable_out) { + std::pair types = MakeTypePair(Type::Wildcard()); *iterable_out = false; StringPiece attr_type = attr_def.type(); if (str_util::ConsumePrefix(&attr_type, "list(")) { attr_type.remove_suffix(1); // remove closing brace *iterable_out = true; } - Type type = *iterable_out ? Type::Wildcard() : Type::Class("Object"); - if (attr_type == "type") { - if (*iterable_out) { - type = Type::Enum("DataType", "org.tensorflow"); - } else { - type = Type::Generic(string(1, next_generic_)); - next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1; - if (IsRealNumbers(attr_def.allowed_values())) { - // enforce real numbers datasets by extending java.lang.Number - type.add_supertype(Type::Class("Number")); - } - } - } else if (attr_type == "string") { - type = Type::Class("String"); + if (attr_type == "string") { + types = MakeTypePair(Type::Class("String")); } else if (attr_type == "int") { - type = Type::Class("Integer"); + types = MakeTypePair(Type::Class("Long"), Type::Long()); } else if (attr_type == "float") { - type = Type::Class("Float"); + types = MakeTypePair(Type::Class("Float"), Type::Float()); } else if (attr_type == "bool") { - type = Type::Class("Boolean"); + types = MakeTypePair(Type::Class("Boolean"), Type::Boolean()); } else if (attr_type == "shape") { - type = Type::Class("Shape", "org.tensorflow"); + types = MakeTypePair(Type::Class("Shape", "org.tensorflow")); } else if (attr_type == "tensor") { - type = Type::Class("Tensor", "org.tensorflow") - .add_parameter(Type::Wildcard()); + types = MakeTypePair(Type::Class("Tensor", "org.tensorflow") + .add_parameter(Type::Wildcard())); + + } else if (attr_type == "type") { + Type type = *iterable_out ? Type::Wildcard() : NextGeneric(); + if (IsRealNumbers(attr_def.allowed_values())) { + type.add_supertype(Type::Class("Number")); + } + types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow")); } else { LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type << "\" in operation \"" << op_def_.name() << "\""; } - visited_attrs_.insert(std::make_pair(attr_def.name(), type)); - return type; + visited_attrs_.insert(std::make_pair(attr_def.name(), types.first)); + return types; } string SnakeToCamelCase(const string& str, bool upper = false) { @@ -307,19 +319,19 @@ ArgumentSpec CreateInput(const OpDef_ArgDef& input_def, AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def, const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) { bool iterable = false; - Type type = type_resolver->TypeOf(attr_def, &iterable); - // type attributes must be passed explicitly in methods as a Class<> parameter - bool is_explicit = type.kind() == Type::GENERIC && !iterable; - Type var_type = is_explicit ? Type::Class("Class").add_parameter(type) : type; + std::pair types = type_resolver->TypeOf(attr_def, &iterable); + Type var_type = types.first.kind() == Type::GENERIC ? + Type::Class("Class").add_parameter(types.first) : types.first; if (iterable) { - var_type = Type::ListOf(type); + var_type = Type::ListOf(var_type); } return AttributeSpec(attr_api_def.name(), Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type), - type, + types.first, + types.second, ParseDocumentation(attr_api_def.description()), iterable, - attr_api_def.has_default_value() && !is_explicit); + attr_api_def.has_default_value()); } ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def, @@ -340,7 +352,6 @@ ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def, EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def, const ApiDef_Endpoint& endpoint_def) { - std::vector name_tokens = str_util::Split(endpoint_def.name(), "."); string package; string name; @@ -381,7 +392,7 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) { AttributeSpec attr = CreateAttribute(op_def.attr(i), api_def.attr(i), &type_resolver); // attributes with a default value are optional - if (attr.optional()) { + if (attr.has_default_value() && attr.type().kind() != Type::GENERIC) { op.optional_attributes_.push_back(attr); } else { op.attributes_.push_back(attr); diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h index 55c2c3f307..7d64391446 100644 --- a/tensorflow/java/src/gen/cc/op_specs.h +++ b/tensorflow/java/src/gen/cc/op_specs.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/java/src/gen/cc/java_defs.h" namespace tensorflow { @@ -87,20 +88,23 @@ class AttributeSpec : public ArgumentSpec { // op_def_name: attribute name, as known by TensorFlow core // var: a variable to represent this attribute in Java // type: the type of this attribute + // jni_type: the type of this attribute in JNI layer (see OperationBuilder) // description: a description of this attribute, in javadoc // iterable: true if this attribute is a list - // optional: true if this attribute does not require to be set explicitly + // has_default_value: true if this attribute has a default value if not set AttributeSpec(const string& op_def_name, const Variable& var, - const Type& type, const string& description, bool iterable, - bool optional) + const Type& type, const Type& jni_type, const string& description, + bool iterable, bool has_default_value) : ArgumentSpec(op_def_name, var, type, description, iterable), - optional_(optional) {} + jni_type_(jni_type), has_default_value_(has_default_value) {} virtual ~AttributeSpec() = default; - bool optional() const { return optional_; } + const Type& jni_type() const { return jni_type_; } + bool has_default_value() const { return has_default_value_; } private: - const bool optional_; + const Type jni_type_; + const bool has_default_value_; }; class OpSpec { -- GitLab From dd1ef8fa8f6861e53e8a7953c171b3e9253043ed Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Thu, 3 May 2018 22:39:35 -0400 Subject: [PATCH 030/755] Second code review --- tensorflow/core/api_def/BUILD | 7 ++ .../java_api/api_def_FilterDataset.pbtxt | 4 + .../java_api/api_def_FlatMapDataset.pbtxt | 4 + .../core/api_def/java_api/api_def_For.pbtxt | 4 + .../java_api/api_def_GeneratorDataset.pbtxt | 4 + .../api_def_GroupByWindowDataset.pbtxt | 4 + .../core/api_def/java_api/api_def_If.pbtxt | 4 + .../java_api/api_def_InterleaveDataset.pbtxt | 4 + .../java_api/api_def_MapAndBatchDataset.pbtxt | 4 + .../api_def/java_api/api_def_MapDataset.pbtxt | 4 + .../java_api/api_def_OneShotIterator.pbtxt | 4 + .../api_def_ParallelInterleaveDataset.pbtxt | 4 + .../java_api/api_def_ParallelMapDataset.pbtxt | 4 + .../api_def/java_api/api_def_RemoteCall.pbtxt | 4 + .../java_api/api_def_ScanDataset.pbtxt | 4 + .../java_api/api_def_SymbolicGradient.pbtxt | 4 + .../core/api_def/java_api/api_def_While.pbtxt | 4 + tensorflow/java/BUILD | 39 ++++------ tensorflow/java/src/gen/cc/java_defs.h | 6 +- tensorflow/java/src/gen/cc/op_gen_main.cc | 2 +- tensorflow/java/src/gen/cc/op_generator.cc | 77 +++++++++++-------- tensorflow/java/src/gen/cc/op_generator.h | 2 +- tensorflow/java/src/gen/cc/op_specs.cc | 25 +++++- tensorflow/java/src/gen/cc/op_specs.h | 17 +++- tensorflow/java/src/gen/cc/source_writer.cc | 20 +++-- tensorflow/java/src/gen/cc/source_writer.h | 2 +- .../java/src/gen/cc/source_writer_test.cc | 2 +- tensorflow/java/src/gen/gen_ops.bzl | 41 +++------- 28 files changed, 195 insertions(+), 109 deletions(-) create mode 100644 tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_For.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_If.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt create mode 100644 tensorflow/core/api_def/java_api/api_def_While.pbtxt diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 19d6438809..06b797e32e 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -4,6 +4,7 @@ # The following targets can be used to access ApiDefs: # :base_api_def # :python_api_def +# :java_api_def package( default_visibility = ["//visibility:private"], @@ -29,6 +30,12 @@ filegroup( visibility = ["//tensorflow:internal"], ) +filegroup( + name = "java_api_def", + srcs = glob(["java_api/*"]), + visibility = ["//tensorflow:internal"], +) + cc_library( name = "excluded_ops_lib", srcs = ["excluded_ops.cc"], diff --git a/tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt new file mode 100644 index 0000000000..debd7e5709 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "FilterDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt new file mode 100644 index 0000000000..329ab15ef5 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "FlatMapDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_For.pbtxt b/tensorflow/core/api_def/java_api/api_def_For.pbtxt new file mode 100644 index 0000000000..caabc947bb --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_For.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "For" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt new file mode 100644 index 0000000000..a6e5167c30 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "GeneratorDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt new file mode 100644 index 0000000000..4c0b2084a8 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "GroupByWindowDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_If.pbtxt b/tensorflow/core/api_def/java_api/api_def_If.pbtxt new file mode 100644 index 0000000000..13b8635ca7 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_If.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "If" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt new file mode 100644 index 0000000000..ed748d4d2a --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "InterleaveDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt new file mode 100644 index 0000000000..cb96bf63d8 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "MapAndBatchDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt new file mode 100644 index 0000000000..e0ab8dd9db --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "MapDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt b/tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt new file mode 100644 index 0000000000..13130e6882 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "OneShotIterator" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt new file mode 100644 index 0000000000..6a985d24fa --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ParallelInterleaveDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt new file mode 100644 index 0000000000..64f25b9e5e --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ParallelMapDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt b/tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt new file mode 100644 index 0000000000..2ccb5c8cf3 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "RemoteCall" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt new file mode 100644 index 0000000000..3463e60049 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ScanDataset" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt b/tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt new file mode 100644 index 0000000000..88c3acea74 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "SymbolicGradient" + visibility: SKIP +} diff --git a/tensorflow/core/api_def/java_api/api_def_While.pbtxt b/tensorflow/core/api_def/java_api/api_def_While.pbtxt new file mode 100644 index 0000000000..33756682c3 --- /dev/null +++ b/tensorflow/core/api_def/java_api/api_def_While.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "While" + visibility: SKIP +} \ No newline at end of file diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 17566e1a9c..7cd0208dbf 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -68,34 +68,27 @@ filegroup( ], ) -# Build the gen tool as a library, as it will be linked to a core/ops binary -# files before making it an executable. tf_java_op_gen_srcjar( name = "java_op_gen_sources", api_def_srcs = [ "//tensorflow/core/api_def:base_api_def", + "//tensorflow/core/api_def:java_api_def", ], - gen_base_package = "org.tensorflow.op", - gen_tool = "java_op_gen_tool", - ops_libs = [ - "array_ops", - "candidate_sampling_ops", - "control_flow_ops", - "data_flow_ops", - "image_ops", - "io_ops", - "linalg_ops", - "logging_ops", - "math_ops", - "nn_ops", - "no_op", - "parsing_ops", - "random_ops", - "sparse_ops", - "state_ops", - "string_ops", - "training_ops", - "user_ops", + base_package = "org.tensorflow.op", + gen_tool = ":java_op_gen_tool", +) + +tf_cc_binary( + name = "java_op_gen_tool", + srcs = [ + "src/gen/cc/op_gen_main.cc", + ], + copts = tf_copts(), + linkopts = ["-lm"], + linkstatic = 1, + deps = [ + ":java_op_gen_lib", + "//tensorflow/core:ops", ], ) diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index 81ac67eb2f..62575f6683 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -102,10 +102,10 @@ class Type { const Kind& kind() const { return kind_; } const string& name() const { return name_; } const string& package() const { return package_; } - const string full_name() const { + const string canonical_name() const { return package_.empty() ? name_ : package_ + "." + name_; } - bool unknown() const { return name_.empty(); } // only wildcards has no name + bool wildcard() const { return name_.empty(); } // only wildcards has no name const std::list& parameters() const { return parameters_; } Type& add_parameter(const Type& parameter) { parameters_.push_back(parameter); diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index a508c96516..6c35cd9595 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 2327a4daf1..7355b3a395 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -39,13 +38,26 @@ namespace { const char* kLicenseSnippet = "tensorflow/java/src/gen/resources/license.java.snippet"; +// There is three different modes to render an op class, depending on the +// number and type of outputs it has: +// +// DEFAULT: This mode does not provide any specialization for the op class, it +// is applied when the operation does not comply with any other mode +// +// OPERAND: The op class implements the Operand interface, allowing an +// instance to be passed directly in input to another operation +// +// LIST_OPERAND: The op class implements the Iterable> interface, +// allowing an instance to be passed directly as a list input to +// another operation +// enum RenderMode { DEFAULT, - SINGLE_OUTPUT, - SINGLE_LIST_OUTPUT + OPERAND, + LIST_OPERAND }; -inline void AddArgument(const Variable& var, const string& description, +void AddArgument(const Variable& var, const string& description, Method* method_out, Javadoc* javadoc_out) { method_out->add_argument(var); javadoc_out->add_param_tag(var.name(), description); @@ -56,9 +68,9 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode, out->push_back(Type::Class("Operation", "org.tensorflow")); out->push_back(Type::Class("OperationBuilder", "org.tensorflow")); out->push_back(Type::Class("Scope", "org.tensorflow.op")); - if (mode == SINGLE_OUTPUT) { + if (mode == OPERAND) { out->push_back(Type::Class("Output", "org.tensorflow")); - } else if (mode == SINGLE_LIST_OUTPUT) { + } else if (mode == LIST_OPERAND) { out->push_back(Type::Interface("Iterator", "java.util")); } // Don't pay attention to duplicate types in the dependency list, they will @@ -180,7 +192,7 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, Variable::Create("operation", Type::Class("Operation", "org.tensorflow")); Method constructor = Method::ConstructorFor(op_class).add_argument(operation); for (const ArgumentSpec& output : op.outputs()) { - if (output.iterable() && !output.type().unknown()) { + if (output.iterable() && !output.type().wildcard()) { constructor.add_annotation( Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); break; @@ -200,7 +212,7 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, + "\");") .EndLine() .Append(output.var().name() + " = Arrays.asList("); - if (!output.type().unknown()) { + if (!output.type().wildcard()) { writer->Append("(") .AppendType(output.var().type().parameters().front()) .Append("[])"); @@ -245,8 +257,8 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, SourceWriter* writer) { ArgumentSpec output = op.outputs().front(); - if (mode == SINGLE_OUTPUT) { - bool cast2obj = output.type().unknown(); + if (mode == OPERAND) { + bool cast2obj = output.type().wildcard(); Type return_type = Type::Class("Output", "org.tensorflow") .add_parameter(cast2obj ? Type::Class("Object") : output.type()); Method as_output = Method::Create("asOutput", return_type) @@ -265,9 +277,9 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, .EndLine() .EndMethod(); - } else if (mode == SINGLE_LIST_OUTPUT) { + } else if (mode == LIST_OPERAND) { Type operand = Type::Interface("Operand", "org.tensorflow"); - if (output.type().unknown()) { + if (output.type().wildcard()) { operand.add_parameter(Type::Class("Object")); } else { operand.add_parameter(output.type()); @@ -291,7 +303,7 @@ void RenderOptionsClass(const OpSpec& op, const Type& op_class, SourceWriter* writer) { Type options_class = Type::Class("Options"); Javadoc options_doc = Javadoc::Create( - "Optional attributes for {@link " + op_class.full_name() + "}"); + "Optional attributes for {@link " + op_class.canonical_name() + "}"); writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc); for (const AttributeSpec& attr : op.optional_attributes()) { Method setter = Method::Create(attr.var().name(), options_class); @@ -319,8 +331,7 @@ inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) { } void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, - const string& base_package, const string& output_dir, Env* env, - const std::tm* timestamp) { + const string& base_package, const string& output_dir, Env* env) { Type op_class(ClassOf(endpoint, base_package) .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op"))); Javadoc op_javadoc(endpoint.javadoc()); @@ -329,22 +340,22 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, RenderMode mode = DEFAULT; if (op.outputs().size() == 1) { const ArgumentSpec& output = op.outputs().front(); - Type operand_type(output.type().unknown() ? + Type operand_type(output.type().wildcard() ? Type::Class("Object") : output.type()); Type operand_inf(Type::Interface("Operand", "org.tensorflow") .add_parameter(operand_type)); if (output.iterable()) { - mode = SINGLE_LIST_OUTPUT; + mode = LIST_OPERAND; op_class.add_supertype(Type::IterableOf(operand_inf)); } else { - mode = SINGLE_OUTPUT; + mode = OPERAND; op_class.add_supertype(operand_inf); } } // op generic parameters std::set generics; for (const ArgumentSpec& output : op.outputs()) { - if (output.type().kind() == Type::GENERIC && !output.type().unknown() + if (output.type().kind() == Type::GENERIC && !output.type().wildcard() && generics.find(output.type().name()) == generics.end()) { op_class.add_parameter(output.type()); op_javadoc.add_param_tag("<" + output.type().name() + ">", @@ -353,16 +364,15 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, } } // op annotations - char date[20]; - strftime(date, sizeof date, "%FT%TZ", timestamp); - op_class.add_annotation(Annotation::Create("Generated", "javax.annotation") - .attributes(string("value = \"op_generator\", date = \"") + date + "\"")); + op_class.add_annotation( + Annotation::Create("Generated", "javax.annotation") + .attributes("value = \"TensorFlow Java Op Generator\"")); if (endpoint.deprecated()) { op_class.add_annotation(Annotation::Create("Deprecated")); string explanation; if (!op.endpoints().front().deprecated()) { explanation = "use {@link " + - ClassOf(op.endpoints().front(), base_package).full_name() + ClassOf(op.endpoints().front(), base_package).canonical_name() + "} instead"; } else { explanation = op.deprecation_explanation(); @@ -376,14 +386,16 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, .attributes("group = \"" + endpoint.package() + "\"")); } // create op class file - string op_dir = io::JoinPath(output_dir, + const string op_dir_name = io::JoinPath(output_dir, str_util::StringReplace(op_class.package(), ".", "/", true)); - if (!env->FileExists(op_dir).ok()) { - TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir)); + if (!env->FileExists(op_dir_name).ok()) { + TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir_name)) + << op_dir_name; } + const string op_file_name = op_class.name() + ".java"; std::unique_ptr op_file; TF_CHECK_OK(env->NewWritableFile( - io::JoinPath(op_dir, op_class.name() + ".java"), &op_file)); + io::JoinPath(op_dir_name, op_file_name), &op_file)) << op_file_name; // render endpoint source code SourceFileWriter writer(op_file.get()); @@ -420,20 +432,19 @@ Status OpGenerator::Run(const OpList& op_list, const string& base_package, const std::string api_def_file_pattern = io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt"); if (env_->FileExists(api_def_file_pattern).ok()) { - TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern)); + TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern)) + << api_def_file_pattern; } } } } api_map.UpdateDocs(); - time_t now; - time(&now); for (const auto& op_def : op_list.op()) { const ApiDef* api_def = api_map.GetApiDef(op_def.name()); if (api_def->visibility() != ApiDef::SKIP) { OpSpec op(OpSpec::Create(op_def, *api_def)); for (const EndpointSpec& endpoint : op.endpoints()) { - GenerateOp(op, endpoint, base_package, output_dir, env_, gmtime(&now)); + GenerateOp(op, endpoint, base_package, output_dir, env_); } } } diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index b789e11fa9..cfe842070a 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index dcc6388614..081062ceaf 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -45,9 +45,26 @@ class TypeResolver { public: explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {} + // Returns the class type of an input/output argument + // + // For example, if the argument's datatype is DT_STRING, this method will + // return "java.lang.String", so the argument can become "Operand" + // in the Ops API Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out); - std::pair TypeOf(const OpDef_AttrDef& attr_def, + + // Returns types of an input attribute + // + // The first element of the pair is the class type of this attribute while + // the second is its JNI/primitive type equivalent, required for explicit + // unboxing. + // + // For example, if the attribute is of type "float", this method will return + // , so the attribute can be used as a "Float" object + // in the Ops API and casted to a "float" when passing through the JNI layer. + std::pair TypesOf(const OpDef_AttrDef& attr_def, bool *iterable_out); + + // Returns true if the type of this attribute has already been resolved bool IsAttributeVisited(const string& attr_name) { return visited_attrs_.find(attr_name) != visited_attrs_.cend(); } @@ -123,7 +140,7 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, } else { for (const auto& attr_def : op_def_.attr()) { if (attr_def.name() == arg_def.type_attr()) { - type = TypeOf(attr_def, iterable_out).first; + type = TypesOf(attr_def, iterable_out).first; break; } } @@ -141,7 +158,7 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, return type; } -std::pair TypeResolver::TypeOf(const OpDef_AttrDef& attr_def, +std::pair TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, bool* iterable_out) { std::pair types = MakeTypePair(Type::Wildcard()); *iterable_out = false; @@ -319,7 +336,7 @@ ArgumentSpec CreateInput(const OpDef_ArgDef& input_def, AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def, const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) { bool iterable = false; - std::pair types = type_resolver->TypeOf(attr_def, &iterable); + std::pair types = type_resolver->TypesOf(attr_def, &iterable); Type var_type = types.first.kind() == Type::GENERIC ? Type::Class("Class").add_parameter(types.first) : types.first; if (iterable) { diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h index 7d64391446..81582ea207 100644 --- a/tensorflow/java/src/gen/cc/op_specs.h +++ b/tensorflow/java/src/gen/cc/op_specs.h @@ -65,7 +65,6 @@ class ArgumentSpec { const Type& type, const string& description, bool iterable) : op_def_name_(op_def_name), var_(var), type_(type), description_(description), iterable_(iterable) {} - virtual ~ArgumentSpec() = default; const string& op_def_name() const { return op_def_name_; } const Variable& var() const { return var_; } @@ -81,7 +80,7 @@ class ArgumentSpec { const bool iterable_; }; -class AttributeSpec : public ArgumentSpec { +class AttributeSpec { public: // A specification for an operation attribute // @@ -95,14 +94,24 @@ class AttributeSpec : public ArgumentSpec { AttributeSpec(const string& op_def_name, const Variable& var, const Type& type, const Type& jni_type, const string& description, bool iterable, bool has_default_value) - : ArgumentSpec(op_def_name, var, type, description, iterable), + : op_def_name_(op_def_name), var_(var), type_(type), + description_(description), iterable_(iterable), jni_type_(jni_type), has_default_value_(has_default_value) {} - virtual ~AttributeSpec() = default; + const string& op_def_name() const { return op_def_name_; } + const Variable& var() const { return var_; } + const Type& type() const { return type_; } + const string& description() const { return description_; } + bool iterable() const { return iterable_; } const Type& jni_type() const { return jni_type_; } bool has_default_value() const { return has_default_value_; } private: + const string op_def_name_; + const Variable var_; + const Type type_; + const string description_; + const bool iterable_; const Type jni_type_; const bool has_default_value_; }; diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc index 7e427787f9..56806cbb6d 100644 --- a/tensorflow/java/src/gen/cc/source_writer.cc +++ b/tensorflow/java/src/gen/cc/source_writer.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -83,17 +83,19 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) { } SourceWriter& SourceWriter::AppendType(const Type& type) { - if (type.unknown()) { + if (type.wildcard()) { Append("?"); } else { Append(type.name()); if (!type.parameters().empty()) { Append("<"); + bool first = true; for (const Type& t : type.parameters()) { - if (&t != &type.parameters().front()) { + if (!first) { Append(", "); } AppendType(t); + first = false; } Append(">"); } @@ -145,11 +147,13 @@ SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers, AppendType(method.return_type()).Append(" "); } Append(method.name()).Append("("); + bool first = true; for (const Variable& v : method.arguments()) { - if (&v != &method.arguments().front()) { + if (!first) { Append(", "); } AppendType(v.type()).Append(v.variadic() ? "... " : " ").Append(v.name()); + first = false; } return Append(")").BeginBlock(); } @@ -294,14 +298,16 @@ SourceWriter& SourceWriter::WriteAnnotations( SourceWriter& SourceWriter::WriteGenerics( const std::list& generics) { Append("<"); + bool first = true; for (const Type* pt : generics) { - if (pt != generics.front()) { + if (!first) { Append(", "); } Append(pt->name()); if (!pt->supertypes().empty()) { Append(" extends ").AppendType(pt->supertypes().front()); } + first = false; } return Append(">"); } @@ -339,7 +345,7 @@ void SourceWriter::TypeVisitor::Visit(const Type& type) { void SourceWriter::GenericNamespace::DoVisit(const Type& type) { // ignore non-generic parameters, wildcards and generics already declared - if (type.kind() == Type::GENERIC && !type.unknown() + if (type.kind() == Type::GENERIC && !type.wildcard() && generic_names_.find(type.name()) == generic_names_.end()) { declared_types_.push_back(&type); generic_names_.insert(type.name()); @@ -348,7 +354,7 @@ void SourceWriter::GenericNamespace::DoVisit(const Type& type) { void SourceWriter::TypeImporter::DoVisit(const Type& type) { if (!type.package().empty() && type.package() != current_package_) { - imports_.insert(type.full_name()); + imports_.insert(type.canonical_name()); } } diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h index bcae33ccce..1f0febe9a3 100644 --- a/tensorflow/java/src/gen/cc/source_writer.h +++ b/tensorflow/java/src/gen/cc/source_writer.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc index 875ad99ae2..b9a5fee9be 100644 --- a/tensorflow/java/src/gen/cc/source_writer_test.cc +++ b/tensorflow/java/src/gen/cc/source_writer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl index 7017b52649..f4ff34ea03 100644 --- a/tensorflow/java/src/gen/gen_ops.bzl +++ b/tensorflow/java/src/gen/gen_ops.bzl @@ -3,33 +3,26 @@ load( "//tensorflow:tensorflow.bzl", "tf_binary_additional_srcs", - "tf_cc_binary", - "tf_copts", ) -# Given a list of "ops_libs" (a list of files in the core/ops directory -# without their .cc extensions), generate Java wrapper code for all operations -# found in the ops files. -# Then, combine all those source files into a single archive (.srcjar). +# Generate Java wrapper classes for all registered core operations and package +# them into a single source archive (.srcjar). # # For example: -# tf_java_op_gen_srcjar("gen_sources", "gen_tool", "my.package", [ "array_ops", "math_ops" ]) +# tf_java_op_gen_srcjar("gen_sources", ":gen_tool", "my.package") # -# will create a genrule named "gen_sources" that first generate source files: -# ops/src/main/java/my/package/array/*.java -# ops/src/main/java/my/package/math/*.java +# will create a genrule named "gen_sources" that generates source files under +# ops/src/main/java/my/package/**/*.java # -# and then archive those source files in: +# and then archive those source files into # ops/gen_sources.srcjar # def tf_java_op_gen_srcjar(name, gen_tool, - gen_base_package, - ops_libs=[], - ops_libs_pkg="//tensorflow/core", + base_package, + api_def_srcs=[], out_dir="ops/", out_src_dir="src/main/java/", - api_def_srcs=[], visibility=["//tensorflow/java:__pkg__"]): gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files @@ -48,23 +41,9 @@ def tf_java_op_gen_srcjar(name, ") | cut -d\" \" -f1))") api_def_args_str = ",".join(api_def_args) - gen_tool_deps = [":java_op_gen_lib"] - for ops_lib in ops_libs: - gen_tool_deps.append(ops_libs_pkg + ":" + ops_lib + "_op_lib") - - tf_cc_binary( - name=gen_tool, - srcs=[ - "src/gen/cc/op_gen_main.cc", - ], - copts=tf_copts(), - linkopts=["-lm"], - linkstatic=1, # Faster to link this one-time-use binary dynamically - deps = gen_tool_deps) - - gen_cmds += ["$(location :" + gen_tool + ")" + + gen_cmds += ["$(location " + gen_tool + ")" + " --output_dir=$(@D)/" + out_src_dir + - " --base_package=" + gen_base_package + + " --base_package=" + base_package + " --api_dirs=" + api_def_args_str] # Generate a source archive containing generated code for these ops. -- GitLab From aaa345f5a662aab524bbee3912c605919239bef6 Mon Sep 17 00:00:00 2001 From: wangsiyu Date: Fri, 4 May 2018 10:52:26 +0800 Subject: [PATCH 031/755] refine by using iterator of partitioned variable --- tensorflow/python/layers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index c050e6be04..f7b2e471b2 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -358,7 +358,7 @@ def _add_elements_to_collection(elements, collection_list): def _should_add_regularizer(variable, existing_variable_set): result = True if isinstance(variable, tf_variables.PartitionedVariable): - for var in variable._get_variable_list(): + for var in variable: if var in existing_variable_set: result = False break -- GitLab From a2cba4a627f880cf8160de624fc1ad947c01e973 Mon Sep 17 00:00:00 2001 From: mbhuiyan Date: Fri, 4 May 2018 12:02:28 -0700 Subject: [PATCH 032/755] if MKL is used allocation id is set to 9 and 10 --- .../direct_session_with_tracking_alloc_test.cc | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index 0ff022a8bc..29c8c8daec 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -101,18 +101,21 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { EXPECT_EQ(2, shape.dim_size()); EXPECT_EQ(2, shape.dim(0).size()); EXPECT_EQ(1, shape.dim(1).size()); -#ifndef INTEL_MKL +#ifdef INTEL_MKL // if MKL is used, it goes through various additional // graph rewrite pass. In TF, everytime a graph pass // happens, "constant" nodes are allocated // and deallocated. Each allocation calls the - // (FindChunkPtr of BFCAllocator) - // , which increments the value of AllocationId. + // (FindChunkPtr of BFCAllocator), + // which increments the value of AllocationId. // Thus AllocationId becomes more than 3 and 4 if - // MKL is used, they can be 10 and 11 or - // other numbers. If MKL is used - // following check will not hold. - // Thus, skipping the check if MKL is used. + // MKL is used. Now they are 9 and 10 for MKL. + if (node->name() == y->name()) { + EXPECT_EQ(9, cm->AllocationId(node, 0)); + } else { + EXPECT_EQ(10, cm->AllocationId(node, 0)); + } +#else if (node->name() == y->name()) { EXPECT_EQ(3, cm->AllocationId(node, 0)); } else { -- GitLab From 2314acf98fb874317dd17ef3daf438d7af87f900 Mon Sep 17 00:00:00 2001 From: Anya Petrova Date: Fri, 4 May 2018 13:22:03 -0700 Subject: [PATCH 033/755] Fix a small typo. --- SECURITY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SECURITY.md b/SECURITY.md index a5ce3a62ee..01886b613e 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -173,7 +173,7 @@ the progress being made towards a fix and announcement. In addition, please include the following information along with your report: * Your name and affiliation (if any). -* A description the technical details of the vulnerabilities. It is very +* A description of the technical details of the vulnerabilities. It is very important to let us know how we can reproduce your findings. * An explanation who can exploit this vulnerability, and what they gain when doing so -- write an attack scenario. This will help us evaluate your report -- GitLab From f368558429f5ebdbc0a187c3801dccf1ca6963c7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 11:40:01 -0700 Subject: [PATCH 034/755] [XLA] Cleanup client_library_test_base: move definition of CreateParameterAndTransferLiteral to .cc file PiperOrigin-RevId: 195446864 --- .../xla/tests/client_library_test_base.cc | 29 +++++++++++++++++++ .../xla/tests/client_library_test_base.h | 29 ------------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index c09e7eaf2b..41f9a5f666 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -565,4 +565,33 @@ XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); } +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, + const Literal& literal, + const string& name, + XlaBuilder* builder, + XlaOp* data_handle) { + return CreateParameterAndTransferLiteral(parameter_number, literal, name, + nullptr, builder, data_handle); +} + +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, XlaBuilder* builder, + XlaOp* data_handle) { + const Literal* param_literal = &literal; + std::unique_ptr converted_literal; + if (use_bfloat16_) { + converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); + param_literal = converted_literal.get(); + } + std::unique_ptr data = + client_->TransferToServer(*param_literal, device_handle) + .ConsumeValueOrDie(); + *data_handle = + builder->Parameter(parameter_number, param_literal->shape(), name); + return data; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index e58979a303..16e838e60f 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -616,35 +616,6 @@ std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( return result; } -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, - const Literal& literal, - const string& name, - XlaBuilder* builder, - XlaOp* data_handle) { - return CreateParameterAndTransferLiteral(parameter_number, literal, name, - nullptr, builder, data_handle); -} - -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, XlaBuilder* builder, - XlaOp* data_handle) { - const Literal* param_literal = &literal; - std::unique_ptr converted_literal; - if (use_bfloat16_) { - converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); - param_literal = converted_literal.get(); - } - std::unique_ptr data = - client_->TransferToServer(*param_literal, device_handle) - .ConsumeValueOrDie(); - *data_handle = - builder->Parameter(parameter_number, param_literal->shape(), name); - return data; -} - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ -- GitLab From be9b87375adecad9bd8bb12c81b2566c77a68ad7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 11:40:20 -0700 Subject: [PATCH 035/755] [XLA] Redesign: migrate the SWIG wrapped xla client. Added LocalOp that wraps XlaOp, so that it's fully visible to swig. PiperOrigin-RevId: 195446939 --- tensorflow/compiler/xla/python/BUILD | 3 +- .../xla/python/local_computation_builder.cc | 315 ++++++++------- .../xla/python/local_computation_builder.h | 206 +++++----- .../xla/python/local_computation_builder.i | 53 +-- tensorflow/compiler/xla/python/xla_client.py | 362 +++++++----------- 5 files changed, 415 insertions(+), 524 deletions(-) diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index ecb87bd889..932cce943f 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -49,9 +49,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 044458164f..df262c97bf 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/python/local_computation_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/default/thread_annotations.h" @@ -248,7 +249,7 @@ LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( return new LocalShapedBuffer(std::move(result_buffer)); } -LocalComputation::LocalComputation(Computation computation) +LocalComputation::LocalComputation(XlaComputation computation) : computation_(std::move(computation)) {} StatusOr LocalComputation::Compile( @@ -271,7 +272,7 @@ StatusOr LocalComputation::Compile( return new CompiledLocalComputation(std::move(local_executable)); } -const Computation& LocalComputation::computation() const { +const XlaComputation& LocalComputation::computation() const { return computation_; } @@ -281,8 +282,12 @@ StatusOr LocalComputation::GetReturnValueShape() const { return std::move(*program_shape.mutable_result()); } +LocalOp::LocalOp(const XlaOp& op) : op_(op) {} + +const XlaOp& LocalOp::op() const { return op_; } + LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) - : builder_(GetOrCreateLocalClient(), computation_name) {} + : builder_(computation_name) {} void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { builder_.SetOpMetadata(metadata); @@ -291,19 +296,21 @@ void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } StatusOr LocalComputationBuilder::Build() { - TF_ASSIGN_OR_RETURN(Computation computation, builder_.Build()); + TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build()); return new LocalComputation(std::move(computation)); } -ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { +LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, + const string& name) { return builder_.Parameter(parameter_number, shape, name); } std::unique_ptr LocalComputationBuilder::GetShape( - const ComputationDataHandle& operand) { - return builder_.GetShape(operand).ConsumeValueOrDie(); + const LocalOp& operand) { + auto result = MakeUnique(); + *result = builder_.GetShape(operand.op()).ValueOrDie(); + return result; } StatusOr LocalComputationBuilder::GetReturnValueShape() { @@ -311,222 +318,236 @@ StatusOr LocalComputationBuilder::GetReturnValueShape() { return program_shape.result(); } -ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) { +LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { return builder_.Infeed(shape); } -void LocalComputationBuilder::Outfeed(const ComputationDataHandle& operand, +void LocalComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config) { - builder_.Outfeed(operand, shape, outfeed_config); + builder_.Outfeed(operand.op(), shape, outfeed_config); } -ComputationDataHandle LocalComputationBuilder::ConstantLiteral( - const Literal& literal) { +LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { return builder_.ConstantLiteral(literal); } -ComputationDataHandle LocalComputationBuilder::Broadcast( - const ComputationDataHandle& operand, +LocalOp LocalComputationBuilder::Broadcast( + const LocalOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - return builder_.Broadcast(operand, broadcast_sizes); + return builder_.Broadcast(operand.op(), broadcast_sizes); } -ComputationDataHandle LocalComputationBuilder::Pad( - const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config) { - return builder_.Pad(operand, padding_value, padding_config); +LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, + const LocalOp& padding_value, + const PaddingConfig& padding_config) { + return builder_.Pad(operand.op(), padding_value.op(), padding_config); } -ComputationDataHandle LocalComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, +LocalOp LocalComputationBuilder::Reshape( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - return builder_.Reshape(operand, dimensions, new_sizes); + return builder_.Reshape(operand.op(), dimensions, new_sizes); } -ComputationDataHandle LocalComputationBuilder::Collapse( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - return builder_.Collapse(operand, dimensions); +LocalOp LocalComputationBuilder::Collapse( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return builder_.Collapse(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::CrossReplicaSum( - const ComputationDataHandle& operand) { - return builder_.CrossReplicaSum(operand); +LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { + return builder_.CrossReplicaSum(operand.op()); } -ComputationDataHandle LocalComputationBuilder::Slice( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, +LocalOp LocalComputationBuilder::Slice( + const LocalOp& operand, tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - return builder_.Slice(operand, start_indices, limit_indices, strides); + return builder_.Slice(operand.op(), start_indices, limit_indices, strides); } -ComputationDataHandle LocalComputationBuilder::SliceInDim( - const ComputationDataHandle& operand, int64 start_index, int64 limit_index, - int64 stride, int64 dimno) { - return builder_.SliceInDim(operand, start_index, limit_index, stride, dimno); +LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, + int64 start_index, + int64 limit_index, int64 stride, + int64 dimno) { + return builder_.SliceInDim(operand.op(), start_index, limit_index, stride, + dimno); } -ComputationDataHandle LocalComputationBuilder::DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, +LocalOp LocalComputationBuilder::DynamicSlice( + const LocalOp& operand, const LocalOp& start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - return builder_.DynamicSlice(operand, start_indices, slice_sizes); + return builder_.DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } -ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices) { - return builder_.DynamicUpdateSlice(operand, update, start_indices); +LocalOp LocalComputationBuilder::DynamicUpdateSlice( + const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices) { + return builder_.DynamicUpdateSlice(operand.op(), update.op(), + start_indices.op()); } -ComputationDataHandle LocalComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension) { - return builder_.ConcatInDim(operands, dimension); +LocalOp LocalComputationBuilder::ConcatInDim( + tensorflow::gtl::ArraySlice operands, int64 dimension) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.ConcatInDim(xla_ops, dimension); } -ComputationDataHandle -LocalComputationBuilder::SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, +LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter) { + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter) { return builder_.SelectAndScatterWithGeneralPadding( - operand, select.computation(), window_dimensions, window_strides, padding, - source, init_value, scatter.computation()); + operand.op(), select.computation(), window_dimensions, window_strides, + padding, source.op(), init_value.op(), scatter.computation()); } -ComputationDataHandle LocalComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { - return builder_.Tuple(elements); +LocalOp LocalComputationBuilder::Tuple( + tensorflow::gtl::ArraySlice elements) { + std::vector xla_ops; + xla_ops.reserve(elements.size()); + for (const auto& op : elements) { + xla_ops.push_back(op.op()); + } + + return builder_.Tuple(xla_ops); } -ComputationDataHandle LocalComputationBuilder::GetTupleElement( - const ComputationDataHandle& tuple_data, int64 index) { - return builder_.GetTupleElement(tuple_data, index); +LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, + int64 index) { + return builder_.GetTupleElement(tuple_data.op(), index); } -ComputationDataHandle LocalComputationBuilder::Dot( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - return builder_.Dot(lhs, rhs); +LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { + return builder_.Dot(lhs.op(), rhs.op()); } -ComputationDataHandle LocalComputationBuilder::DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::DotGeneral( + const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return builder_.DotGeneral(lhs, rhs, dimension_numbers); + return builder_.DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, + return builder_.ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, + padding, lhs_dilation, rhs_dilation, dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvertElementType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - return builder_.ConvertElementType(operand, new_element_type); +LocalOp LocalComputationBuilder::ConvertElementType( + const LocalOp& operand, PrimitiveType new_element_type) { + return builder_.ConvertElementType(operand.op(), new_element_type); } -ComputationDataHandle LocalComputationBuilder::Call( +LocalOp LocalComputationBuilder::Call( const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands) { - return builder_.Call(local_computation.computation(), operands); + tensorflow::gtl::ArraySlice operands) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.Call(local_computation.computation(), xla_ops); } -ComputationDataHandle LocalComputationBuilder::Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation) { - return builder_.Transpose(operand, permutation); +LocalOp LocalComputationBuilder::Transpose( + const LocalOp& operand, tensorflow::gtl::ArraySlice permutation) { + return builder_.Transpose(operand.op(), permutation); } -ComputationDataHandle LocalComputationBuilder::Rev( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - return builder_.Rev(operand, dimensions); +LocalOp LocalComputationBuilder::Rev( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return builder_.Rev(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, +LocalOp LocalComputationBuilder::Map( + tensorflow::gtl::ArraySlice operands, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { - return builder_.Map(operands, local_computation.computation(), dimensions, - static_operands); + tensorflow::gtl::ArraySlice static_operands) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + + std::vector static_xla_ops; + static_xla_ops.reserve(static_operands.size()); + for (const auto& op : static_operands) { + static_xla_ops.push_back(op.op()); + } + + return builder_.Map(xla_ops, local_computation.computation(), dimensions, + static_xla_ops); } -ComputationDataHandle LocalComputationBuilder::Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::Reduce( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions_to_reduce) { - return builder_.Reduce(operand, init_value, local_computation.computation(), - dimensions_to_reduce); + return builder_.Reduce(operand.op(), init_value.op(), + local_computation.computation(), dimensions_to_reduce); } -ComputationDataHandle LocalComputationBuilder::ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { return builder_.ReduceWindowWithGeneralPadding( - operand, init_value, local_computation.computation(), window_dimensions, - window_strides, padding); + operand.op(), init_value.op(), local_computation.computation(), + window_dimensions, window_strides, padding); } -ComputationDataHandle LocalComputationBuilder::RngNormal( - const ComputationDataHandle& mu, const ComputationDataHandle& sigma, - const Shape& shape) { - return builder_.RngNormal(mu, sigma, shape); +LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, + const LocalOp& sigma, + const Shape& shape) { + return builder_.RngNormal(mu.op(), sigma.op(), shape); } -ComputationDataHandle LocalComputationBuilder::RngUniform( - const ComputationDataHandle& a, const ComputationDataHandle& b, - const Shape& shape) { - return builder_.RngUniform(a, b, shape); +LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, + const Shape& shape) { + return builder_.RngUniform(a.op(), b.op(), shape); } -ComputationDataHandle LocalComputationBuilder::While( - const LocalComputation& condition, const LocalComputation& body, - const ComputationDataHandle& init) { - return builder_.While(condition.computation(), body.computation(), init); +LocalOp LocalComputationBuilder::While(const LocalComputation& condition, + const LocalComputation& body, + const LocalOp& init) { + return builder_.While(condition.computation(), body.computation(), init.op()); } -ComputationDataHandle LocalComputationBuilder::Conditional( - const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, +LocalOp LocalComputationBuilder::Conditional( + const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, const LocalOp& false_operand, const LocalComputation& false_computation) { - return builder_.Conditional(predicate, true_operand, - true_computation.computation(), false_operand, - false_computation.computation()); + return builder_.Conditional( + predicate.op(), true_operand.op(), true_computation.computation(), + false_operand.op(), false_computation.computation()); } -StatusOr LocalComputationBuilder::IsConstant( - const ComputationDataHandle& operand, int64 num_parameters) { - return builder_.IsConstant(operand, num_parameters); +StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { + return builder_.IsConstant(operand.op()); } -StatusOr> LocalComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters) { - return builder_.ComputeConstant(operand, output_layout, parameters); +StatusOr LocalComputationBuilder::BuildConstantSubGraph( + const LocalOp& operand) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, + builder_.BuildConstantSubGraph(operand.op())); + return new LocalComputation(std::move(computation)); } #define _FORWARD(method_name, return_sig, args_sig, args) \ @@ -534,23 +555,19 @@ StatusOr> LocalComputationBuilder::ComputeConstant( return builder_.method_name args; \ } -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand), (operand)) - -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions), \ - (lhs, rhs, broadcast_dimensions)) - -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs), \ - (lhs, rhs, ehs)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op())) + +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions), \ + (lhs.op(), rhs.op(), broadcast_dimensions)) + +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs), \ + (lhs.op(), rhs.op(), ehs.op())) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 5ec097846a..a06b85b4ea 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -17,9 +17,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -97,25 +98,37 @@ class CompiledLocalComputation { std::unique_ptr executable_; }; -// Wraps a Computation produced by a LocalComputationBuilder. The +// Wraps a XlaComputation produced by a LocalComputationBuilder. The // Compile method compiles the computation to a (local) executable via // the client library's local client. This class is intended to be // made available to Python via SWIG. class LocalComputation { public: - LocalComputation(Computation computation); + LocalComputation(XlaComputation computation); StatusOr Compile( const std::vector& argument_shapes, const ExecutableBuildOptions* build_options); - const Computation& computation() const; + const XlaComputation& computation() const; // Returns the return-value shape for this computation. StatusOr GetReturnValueShape() const; private: - Computation computation_; + XlaComputation computation_; +}; + +// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended +// to be made available to Python via SWIG. +class LocalOp { + public: + LocalOp(const XlaOp& op); + + const XlaOp& op() const; + + private: + XlaOp op_; }; // Wraps the ComputationBuilder API in order to: @@ -135,166 +148,137 @@ class LocalComputationBuilder { // Returns an owned LocalComputation to the caller on success. StatusOr Build(); - ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, - const string& name); + LocalOp Parameter(int64 parameter_number, const Shape& shape, + const string& name); - std::unique_ptr GetShape(const ComputationDataHandle& operand); + std::unique_ptr GetShape(const LocalOp& operand); // Returns the shape of the current return value for the computation. StatusOr GetReturnValueShape(); - ComputationDataHandle Infeed(const Shape& shape); + LocalOp Infeed(const Shape& shape); - void Outfeed(const ComputationDataHandle& operand, const Shape& shape, + void Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config); - ComputationDataHandle ConstantLiteral(const Literal& literal); + LocalOp ConstantLiteral(const Literal& literal); - ComputationDataHandle Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + LocalOp Broadcast(const LocalOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); - ComputationDataHandle Pad(const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config); + LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, + const PaddingConfig& padding_config); - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + LocalOp Reshape(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); - ComputationDataHandle Collapse(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Collapse(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions); - ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); + LocalOp CrossReplicaSum(const LocalOp& operand); - ComputationDataHandle Slice(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + LocalOp Slice(const LocalOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); - ComputationDataHandle SliceInDim(const ComputationDataHandle& operand, - int64 start_index, int64 limit_index, - int64 stride, int64 dimno); + LocalOp SliceInDim(const LocalOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno); - ComputationDataHandle DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); - ComputationDataHandle DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices); + LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices); - ComputationDataHandle ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension); + LocalOp ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension); - ComputationDataHandle SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, + LocalOp SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter); + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter); - ComputationDataHandle Tuple( - tensorflow::gtl::ArraySlice elements); + LocalOp Tuple(tensorflow::gtl::ArraySlice elements); - ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, - int64 index); + LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index); - ComputationDataHandle Dot(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); + LocalOp Dot(const LocalOp& lhs, const LocalOp& rhs); - ComputationDataHandle DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers); + LocalOp DotGeneral(const LocalOp& lhs, const LocalOp& rhs, + const DotDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + LocalOp ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); + LocalOp ConvertElementType(const LocalOp& operand, + PrimitiveType new_element_type); - ComputationDataHandle Call( - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands); + LocalOp Call(const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice operands); - ComputationDataHandle Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation); + LocalOp Transpose(const LocalOp& operand, + tensorflow::gtl::ArraySlice permutation); - ComputationDataHandle Rev(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Rev(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions); - ComputationDataHandle Map( - tensorflow::gtl::ArraySlice operands, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands); + LocalOp Map(tensorflow::gtl::ArraySlice operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands); - ComputationDataHandle Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); - ComputationDataHandle ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, + LocalOp ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding); - ComputationDataHandle RngNormal(const ComputationDataHandle& mu, - const ComputationDataHandle& sigma, - const Shape& shape); + LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, + const Shape& shape); - ComputationDataHandle RngUniform(const ComputationDataHandle& a, - const ComputationDataHandle& b, - const Shape& shape); + LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); - ComputationDataHandle While(const LocalComputation& condition, - const LocalComputation& body, - const ComputationDataHandle& init); + LocalOp While(const LocalComputation& condition, const LocalComputation& body, + const LocalOp& init); - ComputationDataHandle Conditional(const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, - const LocalComputation& false_computation); + LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, + const LocalOp& false_operand, + const LocalComputation& false_computation); - StatusOr IsConstant(const ComputationDataHandle& operand, - int64 num_parameters); + StatusOr IsConstant(const LocalOp& operand); - StatusOr > ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand)) -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions)) +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions)) -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs)) +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs)) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) @@ -338,7 +322,7 @@ class LocalComputationBuilder { #undef _FORWARD_TRIOP private: - ComputationBuilder builder_; + XlaBuilder builder_; }; // Functions for freeing resources from the Python side. diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index b8cce5a5f7..04c56bbba9 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -22,9 +22,8 @@ limitations under the License. // // C++ Python // -------------------------------------+--------------------------------------- -// ComputationDataHandle <-> int // ArraySlice <- sequence of int -// ArraySlice <- sequence of int +// ArraySlice <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray // Shape -> pair holding (dtype, dimensions) @@ -91,12 +90,9 @@ limitations under the License. // One central reason for the Python-side indirection is that the // Python-side objects produced by the typemaps in this file are // further packaged up by xla_client before being passed on. For -// instance, xla_client wraps the long produced for a C++ -// ComputationDataHandle in a Python ComputationDataHandle proto, -// rather than exposing a raw long outside of the client. Similarly, -// the Python pair produced for a C++ Shape is further wrapped in a -// Python class (xla_client.Shape) so as not to expose the raw pair -// externally. +// instance, the Python pair produced for a C++ Shape is further +// wrapped in a Python class (xla_client.Shape) so as not to expose +// the raw pair externally. // // Other SWIG object wrappers (e.g. of LocalComputation) are further // wrapped by xla_client in order to set up a custom destructor that @@ -124,6 +120,7 @@ using namespace xla; using namespace xla::swig; namespace xla { + namespace swig { bool GetIntAttr(PyObject* o, const char* field, int64* result) { @@ -177,21 +174,6 @@ bool HandleStringAttribute(PyObject* o, tensorflow::ImportNumpy(); %} -// ComputationDataHandle - -%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { - const int64 handle = numpy::PyIntOrPyLongToLong($input); - if (handle == -1 && PyErr_Occurred()) { - SWIG_fail; - } - temp.set_handle(handle); - $1 = &temp; -} - -%typemap(out) ComputationDataHandle { - $result = numpy::LongToPyIntOrPyLong($1.handle()); -} - %typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); @@ -301,33 +283,23 @@ tensorflow::ImportNumpy(); $1 = temps; } -// ComputationDataHandle +// ArraySlice -%typemap(in) tensorflow::gtl::ArraySlice - (std::vector temps) { +%typemap(in) tensorflow::gtl::ArraySlice( + std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); SWIG_fail; } const int size = PySequence_Size($input); - temps.resize(size); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - PyObject* py_int = numpy::PyNumberToPyInt(o); - if (!py_int) { - PyErr_SetString( - PyExc_TypeError, - "Argument sequence element cannot be converted to int"); - SWIG_fail; - } - const int64 handle = numpy::PyIntOrPyLongToLong(py_int); - if (handle == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - Py_DECREF(o); + LocalOp* op; + if ((SWIG_ConvertPtr(o, (void**)&op, $descriptor(xla::swig::LocalOp*), + SWIG_POINTER_EXCEPTION)) == -1) { SWIG_fail; } - temps[i].set_handle(handle); - Py_DECREF(py_int); + temps.push_back(*op); Py_DECREF(o); } $1 = temps; @@ -934,6 +906,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputation; %unignore xla::swig::LocalComputation::Compile; %unignore xla::swig::LocalComputation::GetReturnValueShape; +%unignore xla::swig::LocalOp; %unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::Build; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index f6809b6b87..1d5b75d1be 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -335,20 +335,6 @@ def _wrap_shape(shape_info): return Shape.array_shape(dtype, dims) -def _wrap_data_handle(handle): - cdh = xla_data_pb2.ComputationDataHandle() - cdh.handle = handle - return cdh - - -def _unwrap_data_handle(handle_proto): - return handle_proto.handle - - -def _unwrap_data_handles(handle_protos): - return [_unwrap_data_handle(cdh) for cdh in handle_protos] - - def require_numpy_array_layout(value): if isinstance(value, tuple): return tuple(require_numpy_array_layout(x) for x in value) @@ -535,9 +521,9 @@ class ComputationBuilder(object): queue for subsequent use in the computation. Returns: - A ComputationDataHandle message. + A LocalOp. """ - return _wrap_data_handle(self._client.Infeed(shape)) + return self._client.Infeed(shape) def Outfeed(self, operand): """Enqueues an outfeed op onto the computation. @@ -545,9 +531,7 @@ class ComputationBuilder(object): Outfeed operations enqueue data, using the given operand, onto the XLA outfeed queue for subsequent dequeue via the client API. """ - self._client.Outfeed( - _unwrap_data_handle(operand), self.GetShape(operand), - ''.encode('utf-8')) + self._client.Outfeed(operand, self.GetShape(operand), ''.encode('utf-8')) def Constant(self, value): """Enqueues a constant op onto the computation. @@ -557,10 +541,10 @@ class ComputationBuilder(object): to one of the supported types. Returns: - A ComputationDataHandle message. + A LocalOp. """ value = require_numpy_array_layout(value) - return _wrap_data_handle(self._client.ConstantLiteral(value)) + return self._client.ConstantLiteral(value) def ConstantF32Scalar(self, value): """Convenience method to enqueue a scalar F32 constant op. @@ -569,7 +553,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float32)) @@ -580,7 +564,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float64)) @@ -591,7 +575,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int32)) @@ -602,7 +586,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int64)) @@ -613,7 +597,7 @@ class ComputationBuilder(object): value: a boolean value. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.bool)) @@ -629,15 +613,14 @@ class ComputationBuilder(object): parameters, use it for *all* parameters to avoid clashes. Returns: - A ComputationDataHandle message. + A LocalOp. """ if name is None: name = '' if parameter_num is None: parameter_num = next(self._parameter_numbering) - return _wrap_data_handle( - self._client.Parameter(parameter_num, shape, name.encode('utf8'))) + return self._client.Parameter(parameter_num, shape, name.encode('utf8')) def ParameterFromNumpy(self, value, name=None, parameter_num=None): """Enqueues a Parameter op onto the computation. @@ -649,7 +632,7 @@ class ComputationBuilder(object): parameter_num: as in ParameterWithShape. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.ParameterWithShape( Shape.from_pyval(value), name=name, parameter_num=parameter_num) @@ -658,14 +641,13 @@ class ComputationBuilder(object): """Enqueues a broadcast operation onto the computation. Args: - operand: the operand ComputationDataHandle to broadcast. + operand: the operand LocalOp to broadcast. sizes: an iterable of broadcast sizes. Returns: - A ComputationDataHandle representing the added broadcast op. + A LocalOp representing the added broadcast op. """ - return _wrap_data_handle( - self._client.Broadcast(_unwrap_data_handle(operand), sizes)) + return self._client.Broadcast(operand, sizes) def Concatenate(self, operands, dimension): """Enqueues a concatenate operation onto the computation. @@ -675,10 +657,9 @@ class ComputationBuilder(object): dimension: the dimension in which to perform the concatenation. Returns: - A ComputationDataHandle representing the added concatenate op. + A LocalOp representing the added concatenate op. """ - return _wrap_data_handle( - self._client.ConcatInDim(_unwrap_data_handles(operands), dimension)) + return self._client.ConcatInDim(operands, dimension) def ConvertElementType(self, operand, new_element_type): """Enqueues an element type conversion operation onto the computation. @@ -688,14 +669,12 @@ class ComputationBuilder(object): new_element_type: the target primitive type. Returns: - A ComputationDataHandle representing the added conversion op. + A LocalOp representing the added conversion op. """ - return _wrap_data_handle( - self._client.ConvertElementType( - _unwrap_data_handle(operand), new_element_type)) + return self._client.ConvertElementType(operand, new_element_type) def GetShape(self, operand): - return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) + return _wrap_shape(self._client.GetShape(operand)) def GetReturnValueShape(self): return _wrap_shape(self._client.GetReturnValueShape()) @@ -707,40 +686,35 @@ class ComputationBuilder(object): """Enqueues a Pad operation onto the computation. Args: - operand: ComputationDataHandle representing the array to pad. - padding_value: ComputationDataHandle representing the scalar pad value. + operand: LocalOp representing the array to pad. + padding_value: LocalOp representing the scalar pad value. padding_config: either an xla_data_pb2.PaddingConfig or a list of integer triples (edge_padding_low, edge_padding_high, interior_padding) representing the configuration of the padding operation. Returns: - A ComputationDataHandle representing the added Pad op. + A LocalOp representing the added Pad op. """ if not isinstance(padding_config, xla_data_pb2.PaddingConfig): padding_config = GetPaddingConfigFromTriples(padding_config) - return _wrap_data_handle( - self._client.Pad(_unwrap_data_handle(operand), - _unwrap_data_handle(padding_value), - padding_config)) + return self._client.Pad(operand, padding_value, padding_config) def Reshape(self, operand, dimensions, new_sizes): """Enqueues a reshape op onto the computation. Args: - operand: ComputationDataHandle representing the array to be reshaped. + operand: LocalOp representing the array to be reshaped. dimensions: sequence of integers encoding the order in which dimensions are collapsed or None, in which case dimensions are flattened in order. new_sizes: sequence of integers encoding the new dimension sizes (shape). Returns: - A ComputationDataHandle representing the added Reshape op. + A LocalOp representing the added Reshape op. """ if dimensions is None: ndim = len(self.GetShape(operand).dimensions()) dimensions = tuple(range(ndim)) - return _wrap_data_handle( - self._client.Reshape( - _unwrap_data_handle(operand), dimensions, new_sizes)) + return self._client.Reshape(operand, dimensions, new_sizes) def CrossReplicaSum(self, operand): """CrossReplicaSum op. @@ -749,67 +723,56 @@ class ComputationBuilder(object): operand: the operand to sum across replica instances. Returns: - A ComputationDataHandle that has the sum of the value among all replicas. + A LocalOp that has the sum of the value among all replicas. """ - return _wrap_data_handle( - self._client.CrossReplicaSum(_unwrap_data_handle(operand))) + return self._client.CrossReplicaSum(operand) def Collapse(self, operand, dimensions): """Collapse op.""" - return _wrap_data_handle( - self._client.Collapse(_unwrap_data_handle(operand), dimensions)) + return self._client.Collapse(operand, dimensions) def Trans(self, operand): """Specialized matrix transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), [1, 0])) + return self._client.Transpose(operand, [1, 0]) def Transpose(self, operand, permutation): """Transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), permutation)) + return self._client.Transpose(operand, permutation) def Rev(self, operand, dimensions): """Rev op.""" - return _wrap_data_handle( - self._client.Rev(_unwrap_data_handle(operand), dimensions)) + return self._client.Rev(operand, dimensions) def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin """Clamp op.""" - return _wrap_data_handle( - self._client.Clamp(_unwrap_data_handle(min), - _unwrap_data_handle(operand), - _unwrap_data_handle(max))) + return self._client.Clamp(min, operand, max) def SelectAndScatter(self, operand, select, window_dimensions, window_strides, padding, source, init_value, scatter): """Select and scatter op, used by the gradient of ReduceWindow. Args: - operand: ComputationDataHandle for array of dimension N and type T over + operand: LocalOp for array of dimension N and type T over which the windows slide. select: Computation of type (T, T) -> Pred to apply to the elements of each window to indicate which element is selected. window_dimensions: sequence of N integers for dimensions of the window. window_strides: sequence of N integers for the strides of the window. padding: PaddingType representing either 'SAME' or 'VALID ' padding. - source: ComputationDataHandle for array of type T with values to scatter. - init_value: ComputationDataHandle of scalar type T for initial out value. + source: LocalOp for array of type T with values to scatter. + init_value: LocalOp of scalar type T for initial out value. scatter: Computation of type (T, T) -> T to apply to each scatter source element with its destination element. Returns: - A ComputationDataHandle representing the added SelectAndScatter op. + A LocalOp representing the added SelectAndScatter op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.SelectAndScatterWithGeneralPadding( - _unwrap_data_handle(operand), select.c_local_computation, - window_dimensions, window_strides, pads, - _unwrap_data_handle(source), _unwrap_data_handle(init_value), - scatter.c_local_computation)) + return self._client.SelectAndScatterWithGeneralPadding( + operand, select.c_local_computation, window_dimensions, window_strides, + pads, source, init_value, scatter.c_local_computation) def Select(self, pred, on_true, on_false): """Element-wise selection op. @@ -817,17 +780,13 @@ class ComputationBuilder(object): Constructs an output array from elements of two input arrays, based on the values of a predicate array. """ - return _wrap_data_handle( - self._client.Select( - _unwrap_data_handle(pred), - _unwrap_data_handle(on_true), - _unwrap_data_handle(on_false))) + return self._client.Select(pred, on_true, on_false) def Slice(self, operand, start_indices, limit_indices, strides=None): """Enqueues a slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. + operand: LocalOp for the N dimensional array to be sliced. start_indices: iterable of N integers containing the starting indices of the slice for each dimension. limit_indices: iterable of N integers containing the ending indices @@ -836,207 +795,177 @@ class ComputationBuilder(object): each dimension. Returns: - A ComputationDataHandle representing the added Slice op. + A LocalOp representing the added Slice op. """ if strides is None: start_indices = list(start_indices) strides = [1] * len(start_indices) - return _wrap_data_handle( - self._client.Slice( - _unwrap_data_handle(operand), start_indices, limit_indices, - strides)) + return self._client.Slice(operand, start_indices, limit_indices, strides) def SliceInDim(self, operand, start_index, limit_index, stride, dimno): """Enqueues a slice-in-dimension operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. + operand: LocalOp for the N dimensional array to be sliced. start_index: an integer containing the start index of the slice. limit_index: an integer containing the end index of the slice. stride: an integer containing the stride size for the slice. dimno: an integer indicating the dimension along which to slice. Returns: - A ComputationDataHandle representing the added Slice op. + A LocalOp representing the added Slice op. """ - return _wrap_data_handle( - self._client.SliceInDim( - _unwrap_data_handle(operand), start_index, limit_index, stride, - dimno)) + return self._client.SliceInDim(operand, start_index, limit_index, stride, + dimno) def DynamicSlice(self, operand, start_indices, slice_sizes): """Enqueues a slice op with dynamic start indices onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. - start_indices: ComputationDataHandle for the 1D array of N integers + operand: LocalOp for the N dimensional array to be sliced. + start_indices: LocalOp for the 1D array of N integers containing the starting indices of the slice. slice_sizes: iterable of N integers containing the slice sizes in each dimension. Returns: - A ComputationDataHandle representing the added DynamicSlice op. + A LocalOp representing the added DynamicSlice op. """ - return _wrap_data_handle( - self._client.DynamicSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(start_indices), - slice_sizes)) + return self._client.DynamicSlice(operand, start_indices, slice_sizes) def DynamicUpdateSlice(self, operand, update, start_indices): """Enqueues a dynamic update slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be updated. + operand: LocalOp for the N dimensional array to be updated. update: N dimensional array comprising the slice update. start_indices: Rank-1 array of N integers comprising the starting indices of the slice along each dimension. Returns: - A ComputationDataHandle representing the added DynamicUpdateSlice op. + A LocalOp representing the added DynamicUpdateSlice op. """ - return _wrap_data_handle( - self._client.DynamicUpdateSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(update), - _unwrap_data_handle(start_indices))) + return self._client.DynamicUpdateSlice(operand, update, start_indices) def Tuple(self, *ops): """Enqueues a tuple operation onto the computation. Args: - ops: a sequence of tuple operands (each a ComputationDataHandle). + ops: a sequence of tuple operands (each a LocalOp). Returns: - A ComputationDataHandle representing the added Tuple op. + A LocalOp representing the added Tuple op. """ - return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops))) + return self._client.Tuple(ops) def GetTupleElement(self, tup, index): """Enqueues a 'get tuple element' operation onto the computation. Args: - tup: the tuple operand (a ComputationDataHandle). + tup: the tuple operand (a LocalOp). index: numeric index to select from the tuple. Returns: - A ComputationDataHandle representing the added GetTupleElement op. + A LocalOp representing the added GetTupleElement op. """ - return _wrap_data_handle( - self._client.GetTupleElement(_unwrap_data_handle(tup), index)) + return self._client.GetTupleElement(tup, index) def Call(self, computation_to_apply, operands): """Enqueues a call operation onto the computation. Args: computation_to_apply: a Computation object. - operands: an iterable of ComputationDataHandle. The number and types of + operands: an iterable of LocalOp. The number and types of operands must match the arity of computation_to_apply. Returns: - A ComputationDataHandle representing the added call op. + A LocalOp representing the added call op. """ - return _wrap_data_handle( - self._client.Call(computation_to_apply.c_local_computation, - _unwrap_data_handles(operands))) + return self._client.Call(computation_to_apply.c_local_computation, operands) def Map(self, operands, computation_to_apply, dimensions, static_operands=()): """Enqueues a map operation onto the computation. Args: - operands: an iterable of ComputationDataHandle. + operands: an iterable of LocalOp. computation_to_apply: a Computation object. dimensions: dimensions over which to apply map the function. static_operands: auxiliary arguments passed to the applied computation. Returns: - A ComputationDataHandle representing the added Map op. + A LocalOp representing the added Map op. """ - return _wrap_data_handle( - self._client.Map( - _unwrap_data_handles(operands), - computation_to_apply.c_local_computation, - dimensions, - _unwrap_data_handles(static_operands))) + return self._client.Map(operands, computation_to_apply.c_local_computation, + dimensions, static_operands) def Reduce(self, operand, init_value, computation_to_apply, dimensions): """Enqueues a reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a Computation object - binary reduction function. dimensions: sequence of dimensions (integers) to reduce on. Returns: - A ComputationDataHandle representing the added Reduce op. + A LocalOp representing the added Reduce op. """ - return _wrap_data_handle( - self._client.Reduce( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - dimensions)) + return self._client.Reduce(operand, init_value, + computation_to_apply.c_local_computation, + dimensions) def ReduceWindow(self, operand, init_value, computation_to_apply, window_dimensions, window_strides, padding): """Enqueues a windowed reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a binary reduction function (Computation). window_dimensions: dimensions of window (sequence of integers). window_strides: strides for window (sequence of integers). padding: PaddingType representing either 'SAME' or 'VALID' padding. Returns: - A ComputationDataHandle representing the added ReduceWindow op. + A LocalOp representing the added ReduceWindow op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.ReduceWindowWithGeneralPadding( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - window_dimensions, window_strides, pads)) + return self._client.ReduceWindowWithGeneralPadding( + operand, init_value, computation_to_apply.c_local_computation, + window_dimensions, window_strides, pads) def RngNormal(self, mu, sigma, dims): """Enqueues an RngNormal operation onto the computation. Args: - mu: A ComputationDataHandle to an F32 scalar specifying the mean. - sigma: A ComputationDataHandle to an F32 scalar specifying the standard + mu: A LocalOp to an F32 scalar specifying the mean. + sigma: A LocalOp to an F32 scalar specifying the standard deviation. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of F32 values. + Returns: a LocalOp to the generated array of F32 values. """ shape = Shape.array_shape(self.GetShape(mu).element_type(), dims) - return _wrap_data_handle( - self._client.RngNormal( - _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape)) + return self._client.RngNormal(mu, sigma, shape) def RngUniform(self, a, b, dims): """Enqueues an RngUniform operation onto the computation. Args: - a: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + a: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of b) specifying the low end of the interval [a, b) over which values are generated. - b: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + b: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of a) specifying the high end of the interval [a, b) over which values are generated. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of values with the + Returns: a LocalOp to the generated array of values with the same numeric type (F32, S32, or U32) as the arguments a and b. """ shape = Shape.array_shape(self.GetShape(a).element_type(), dims) - return _wrap_data_handle( - self._client.RngUniform( - _unwrap_data_handle(a), _unwrap_data_handle(b), shape)) + return self._client.RngUniform(a, b, shape) def While(self, cond, body, init): """Enqueues a While operation onto the computation. @@ -1044,112 +973,105 @@ class ComputationBuilder(object): Args: cond: a Computation for the loop condition, which has type T -> PRED body: a Computation for the loop body, which has type T -> T - init: a ComputationDataHandle for the initial parameter, which has type T + init: a LocalOp for the initial parameter, which has type T - Returns: a ComputationDataHandle representing the While operation. + Returns: a LocalOp representing the While operation. """ - return _wrap_data_handle( - self._client.While(cond.c_local_computation, - body.c_local_computation, - _unwrap_data_handle(init))) + return self._client.While(cond.c_local_computation, + body.c_local_computation, init) def Conditional(self, pred, true_operand, true_computation, false_operand, false_computation): """Enqueues a Conditional operation onto the computation. Args: - predicate: a ComputationDataHandle to test, which has scalar type PRED - true_operand: a ComputationDataHandle of type T_0 + predicate: a LocalOp to test, which has scalar type PRED + true_operand: a LocalOp of type T_0 true_computation: a Computation to apply to true_operand, type T_0 -> S false_operand: a ComputationDatahandle of type T_1 false_computation: a Computation to apply to false_operand, type T_1 -> S - Returns: a ComputationDataHandle representing the Conditional operation. + Returns: a LocalOp representing the Conditional operation. """ - return _wrap_data_handle( - self._client.Conditional( - _unwrap_data_handle(pred), _unwrap_data_handle(true_operand), - true_computation.c_local_computation, - _unwrap_data_handle(false_operand), - false_computation.c_local_computation)) + return self._client.Conditional( + pred, true_operand, true_computation.c_local_computation, false_operand, + false_computation.c_local_computation) - def IsConstant(self, operand, num_parameters=0): - """Enqueues an IsConstant operation onto the computation. + def IsConstant(self, operand): + """Checks whether the given operand is a compile-time constant. Args: operand: a ComputationDataHandle to test. - num_parameters: optional int, number of computation parameters to treat as - constant (default 0). Returns: bool indicating whether `operand` is a compile-time constant, - meaning its value does not depend on parameters with index greater than or - equal to `num_parameters`. + meaning its value does not depend on any parametersor, or on stateful + operators such as `RngNormal` or `Infeed`. + """ + return self._client.IsConstant(operand) + + def BuildConstantSubGraph(self, operand): + """Builds a constant sub graph. + + Args: + operand: a LocalOp to test. + Returns: a LocalComputation that is rooted on the given `operand` which is a + compile-time constant. """ - return self._client.IsConstant(_unwrap_data_handle(operand), num_parameters) + return self._client.BuildConstantSubGraph(operand) def Dot(self, lhs, rhs): """Enqueues a dot operation onto the computation. Args: - lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array. - rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array. + lhs: LocalOp for the rank 1 or rank 2 left-hand-side array. + rhs: LocalOp for the rank 1 or rank 2 right-hand-side array. - Returns: a ComputationDataHandle representing the Dot operation. + Returns: a LocalOp representing the Dot operation. """ - return _wrap_data_handle( - self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) + return self._client.Dot(lhs, rhs) def DotGeneral(self, lhs, rhs, dimension_numbers): """Enqueues a general dot operation onto the computation. Args: - lhs: ComputationDataHandle for the left-hand-side array. - rhs: ComputationDataHandle for the right-hand-side array. + lhs: LocalOp for the left-hand-side array. + rhs: LocalOp for the right-hand-side array. dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of integers representing the dimensions to treat as contracting dimensions and batch dimensions on each input operand. - Returns: a ComputationDataHandle representing the DotGeneral operation. + Returns: a LocalOp representing the DotGeneral operation. """ if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers): dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) - return _wrap_data_handle( - self._client.DotGeneral( - _unwrap_data_handle(lhs), _unwrap_data_handle(rhs), - dimension_numbers)) + return self._client.DotGeneral(lhs, rhs, dimension_numbers) def Conv(self, lhs, rhs, window_strides, padding): """Enqueues a Conv operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of integer kernel strides. padding: PaddingType representing either 'SAME' or 'VALID' padding. - Returns: a ComputationDataHandle representing the Conv operation. + Returns: a LocalOp representing the Conv operation. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(lhs).dimensions()[2:], self.GetShape(rhs).dimensions()[2:], window_strides) dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - pads, - (), - (), - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (), + (), dimension_numbers) def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation): """Enqueues a ConvWithGeneralPadding operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of kernel strides. padding: length-N array-like of pairs of integers of (low, high) padding. lhs_dilation: length-N array-like of dilation factors. @@ -1159,14 +1081,9 @@ class ComputationBuilder(object): A ComputationdataHandle representing the added ConvWithGeneralPadding op. """ dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - padding, - lhs_dilation, - rhs_dilation, - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers) def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" @@ -1196,15 +1113,14 @@ def _forward_methods_to_local_builder(): """Generate a forwarding method that wraps/unwraps data handles.""" def forward(self, *args, **kwargs): - unwrapped_args = [_unwrap_data_handle(arg) for arg in args] + arg_list = list(args) - if is_binop and len(unwrapped_args) < 3: - unwrapped_args.append(kwargs.get('broadcast_dimensions', ())) + if is_binop and len(arg_list) < 3: + arg_list.append(kwargs.get('broadcast_dimensions', ())) - return _wrap_data_handle( - target_method( - self._client, # pylint: disable=protected-access - *unwrapped_args)) + return target_method( + self._client, # pylint: disable=protected-access + *arg_list) return forward -- GitLab From 5ca373b4b64167f8b0fcab96d7d2e7886ea31b6a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 12:28:42 -0700 Subject: [PATCH 036/755] Some fixes to support another TF graph: 1. Fix ResolveBatchNormalization to avoid deleting arrays that may still be used. 2. Correctly count the number of ops using a given array, even when some ops use the same array as more than one of their inputs. 3. In PropagateFixedSizes for Concatenation ops, when resolving a -1 wildcard to a fixed value, we were doing so in a local 'axis' variable without actually updating op->axis! The resulting -1 value still in op->axis tripped runtime code, causing the concatenation to misbehave during inference. PiperOrigin-RevId: 195454037 --- .../graph_transformations/propagate_fixed_sizes.cc | 11 +++++------ .../resolve_batch_normalization.cc | 6 +++--- tensorflow/contrib/lite/toco/tooling_util.cc | 4 ++++ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 4923f83d91..b02b02c5be 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -670,8 +670,7 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { const auto& first_input_array = model->GetArray(op->inputs[0]); output_array.copy_shape(first_input_array.shape()); // Negative axis means the count starts at the back of the dims(). - int axis = op->axis; - if (axis < 0) axis += first_input_array.shape().dims().size(); + if (op->axis < 0) op->axis += first_input_array.shape().dims().size(); // Determine the concat size, and enfore that all inputs have // the same dimensions count. int concat_size = 0; @@ -684,14 +683,14 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { CHECK_EQ(input_array.shape().dimensions_count(), output_array.shape().dimensions_count()); const std::vector& input_dims = input_array.shape().dims(); - CHECK_LT(axis, input_dims.size()); - concat_size += input_dims[axis]; + CHECK_LT(op->axis, input_dims.size()); + concat_size += input_dims[op->axis]; } // Write out the concat_size on the output array shape. auto& output_shape = *output_array.mutable_shape(); auto& output_dims = *output_shape.mutable_dims(); - CHECK_LT(axis, output_shape.dimensions_count()); - output_dims[axis] = concat_size; + CHECK_LT(op->axis, output_shape.dimensions_count()); + output_dims[op->axis] = concat_size; } void ProcessRangeOperator(Model* model, RangeOperator* op) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc index 2b3ee36ad1..8f2c1f8162 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc @@ -134,9 +134,9 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { } // Remove the old param arrays - model->EraseArray(bn_op->inputs[1]); - model->EraseArray(bn_op->inputs[2]); - model->EraseArray(bn_op->inputs[3]); + DeleteArrayIfUsedOnce(bn_op->inputs[1], model); + DeleteArrayIfUsedOnce(bn_op->inputs[2], model); + DeleteArrayIfUsedOnce(bn_op->inputs[3], model); // Remove the old operator DCHECK_EQ(bn_it->get(), bn_op); diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 86ee1f3761..341d45e753 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -143,6 +143,10 @@ int CountOpsWithInput(const Model& model, const string& array_name) { for (auto& input : op->inputs) { if (input == array_name) { count++; + // Breaking here is important: some graphs have ops that use the + // same array as more than one of their inputs, and in that case + // we want it counted only once. + break; } } } -- GitLab From 67b5e724121c5874425936fe01318642508d9975 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 4 May 2018 14:40:02 -0700 Subject: [PATCH 037/755] [XLA:GPU] Mark floating-point division as an inexpensive op. "Expensive" really means "so expensive you'd choose not to fuse in order to avoid doing it twice". FP division definitely isn't that expensive. PiperOrigin-RevId: 195473524 --- .../xla/service/gpu/instruction_fusion.cc | 13 +++++ .../xla/service/gpu/instruction_fusion.h | 2 + .../service/gpu/instruction_fusion_test.cc | 56 +++++++++++++++++++ 3 files changed, 71 insertions(+) diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 85ecbe8fdb..c5eb721185 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -48,6 +48,19 @@ bool IsFusile(const HloInstruction& hlo) { } // namespace +/*static*/ bool GpuInstructionFusion::IsExpensive( + const HloInstruction& instruction) { + switch (instruction.opcode()) { + // We say that floating-point division is cheap on the GPU. + case HloOpcode::kDivide: + return !ShapeUtil::ElementIsFloating(instruction.shape()) && + InstructionFusion::IsExpensive(instruction); + + default: + return InstructionFusion::IsExpensive(instruction); + } +} + bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index bb2990e6df..9fb06b0a24 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -27,6 +27,8 @@ class GpuInstructionFusion : public InstructionFusion { explicit GpuInstructionFusion(bool may_duplicate) : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {} + static bool IsExpensive(const HloInstruction& instruction); + bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; HloInstruction::FusionKind ChooseKind( diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 4b231c449f..6c9a805ad6 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -253,5 +253,61 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { op::Dot(op::Parameter(), op::Transpose(op::Parameter())))); } +// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is +// duplicated and fused into both reduces. +TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { + auto module = tools::Parse(R"( + HloModule test_module + Add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = f32[] constant(0) + one = f32[] constant(1) + p0 = f32[100] parameter(0) + recip = f32[100] divide(one, p0) + sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT root = (f32[], f32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion())); +} + +// Compute sum(100/p0), where p0 has type s32, twice. Check that the division +// is *not* duplicated and fused into both reduces, because we say that integer +// division is not cheap. +TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { + auto module = tools::Parse(R"( + HloModule test_module + Add { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = s32[] constant(0) + one_hundred = s32[] constant(100) + p0 = s32[100] parameter(0) + recip = s32[100] divide(one_hundred, p0) + sum1 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT mul = (s32[], s32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + } // namespace gpu } // namespace xla -- GitLab From 4d0388d22060a61f40965127c153c681b2412c50 Mon Sep 17 00:00:00 2001 From: James Qin Date: Fri, 4 May 2018 14:53:58 -0700 Subject: [PATCH 038/755] Fix build failure for macos py3 PiperOrigin-RevId: 195475780 --- tensorflow/python/debug/examples/debug_tflearn_iris.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/debug/examples/debug_tflearn_iris.py b/tensorflow/python/debug/examples/debug_tflearn_iris.py index 00090b21fe..7cbaae46b4 100644 --- a/tensorflow/python/debug/examples/debug_tflearn_iris.py +++ b/tensorflow/python/debug/examples/debug_tflearn_iris.py @@ -140,7 +140,7 @@ def main(_): # Make predictions, using tfdbg hook. predict_results = classifier.predict(test_input_fn, hooks=hooks) - print("A prediction result: %s" % predict_results.next()) + print("A prediction result: %s" % next(predict_results)) if __name__ == "__main__": -- GitLab From cb1775e9525ae621d23708a3d64a6cad897be95e Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 4 May 2018 15:14:00 -0700 Subject: [PATCH 039/755] Identify and prune nodes that can never be executed PiperOrigin-RevId: 195478951 --- tensorflow/core/grappler/optimizers/BUILD | 1 + .../grappler/optimizers/loop_optimizer.cc | 140 ++++++++++++++++++ .../core/grappler/optimizers/loop_optimizer.h | 1 + .../optimizers/loop_optimizer_test.cc | 107 +++++++++++++ tensorflow/core/grappler/utils.h | 4 +- 5 files changed, 251 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 5b5e1e024e..900dfa95c5 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -604,6 +604,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index 5adc5b9227..7d3520febc 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" @@ -504,6 +505,140 @@ Status RemoveStackOps(const std::unordered_set& nodes_to_preserve, return Status::OK(); } +Status RemoveDeadBranches(const std::unordered_set& nodes_to_preserve, + GraphDef* optimized_graph) { + std::unordered_set dead_nodes; + std::unordered_map> dead_merge_inputs; + // TODO(bsteiner): also rewrite switches as identity. For now we just record + // them + std::unordered_set + identity_switches; + + GraphView view(optimized_graph); + for (const NodeDef& node : optimized_graph->node()) { + if (!IsSwitch(node)) { + continue; + } + if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) { + continue; + } + GraphView::InputPort ctrl_port(&node, 1); + GraphView::OutputPort ctrl_node = view.GetRegularFanin(ctrl_port); + if (!IsConstant(*ctrl_node.node)) { + continue; + } + Tensor selector; + CHECK(selector.FromProto(ctrl_node.node->attr().at("value").tensor())); + const int dead_fanout = selector.scalar()() ? 0 : 1; + GraphView::OutputPort dead(const_cast(&node), dead_fanout); + identity_switches.insert(dead); + + SetVector zombie_inputs; + for (const GraphView::InputPort& port : view.GetFanout(dead)) { + if (dead_nodes.find(port.node) == dead_nodes.end()) { + zombie_inputs.PushBack(port); + } + } + // If we encounter a single node that must be preserved in the fanout of the + // switch node we need to preserve the entire switch fanout: we therefore + // work on a local copy that only gets committed to the master copy once the + // whole fanout has been explored. + std::unordered_set local_dead_nodes = dead_nodes; + std::unordered_map> local_dead_merge_inputs = + dead_merge_inputs; + bool found_node_to_preserve = false; + while (!found_node_to_preserve && !zombie_inputs.Empty()) { + GraphView::InputPort dead = zombie_inputs.PopBack(); + if (nodes_to_preserve.find(dead.node->name()) != + nodes_to_preserve.end()) { + found_node_to_preserve = true; + break; + } + + if (local_dead_nodes.find(dead.node) != local_dead_nodes.end()) { + continue; + } + + if (IsMerge(*dead.node)) { + const int fanout = dead.node->attr().at("N").i(); + if (fanout > 2) { + // This never happens in practice, so we'll just skip these to + // simplify the code for now. + found_node_to_preserve = true; + break; + } + GraphView::OutputPort value_index(dead.node, 1); + const std::unordered_set& + index_fanout = view.GetFanout(value_index); + if (!index_fanout.empty()) { + // The 2nd output (that indicates which input is propagated) is + // connected. This never happens in practice, so we'll just skip this + // case to simplify the code for now. + found_node_to_preserve = true; + break; + } + + bool fully_dead = false; + if (dead.port_id < 0) { + // If the control dependency never gets triggered the merge will also + // never get triggered. + local_dead_nodes.insert(dead.node); + fully_dead = true; + } else { + local_dead_merge_inputs[dead.node].insert(dead.port_id); + if (local_dead_merge_inputs[dead.node].size() == + dead.node->attr().at("N").i()) { + fully_dead = true; + } + if (fully_dead) { + local_dead_nodes.insert(dead.node); + for (const GraphView::InputPort& port : + view.GetFanouts(*dead.node, true)) { + zombie_inputs.PushBack(port); + } + } + } + } else { + if (local_dead_nodes.insert(dead.node).second) { + for (const GraphView::InputPort& dead_fanout : + view.GetFanouts(*dead.node, true)) { + zombie_inputs.PushBack(dead_fanout); + } + } + } + } + if (!found_node_to_preserve) { + std::swap(dead_nodes, local_dead_nodes); + std::swap(dead_merge_inputs, local_dead_merge_inputs); + } + } + + int last = optimized_graph->node_size() - 1; + for (int i = optimized_graph->node_size() - 1; i >= 0; --i) { + NodeDef* node = optimized_graph->mutable_node(i); + if (dead_nodes.find(node) != dead_nodes.end()) { + optimized_graph->mutable_node()->SwapElements(i, last); + last--; + } + } + optimized_graph->mutable_node()->DeleteSubrange(last + 1, dead_nodes.size()); + + for (const auto& itr : dead_merge_inputs) { + NodeDef* dead_node = itr.first; + if (dead_nodes.find(dead_node) != dead_nodes.end()) { + // The node has been pruned since all its inputs are dead. + continue; + } + const std::set& dead_inputs = itr.second; + for (int index : dead_inputs) { + dead_node->mutable_input()->DeleteSubrange(index, 1); + } + dead_node->set_op("Identity"); + dead_node->mutable_attr()->erase("N"); + } + return Status::OK(); +} + } // namespace Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, @@ -517,6 +652,11 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, if (options_.enable_stack_push_removal) { TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph)); } + if (opt_level_ == RewriterConfig::AGGRESSIVE && + options_.enable_dead_branch_removal) { + TF_RETURN_IF_ERROR( + RemoveDeadBranches(item.NodesToPreserve(), optimized_graph)); + } return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.h b/tensorflow/core/grappler/optimizers/loop_optimizer.h index 764506f7c1..85b8e65543 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.h +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.h @@ -54,6 +54,7 @@ class LoopOptimizer : public GraphOptimizer { struct LoopOptimizerOptions { bool enable_loop_invariant_node_motion = false; bool enable_stack_push_removal = true; + bool enable_dead_branch_removal = true; static LoopOptimizerOptions Default(RewriterConfig::Toggle opt_level) { LoopOptimizerOptions options; diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc index 10ec544424..6fd177b710 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc @@ -589,5 +589,112 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) { } } +TEST_F(LoopOptimizerTest, RemoveDeadBranches) { + Scope scope = Scope::NewRootScope(); + Output v_in = ops::Variable(scope.WithOpName("v_in"), {3}, DT_FLOAT); + + Output ctrl1 = ops::Const(scope.WithOpName("ctrl1"), false, TensorShape({})); + ops::Switch s1(scope.WithOpName("switch1"), v_in, ctrl1); + Output square1 = ops::Square(scope.WithOpName("square1"), s1.output_false); + Output sqrt1 = ops::Sqrt(scope.WithOpName("sqrt1"), s1.output_true); + + Output ctrl2 = ops::Const(scope.WithOpName("ctrl2"), true, TensorShape({})); + ops::Switch s2(scope.WithOpName("switch2"), v_in, ctrl2); + Output square2 = ops::Square(scope.WithOpName("square2"), s2.output_false); + Output sqrt2 = ops::Sqrt(scope.WithOpName("sqrt2"), s2.output_true); + + Output ctrl3 = ops::Const(scope.WithOpName("ctrl3"), false, TensorShape({})); + ops::Switch s3(scope.WithOpName("switch3"), v_in, ctrl3); + Output square3 = ops::Square(scope.WithOpName("square3"), s3.output_false); + Output sqrt3 = ops::Sqrt(scope.WithOpName("sqrt3"), s3.output_true); + + Output ctrl4 = ops::Const(scope.WithOpName("ctrl4"), false, TensorShape({})); + ops::Switch s4(scope.WithOpName("switch4"), v_in, ctrl4); + Output square4 = ops::Square(scope.WithOpName("square4"), s4.output_false); + Output sqrt4 = ops::Sqrt(scope.WithOpName("sqrt4"), s4.output_true); + + ops::Merge m1(scope.WithOpName("m1"), {square1, sqrt1}); + ops::Merge m2(scope.WithOpName("m2"), {v_in, square1}); + ops::Merge m3(scope.WithOpName("m3"), {v_in, sqrt1}); + ops::Merge m4(scope.WithOpName("m4"), {square1, sqrt2}); + ops::Merge m5(scope.WithOpName("m5"), {square2, sqrt1}); + ops::Merge m6(scope.WithOpName("m6").WithControlDependencies(sqrt2), + {v_in, square1}); + ops::Merge m7(scope.WithOpName("m7").WithControlDependencies(sqrt1), + {v_in, square1}); + + ops::Switch s5(scope.WithOpName("switch5"), v_in, ctrl1); + Output id1 = ops::Identity(scope.WithOpName("id1"), s5.output_false); + Output id2 = ops::Identity(scope.WithOpName("id2"), s5.output_true); + ops::Merge m8(scope.WithOpName("m8"), {id1, id2}); + + ops::Switch s6(scope.WithOpName("switch6"), v_in, ctrl1); + Output id3 = ops::Identity(scope.WithOpName("id3"), s6.output_false); + Output id4 = ops::Identity(scope.WithOpName("id4"), s6.output_true); + ops::Merge m9(scope.WithOpName("m9"), {id3, id4}); + + GrapplerItem item; + item.fetch.push_back("m8"); + item.fetch.push_back("id4"); + + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_CHECK_OK(status); + + for (const NodeDef& node : output.node()) { + // These nodes should have been pruned + EXPECT_NE("Square1", node.name()); + EXPECT_NE("Sqrt2", node.name()); + EXPECT_NE("m5", node.name()); + EXPECT_NE("m7", node.name()); + + if (node.name() == "m1") { + // sqrt1 is dead + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("square1", node.input(0)); + } else if (node.name() == "m2") { + // both inputs are alive + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("v_in", node.input(0)); + EXPECT_EQ("square1", node.input(1)); + } else if (node.name() == "m3") { + // sqrt1 is dead + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("v_in", node.input(0)); + } else if (node.name() == "m4") { + // both inputs are alive + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("square1", node.input(0)); + EXPECT_EQ("sqrt2", node.input(1)); + } else if (node.name() == "m6") { + // both inputs are alive and the control dependency can get triggered + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("v_in", node.input(0)); + EXPECT_EQ("square1", node.input(1)); + EXPECT_EQ("^sqrt2", node.input(2)); + } else if (node.name() == "m8") { + // The node is to be preserved because of a fetch + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("id1", node.input(0)); + EXPECT_EQ("id2", node.input(1)); + } else if (node.name() == "m9") { + // The node is to be preserved because of a fetch + EXPECT_EQ("Merge", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("id3", node.input(0)); + EXPECT_EQ("id4", node.input(1)); + } + } +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index b87ae05546..1c6fef59ea 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -65,7 +65,7 @@ class NodeMap { // A vector with a set. The set stores the same elements as the vector, and // quickly answers whether a value is in the vector. Duplicated elements are not // allowed for now. -template +template > class SetVector { public: // Returns false if value already existed in the set, true otherwise. @@ -91,7 +91,7 @@ class SetVector { void Reserve(int64 size) { vector_.reserve(size); } private: - std::unordered_set set_; + std::unordered_set set_; std::vector vector_; }; -- GitLab From 77a866ced3ca76c96b74af2759e432bfe250566f Mon Sep 17 00:00:00 2001 From: manhyuk Date: Sat, 5 May 2018 21:01:01 +0900 Subject: [PATCH 040/755] fix typo --- .../hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc index 60281951dd..66939fbb0f 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc @@ -115,7 +115,7 @@ static void CheckOpsSupport(const GraphDef& graph_def, HexagonOpsDefinitions::getInstance(); LOG(INFO) << "Checking " << graph_def.node_size() << " nodes"; LOG(INFO) << "dump_all_nodes = " << dump_all_nodes - << ", dump_shape_and_tpye = " << dump_shape_and_type; + << ", dump_shape_and_type = " << dump_shape_and_type; std::unordered_set unsupported_ops; bool all_supported = true; -- GitLab From cad5d6694aced77ab3c9141be2eea121bc6c9cb7 Mon Sep 17 00:00:00 2001 From: manhyuk Date: Sat, 5 May 2018 21:02:58 +0900 Subject: [PATCH 041/755] fix typo --- tensorflow/compiler/xla/shape_util.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index cb8bf5a2b9..82c75f85d8 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -231,7 +231,7 @@ class ShapeUtil { } // Returns the higher-precision element type if a and b are both floating - // point types; otherwise, checks that that they have the same element type + // point types; otherwise, checks that they have the same element type // and returns it. static PrimitiveType HigherPrecisionElementType(const Shape& a, const Shape& b) { -- GitLab From c92de2f3fc81c701ab29408a8a84cd6e41e96fe5 Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Sat, 5 May 2018 10:44:20 -0400 Subject: [PATCH 042/755] Skip all ops with function attribute by default --- tensorflow/core/api_def/BUILD | 6 ------ .../api_def/java_api/api_def_FilterDataset.pbtxt | 4 ---- .../api_def/java_api/api_def_FlatMapDataset.pbtxt | 4 ---- tensorflow/core/api_def/java_api/api_def_For.pbtxt | 4 ---- .../java_api/api_def_GeneratorDataset.pbtxt | 4 ---- .../java_api/api_def_GroupByWindowDataset.pbtxt | 4 ---- tensorflow/core/api_def/java_api/api_def_If.pbtxt | 4 ---- .../java_api/api_def_InterleaveDataset.pbtxt | 4 ---- .../java_api/api_def_MapAndBatchDataset.pbtxt | 4 ---- .../core/api_def/java_api/api_def_MapDataset.pbtxt | 4 ---- .../api_def/java_api/api_def_OneShotIterator.pbtxt | 4 ---- .../api_def_ParallelInterleaveDataset.pbtxt | 4 ---- .../java_api/api_def_ParallelMapDataset.pbtxt | 4 ---- .../core/api_def/java_api/api_def_RemoteCall.pbtxt | 4 ---- .../api_def/java_api/api_def_ScanDataset.pbtxt | 4 ---- .../java_api/api_def_SymbolicGradient.pbtxt | 4 ---- .../core/api_def/java_api/api_def_While.pbtxt | 4 ---- tensorflow/java/BUILD | 1 - tensorflow/java/src/gen/cc/op_generator.cc | 14 +++++++++++++- 19 files changed, 13 insertions(+), 72 deletions(-) delete mode 100644 tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_For.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_If.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt delete mode 100644 tensorflow/core/api_def/java_api/api_def_While.pbtxt diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 06b797e32e..1454a1d9b2 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -30,12 +30,6 @@ filegroup( visibility = ["//tensorflow:internal"], ) -filegroup( - name = "java_api_def", - srcs = glob(["java_api/*"]), - visibility = ["//tensorflow:internal"], -) - cc_library( name = "excluded_ops_lib", srcs = ["excluded_ops.cc"], diff --git a/tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt deleted file mode 100644 index debd7e5709..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_FilterDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "FilterDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt deleted file mode 100644 index 329ab15ef5..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_FlatMapDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "FlatMapDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_For.pbtxt b/tensorflow/core/api_def/java_api/api_def_For.pbtxt deleted file mode 100644 index caabc947bb..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_For.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "For" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt deleted file mode 100644 index a6e5167c30..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_GeneratorDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "GeneratorDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt deleted file mode 100644 index 4c0b2084a8..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_GroupByWindowDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "GroupByWindowDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_If.pbtxt b/tensorflow/core/api_def/java_api/api_def_If.pbtxt deleted file mode 100644 index 13b8635ca7..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_If.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "If" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt deleted file mode 100644 index ed748d4d2a..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_InterleaveDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "InterleaveDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt deleted file mode 100644 index cb96bf63d8..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_MapAndBatchDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "MapAndBatchDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt deleted file mode 100644 index e0ab8dd9db..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_MapDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "MapDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt b/tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt deleted file mode 100644 index 13130e6882..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_OneShotIterator.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "OneShotIterator" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt deleted file mode 100644 index 6a985d24fa..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_ParallelInterleaveDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "ParallelInterleaveDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt deleted file mode 100644 index 64f25b9e5e..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_ParallelMapDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "ParallelMapDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt b/tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt deleted file mode 100644 index 2ccb5c8cf3..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_RemoteCall.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "RemoteCall" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt b/tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt deleted file mode 100644 index 3463e60049..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_ScanDataset.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "ScanDataset" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt b/tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt deleted file mode 100644 index 88c3acea74..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_SymbolicGradient.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "SymbolicGradient" - visibility: SKIP -} diff --git a/tensorflow/core/api_def/java_api/api_def_While.pbtxt b/tensorflow/core/api_def/java_api/api_def_While.pbtxt deleted file mode 100644 index 33756682c3..0000000000 --- a/tensorflow/core/api_def/java_api/api_def_While.pbtxt +++ /dev/null @@ -1,4 +0,0 @@ -op { - graph_op_name: "While" - visibility: SKIP -} \ No newline at end of file diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 7cd0208dbf..0cc8e7c3e2 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -72,7 +72,6 @@ tf_java_op_gen_srcjar( name = "java_op_gen_sources", api_def_srcs = [ "//tensorflow/core/api_def:base_api_def", - "//tensorflow/core/api_def:java_api_def", ], base_package = "org.tensorflow.op", gen_tool = ":java_op_gen_tool", diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 7355b3a395..f4cefbe933 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -420,6 +420,18 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, writer.EndType(); } +bool CanGenerateOp(const OpDef& op_def, const ApiDef& api_def) { + if (api_def.visibility() == ApiDef::SKIP) { + return false; + } + for (const auto& attr : op_def.attr()) { + if (attr.type() == "func") { + return false; // TODO(karllessard) add support for function attributes + } + } + return true; +} + } // namespace Status OpGenerator::Run(const OpList& op_list, const string& base_package, @@ -441,7 +453,7 @@ Status OpGenerator::Run(const OpList& op_list, const string& base_package, api_map.UpdateDocs(); for (const auto& op_def : op_list.op()) { const ApiDef* api_def = api_map.GetApiDef(op_def.name()); - if (api_def->visibility() != ApiDef::SKIP) { + if (CanGenerateOp(op_def, *api_def)) { OpSpec op(OpSpec::Create(op_def, *api_def)); for (const EndpointSpec& endpoint : op.endpoints()) { GenerateOp(op, endpoint, base_package, output_dir, env_); -- GitLab From 90bbbdcc42a67c93ba8dcbc66f9c1d06909c48cb Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Sat, 5 May 2018 10:48:22 -0400 Subject: [PATCH 043/755] Remove comment left-over --- tensorflow/core/api_def/BUILD | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 1454a1d9b2..19d6438809 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -4,7 +4,6 @@ # The following targets can be used to access ApiDefs: # :base_api_def # :python_api_def -# :java_api_def package( default_visibility = ["//visibility:private"], -- GitLab From ab48fb528221152299fb08da8116d2eca54b8423 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 4 May 2018 15:40:07 -0700 Subject: [PATCH 044/755] [XLA] Print allowed attributes when the user specifies an invalid attr. PiperOrigin-RevId: 195482974 --- .../compiler/xla/tools/parser/hlo_parser.cc | 30 +++++++++++++------ .../xla/tools/parser/hlo_parser_test.cc | 2 +- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 3a945fb3b1..40dc0730ce 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -30,6 +30,7 @@ namespace { using tensorflow::StringPiece; using tensorflow::gtl::optional; +using tensorflow::str_util::Join; using tensorflow::str_util::Split; using tensorflow::str_util::SplitAndParseAsInts; using tensorflow::strings::Printf; @@ -53,7 +54,7 @@ class HloParser { std::unique_ptr ConsumeHloModule() { return std::move(module_); } // Returns the error information. - string GetError() const { return tensorflow::str_util::Join(error_, "\n"); } + string GetError() const { return Join(error_, "\n"); } private: // ParseXXX returns false if an error occurred. @@ -245,7 +246,7 @@ bool HloParser::Error(LocTy loc, StringPiece msg) { error_lines.push_back(std::string(lexer_.GetLine(loc))); error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); - error_.push_back(tensorflow::str_util::Join(error_lines, "\n")); + error_.push_back(Join(error_lines, "\n")); VLOG(1) << "Error: " << error_.back(); return false; } @@ -1488,11 +1489,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", - tensorflow::str_util::Join( - elems_seen_until_dim, ",", - [](string* out, const int64& num_elems) { - tensorflow::strings::StrAppend(out, num_elems - 1); - }), + Join(elems_seen_until_dim, ",", + [](string* out, const int64& num_elems) { + tensorflow::strings::StrAppend(out, num_elems - 1); + }), "]"); }; do { @@ -1680,7 +1680,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, return Error( index_loc, StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", tensorflow::str_util::Join(index, ", "), "]")); + ": [", Join(index, ", "), "]")); } } if (!ParseToken(TokKind::kColon, @@ -1848,7 +1848,19 @@ bool HloParser::ParseAttributeHelper( } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { - return Error(loc, Printf("unexpected attribute %s", name.c_str())); + string allowed_attrs; + if (attrs.empty()) { + allowed_attrs = "No attributes are allowed here."; + } else { + allowed_attrs = StrCat( + "Allowed attributes: ", + Join(attrs, ", ", + [&](string* out, const std::pair& kv) { + StrAppend(out, kv.first); + })); + } + return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(), + allowed_attrs.c_str())); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 4e085bc89c..d38d8907a6 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -1138,7 +1138,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { )"; ExpectHasSubstr(Parse(original).status().error_message(), - "unexpected attribute calls"); + "unexpected attribute \"calls\""); } TEST_F(HloParserTest, MissingAttribute) { -- GitLab From 008a3b69a601dc68fd940eb8a03b0c445714a339 Mon Sep 17 00:00:00 2001 From: Karmel Allison Date: Fri, 4 May 2018 16:01:02 -0700 Subject: [PATCH 045/755] Add the ability to export separate SavedModels for train and eval mode to Estimator with two new methods, available in tf.contrib: export_all_saved_models and export_saved_model_for_mode. PiperOrigin-RevId: 195485922 --- tensorflow/contrib/estimator/BUILD | 38 ++ tensorflow/contrib/estimator/__init__.py | 3 + .../estimator/python/estimator/export.py | 216 ++++++++++ .../estimator/python/estimator/export_test.py | 391 ++++++++++++++++++ tensorflow/python/estimator/BUILD | 1 + tensorflow/python/estimator/estimator.py | 346 +++++++++++++--- tensorflow/python/estimator/estimator_test.py | 336 ++++++++++++++- tensorflow/python/estimator/export/export.py | 325 +++++++++++---- .../python/estimator/export/export_output.py | 223 +++++++++- .../estimator/export/export_output_test.py | 110 +++++ .../python/estimator/export/export_test.py | 253 +++++++++++- tensorflow/python/estimator/model_fn.py | 8 + tensorflow/python/saved_model/builder_impl.py | 54 ++- tensorflow/python/saved_model/constants.py | 6 + .../python/saved_model/saved_model_test.py | 90 ++++ .../python/saved_model/signature_constants.py | 6 + .../python/saved_model/signature_def_utils.py | 2 + .../saved_model/signature_def_utils_impl.py | 56 +++ .../saved_model/signature_def_utils_test.py | 95 +++++ .../python/saved_model/tag_constants.py | 5 + 20 files changed, 2373 insertions(+), 191 deletions(-) create mode 100644 tensorflow/contrib/estimator/python/estimator/export.py create mode 100644 tensorflow/contrib/estimator/python/estimator/export_test.py diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 571e2e3a5d..e9a68801ef 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -17,6 +17,7 @@ py_library( ":boosted_trees", ":dnn", ":dnn_linear_combined", + ":export", ":extenders", ":head", ":linear", @@ -180,6 +181,43 @@ py_test( ], ) +py_library( + name = "export", + srcs = [ + "python/estimator/export.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( + name = "export_test", + size = "medium", + srcs = ["python/estimator/export_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], # b/62863147 + deps = [ + ":export", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:metrics", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:session", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:tag_constants", + ], +) + py_library( name = "head", srcs = [ diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index d43b3ea6bf..ec502f86dd 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * +from tensorflow.contrib.estimator.python.estimator.export import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * from tensorflow.contrib.estimator.python.estimator.linear import * @@ -56,6 +57,8 @@ _allowed_symbols = [ 'TowerOptimizer', 'RNNClassifier', 'RNNEstimator', + 'export_saved_model_for_mode', + 'export_all_saved_models', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/export.py b/tensorflow/contrib/estimator/python/estimator/export.py new file mode 100644 index 0000000000..e7e366a3f2 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export.py @@ -0,0 +1,216 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrapper for methods to export train/eval graphs from Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import model_fn as model_fn_lib + + +def export_saved_model_for_mode( + estimator, export_dir_base, input_receiver_fn, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Exports a single train/eval/predict graph as a SavedModel. + + For a detailed guide, see + @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn, steps=1000) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + export_dir = tf.contrib.estimator.export_saved_model_for_mode( + classifier, + export_dir_base='my_model/', + input_receiver_fn=train_rcvr_fn, + mode=model_fn_lib.ModeKeys.TRAIN) + + # export_dir is a timestamped directory with the SavedModel, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + ... + ``` + + This method takes an input_receiver_fn and mode. For the mode passed in, + this method builds a new graph by calling the input_receiver_fn to obtain + feature and label `Tensor`s. Next, this method calls the `Estimator`'s + model_fn in the passed mode to generate the model graph based on + those features and labels, and restores the given checkpoint + (or, lacking that, the most recent checkpoint) into the graph. + Finally, it creates a timestamped export directory below the + export_dir_base, and writes a `SavedModel` into it containing + the `MetaGraphDef` for the given mode and its associated signatures. + + For prediction, the exported `MetaGraphDef` will provide one `SignatureDef` + for each element of the export_outputs dict returned from the model_fn, + named using the same keys. One of these keys is always + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which + signature will be served when a serving request does not specify one. + For each signature, the outputs are provided by the corresponding + `ExportOutput`s, and the inputs are always the input receivers provided by + the serving_input_receiver_fn. + + For training and evaluation, the train_op is stored in an extra collection, + and loss, metrics, and predictions are included in a SignatureDef for the + mode in question. + + Extra assets may be written into the SavedModel via the assets_extra + argument. This should be a dict, where each key gives a destination path + (including the filename) relative to the assets.extra directory. The + corresponding value gives the full path of the source file to be copied. + For example, the simple case of copying a single file without renaming it + is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating with mode will be exported. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_saved_model_for_mode( + export_dir_base, input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=mode) + # pylint: enable=protected-access + + +def export_all_saved_models( + estimator, export_dir_base, input_receiver_fn_map, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False): + # pylint: disable=line-too-long + """Exports requested train/eval/predict graphs as separate SavedModels. + + This is a wrapper around export_saved_model_for_mode that accepts + multiple modes simultaneously and creates directories for each under + export_dir_base. See `Estimator.export_saved_model_for_mode` for + further details as to how the export works for each mode. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + serve_rcvr_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( + feature_spec) + + rcvr_fn_map = { + model_fn_lib.ModeKeys.TRAIN: train_rcvr_fn, + model_fn_lib.ModeKeys.PREDICT: serve_rcvr_fn, + } + + export_dirs = tf.contrib.estimator.export_all_saved_models( + classifier, + export_dir_base='my_model/', + input_receiver_fn_map=rcvr_fn_map) + + # export_dirs is a dict of directories with SavedModels, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], + export_dirs[tf.estimator.ModeKeys.TRAIN]) + ... + ``` + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn + mappings, where the input_receiver_fn is a function that takes no + argument and returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + + Returns: + A dict of tf.estimator.ModeKeys value to string path for each exported + directory. + + Raises: + ValueError: if any input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_all_saved_models( + export_dir_base, input_receiver_fn_map, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs) + # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/export_test.py b/tensorflow/contrib/estimator/python/estimator/export_test.py new file mode 100644 index 0000000000..89d02582e1 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export_test.py @@ -0,0 +1,391 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib wrapping of export_saved_model_for_mode functionality. + +These are direct copies of the tests included in core, with import locations +changed. These should be removed when the functionality in core is part of the +public API. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from tensorflow.contrib.estimator.python.estimator import export as contrib_export +from tensorflow.python.client import session +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.training import training +from tensorflow.python.util import compat + + +def _model_fn_for_export_tests(features, labels, mode): + _, _ = features, labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + update_global_step = state_ops.assign_add(training.get_global_step(), 1) + with ops.control_dependencies([update_global_step]): + train_op = constant_op.constant(2.) + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + loss=constant_op.constant(1.), + train_op=train_op, + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + + +def _x_y_input_fn(): + return ({'x': constant_op.constant([[1], [1]]), + 'y': constant_op.constant([[2], [2]])}, + constant_op.constant([[1], [1]])) + + +def _model_fn_with_x_y(features, labels, mode): + _ = labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + if mode == model_fn_lib.ModeKeys.PREDICT: + variables.Variable(36., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + else: + prefix = 'eval_' if mode == model_fn_lib.ModeKeys.EVAL else '' + + multiplied = math_ops.multiply( + features['x'], features['y'], name='{}multiplied'.format(prefix)) + metrics = {'mean': metrics_lib.mean(features['x'] - features['y'], + name='{}mean'.format(prefix))} + variables.Variable(1., name='later_var') + variables.Variable(3., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=multiplied, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + eval_metric_ops=metrics) + + +def _get_serving_input_receiver_fn(): + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + return export.build_parsing_serving_input_receiver_fn(feature_spec) + + +def _get_supervised_input_receiver_fn(): + feature_spec = { + 'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'), + 'y': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_y') + } + label_spec = array_ops.placeholder( + dtype=dtypes.float32, shape=[1], name='truth') + + return export.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + +class EstimatorExportTest(test.TestCase): + + def test_export_saved_model_train(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.TRAIN) + + def test_export_saved_model_eval(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL) + + def test_export_saved_model_predict(self): + self._test_export_saved_model_for_mode( + _get_serving_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT) + + def _test_export_saved_model_for_mode(self, input_receiver_fn, mode): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = contrib_export.export_saved_model_for_mode( + est, export_dir_base, input_receiver_fn, mode=mode) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + tag_set = model_fn_lib.EXPORT_TAG_MAP[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('name_collision_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_receiver_map(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertFalse('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_train_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertTrue('mean/update_op' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_eval_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertTrue('eval_mean/value' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_no_serving(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 2) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + # TODO(karmel): is this the desired behavior when names are shared? + self.assertTrue('feature_x_1' in graph_ops) + self.assertTrue('feature_y_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_three_defs(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + # Restore, to validate that the export was well-formed. + for mode, tag_set in model_fn_lib.EXPORT_TAG_MAP.items(): + export_dir = export_dirs[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('global_step/Assign' in graph_ops) + self.assertTrue('global_step/Initializer/zeros' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_all_vars(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_name_collision(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertEqual(3, collection_vars[-1].eval()) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # This is a non-obvious detail: when we load the estimator spec + # for predict, name_collision gets set to 36. However, we then restore + # from checkpoint, which should overwrite that var and make it the 3 + # from training. In practice, this would not be a good way to write + # a model_fn, but leaving this check in for now to ensure consistency + # with what would happen given our current order of spec, then + # checkpoint. + self.assertEqual(3, collection_vars[-1].eval()) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def _test_export_all_saved_models(self, input_receiver_fn_map): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_with_x_y) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dirs = contrib_export.export_all_saved_models( + est, export_dir_base, input_receiver_fn_map) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + + for _, export_dir in export_dirs.items(): + self._validate_exported_files(export_dir) + + return export_dirs, tmpdir + + def _validate_exported_files(self, export_dir): + self.assertTrue(gfile.Exists(export_dir)) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.index')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.data-00000-of-00001')))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 56dec1eaa1..b25cc7aa26 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -91,6 +91,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/saved_model:signature_constants", + "//tensorflow/python/saved_model:tag_constants", "@six_archive//:six", ], ) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 530a4a24ef..9ae64d230e 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -37,9 +37,8 @@ from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config from tensorflow.python.estimator import util -from tensorflow.python.estimator.export.export import build_all_signature_defs -from tensorflow.python.estimator.export.export import get_temp_export_dir -from tensorflow.python.estimator.export.export import get_timestamped_export_dir +from tensorflow.python.estimator.export import export as export_helpers +from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops @@ -51,7 +50,6 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import constants -from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import device_setter @@ -609,73 +607,283 @@ class Estimator(object): are provided, or no checkpoint can be found. """ # pylint: enable=line-too-long + return self._export_saved_model_for_mode( + export_dir_base, + serving_input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=model_fn_lib.ModeKeys.PREDICT) + + def _export_all_saved_models( + self, export_dir_base, input_receiver_fn_map, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False): + # pylint: disable=line-too-long + """Exports requested train/eval/predict graphs as separate SavedModels. + + This is a wrapper around export_saved_model_for_mode that accepts + multiple modes simultaneously and creates directories for each under + export_dir_base. See `Estimator.export_saved_model_for_mode` for + further details as to how the export works for each mode. + + See tf.contrib.estimator.export_all_saved_models for the currently + exposed version of this function. + + Args: + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn + mappings, where the input_receiver_fn is a function that takes no + argument and returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + + Returns: + A dict of tf.estimator.ModeKeys value to string path for each exported + directory. + + Raises: + ValueError: if any input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + # TODO(b/65561022): Consider allowing multiple input_receiver_fns per mode. + exported = {} + for mode, input_receiver_fn in input_receiver_fn_map.items(): + export_mode_dir = os.path.join( + compat.as_bytes(export_dir_base), + compat.as_bytes(mode)) + gfile.MakeDirs(export_mode_dir) + + exported_path = self._export_saved_model_for_mode( + export_mode_dir, + input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=mode) + + exported[mode] = exported_path + + return exported + + def _export_saved_model_for_mode( + self, export_dir_base, input_receiver_fn, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Exports a single train/eval/predict graph as a SavedModel. + + For a detailed guide, see + @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + + See tf.contrib.estimator.export_saved_model_for_mode for the currently + exposed version of this function. + + This method takes an input_receiver_fn and mode. For the mode passed in, + this method builds a new graph by calling the input_receiver_fn to obtain + feature and label `Tensor`s. Next, this method calls the `Estimator`'s + model_fn in the passed mode to generate the model graph based on + those features and labels, and restores the given checkpoint + (or, lacking that, the most recent checkpoint) into the graph. + Finally, it creates a timestamped export directory below the + export_dir_base, and writes a `SavedModel` into it containing + the `MetaGraphDef` for the given mode and its associated signatures. + + For prediction, the exported `MetaGraphDef` will provide one `SignatureDef` + for each element of the export_outputs dict returned from the model_fn, + named using the same keys. One of these keys is always + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which + signature will be served when a serving request does not specify one. + For each signature, the outputs are provided by the corresponding + `ExportOutput`s, and the inputs are always the input receivers provided by + the serving_input_receiver_fn. + + For training and evaluation, the train_op is stored in an extra collection, + and loss, metrics, and predictions are included in a SignatureDef for the + mode in question. + + Extra assets may be written into the SavedModel via the assets_extra + argument. This should be a dict, where each key gives a destination path + (including the filename) relative to the assets.extra directory. The + corresponding value gives the full path of the source file to be copied. + For example, the simple case of copying a single file without renaming it + is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + + Args: + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating with mode will be exported. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long with context.graph_mode(): - if serving_input_receiver_fn is None: - raise ValueError('serving_input_receiver_fn must be defined.') + if not input_receiver_fn: + raise ValueError('An input_receiver_fn must be defined.') - with ops.Graph().as_default() as g: - self._create_and_assert_global_step(g) - random_seed.set_random_seed(self._config.tf_random_seed) - serving_input_receiver = serving_input_receiver_fn() + if not checkpoint_path: + # Locate the latest checkpoint + checkpoint_path = saver.latest_checkpoint(self._model_dir) + if not checkpoint_path: + raise ValueError("Couldn't find trained model at %s." % self._model_dir) - # Call the model_fn and collect the export_outputs. - estimator_spec = self._call_model_fn( - features=serving_input_receiver.features, - labels=None, - mode=model_fn_lib.ModeKeys.PREDICT, - config=self.config) - - # Build the SignatureDefs from receivers and all outputs - signature_def_map = build_all_signature_defs( - serving_input_receiver.receiver_tensors, - estimator_spec.export_outputs, - serving_input_receiver.receiver_tensors_alternatives) - - if not checkpoint_path: - # Locate the latest checkpoint - checkpoint_path = saver.latest_checkpoint(self._model_dir) - if not checkpoint_path: - raise ValueError( - "Couldn't find trained model at %s." % self._model_dir) - - export_dir = get_timestamped_export_dir(export_dir_base) - temp_export_dir = get_temp_export_dir(export_dir) - - # TODO(soergel): Consider whether MonitoredSession makes sense here - with tf_session.Session(config=self._session_config) as session: - - saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( - sharded=True) - saver_for_restore.restore(session, checkpoint_path) - - local_init_op = ( - estimator_spec.scaffold.local_init_op or - monitored_session.Scaffold.default_local_init_op()) - - # Perform the export - builder = saved_model_builder.SavedModelBuilder(temp_export_dir) - builder.add_meta_graph_and_variables( - session, [tag_constants.SERVING], - signature_def_map=signature_def_map, - assets_collection=ops.get_collection( - ops.GraphKeys.ASSET_FILEPATHS), - legacy_init_op=local_init_op, - strip_default_attrs=strip_default_attrs) - builder.save(as_text) - - # Add the extra assets - if assets_extra: - assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir), - compat.as_bytes('assets.extra')) - for dest_relative, source in assets_extra.items(): - dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), - compat.as_bytes(dest_relative)) - dest_path = os.path.dirname(dest_absolute) - gfile.MakeDirs(dest_path) - gfile.Copy(source, dest_absolute) - - gfile.Rename(temp_export_dir, export_dir) - return export_dir + export_dir = export_helpers.get_timestamped_export_dir(export_dir_base) + temp_export_dir = export_helpers.get_temp_export_dir(export_dir) + + builder = saved_model_builder.SavedModelBuilder(temp_export_dir) + + self._add_meta_graph_and_variables_for_mode( + builder, input_receiver_fn, checkpoint_path, + strip_default_attrs, mode) + + builder.save(as_text) + + # Add the extra assets + if assets_extra: + assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir), + compat.as_bytes('assets.extra')) + for dest_relative, source in assets_extra.items(): + dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), + compat.as_bytes(dest_relative)) + dest_path = os.path.dirname(dest_absolute) + gfile.MakeDirs(dest_path) + gfile.Copy(source, dest_absolute) + + gfile.Rename(temp_export_dir, export_dir) + return export_dir + + def _add_meta_graph_and_variables_for_mode( + self, builder, input_receiver_fn, checkpoint_path, strip_default_attrs, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Loads variables and adds them along with a MetaGraphDef for saving. + + Args: + builder: instance of SavedModelBuilder that will be used for saving. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating which mode will be exported. + """ + # pylint: enable=line-too-long + with ops.Graph().as_default() as g: + self._create_and_assert_global_step(g) + random_seed.set_random_seed(self._config.tf_random_seed) + + input_receiver = input_receiver_fn() + + # Call the model_fn and collect the export_outputs. + estimator_spec = self._call_model_fn( + features=input_receiver.features, + labels=getattr(input_receiver, 'labels', None), + mode=mode, + config=self.config) + + export_outputs = self._get_export_outputs_for_spec(estimator_spec) + + # Build the SignatureDefs from receivers and all outputs + signature_def_map = export_helpers.build_all_signature_defs( + input_receiver.receiver_tensors, + export_outputs, + getattr(input_receiver, 'receiver_tensors_alternatives', None), + serving_only=(mode == model_fn_lib.ModeKeys.PREDICT)) + + with tf_session.Session(config=self._session_config) as session: + + export_tags = model_fn_lib.EXPORT_TAG_MAP[mode] + + local_init_op = ( + estimator_spec.scaffold.local_init_op or + monitored_session.Scaffold.default_local_init_op()) + + saver_for_restore = estimator_spec.scaffold.saver or saver.Saver( + sharded=True) + saver_for_restore.restore(session, checkpoint_path) + + # We add the train op explicitly for now, so that we don't have to + # change the Builder public interface. Note that this is a no-op + # for prediction, where train_op is None. + builder._add_train_op(estimator_spec.train_op) # pylint: disable=protected-access + + builder.add_meta_graph_and_variables( + session, + tags=export_tags, + signature_def_map=signature_def_map, + assets_collection=ops.get_collection( + ops.GraphKeys.ASSET_FILEPATHS), + strip_default_attrs=strip_default_attrs, + legacy_init_op=local_init_op) + + def _get_export_outputs_for_spec(self, estimator_spec): + """Given an EstimatorSpec, determine what our export outputs should be. + + EstimatorSpecs contain export_outputs that are used for serving, but for + training and eval graphs, we must wrap the tensors of interest in + appropriate ExportOutput objects. + + Args: + estimator_spec: EstimatorSpec object that will be exported. + + Returns: + a dict mapping export_output_name to ExportOutput object. + + Raises: + ValueError: if an appropriate ExportOutput cannot be found for the + passed EstimatorSpec.mode + """ + mode = estimator_spec.mode + if mode == model_fn_lib.ModeKeys.PREDICT: + outputs = estimator_spec.export_outputs + else: + if mode == model_fn_lib.ModeKeys.TRAIN: + output_class = export_output.TrainOutput + elif mode == model_fn_lib.ModeKeys.EVAL: + output_class = export_output.EvalOutput + else: + raise ValueError( + 'Export output type not found for mode: {}'.format(mode)) + + export_out = output_class( + loss=estimator_spec.loss, + predictions=estimator_spec.predictions, + metrics=estimator_spec.eval_metric_ops) + outputs = {mode: export_out} + + return outputs def _get_features_from_input_fn(self, input_fn, mode): """Extracts the `features` from return values of `input_fn`.""" @@ -1544,3 +1752,5 @@ def _get_default_warm_start_settings(warm_start_from): else: raise ValueError('warm_start_from must be a string or a WarmStartSettings, ' 'instead got {}'.format(type(warm_start_from))) + + diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 76b45b7f57..02088e5134 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -1865,6 +1865,41 @@ def _model_fn_for_export_tests(features, labels, mode): 'test': export_output.ClassificationOutput(scores, classes)}) +def _x_y_input_fn(): + return ({'x': constant_op.constant([[1], [1]]), + 'y': constant_op.constant([[2], [2]])}, + constant_op.constant([[1], [1]])) + + +def _model_fn_with_x_y(features, labels, mode): + _ = labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + if mode == model_fn_lib.ModeKeys.PREDICT: + variables.Variable(36., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + else: + prefix = 'eval_' if mode == model_fn_lib.ModeKeys.EVAL else '' + + multiplied = math_ops.multiply( + features['x'], features['y'], name='{}multiplied'.format(prefix)) + metrics = {'mean': metrics_lib.mean(features['x'] - features['y'], + name='{}mean'.format(prefix))} + variables.Variable(1., name='later_var') + variables.Variable(3., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=multiplied, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + eval_metric_ops=metrics) + + def _model_fn_with_saveables_for_export_tests(features, labels, mode): _, _ = features, labels table = saver_test_utils.CheckpointedOp(name='v2') @@ -1881,21 +1916,41 @@ def _model_fn_with_saveables_for_export_tests(features, labels, mode): 'test': export_output.PredictOutput({'prediction': prediction})}) +def _get_serving_input_receiver_fn(): + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + return export.build_parsing_serving_input_receiver_fn(feature_spec) + + +def _get_supervised_input_receiver_fn(): + feature_spec = { + 'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'), + 'y': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_y') + } + label_spec = array_ops.placeholder( + dtype=dtypes.float32, shape=[1], name='truth') + + return export.build_raw_supervised_input_receiver_fn(feature_spec, label_spec) + + _VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n' _EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n' class EstimatorExportTest(test.TestCase): - def test_export_savedmodel_proto_roundtrip(self): - tmpdir = tempfile.mkdtemp() - est = estimator.Estimator(model_fn=_model_fn_for_export_tests) - est.train(input_fn=dummy_input_fn, steps=1) + def test_export_savedmodel_proto_roundtrip_raw_receiver(self): feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( feature_spec) + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=dummy_input_fn, steps=1) + # Perform the export. export_dir_base = os.path.join( compat.as_bytes(tmpdir), compat.as_bytes('export')) @@ -1904,6 +1959,266 @@ class EstimatorExportTest(test.TestCase): # Check that all the files are in the right places. self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertTrue('weight' in graph_ops) + + def test_export_saved_model_train(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.TRAIN) + + def test_export_saved_model_eval(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL) + + def test_export_saved_model_predict(self): + self._test_export_saved_model_for_mode( + _get_serving_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT) + + def _test_export_saved_model_for_mode(self, input_receiver_fn, mode): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = est._export_saved_model_for_mode( + export_dir_base, input_receiver_fn, mode=mode) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + tag_set = model_fn_lib.EXPORT_TAG_MAP[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('name_collision_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_receiver_map(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertFalse('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_train_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertTrue('mean/update_op' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_eval_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 1) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertTrue('eval_mean/value' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_no_serving(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + self.assertEqual(len(export_dirs), 2) + # Restore, to validate that the export was well-formed. + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + export_dir = export_dirs[model_fn_lib.ModeKeys.EVAL] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + # TODO(karmel): is this the desired behavior when names are shared? + self.assertTrue('feature_x_1' in graph_ops) + self.assertTrue('feature_y_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_three_defs(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + # Restore, to validate that the export was well-formed. + for mode, tag_set in model_fn_lib.EXPORT_TAG_MAP.items(): + export_dir = export_dirs[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('global_step/Assign' in graph_ops) + self.assertTrue('global_step/Initializer/zeros' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_all_vars(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_name_collision(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dirs, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + export_dir = export_dirs[model_fn_lib.ModeKeys.TRAIN] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertEqual(3, collection_vars[-1].eval()) + + export_dir = export_dirs[model_fn_lib.ModeKeys.PREDICT] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # This is a non-obvious detail: when we load the estimator spec + # for predict, name_collision gets set to 36. However, we then restore + # from checkpoint, which should overwrite that var and make it the 3 + # from training. In practice, this would not be a good way to write + # a model_fn, but leaving this check in for now to ensure consistency + # with what would happen given our current order of spec, then + # checkpoint. + self.assertEqual(3, collection_vars[-1].eval()) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def _test_export_all_saved_models(self, input_receiver_fn_map): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_with_x_y) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dirs = est._export_all_saved_models( + export_dir_base, input_receiver_fn_map) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + + for _, export_dir in export_dirs.items(): + self._validate_exported_files(export_dir) + + return export_dirs, tmpdir + + def _validate_exported_files(self, export_dir): self.assertTrue(gfile.Exists(export_dir)) self.assertTrue(gfile.Exists(os.path.join( compat.as_bytes(export_dir), @@ -1918,18 +2233,6 @@ class EstimatorExportTest(test.TestCase): compat.as_bytes(export_dir), compat.as_bytes('variables/variables.data-00000-of-00001')))) - # Restore, to validate that the export was well-formed. - with ops.Graph().as_default() as graph: - with session.Session(graph=graph) as sess: - loader.load(sess, [tag_constants.SERVING], export_dir) - graph_ops = [x.name for x in graph.get_operations()] - self.assertTrue('input_example_tensor' in graph_ops) - self.assertTrue('ParseExample/ParseExample' in graph_ops) - self.assertTrue('weight' in graph_ops) - - # Clean up. - gfile.DeleteRecursively(tmpdir) - def test_export_savedmodel_with_saveables_proto_roundtrip(self): tmpdir = tempfile.mkdtemp() est = estimator.Estimator( @@ -2485,5 +2788,6 @@ class EstimatorIntegrationTest(test.TestCase): serving_input_receiver_fn) self.assertTrue(gfile.Exists(export_dir)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py index 41c1f5a2e2..9aafb56679 100644 --- a/tensorflow/python/estimator/export/export.py +++ b/tensorflow/python/estimator/export/export.py @@ -40,6 +40,60 @@ from tensorflow.python.util.tf_export import tf_export _SINGLE_FEATURE_DEFAULT_NAME = 'feature' _SINGLE_RECEIVER_DEFAULT_NAME = 'input' +_SINGLE_LABEL_DEFAULT_NAME = 'label' + + +def _wrap_and_check_receiver_tensors(receiver_tensors): + """Ensure that receiver_tensors is a dict of str to Tensor mappings. + + Args: + receiver_tensors: dict of str to Tensors, or a single Tensor. + + Returns: + dict of str to Tensors; this is the original dict if one was passed, or + the original tensor wrapped in a dictionary. + + Raises: + ValueError: if receiver_tensors is None, or has non-string keys, + or non-Tensor values + """ + if receiver_tensors is None: + raise ValueError('receiver_tensors must be defined.') + if not isinstance(receiver_tensors, dict): + receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} + for name, tensor in receiver_tensors.items(): + _check_tensor_key(name, error_label='receiver_tensors') + _check_tensor(tensor, name, error_label='receiver_tensor') + return receiver_tensors + + +def _check_tensor(tensor, name, error_label='feature'): + """Check that passed `tensor` is a Tensor or SparseTensor.""" + if not (isinstance(tensor, ops.Tensor) + or isinstance(tensor, sparse_tensor.SparseTensor)): + fmt_name = ' {}'.format(name) if name else '' + value_error = ValueError( + '{}{} must be a Tensor or SparseTensor.'.format(error_label, fmt_name)) + # NOTE(ericmc): This if-else block is a specific carve-out for + # LabeledTensor, which has a `.tensor` attribute and which is + # convertible to tf.Tensor via ops.convert_to_tensor. + # Allowing all types convertible to tf.Tensor is considered by soergel@ + # to be too permissive. + # TODO(soergel): accept any type convertible to Tensor, + # as in cl/193238295 snapshot #6. + if hasattr(tensor, 'tensor'): + try: + ops.convert_to_tensor(tensor) + except TypeError: + raise value_error + else: + raise value_error + + +def _check_tensor_key(name, error_label='feature'): + if not isinstance(name, six.string_types): + raise ValueError( + '{} keys must be strings: {}.'.format(error_label, name)) @tf_export('estimator.export.ServingInputReceiver') @@ -51,16 +105,18 @@ class ServingInputReceiver(collections.namedtuple( The expected return values are: features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or `SparseTensor`, specifying the features to be passed to the model. - receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying - input nodes where this receiver expects to be fed by default. Typically, - this is a single placeholder expecting serialized `tf.Example` protos. + receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` + or `SparseTensor`, specifying input nodes where this receiver expects to + be fed by default. Typically, this is a single placeholder expecting + serialized `tf.Example` protos. receiver_tensors_alternatives: a dict of string to additional - groups of receiver tensors, each of which may be a `Tensor` or a dict of - string to `Tensor`. These named receiver tensor alternatives generate - additional serving signatures, which may be used to feed inputs at - different points within the input receiver subgraph. A typical usage is - to allow feeding raw feature `Tensor`s *downstream* of the - tf.parse_example() op. Defaults to None. + groups of receiver tensors, each of which may be a `Tensor`, + `SparseTensor`, or dict of string to `Tensor` or`SparseTensor`. + These named receiver tensor alternatives generate additional serving + signatures, which may be used to feed inputs at different points within + the input receiver subgraph. A typical usage is to allow feeding raw + feature `Tensor`s *downstream* of the tf.parse_example() op. + Defaults to None. """ def __new__(cls, features, receiver_tensors, @@ -70,36 +126,10 @@ class ServingInputReceiver(collections.namedtuple( if not isinstance(features, dict): features = {_SINGLE_FEATURE_DEFAULT_NAME: features} for name, tensor in features.items(): - if not isinstance(name, six.string_types): - raise ValueError('feature keys must be strings: {}.'.format(name)) - if not (isinstance(tensor, ops.Tensor) - or isinstance(tensor, sparse_tensor.SparseTensor)): - value_error = ValueError( - 'feature {} must be a Tensor or SparseTensor.'.format(name)) - # NOTE(ericmc): This if-else block is a specific carve-out for - # LabeledTensor, which has a `.tensor` attribute and which is - # convertible to tf.Tensor via ops.convert_to_tensor. - # Allowing all types convertible to tf.Tensor is considered by soergel@ - # to be too permissive. - if hasattr(tensor, 'tensor'): - try: - ops.convert_to_tensor(tensor) - except TypeError: - raise value_error - else: - raise value_error - - if receiver_tensors is None: - raise ValueError('receiver_tensors must be defined.') - if not isinstance(receiver_tensors, dict): - receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} - for name, tensor in receiver_tensors.items(): - if not isinstance(name, six.string_types): - raise ValueError( - 'receiver_tensors keys must be strings: {}.'.format(name)) - if not isinstance(tensor, ops.Tensor): - raise ValueError( - 'receiver_tensor {} must be a Tensor.'.format(name)) + _check_tensor_key(name) + _check_tensor(tensor, name) + + receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors) if receiver_tensors_alternatives is not None: if not isinstance(receiver_tensors_alternatives, dict): @@ -115,14 +145,9 @@ class ServingInputReceiver(collections.namedtuple( receiver_tensors_alternatives[alternative_name] = ( receiver_tensors_alt) for name, tensor in receiver_tensors_alt.items(): - if not isinstance(name, six.string_types): - raise ValueError( - 'receiver_tensors keys must be strings: {}.'.format(name)) - if not (isinstance(tensor, ops.Tensor) - or isinstance(tensor, sparse_tensor.SparseTensor)): - raise ValueError( - 'receiver_tensor {} must be a Tensor or SparseTensor.'.format( - name)) + _check_tensor_key(name, error_label='receiver_tensors_alternative') + _check_tensor( + tensor, name, error_label='receiver_tensors_alternative') return super(ServingInputReceiver, cls).__new__( cls, @@ -155,25 +180,25 @@ class TensorServingInputReceiver(collections.namedtuple( The expected return values are: features: A single `Tensor` or `SparseTensor`, representing the feature to be passed to the model. - receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying - input nodes where this receiver expects to be fed by default. Typically, - this is a single placeholder expecting serialized `tf.Example` protos. + receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` + or `SparseTensor`, specifying input nodes where this receiver expects to + be fed by default. Typically, this is a single placeholder expecting + serialized `tf.Example` protos. receiver_tensors_alternatives: a dict of string to additional - groups of receiver tensors, each of which may be a `Tensor` or a dict of - string to `Tensor`. These named receiver tensor alternatives generate - additional serving signatures, which may be used to feed inputs at - different points within the input receiver subgraph. A typical usage is - to allow feeding raw feature `Tensor`s *downstream* of the - tf.parse_example() op. Defaults to None. + groups of receiver tensors, each of which may be a `Tensor`, + `SparseTensor`, or dict of string to `Tensor` or`SparseTensor`. + These named receiver tensor alternatives generate additional serving + signatures, which may be used to feed inputs at different points within + the input receiver subgraph. A typical usage is to allow feeding raw + feature `Tensor`s *downstream* of the tf.parse_example() op. + Defaults to None. """ def __new__(cls, features, receiver_tensors, receiver_tensors_alternatives=None): if features is None: raise ValueError('features must be defined.') - if not (isinstance(features, ops.Tensor) - or isinstance(features, sparse_tensor.SparseTensor)): - raise ValueError('feature must be a Tensor or SparseTensor.') + _check_tensor(features, None) receiver = ServingInputReceiver( features=features, @@ -187,6 +212,49 @@ class TensorServingInputReceiver(collections.namedtuple( receiver_tensors_alternatives=receiver.receiver_tensors_alternatives) +class SupervisedInputReceiver(collections.namedtuple( + 'SupervisedInputReceiver', + ['features', 'labels', 'receiver_tensors'])): + """A return type for a training_input_receiver_fn or eval_input_receiver_fn. + + This differs from a ServingInputReceiver in that (1) this receiver expects + a set of labels to be passed in with features, and (2) this receiver does + not support receiver_tensors_alternatives, which are primarily used for + serving. + + The expected return values are: + features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the features to be passed to the model. + labels: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or + `SparseTensor`, specifying the labels to be passed to the model. + receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` + or `SparseTensor`, specifying input nodes where this receiver expects to + be fed by default. Typically, this is a single placeholder expecting + serialized `tf.Example` protos. + + """ + + def __new__(cls, features, labels, receiver_tensors): + # Both features and labels can be dicts or raw tensors. + for input_vals, error_label in ((features, 'feature'), (labels, 'label')): + if input_vals is None: + raise ValueError('{}s must be defined.'.format(error_label)) + if isinstance(input_vals, dict): + for name, tensor in input_vals.items(): + _check_tensor_key(name, error_label=error_label) + _check_tensor(tensor, name, error_label=error_label) + else: + _check_tensor(input_vals, None, error_label=error_label) + + receiver_tensors = _wrap_and_check_receiver_tensors(receiver_tensors) + + return super(SupervisedInputReceiver, cls).__new__( + cls, + features=features, + labels=labels, + receiver_tensors=receiver_tensors) + + @tf_export('estimator.export.build_parsing_serving_input_receiver_fn') def build_parsing_serving_input_receiver_fn(feature_spec, default_batch_size=None): @@ -216,6 +284,23 @@ def build_parsing_serving_input_receiver_fn(feature_spec, return serving_input_receiver_fn +def _placeholder_from_tensor(t, default_batch_size=None): + shape_list = t.get_shape().as_list() + shape_list[0] = default_batch_size + shape = tensor_shape.TensorShape(shape_list) + + # Reuse the feature tensor's op name (t.op.name) for the placeholder, + # excluding the index from the tensor's name (t.name): + # t.name = "%s:%d" % (t.op.name, t._value_index) + return array_ops.placeholder(dtype=t.dtype, shape=shape, name=t.op.name) + + +def _placeholders_from_receiver_tensors_dict( + input_vals, default_batch_size=None): + return {name: _placeholder_from_tensor(t, default_batch_size) + for name, t in input_vals.items()} + + @tf_export('estimator.export.build_raw_serving_input_receiver_fn') def build_raw_serving_input_receiver_fn(features, default_batch_size=None): """Build a serving_input_receiver_fn expecting feature Tensors. @@ -233,17 +318,9 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None): """ def serving_input_receiver_fn(): """A serving_input_receiver_fn that expects features to be fed directly.""" - receiver_tensors = {} - for name, t in features.items(): - shape_list = t.get_shape().as_list() - shape_list[0] = default_batch_size - shape = tensor_shape.TensorShape(shape_list) - - # Reuse the feature tensor's op name (t.op.name) for the placeholder, - # excluding the index from the tensor's name (t.name): - # t.name = "%s:%d" % (t.op.name, t._value_index) - receiver_tensors[name] = array_ops.placeholder( - dtype=t.dtype, shape=shape, name=t.op.name) + receiver_tensors = _placeholders_from_receiver_tensors_dict( + features, default_batch_size) + # TODO(b/34885899): remove the unnecessary copy # The features provided are simply the placeholders, but we defensively copy # the dict because it may be mutated. @@ -252,13 +329,100 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None): return serving_input_receiver_fn +def build_raw_supervised_input_receiver_fn( + features, labels, default_batch_size=None): + """Build a supervised_input_receiver_fn for raw features and labels. + + This function wraps tensor placeholders in a supervised_receiver_fn + with the expectation that the features and labels appear precisely as + the model_fn expects them. Features and labels can therefore be dicts of + tensors, or raw tensors. + + Args: + features: a dict of string to `Tensor` or `Tensor`. + labels: a dict of string to `Tensor` or `Tensor`. + default_batch_size: the number of query examples expected per batch. + Leave unset for variable batch size (recommended). + + Returns: + A supervised_input_receiver_fn. + + Raises: + ValueError: if features and labels have overlapping keys. + """ + # Check for overlapping keys before beginning. + try: + feat_keys = features.keys() + except AttributeError: + feat_keys = [_SINGLE_RECEIVER_DEFAULT_NAME] + try: + label_keys = labels.keys() + except AttributeError: + label_keys = [_SINGLE_LABEL_DEFAULT_NAME] + + overlap_keys = set(feat_keys) & set(label_keys) + if overlap_keys: + raise ValueError('Features and labels must have distinct keys. ' + 'Found overlapping keys: {}'.format(overlap_keys)) + + def supervised_input_receiver_fn(): + """A receiver_fn that expects pass-through features and labels.""" + if not isinstance(features, dict): + features_cp = _placeholder_from_tensor(features, default_batch_size) + receiver_features = {_SINGLE_RECEIVER_DEFAULT_NAME: features_cp} + else: + receiver_features = _placeholders_from_receiver_tensors_dict( + features, default_batch_size) + features_cp = receiver_features + + if not isinstance(labels, dict): + labels_cp = _placeholder_from_tensor(labels, default_batch_size) + receiver_labels = {_SINGLE_LABEL_DEFAULT_NAME: labels_cp} + else: + receiver_labels = _placeholders_from_receiver_tensors_dict( + labels, default_batch_size) + labels_cp = receiver_labels + + receiver_tensors = dict(receiver_features) + receiver_tensors.update(receiver_labels) + return SupervisedInputReceiver(features_cp, labels_cp, receiver_tensors) + + return supervised_input_receiver_fn + + ### Below utilities are specific to SavedModel exports. def build_all_signature_defs(receiver_tensors, export_outputs, - receiver_tensors_alternatives=None): - """Build `SignatureDef`s for all export outputs.""" + receiver_tensors_alternatives=None, + serving_only=True): + """Build `SignatureDef`s for all export outputs. + + Args: + receiver_tensors: a `Tensor`, or a dict of string to `Tensor`, specifying + input nodes where this receiver expects to be fed by default. Typically, + this is a single placeholder expecting serialized `tf.Example` protos. + export_outputs: a dict of ExportOutput instances, each of which has + an as_signature_def instance method that will be called to retrieve + the signature_def for all export output tensors. + receiver_tensors_alternatives: a dict of string to additional + groups of receiver tensors, each of which may be a `Tensor` or a dict of + string to `Tensor`. These named receiver tensor alternatives generate + additional serving signatures, which may be used to feed inputs at + different points within the input receiver subgraph. A typical usage is + to allow feeding raw feature `Tensor`s *downstream* of the + tf.parse_example() op. Defaults to None. + serving_only: boolean; if true, resulting signature defs will only include + valid serving signatures. If false, all requested signatures will be + returned. + + Returns: + signature_def representing all passed args. + + Raises: + ValueError: if export_outputs is not a dict + """ if not isinstance(receiver_tensors, dict): receiver_tensors = {_SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors} if export_outputs is None or not isinstance(export_outputs, dict): @@ -293,17 +457,24 @@ def build_all_signature_defs(receiver_tensors, _log_signature_report(signature_def_map, excluded_signatures) # The above calls to export_output.as_signature_def should return only - # valid signatures; if there is a validity problem, they raise ValueError, - # which we ignore above. Consequently the call to is_valid_signature here - # should not remove anything else; it's just an extra sanity check. - return {k: v for k, v in signature_def_map.items() - if signature_def_utils.is_valid_signature(v)} + # valid signatures; if there is a validity problem, they raise a ValueError, + # in which case we exclude that signature from signature_def_map above. + # The is_valid_signature check ensures that the signatures produced are + # valid for serving, and acts as an additional sanity check for export + # signatures produced for serving. We skip this check for training and eval + # signatures, which are not intended for serving. + if serving_only: + signature_def_map = {k: v for k, v in signature_def_map.items() + if signature_def_utils.is_valid_signature(v)} + return signature_def_map _FRIENDLY_METHOD_NAMES = { signature_constants.CLASSIFY_METHOD_NAME: 'Classify', signature_constants.REGRESS_METHOD_NAME: 'Regress', signature_constants.PREDICT_METHOD_NAME: 'Predict', + signature_constants.SUPERVISED_TRAIN_METHOD_NAME: 'Train', + signature_constants.SUPERVISED_EVAL_METHOD_NAME: 'Eval', } diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index 87b964be37..d387ea2940 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -38,6 +38,8 @@ class ExportOutput(object): __metaclass__ = abc.ABCMeta + _SEPARATOR_CHAR = '/' + @abc.abstractmethod def as_signature_def(self, receiver_tensors): """Generate a SignatureDef proto for inclusion in a MetaGraphDef. @@ -51,6 +53,52 @@ class ExportOutput(object): """ pass + def _check_output_key(self, key, error_label): + # For multi-head models, the key can be a tuple. + if isinstance(key, tuple): + key = self._SEPARATOR_CHAR.join(key) + + if not isinstance(key, six.string_types): + raise ValueError( + '{} output key must be a string; got {}.'.format(error_label, key)) + return key + + def _wrap_and_check_outputs( + self, outputs, single_output_default_name, error_label=None): + """Wraps raw tensors as dicts and checks type. + + Note that we create a new dict here so that we can overwrite the keys + if necessary. + + Args: + outputs: A `Tensor` or a dict of string to `Tensor`. + single_output_default_name: A string key for use in the output dict + if the provided `outputs` is a raw tensor. + error_label: descriptive string for use in error messages. If none, + single_output_default_name will be used. + + Returns: + A dict of tensors + + Raises: + ValueError: if the outputs dict keys are not strings or tuples of strings + or the values are not Tensors. + """ + if not isinstance(outputs, dict): + outputs = {single_output_default_name: outputs} + + output_dict = {} + for key, value in outputs.items(): + error_name = error_label or single_output_default_name + key = self._check_output_key(key, error_name) + if not isinstance(value, ops.Tensor): + raise ValueError( + '{} output value must be a Tensor; got {}.'.format( + error_name, value)) + + output_dict[key] = value + return output_dict + @tf_export('estimator.export.ClassificationOutput') class ClassificationOutput(ExportOutput): @@ -154,9 +202,6 @@ class RegressionOutput(ExportOutput): return signature_def_utils.regression_signature_def(examples, self.value) -_SINGLE_OUTPUT_DEFAULT_NAME = 'output' - - @tf_export('estimator.export.PredictOutput') class PredictOutput(ExportOutput): """Represents the output of a generic prediction head. @@ -165,6 +210,7 @@ class PredictOutput(ExportOutput): Named outputs must be provided as a dict from string to `Tensor`, """ + _SINGLE_OUTPUT_DEFAULT_NAME = 'output' def __init__(self, outputs): """Constructor for PredictOutput. @@ -177,16 +223,9 @@ class PredictOutput(ExportOutput): ValueError: if the outputs is not dict, or any of its keys are not strings, or any of its values are not `Tensor`s. """ - if not isinstance(outputs, dict): - outputs = {_SINGLE_OUTPUT_DEFAULT_NAME: outputs} - for key, value in outputs.items(): - if not isinstance(key, six.string_types): - raise ValueError( - 'Prediction output key must be a string; got {}.'.format(key)) - if not isinstance(value, ops.Tensor): - raise ValueError( - 'Prediction output value must be a Tensor; got {}.'.format(value)) - self._outputs = outputs + + self._outputs = self._wrap_and_check_outputs( + outputs, self._SINGLE_OUTPUT_DEFAULT_NAME, error_label='Prediction') @property def outputs(self): @@ -195,3 +234,161 @@ class PredictOutput(ExportOutput): def as_signature_def(self, receiver_tensors): return signature_def_utils.predict_signature_def(receiver_tensors, self.outputs) + + +class _SupervisedOutput(ExportOutput): + """Represents the output of a supervised training or eval process.""" + __metaclass__ = abc.ABCMeta + + LOSS_NAME = 'loss' + PREDICTIONS_NAME = 'predictions' + METRICS_NAME = 'metrics' + + METRIC_VALUE_SUFFIX = 'value' + METRIC_UPDATE_SUFFIX = 'update_op' + + _loss = None + _predictions = None + _metrics = None + + def __init__(self, loss=None, predictions=None, metrics=None): + """Constructor for SupervisedOutput (ie, Train or Eval output). + + Args: + loss: dict of Tensors or single Tensor representing calculated loss. + predictions: dict of Tensors or single Tensor representing model + predictions. + metrics: dict of (metric_value, update_op) tuples, or a single tuple. + metric_value must be a Tensor, and update_op must be a Tensor or Op. + + Raises: + ValueError: if any of the outputs' dict keys are not strings or tuples of + strings or the values are not Tensors (or Operations in the case of + update_op). + """ + + if loss is not None: + loss_dict = self._wrap_and_check_outputs(loss, self.LOSS_NAME) + self._loss = self._prefix_output_keys(loss_dict, self.LOSS_NAME) + if predictions is not None: + pred_dict = self._wrap_and_check_outputs( + predictions, self.PREDICTIONS_NAME) + self._predictions = self._prefix_output_keys( + pred_dict, self.PREDICTIONS_NAME) + if metrics is not None: + self._metrics = self._wrap_and_check_metrics(metrics) + + def _prefix_output_keys(self, output_dict, output_name): + """Prepend output_name to the output_dict keys if it doesn't exist. + + This produces predictable prefixes for the pre-determined outputs + of SupervisedOutput. + + Args: + output_dict: dict of string to Tensor, assumed valid. + output_name: prefix string to prepend to existing keys. + + Returns: + dict with updated keys and existing values. + """ + + new_outputs = {} + for key, val in output_dict.items(): + key = self._prefix_key(key, output_name) + new_outputs[key] = val + return new_outputs + + def _prefix_key(self, key, output_name): + if key.find(output_name) != 0: + key = output_name + self._SEPARATOR_CHAR + key + return key + + def _wrap_and_check_metrics(self, metrics): + """Handle the saving of metrics. + + Metrics is either a tuple of (value, update_op), or a dict of such tuples. + Here, we separate out the tuples and create a dict with names to tensors. + + Args: + metrics: dict of (metric_value, update_op) tuples, or a single tuple. + + Returns: + dict of output_names to tensors + + Raises: + ValueError: if the dict key is not a string, or the metric values or ops + are not tensors. + """ + if not isinstance(metrics, dict): + metrics = {self.METRICS_NAME: metrics} + + outputs = {} + for key, (metric_val, metric_op) in metrics.items(): + key = self._check_output_key(key, self.METRICS_NAME) + key = self._prefix_key(key, self.METRICS_NAME) + + val_name = key + self._SEPARATOR_CHAR + self.METRIC_VALUE_SUFFIX + op_name = key + self._SEPARATOR_CHAR + self.METRIC_UPDATE_SUFFIX + if not isinstance(metric_val, ops.Tensor): + raise ValueError( + '{} output value must be a Tensor; got {}.'.format( + key, metric_val)) + if (not isinstance(metric_op, ops.Tensor) and + not isinstance(metric_op, ops.Operation)): + raise ValueError( + '{} update_op must be a Tensor or Operation; got {}.'.format( + key, metric_op)) + outputs[val_name] = metric_val + outputs[op_name] = metric_op + + return outputs + + @property + def loss(self): + return self._loss + + @property + def predictions(self): + return self._predictions + + @property + def metrics(self): + return self._metrics + + @abc.abstractmethod + def _get_signature_def_fn(self): + """Returns a function that produces a SignatureDef given desired outputs.""" + pass + + def as_signature_def(self, receiver_tensors): + signature_def_fn = self._get_signature_def_fn() + return signature_def_fn( + receiver_tensors, self.loss, self.predictions, self.metrics) + + +class TrainOutput(_SupervisedOutput): + """Represents the output of a supervised training process. + + This class generates the appropriate signature def for exporting + training output by type-checking and wrapping loss, predictions, and metrics + values. + """ + + def _get_signature_def_fn(self): + return signature_def_utils.supervised_train_signature_def + + +class EvalOutput(_SupervisedOutput): + """Represents the output of a supervised eval process. + + This class generates the appropriate signature def for exporting + eval output by type-checking and wrapping loss, predictions, and metrics + values. + """ + + def _get_signature_def_fn(self): + return signature_def_utils.supervised_eval_signature_def + + + + diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py index 7090e53d80..b21ba91b0f 100644 --- a/tensorflow/python/estimator/export/export_output_test.py +++ b/tensorflow/python/estimator/export/export_output_test.py @@ -225,5 +225,115 @@ class ExportOutputTest(test.TestCase): }) +class MockSupervisedOutput(export_output_lib._SupervisedOutput): + """So that we can test the abstract class methods directly.""" + + def _get_signature_def_fn(self): + pass + + +class SupervisedOutputTest(test.TestCase): + + def test_supervised_outputs_valid(self): + """Tests that no errors are raised when provided outputs are valid.""" + loss = {"my_loss": constant_op.constant([0])} + predictions = {u"output1": constant_op.constant(["foo"])} + metrics = {"metrics": (constant_op.constant([0]), + constant_op.constant([10])), + "metrics2": (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = MockSupervisedOutput(loss, predictions, metrics) + self.assertEqual(outputter.loss["loss/my_loss"], loss["my_loss"]) + self.assertEqual( + outputter.predictions["predictions/output1"], predictions["output1"]) + self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0]) + self.assertEqual( + outputter.metrics["metrics2/update_op"], metrics["metrics2"][1]) + + # Single Tensor is OK too + outputter = MockSupervisedOutput( + loss["my_loss"], predictions["output1"], metrics["metrics"]) + self.assertEqual(outputter.loss, {"loss": loss["my_loss"]}) + self.assertEqual( + outputter.predictions, {"predictions": predictions["output1"]}) + self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0]) + + def test_supervised_outputs_none(self): + outputter = MockSupervisedOutput( + constant_op.constant([0]), None, None) + self.assertEqual(len(outputter.loss), 1) + self.assertEqual(outputter.predictions, None) + self.assertEqual(outputter.metrics, None) + + def test_supervised_outputs_invalid(self): + with self.assertRaisesRegexp(ValueError, "predictions output value must"): + MockSupervisedOutput(constant_op.constant([0]), [3], None) + with self.assertRaisesRegexp(ValueError, "loss output value must"): + MockSupervisedOutput("str", None, None) + with self.assertRaisesRegexp(ValueError, "metrics output value must"): + MockSupervisedOutput(None, None, (15.3, 4)) + with self.assertRaisesRegexp(ValueError, "loss output key must"): + MockSupervisedOutput({25: "Tensor"}, None, None) + + def test_supervised_outputs_tuples(self): + """Tests that no errors are raised when provided outputs are valid.""" + loss = {("my", "loss"): constant_op.constant([0])} + predictions = {(u"output1", "2"): constant_op.constant(["foo"])} + metrics = {("metrics", "twice"): (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = MockSupervisedOutput(loss, predictions, metrics) + self.assertEqual(set(outputter.loss.keys()), set(["loss/my/loss"])) + self.assertEqual(set(outputter.predictions.keys()), + set(["predictions/output1/2"])) + self.assertEqual(set(outputter.metrics.keys()), + set(["metrics/twice/value", "metrics/twice/update_op"])) + + def test_supervised_outputs_no_prepend(self): + """Tests that no errors are raised when provided outputs are valid.""" + loss = {"loss": constant_op.constant([0])} + predictions = {u"predictions": constant_op.constant(["foo"])} + metrics = {u"metrics": (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = MockSupervisedOutput(loss, predictions, metrics) + self.assertEqual(set(outputter.loss.keys()), set(["loss"])) + self.assertEqual(set(outputter.predictions.keys()), set(["predictions"])) + self.assertEqual(set(outputter.metrics.keys()), + set(["metrics/value", "metrics/update_op"])) + + def test_train_signature_def(self): + loss = {"my_loss": constant_op.constant([0])} + predictions = {u"output1": constant_op.constant(["foo"])} + metrics = {"metrics": (constant_op.constant([0]), + constant_op.constant([10]))} + + outputter = export_output_lib.TrainOutput(loss, predictions, metrics) + + receiver = {u"features": constant_op.constant(100, shape=(100, 2)), + "labels": constant_op.constant(100, shape=(100, 1))} + sig_def = outputter.as_signature_def(receiver) + + self.assertTrue("loss/my_loss" in sig_def.outputs) + self.assertTrue("metrics/value" in sig_def.outputs) + self.assertTrue("predictions/output1" in sig_def.outputs) + self.assertTrue("features" in sig_def.inputs) + + def test_eval_signature_def(self): + loss = {"my_loss": constant_op.constant([0])} + predictions = {u"output1": constant_op.constant(["foo"])} + + outputter = export_output_lib.EvalOutput(loss, predictions, None) + + receiver = {u"features": constant_op.constant(100, shape=(100, 2)), + "labels": constant_op.constant(100, shape=(100, 1))} + sig_def = outputter.as_signature_def(receiver) + + self.assertTrue("loss/my_loss" in sig_def.outputs) + self.assertFalse("metrics/value" in sig_def.outputs) + self.assertTrue("predictions/output1" in sig_def.outputs) + self.assertTrue("features" in sig_def.inputs) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py index c203be7dac..0af587f2a8 100644 --- a/tensorflow/python/estimator/export/export_test.py +++ b/tensorflow/python/estimator/export/export_test.py @@ -54,7 +54,7 @@ ops.register_tensor_conversion_function(LabeledTensorMock, _convert_labeled_tensor_mock_to_tensor) -class ExportTest(test_util.TensorFlowTestCase): +class ServingInputReceiverTest(test_util.TensorFlowTestCase): def test_serving_input_receiver_constructor(self): """Tests that no errors are raised when input is expected.""" @@ -161,6 +161,165 @@ class ExportTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): _ = export.ServingInputReceiver(feature, receiver_tensor) + +class SupervisedInputReceiverTest(test_util.TensorFlowTestCase): + + def test_input_receiver_constructor(self): + """Tests that no errors are raised when input is expected.""" + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + labels = { + "classes": constant_op.constant([0] * 100), + } + + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + export.SupervisedInputReceiver(features, labels, receiver_tensors) + + def test_input_receiver_raw_values(self): + """Tests that no errors are raised when input is expected.""" + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + + labels = { + "classes": constant_op.constant([0] * 100), + } + + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + rec = export.SupervisedInputReceiver( + features["feature2"], labels, receiver_tensors) + self.assertIsInstance(rec.features, sparse_tensor.SparseTensor) + + rec = export.SupervisedInputReceiver( + features, labels["classes"], receiver_tensors) + self.assertIsInstance(rec.labels, ops.Tensor) + + def test_input_receiver_features_invalid(self): + features = constant_op.constant([0] * 100) + labels = constant_op.constant([0]) + receiver_tensors = { + "example0": array_ops.placeholder(dtypes.string, name="example0"), + u"example1": array_ops.placeholder(dtypes.string, name="example1"), + } + + with self.assertRaisesRegexp(ValueError, "features must be defined"): + export.SupervisedInputReceiver( + features=None, + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp(ValueError, "feature keys must be strings"): + export.SupervisedInputReceiver( + features={1: constant_op.constant([1])}, + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp(ValueError, "label keys must be strings"): + export.SupervisedInputReceiver( + features=features, + labels={1: constant_op.constant([1])}, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp( + ValueError, "feature feature1 must be a Tensor or SparseTensor"): + export.SupervisedInputReceiver( + features={"feature1": [1]}, + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp( + ValueError, "feature must be a Tensor or SparseTensor"): + export.SupervisedInputReceiver( + features=[1], + labels=labels, + receiver_tensors=receiver_tensors) + + with self.assertRaisesRegexp( + ValueError, "label must be a Tensor or SparseTensor"): + export.SupervisedInputReceiver( + features=features, + labels=100, + receiver_tensors=receiver_tensors) + + def test_input_receiver_receiver_tensors_invalid(self): + features = { + "feature0": constant_op.constant([0]), + u"feature1": constant_op.constant([1]), + "feature2": sparse_tensor.SparseTensor( + indices=[[0, 0]], values=[1], dense_shape=[1, 1]), + } + labels = constant_op.constant([0]) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors must be defined"): + export.SupervisedInputReceiver( + features=features, + labels=labels, + receiver_tensors=None) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensors keys must be strings"): + export.SupervisedInputReceiver( + features=features, + labels=labels, + receiver_tensors={ + 1: array_ops.placeholder(dtypes.string, name="example0")}) + + with self.assertRaisesRegexp( + ValueError, "receiver_tensor example1 must be a Tensor"): + export.SupervisedInputReceiver( + features=features, + labels=labels, + receiver_tensors={"example1": [1]}) + + def test_single_feature_single_receiver(self): + feature = constant_op.constant(5) + label = constant_op.constant(5) + receiver_tensor = array_ops.placeholder(dtypes.string) + input_receiver = export.SupervisedInputReceiver( + feature, label, receiver_tensor) + + # single receiver is automatically named + receiver_key, = input_receiver.receiver_tensors.keys() + self.assertEqual("input", receiver_key) + + def test_multi_feature_single_receiver(self): + features = {"foo": constant_op.constant(5), + "bar": constant_op.constant(6)} + labels = {"value": constant_op.constant(5)} + receiver_tensor = array_ops.placeholder(dtypes.string) + _ = export.SupervisedInputReceiver(features, labels, receiver_tensor) + + def test_multi_feature_multi_receiver(self): + features = {"foo": constant_op.constant(5), + "bar": constant_op.constant(6)} + labels = {"value": constant_op.constant(5)} + receiver_tensors = {"baz": array_ops.placeholder(dtypes.int64), + "qux": array_ops.placeholder(dtypes.float32)} + _ = export.SupervisedInputReceiver(features, labels, receiver_tensors) + + def test_feature_labeled_tensor(self): + feature = LabeledTensorMock() + label = constant_op.constant(5) + receiver_tensor = array_ops.placeholder(dtypes.string) + _ = export.SupervisedInputReceiver(feature, label, receiver_tensor) + + +class ExportTest(test_util.TensorFlowTestCase): + def test_build_parsing_serving_input_receiver_fn(self): feature_spec = {"int_feature": parsing_ops.VarLenFeature(dtypes.int64), "float_feature": parsing_ops.VarLenFeature(dtypes.float32)} @@ -237,6 +396,69 @@ class ExportTest(test_util.TensorFlowTestCase): dtypes.int32, serving_input_receiver.receiver_tensors["feature_2"].dtype) + def test_build_raw_supervised_input_receiver_fn(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"foo": constant_op.constant([5]), + "bar": constant_op.constant([6])} + input_receiver_fn = export.build_raw_supervised_input_receiver_fn( + features, labels) + with ops.Graph().as_default(): + input_receiver = input_receiver_fn() + self.assertEqual(set(["feature_1", "feature_2"]), + set(input_receiver.features.keys())) + self.assertEqual(set(["foo", "bar"]), + set(input_receiver.labels.keys())) + self.assertEqual(set(["feature_1", "feature_2", "foo", "bar"]), + set(input_receiver.receiver_tensors.keys())) + self.assertEqual( + dtypes.string, input_receiver.receiver_tensors["feature_1"].dtype) + self.assertEqual( + dtypes.int32, input_receiver.receiver_tensors["feature_2"].dtype) + + def test_build_raw_supervised_input_receiver_fn_raw_tensors(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"foo": constant_op.constant([5]), + "bar": constant_op.constant([6])} + input_receiver_fn1 = export.build_raw_supervised_input_receiver_fn( + features["feature_1"], labels) + input_receiver_fn2 = export.build_raw_supervised_input_receiver_fn( + features["feature_1"], labels["foo"]) + with ops.Graph().as_default(): + input_receiver = input_receiver_fn1() + self.assertIsInstance(input_receiver.features, ops.Tensor) + self.assertEqual(set(["foo", "bar"]), + set(input_receiver.labels.keys())) + self.assertEqual(set(["input", "foo", "bar"]), + set(input_receiver.receiver_tensors.keys())) + + input_receiver = input_receiver_fn2() + self.assertIsInstance(input_receiver.features, ops.Tensor) + self.assertIsInstance(input_receiver.labels, ops.Tensor) + self.assertEqual(set(["input", "label"]), + set(input_receiver.receiver_tensors.keys())) + + def test_build_raw_supervised_input_receiver_fn_batch_size(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"foo": constant_op.constant([5]), + "bar": constant_op.constant([6])} + input_receiver_fn = export.build_raw_supervised_input_receiver_fn( + features, labels, default_batch_size=10) + with ops.Graph().as_default(): + input_receiver = input_receiver_fn() + self.assertEqual([10], input_receiver.receiver_tensors["feature_1"].shape) + self.assertEqual([10], input_receiver.features["feature_1"].shape) + + def test_build_raw_supervised_input_receiver_fn_overlapping_keys(self): + features = {"feature_1": constant_op.constant(["hello"]), + "feature_2": constant_op.constant([42])} + labels = {"feature_1": constant_op.constant([5]), + "bar": constant_op.constant([6])} + with self.assertRaises(ValueError): + export.build_raw_supervised_input_receiver_fn(features, labels) + def test_build_all_signature_defs_without_receiver_alternatives(self): receiver_tensor = array_ops.placeholder(dtypes.string) output_1 = constant_op.constant([1.]) @@ -404,6 +626,35 @@ class ExportTest(test_util.TensorFlowTestCase): self.assertTrue(int(time_1) < int(time_2)) self.assertTrue(int(time_2) < int(time_3)) + def test_build_all_signature_defs_serving_only(self): + receiver_tensor = {"input": array_ops.placeholder(dtypes.string)} + output_1 = constant_op.constant([1.]) + export_outputs = { + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_output.PredictOutput(outputs=output_1), + "train": export_output.TrainOutput(loss=output_1), + } + + signature_defs = export.build_all_signature_defs( + receiver_tensor, export_outputs) + + expected_signature_defs = { + "serving_default": signature_def_utils.predict_signature_def( + receiver_tensor, {"output": output_1}) + } + + self.assertDictEqual(expected_signature_defs, signature_defs) + + signature_defs = export.build_all_signature_defs( + receiver_tensor, export_outputs, serving_only=False) + + expected_signature_defs.update({ + "train": signature_def_utils.supervised_train_signature_def( + receiver_tensor, loss={"loss": output_1}) + }) + + self.assertDictEqual(expected_signature_defs, signature_defs) + class TensorServingReceiverTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index 8111ab564c..4ab2578769 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook from tensorflow.python.util import nest @@ -53,6 +54,13 @@ class ModeKeys(object): LOSS_METRIC_KEY = 'loss' AVERAGE_LOSS_METRIC_KEY = 'average_loss' +# Mapping of the modes to appropriate tag_constants that are used for saving. +EXPORT_TAG_MAP = { + ModeKeys.PREDICT: [tag_constants.SERVING], + ModeKeys.TRAIN: [tag_constants.TRAINING], + ModeKeys.EVAL: [tag_constants.EVAL], +} + @tf_export('estimator.EstimatorSpec') class EstimatorSpec( diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 3447d917e9..071033b066 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -168,6 +168,25 @@ class SavedModelBuilder(object): raise TypeError("main_op needs to be an Operation: %r" % main_op) ops.add_to_collection(constants.MAIN_OP_KEY, main_op) + def _add_train_op(self, train_op): + """Add train op to the SavedModel. + + Note that this functionality is in development, and liable to be + moved elsewhere. + + Args: + train_op: Op or group of ops that are used for training. These are + stored as a collection with key TRAIN_OP_KEY, but not executed. + + Raises: + TypeError if Train op is not of type `Operation`. + """ + if train_op is not None: + if (not isinstance(train_op, ops.Tensor) and + not isinstance(train_op, ops.Operation)): + raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op) + ops.add_to_collection(constants.TRAIN_OP_KEY, train_op) + def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map): """Tags the meta graph def and adds it to the SavedModel. @@ -238,6 +257,20 @@ class SavedModelBuilder(object): for outputs_key in outputs: self._validate_tensor_info(outputs[outputs_key]) + def _add_collections( + self, assets_collection, legacy_init_op, main_op, train_op): + """Add asset and op collections to be saved.""" + # Save asset files and write them to disk, if any. + self._save_and_write_assets(assets_collection) + + if main_op is None: + # Add legacy init op to the SavedModel. + self._maybe_add_legacy_init_op(legacy_init_op) + else: + self._add_main_op(main_op) + + self._add_train_op(train_op) + def add_meta_graph(self, tags, signature_def_map=None, @@ -285,14 +318,8 @@ class SavedModelBuilder(object): # properly populated. self._validate_signature_def_map(signature_def_map) - # Save asset files and write them to disk, if any. - self._save_and_write_assets(assets_collection) - - if main_op is None: - # Add legacy init op to the SavedModel. - self._maybe_add_legacy_init_op(legacy_init_op) - else: - self._add_main_op(main_op) + # Add assets and ops + self._add_collections(assets_collection, legacy_init_op, main_op, None) # Initialize a saver to generate a sharded output for all saveables in the # current scope. @@ -351,6 +378,7 @@ class SavedModelBuilder(object): strip_default_attrs: Boolean. If `True`, default-valued attributes will be removed from the NodeDefs. For a detailed guide, see [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + """ # pylint: enable=line-too-long if self._has_saved_variables: @@ -362,8 +390,8 @@ class SavedModelBuilder(object): # properly populated. self._validate_signature_def_map(signature_def_map) - # Save asset files and write them to disk, if any. - self._save_and_write_assets(assets_collection) + # Add assets and ops + self._add_collections(assets_collection, legacy_init_op, main_op, None) # Create the variables sub-directory, if it does not exist. variables_dir = os.path.join( @@ -376,12 +404,6 @@ class SavedModelBuilder(object): compat.as_text(variables_dir), compat.as_text(constants.VARIABLES_FILENAME)) - if main_op is None: - # Add legacy init op to the SavedModel. - self._maybe_add_legacy_init_op(legacy_init_op) - else: - self._add_main_op(main_op) - # Initialize a saver to generate a sharded output for all saveables in the # current scope. saver = tf_saver.Saver( diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py index 34206c6f6d..61c6ffbd0d 100644 --- a/tensorflow/python/saved_model/constants.py +++ b/tensorflow/python/saved_model/constants.py @@ -41,6 +41,10 @@ MAIN_OP_KEY = "saved_model_main_op" tf_export("saved_model.constants.MAIN_OP_KEY").export_constant( __name__, "MAIN_OP_KEY") +# CollectionDef key for the SavedModel train op. +# Not exported while export_all_saved_models is in contrib. +TRAIN_OP_KEY = "saved_model_train_op" + # Schema version for SavedModel. SAVED_MODEL_SCHEMA_VERSION = 1 tf_export("saved_model.constants.SAVED_MODEL_SCHEMA_VERSION").export_constant( @@ -65,3 +69,5 @@ tf_export("saved_model.constants.VARIABLES_DIRECTORY").export_constant( VARIABLES_FILENAME = "variables" tf_export("saved_model.constants.VARIABLES_FILENAME").export_constant( __name__, "VARIABLES_FILENAME") + + diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 804255375e..a4d994fd43 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -734,6 +734,96 @@ class SavedModelTest(test.TestCase): builder.add_meta_graph_and_variables( sess, ["foo"], legacy_init_op=legacy_init_op) + def testTrainOp(self): + export_dir = self._get_export_dir("test_train_op") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + # Add `v1` and `v2` variables to the graph. + v1 = variables.Variable(1, name="v1") + ops.add_to_collection("v", v1) + v2 = variables.Variable(2, name="v2") + ops.add_to_collection("v", v2) + + sess.run(variables.global_variables_initializer()) + train_op = state_ops.assign_add(v1, v2) + + sess.run(train_op) + # TODO(karmel): remove explicit call when in the public method. + builder._add_train_op(train_op) + builder.add_meta_graph_and_variables(sess, ["foo"]) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo"], export_dir) + self.assertEqual(3, ops.get_collection("v")[0].eval()) + self.assertEqual(2, ops.get_collection("v")[1].eval()) + self.assertIsInstance( + ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor) + + def testTrainOpGroup(self): + export_dir = self._get_export_dir("test_train_op_group") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + # Add `v1` and `v2` variables to the graph. + v1 = variables.Variable(1, name="v1") + ops.add_to_collection("v", v1) + v2 = variables.Variable(2, name="v2") + ops.add_to_collection("v", v2) + + sess.run(variables.global_variables_initializer()) + train_op = control_flow_ops.group() + + sess.run(train_op) + # TODO(karmel): remove explicit call when in the public method. + builder._add_train_op(train_op) + builder.add_meta_graph_and_variables(sess, ["foo"]) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo"], export_dir) + self.assertEqual(1, ops.get_collection("v")[0].eval()) + self.assertEqual(2, ops.get_collection("v")[1].eval()) + self.assertIsInstance( + ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation) + + def testTrainOpAfterVariables(self): + export_dir = self._get_export_dir("test_train_op_after_variables") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + # Add `v1` and `v2` variables to the graph. + v1 = variables.Variable(1, name="v1") + ops.add_to_collection("v", v1) + v2 = variables.Variable(2, name="v2") + ops.add_to_collection("v", v2) + + sess.run(variables.global_variables_initializer()) + builder.add_meta_graph_and_variables(sess, ["pre_foo"]) + + train_op = state_ops.assign_add(v1, v2) + sess.run(train_op) + # TODO(karmel): remove explicit call when in the public method. + builder._add_train_op(train_op) + builder.add_meta_graph(["foo"]) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo"], export_dir) + self.assertIsInstance( + ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor) + + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["pre_foo"], export_dir) + self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY)) + def testMultipleAssets(self): export_dir = self._get_export_dir("test_multiple_assets") builder = saved_model_builder.SavedModelBuilder(export_dir) diff --git a/tensorflow/python/saved_model/signature_constants.py b/tensorflow/python/saved_model/signature_constants.py index 819f351291..99007a9634 100644 --- a/tensorflow/python/saved_model/signature_constants.py +++ b/tensorflow/python/saved_model/signature_constants.py @@ -94,3 +94,9 @@ tf_export("saved_model.signature_constants.REGRESS_OUTPUTS").export_constant( __name__, "REGRESS_OUTPUTS") ################################################################################ +# Train/Eval API constants. +# Not exported while export_all_saved_models is in contrib. + +SUPERVISED_TRAIN_METHOD_NAME = "tensorflow/supervised/training" + +SUPERVISED_EVAL_METHOD_NAME = "tensorflow/supervised/eval" diff --git a/tensorflow/python/saved_model/signature_def_utils.py b/tensorflow/python/saved_model/signature_def_utils.py index ea0f52f17e..27d6b70e9d 100644 --- a/tensorflow/python/saved_model/signature_def_utils.py +++ b/tensorflow/python/saved_model/signature_def_utils.py @@ -26,6 +26,8 @@ from tensorflow.python.saved_model.signature_def_utils_impl import classificatio from tensorflow.python.saved_model.signature_def_utils_impl import is_valid_signature from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def +from tensorflow.python.saved_model.signature_def_utils_impl import supervised_eval_signature_def +from tensorflow.python.saved_model.signature_def_utils_impl import supervised_train_signature_def # pylint: enable=unused-import del absolute_import diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py index d033159188..f8ad788f77 100644 --- a/tensorflow/python/saved_model/signature_def_utils_impl.py +++ b/tensorflow/python/saved_model/signature_def_utils_impl.py @@ -185,6 +185,62 @@ def predict_signature_def(inputs, outputs): return signature_def +def supervised_train_signature_def( + inputs, loss, predictions=None, metrics=None): + return _supervised_signature_def( + signature_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss, + predictions=predictions, metrics=metrics) + + +def supervised_eval_signature_def( + inputs, loss, predictions=None, metrics=None): + return _supervised_signature_def( + signature_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss, + predictions=predictions, metrics=metrics) + + +def _supervised_signature_def( + method_name, inputs, loss=None, predictions=None, + metrics=None): + """Creates a signature for training and eval data. + + This function produces signatures that describe the inputs and outputs + of a supervised process, such as training or evaluation, that + results in loss, metrics, and the like. Note that this function only requires + inputs to be not None. + + Args: + method_name: Method name of the SignatureDef as a string. + inputs: dict of string to `Tensor`. + loss: dict of string to `Tensor` representing computed loss. + predictions: dict of string to `Tensor` representing the output predictions. + metrics: dict of string to `Tensor` representing metric ops. + + Returns: + A train- or eval-flavored signature_def. + + Raises: + ValueError: If inputs or outputs is `None`. + """ + if inputs is None or not inputs: + raise ValueError('{} inputs cannot be None or empty.'.format(method_name)) + + signature_inputs = {key: utils.build_tensor_info(tensor) + for key, tensor in inputs.items()} + + signature_outputs = {} + for output_set in (loss, predictions, metrics): + if output_set is not None: + sig_out = {key: utils.build_tensor_info(tensor) + for key, tensor in output_set.items()} + signature_outputs.update(sig_out) + + signature_def = build_signature_def( + signature_inputs, signature_outputs, method_name) + + return signature_def + + @tf_export('saved_model.signature_def_utils.is_valid_signature') def is_valid_signature(signature_def): """Determine whether a SignatureDef can be served by TensorFlow Serving.""" diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py index b2bd14db8c..ebc5450633 100644 --- a/tensorflow/python/saved_model/signature_def_utils_test.py +++ b/tensorflow/python/saved_model/signature_def_utils_test.py @@ -180,6 +180,101 @@ class SignatureDefUtilsTest(test.TestCase): self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype) self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim)) + def testTrainSignatureDef(self): + self._testSupervisedSignatureDef( + signature_def_utils_impl.supervised_train_signature_def, + signature_constants.SUPERVISED_TRAIN_METHOD_NAME) + + def testEvalSignatureDef(self): + self._testSupervisedSignatureDef( + signature_def_utils_impl.supervised_eval_signature_def, + signature_constants.SUPERVISED_EVAL_METHOD_NAME) + + def _testSupervisedSignatureDef(self, fn_to_test, method_name): + inputs = { + "input-1": constant_op.constant("a", name="input-1"), + "input-2": constant_op.constant("b", name="input-2"), + } + loss = {"loss-1": constant_op.constant(0.45, name="loss-1")} + predictions = { + "classes": constant_op.constant([100], name="classes"), + } + metrics_val = constant_op.constant(100.0, name="metrics_val") + metrics = { + "metrics/value": metrics_val, + "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"), + } + + signature_def = fn_to_test(inputs, loss, predictions, metrics) + + self.assertEqual(method_name, signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(2, len(signature_def.inputs)) + input1_tensor_info_actual = (signature_def.inputs["input-1"]) + self.assertEqual("input-1:0", input1_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype) + self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim)) + input2_tensor_info_actual = (signature_def.inputs["input-2"]) + self.assertEqual("input-2:0", input2_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype) + self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim)) + + # Check outputs in signature def. + self.assertEqual(4, len(signature_def.outputs)) + self.assertEqual("loss-1:0", signature_def.outputs["loss-1"].name) + self.assertEqual(types_pb2.DT_FLOAT, signature_def.outputs["loss-1"].dtype) + + self.assertEqual("classes:0", signature_def.outputs["classes"].name) + self.assertEqual(1, len(signature_def.outputs["classes"].tensor_shape.dim)) + + self.assertEqual( + "metrics_val:0", signature_def.outputs["metrics/value"].name) + self.assertEqual( + types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype) + + self.assertEqual( + "metrics_op:0", signature_def.outputs["metrics/update_op"].name) + self.assertEqual( + types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype) + + def testTrainSignatureDefMissingInputs(self): + self._testSupervisedSignatureDefMissingInputs( + signature_def_utils_impl.supervised_train_signature_def, + signature_constants.SUPERVISED_TRAIN_METHOD_NAME) + + def testEvalSignatureDefMissingInputs(self): + self._testSupervisedSignatureDefMissingInputs( + signature_def_utils_impl.supervised_eval_signature_def, + signature_constants.SUPERVISED_EVAL_METHOD_NAME) + + def _testSupervisedSignatureDefMissingInputs(self, fn_to_test, method_name): + inputs = { + "input-1": constant_op.constant("a", name="input-1"), + "input-2": constant_op.constant("b", name="input-2"), + } + loss = {"loss-1": constant_op.constant(0.45, name="loss-1")} + predictions = { + "classes": constant_op.constant([100], name="classes"), + } + metrics_val = constant_op.constant(100, name="metrics_val") + metrics = { + "metrics/value": metrics_val, + "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"), + } + + with self.assertRaises(ValueError): + signature_def = fn_to_test( + {}, loss=loss, predictions=predictions, metrics=metrics) + + signature_def = fn_to_test(inputs, loss=loss) + self.assertEqual(method_name, signature_def.method_name) + self.assertEqual(1, len(signature_def.outputs)) + + signature_def = fn_to_test(inputs, metrics=metrics, loss=loss) + self.assertEqual(method_name, signature_def.method_name) + self.assertEqual(3, len(signature_def.outputs)) + def testGetShapeAndTypes(self): inputs = { "input-1": constant_op.constant(["a", "b"]), diff --git a/tensorflow/python/saved_model/tag_constants.py b/tensorflow/python/saved_model/tag_constants.py index 5a797da791..c82154e7b9 100644 --- a/tensorflow/python/saved_model/tag_constants.py +++ b/tensorflow/python/saved_model/tag_constants.py @@ -32,6 +32,9 @@ TRAINING = "train" tf_export("saved_model.tag_constants.TRAINING").export_constant( __name__, "TRAINING") +# Tag for the `eval` graph. Not exported while the export logic is in contrib. +EVAL = "eval" + # Tag for the `gpu` graph. GPU = "gpu" tf_export("saved_model.tag_constants.GPU").export_constant(__name__, "GPU") @@ -39,3 +42,5 @@ tf_export("saved_model.tag_constants.GPU").export_constant(__name__, "GPU") # Tag for the `tpu` graph. TPU = "tpu" tf_export("saved_model.tag_constants.TPU").export_constant(__name__, "TPU") + + -- GitLab From 037e52e20157985d3f385f8e0426cdde3f5aae2b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 May 2018 16:37:27 -0700 Subject: [PATCH 046/755] Expose read-only versions of tensors in tflite. PiperOrigin-RevId: 195491701 --- tensorflow/contrib/lite/interpreter.h | 37 ++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 1074f64263..0450e86ae7 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -201,7 +201,7 @@ class Interpreter { // Overrides execution plan. This bounds checks indices sent in. TfLiteStatus SetExecutionPlan(const std::vector& new_plan); - // Get a tensor data structure. + // Get a mutable tensor data structure. // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this // read/write access to structure TfLiteTensor* tensor(int tensor_index) { @@ -210,9 +210,14 @@ class Interpreter { return &context_.tensors[tensor_index]; } + // Get an immutable tensor data structure. + const TfLiteTensor* tensor(int tensor_index) const { + if (tensor_index >= context_.tensors_size || tensor_index < 0) + return nullptr; + return &context_.tensors[tensor_index]; + } + // Get a pointer to an operation and registration data structure if in bounds. - // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this - // read/write access to structure const std::pair* node_and_registration( int node_index) const { if (node_index >= nodes_and_registration_.size() || node_index < 0) @@ -220,7 +225,8 @@ class Interpreter { return &nodes_and_registration_[node_index]; } - // Perform a checked cast to the appropriate tensor type. + // Perform a checked cast to the appropriate tensor type (mutable pointer + // version). template T* typed_tensor(int tensor_index) { if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { @@ -231,6 +237,18 @@ class Interpreter { return nullptr; } + // Perform a checked cast to the appropriate tensor type (immutable pointer + // version). + template + const T* typed_tensor(int tensor_index) const { + if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) { + if (tensor_ptr->type == typeToTfLiteType()) { + return reinterpret_cast(tensor_ptr->data.raw); + } + } + return nullptr; + } + // Return a pointer into the data of a given input tensor. The given index // must be between 0 and inputs().size(). template @@ -238,13 +256,20 @@ class Interpreter { return typed_tensor(inputs_[index]); } - // Return a pointer into the data of a given output tensor. The given index - // must be between 0 and outputs().size(). + // Return a mutable pointer into the data of a given output tensor. The given + // index must be between 0 and outputs().size(). template T* typed_output_tensor(int index) { return typed_tensor(outputs_[index]); } + // Return an immutable pointer into the data of a given output tensor. The + // given index must be between 0 and outputs().size(). + template + const T* typed_output_tensor(int index) const { + return typed_tensor(outputs_[index]); + } + // Change the dimensionality of a given tensor. Note, this is only acceptable // for tensor indices that are inputs. // Returns status of failure or success. -- GitLab From fa1d92f70adf52d9258384e8528f9a7203a141dd Mon Sep 17 00:00:00 2001 From: Bjarke Hammersholt Roune Date: Fri, 4 May 2018 16:51:06 -0700 Subject: [PATCH 047/755] Add infrastructure for a backend-specific configuration for each op. This is intentionally not exposed in ComputationBuilder and is not intended for use or to be set at all prior to the last backend-specific part of compilation. PiperOrigin-RevId: 195493500 --- tensorflow/compiler/xla/service/hlo.proto | 3 + .../compiler/xla/service/hlo_computation.cc | 52 ++++----- .../compiler/xla/service/hlo_computation.h | 20 ++-- .../compiler/xla/service/hlo_graph_dumper.cc | 43 +++++--- .../compiler/xla/service/hlo_graph_dumper.h | 5 +- .../compiler/xla/service/hlo_instruction.cc | 100 +++++------------- .../compiler/xla/service/hlo_instruction.h | 59 +++++++++-- tensorflow/compiler/xla/service/hlo_module.cc | 12 +++ tensorflow/compiler/xla/service/hlo_module.h | 19 ++++ .../compiler/xla/service/hlo_verifier.cc | 71 ++++++------- tensorflow/compiler/xla/statusor.h | 11 +- tensorflow/compiler/xla/statusor_test.cc | 8 ++ .../compiler/xla/tools/parser/hlo_parser.cc | 10 +- .../xla/tools/parser/hlo_parser_test.cc | 22 +++- 14 files changed, 259 insertions(+), 176 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index aa6860880b..1f7c1cffd3 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -147,6 +147,9 @@ message HloInstructionProto { repeated int64 called_computation_ids = 38; xla.OpSharding sharding = 40; + + // Backend configuration for the instruction. Has backend-specific meaning. + string backend_config = 43; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 594413e88f..17e43c3cb8 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -347,6 +347,11 @@ std::list HloComputation::MakeEmbeddedComputationsList() // To avoid special handling of this computation, cast away const of // 'this'. 'this' is immediately removed from the post order after // construction. + // + // TODO(b/78350259): This violates const-correctness, since while the original + // computation is not returned, we still retrieve non-const computations from + // a const one. Consider also avoiding const for HloComputation, or review XLA + // for const-correctness of non-HloInstruction* types like this. ComputeComputationPostOrder(const_cast(this), &visited, &post_order); @@ -723,18 +728,25 @@ Status HloComputation::Accept( return this->Accept(&visitor); } -std::unique_ptr HloComputation::Clone(const string& suffix, - HloModule* module) { +std::unique_ptr HloComputation::Clone( + const string& suffix, HloModule* module, + HloInstruction::CloneMap* clone_map) { return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - module, suffix); + module, clone_map, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - HloModule* module, const string& suffix) { + HloModule* module, HloInstruction::CloneMap* clone_map, + const string& suffix) { + HloInstruction::CloneMap local_clone_map; + if (clone_map == nullptr) { + clone_map = &local_clone_map; + } + // Look up instr in the replacements map, and return either the replacement, // or instr, if the replacement isn't present. // @@ -756,24 +768,19 @@ std::unique_ptr HloComputation::CloneWithReplacements( } } - std::unordered_map clone_map; std::vector> instructions; std::unique_ptr new_instr = nullptr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { auto replaced_operand = replace(operand); - // If replaced_operand is null, that means 'replacements' asked us not to - // include operand in the new computation. But we can't do that, because - // operand is used by instr. CHECK_NE(replaced_operand, nullptr) - << "replacements map tried to eliminate a used instruction " - << operand->ToString() << ", used by " << instr->ToString(); - new_operands.push_back(FindOrDie(clone_map, replaced_operand)); + << "Replacements map specifies to leave out " << operand->ToString() + << ", but it is used by " << instr->ToString() << "."; + new_operands.push_back(FindOrDie(*clone_map, replaced_operand)); } - new_instr = - instr->CloneWithNewOperands(instr->shape(), new_operands, module); - InsertOrDie(&clone_map, instr, new_instr.get()); + new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands, + module, clone_map); instructions.push_back(std::move(new_instr)); } Builder builder(name() + "." + suffix); @@ -781,27 +788,24 @@ std::unique_ptr HloComputation::CloneWithReplacements( builder.AddInstruction(std::move(instr)); } auto result = builder.Build( - /*root_instruction=*/FindOrDie(clone_map, replace(root_instruction()))); + /*root_instruction=*/FindOrDie(*clone_map, replace(root_instruction()))); // Clone control dependencies. for (auto instr : postorder) { - HloInstruction* new_instr = FindOrDie(clone_map, instr); + HloInstruction* new_instr = FindOrDie(*clone_map, instr); for (auto successor : instr->control_successors()) { auto replaced_successor = replace(successor); - - // successor may not be in clone_map, because it might have been - // removed by the replacements map. - if (replaced_successor == nullptr) { - continue; - } + CHECK_NE(replaced_successor, nullptr) + << "Replacements map specifies to leave out " << successor->ToString() + << ", but it is control-depended-on by " << instr->ToString() << "."; TF_CHECK_OK(new_instr->AddControlDependencyTo( - FindOrDie(clone_map, replaced_successor))); + FindOrDie(*clone_map, replaced_successor))); } } // We cloned the elements of 'replacements', so they're all going to be - // destroyed. HloInstructions need to be detached from their operands before + // destroyed. HloInstructions need to be detached from their operands before // they're destroyed, otherwise they stick around in the operands' users lists // and cause use-after-frees. for (auto& kv : replacements) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 9d3f6e9a2c..9898355625 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -291,11 +291,17 @@ class HloComputation { const std::function& visitor_func) const; // Returns a deep copy of this computation including all instructions. - // If the module pointer is not nullptr, it will be the module where - // the cloned computations will be added to (in order to support deep - // cloning). - std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr); + // + // If the module pointer is not nullptr, then the cloned computations will be + // added to this module in order to support deep cloning. Otherwise the module + // of the computation is used. + // + // If clone_map is not nullptr, then each original instruction that is cloned + // will be inserted and map to its clone. clone_map should not already contain + // any of the instructions to clone. + std::unique_ptr Clone( + const string& suffix = "clone", HloModule* module = nullptr, + HloInstruction::CloneMap* clone_map = nullptr); // Like Clone(), but if an instruction is present in replacement_map, we use // the map's value to replace that instruction in the cloned computation. @@ -305,7 +311,9 @@ class HloComputation { std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - HloModule* module = nullptr, const string& suffix = "clone"); + HloModule* module = nullptr, + HloInstruction::CloneMap* clone_map = nullptr, + const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index bb4db89f0a..794f1b4682 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -322,11 +322,13 @@ class HloDotDumper { public: HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, const DebugOptions& debug_options, bool show_metadata, - const HloExecutionProfile* profile, NodeFilter filter) + bool show_backend_config, const HloExecutionProfile* profile, + NodeFilter filter) : computation_(computation), label_(label.ToString()), debug_options_(debug_options), show_metadata_(show_metadata), + show_backend_config_(show_backend_config), profile_(profile), filter_(std::move(filter)) {} @@ -365,6 +367,7 @@ class HloDotDumper { string GetInstructionNodeShape(const HloInstruction* instr); string GetInstructionNodeLabel(const HloInstruction* instr); string GetInstructionNodeMetadata(const HloInstruction* instr); + string GetInstructionNodeBackendConfig(const HloInstruction* instr); string GetInstructionNodeExtraInfo(const HloInstruction* instr); string GetInstructionNodeInlinedOperands(const HloInstruction* instr); void AddInstructionIncomingEdges(const HloInstruction* instr); @@ -393,6 +396,7 @@ class HloDotDumper { const string label_; // overall name for the graph const DebugOptions& debug_options_; const bool show_metadata_; + const bool show_backend_config_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -611,6 +615,10 @@ tooltip = " "; if (!extra_info.empty()) { StrAppend(&subcomp_label, "
", extra_info); } + string node_backend_config = GetInstructionNodeBackendConfig(parent_instr); + if (!node_backend_config.empty()) { + StrAppend(&subcomp_label, "
", node_backend_config); + } bool highlight = filter_.Highlight(parent_instr); const char* fillcolor; @@ -765,6 +773,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string node_shape = GetInstructionNodeShape(instr); string node_label = GetInstructionNodeLabel(instr); string node_metadata = GetInstructionNodeMetadata(instr); + string node_backend_config = GetInstructionNodeBackendConfig(instr); string extra_info = GetInstructionNodeExtraInfo(instr); string inlined_constants = GetInstructionNodeInlinedOperands(instr); string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); @@ -782,8 +791,8 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } // Build the text that will be displayed inside the node. string node_body = node_label; - for (const string& s : - {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) { + for (const string& s : {trivial_subcomputation, node_metadata, + node_backend_config, extra_info, inlined_constants}) { if (!s.empty()) { StrAppend(&node_body, "
", s); } @@ -1078,6 +1087,15 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { return Join(lines, "
"); } +string HloDotDumper::GetInstructionNodeBackendConfig( + const HloInstruction* instr) { + if (!show_backend_config_ || instr->backend_config().empty()) { + return ""; + } + + return StrCat("backend_config=\"", instr->backend_config(), "\""); +} + string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { std::vector lines; @@ -1404,7 +1422,7 @@ string ExportGraph(const string& graph, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, - bool show_metadata) { + bool show_metadata, bool show_backend_config) { GraphRendererInterface::GraphKind graph_kind; string graph; if (debug_options.xla_hlo_dump_as_graphdef()) { @@ -1414,9 +1432,10 @@ string DumpGraph(const HloComputation& computation, const string& label, &graph)); graph_kind = GraphRendererInterface::TF_GRAPHDEF; } else { - graph = HloDotDumper(&computation, label, debug_options, show_metadata, - hlo_execution_profile, NodeFilter()) - .Dump(); + graph = + HloDotDumper(&computation, label, debug_options, show_metadata, + show_backend_config, hlo_execution_profile, NodeFilter()) + .Dump(); graph_kind = GraphRendererInterface::DOT_GRAPH; } @@ -1427,15 +1446,15 @@ string DumpGraph(const HloComputation& computation, const string& label, } string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata) { + bool show_metadata, bool show_backend_config) { auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); NodeFilter filter = MakeNodeFilter(&node, radius); - string graph = - HloDotDumper(node.parent(), label, debug_options, show_metadata, - /*profile=*/nullptr, filter) - .Dump(); + string graph = HloDotDumper(node.parent(), label, debug_options, + show_metadata, show_backend_config, + /*profile=*/nullptr, filter) + .Dump(); return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 2704aae1e3..fc8e1468ac 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -56,7 +56,7 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile = nullptr, - bool show_metadata = false); + bool show_metadata = false, bool show_backend_config = false); // Like DumpGraph, but renders only nodes "near" the given node in the graph. // @@ -64,7 +64,8 @@ string DumpGraph(const HloComputation& computation, const string& label, // (roughly) corresponds to the max distance a node may be from the primary node // before it's omitted from the graph. string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata = false); + bool show_metadata = false, + bool show_backend_config = false); // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a714d0e114..2c733726a6 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -109,6 +109,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->name_ = proto.name(); instruction->metadata_ = proto.metadata(); + instruction->set_backend_config(proto.backend_config()); if (proto.has_literal()) { TF_ASSIGN_OR_RETURN(instruction->literal_, Literal::CreateFromProto(proto.literal())); @@ -1231,12 +1232,15 @@ bool HloInstruction::HasSideEffect() const { std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, - HloModule* module) const { + HloModule* module, CloneMap* clone_map) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { VLOG(3) << " %" << new_operand->name(); } + if (module == nullptr) { + module = GetModule(); + } std::unique_ptr clone; @@ -1342,7 +1346,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( break; case HloOpcode::kFft: CHECK_EQ(new_operands.size(), 1); - return CreateFft(shape, new_operands[0], fft_type_, fft_length_); + clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_); + break; case HloOpcode::kCrossReplicaSum: clone = CreateCrossReplicaSum(shape, new_operands); break; @@ -1415,9 +1420,15 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kConstant: clone = CreateConstant(literal_->CloneToUnique()); break; - case HloOpcode::kFusion: - clone = CloneFusionWithNewOperands(shape, new_operands, module); + case HloOpcode::kFusion: { + CHECK_NE(module, nullptr); + auto new_fused_computation = module->AddEmbeddedComputation( + fused_instructions_computation()->Clone("clone", module, clone_map)); + clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(), + /*operands=*/new_operands, + /*fusion_computation=*/new_fused_computation); break; + } case HloOpcode::kParameter: clone = CreateParameter(parameter_number_, shape, name_); break; @@ -1481,15 +1492,19 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); + clone->set_backend_config(backend_config()); + if (clone_map != nullptr) { + InsertOrDie(clone_map, this, clone.get()); + } return clone; } HloInstruction::~HloInstruction() {} -std::unique_ptr HloInstruction::Clone(const string& suffix, - HloModule* module) const { +std::unique_ptr HloInstruction::Clone( + const string& suffix, HloModule* module, CloneMap* clone_map) const { std::unique_ptr clone = - CloneWithNewOperands(shape_, operands_, module); + CloneWithNewOperands(shape_, operands_, module, clone_map); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1526,71 +1541,6 @@ std::unique_ptr HloInstruction::Clone(const string& suffix, return clone; } -std::unique_ptr HloInstruction::CloneFusionWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module) const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(parent() != nullptr); - - auto new_instruction = - WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - // Add the operands to our new fusion instruction. - for (HloInstruction* new_operand : operands) { - new_instruction->AppendOperand(new_operand); - } - // Clone all the fused instructions for the new fusion instruction. - HloInstructionMap old_to_new; - std::list> new_fused_instructions; - // Create the list of fused parameters by mapping through the cloned, - // fused instructions. - for (HloInstruction* old_fused_parameter : - fused_instructions_computation()->parameter_instructions()) { - new_fused_instructions.push_back( - old_fused_parameter->Clone("clone", module)); - HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); - InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter); - } - for (auto old_fused_instruction : - fused_instructions_computation()->MakeInstructionPostOrder()) { - if (old_fused_instruction->opcode() == HloOpcode::kParameter) { - FindOrDie(old_to_new, old_fused_instruction); - continue; - } - std::vector new_operands; - for (int64 operand_idx = 0; - operand_idx < old_fused_instruction->operand_count(); ++operand_idx) { - HloInstruction* old_operand = - old_fused_instruction->mutable_operand(operand_idx); - new_operands.push_back(FindOrDie(old_to_new, old_operand)); - } - new_fused_instructions.push_back( - old_fused_instruction->CloneWithNewOperands( - old_fused_instruction->shape(), new_operands, module)); - HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); - new_fused_instruction->set_parent(parent_); - InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); - } - new_instruction->fusion_kind_ = fusion_kind_; - auto computation_builder = HloComputation::Builder( - fused_instructions_computation()->name() + ".clone", - new_instruction.get()); - // We iterated the fusion instructions in reverse post order which means - // that we must reverse our new list of fusion instructions. - for (auto new_fused_instruction_iter = new_fused_instructions.rbegin(); - new_fused_instruction_iter != new_fused_instructions.rend(); - ++new_fused_instruction_iter) { - computation_builder.AddInstruction(std::move(*new_fused_instruction_iter)); - } - if (module == nullptr) { - module = GetModule(); - } - auto fused_root_ = fused_expression_root(); - new_instruction->called_computations_.push_back( - CHECK_NOTNULL(module)->AddEmbeddedComputation( - computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); - return new_instruction; -} - std::pair HloInstruction::LatestNonGteAncestorAndIndex() const { const HloInstruction* hlo = this; @@ -2172,6 +2122,9 @@ string HloInstruction::ToString(const HloPrintOptions& options) const { !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } + if (options.print_backend_config() && !backend_config().empty()) { + StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\""); + } return result; } @@ -2357,6 +2310,7 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); } + return extra; } @@ -2386,6 +2340,7 @@ HloInstructionProto HloInstruction::ToProto() const { } *proto.mutable_metadata() = metadata_; + proto.set_backend_config(backend_config()); if (literal_ != nullptr) { *proto.mutable_literal() = literal_->ToProto(); } @@ -2971,6 +2926,7 @@ Status HloInstruction::AcceptOrdered( continue; } + // TODO(b/78350259): Eliminate const laundering. HloInstruction* instruction = const_cast(const_instruction); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index a5e9aecb9e..19c8c11453 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -66,6 +66,7 @@ class HloPrintOptions { : print_large_constants_(false), print_subcomputation_references_(true), print_metadata_(true), + print_backend_config_(true), compact_operands_(false), print_operand_shape_(true), print_program_shape_(true), @@ -77,6 +78,7 @@ class HloPrintOptions { .set_print_large_constants(true) .set_print_subcomputation_references(true) .set_print_metadata(false) + .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_program_shape(false) .set_print_percent(false); @@ -99,12 +101,18 @@ class HloPrintOptions { return *this; } - // If true, metatdata will be printed. + // If true, metadata will be printed. HloPrintOptions& set_print_metadata(bool value) { print_metadata_ = value; return *this; } + // If true, backend_config will be printed. + HloPrintOptions& set_print_backend_config(bool value) { + print_backend_config_ = value; + return *this; + } + // If true, operands' shapes will be printed. HloPrintOptions& set_print_operand_shape(bool value) { print_operand_shape_ = value; @@ -141,6 +149,7 @@ class HloPrintOptions { return print_subcomputation_references_; } bool print_metadata() const { return print_metadata_; } + bool print_backend_config() const { return print_metadata_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } @@ -151,6 +160,7 @@ class HloPrintOptions { bool print_large_constants_; bool print_subcomputation_references_; bool print_metadata_; + bool print_backend_config_; bool compact_operands_; bool print_operand_shape_; bool print_program_shape_; @@ -643,6 +653,8 @@ class HloInstruction { // Detaches an instruction from its operands. That is, remove the instruction // from each operand's user set. This should only be called prior to // deallocating the instruction. + // + // TODO(b/78305363): Make this automatic when deleting an instruction. void DetachFromOperands(); // Performs a postorder DFS visit using this node as the root. If @@ -1157,23 +1169,30 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kRng RandomDistribution random_distribution() const; + // See documentation for Clone(). + using CloneMap = std::unordered_map; + // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of - // the instruction to form the name of the cloned instruction. If the module - // pointer is not nullptr, it will be the module where the cloned computations - // will be added to (in order to support deep cloning). Ignores the control - // predecessors and successors of this HLO instruction. + // the instruction to form the name of the cloned instruction. Ignores the + // control predecessors and successors of this HLO instruction. + // + // If the module pointer is not nullptr, then any cloned computations will be + // added to this module in order to support deep cloning. Otherwise the module + // of the instruction is used. + // + // If clone_map is not nullptr, then each original instruction that is cloned + // will be inserted and map to its clone. clone_map should not already contain + // any of the instructions to clone. std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr) const; + HloModule* module = nullptr, + CloneMap* clone_map = nullptr) const; - // Clones the HLO instruction as above but with new shape and operands. If - // the module pointer is not nullptr, it will be the module where the cloned - // computations will be added to (in order to support deep cloning). Ignores - // the control predecessors and successors of this HLO instruction. + // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr) const; + HloModule* module = nullptr, CloneMap* clone_map = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector& called_computations() const { @@ -1262,6 +1281,19 @@ class HloInstruction { // if no id has been assigned yet). int unique_id() const { return unique_id_; } + // Returns the backend-specific configuration for how a backend should compile + // this HLO. The meaning of the field is backend specific. Not for use before + // or during general HLO optimization, since HLO optimizations do not preserve + // this field and they cannot interpret it due to its meaning being backend + // specific. + // + // TODO(b/78194644): Introduce structured configuration format as per + // go/xla-heuristics. + const string& backend_config() const { return backend_config_; } + void set_backend_config(string backend_config) { + backend_config_ = std::move(backend_config); + } + // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } @@ -1283,6 +1315,7 @@ class HloInstruction { // Get/Set the number of partitions per outer dimension (in order, starting // with outer-most dimension first). Currently used by the parallel cpu // backend to partition HLOs into parallel tasks. + // // TODO(b/62783254) Replace these methods with a more general way to // annotate HLOs with backend-specific information. const std::vector& outer_dimension_partitions() const { @@ -1510,6 +1543,10 @@ class HloInstruction { // The string representation of the infeed configuration. string infeed_config_; + // The backend-specific configuration for how a backend should compile this + // HLO. See the documentation on backend_config(). + string backend_config_; + // String identifier for instruction. string name_; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index c7a7192867..5308fb5848 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -46,6 +46,18 @@ HloModule::HloModule(const string& name, const HloModuleConfig& config) config_(config), unique_id_(next_unique_module_id_++) {} +StatusOr HloModule::LaunderConstInstructionFromModule( + const HloInstruction* hlo) { + if (hlo == nullptr) { + return nullptr; + } + + TF_RET_CHECK(hlo->GetModule() == this); + + // TODO(b/78350259): Eliminate const laundering. + return const_cast(hlo); +} + HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, bool uniquify_names) { diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index f9674df812..1604a72612 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -217,6 +217,25 @@ class HloModule { // the lifetime of this process. int unique_id() const { return unique_id_; } + // Returns a non-const version of the passed-in const HloInstruction*. This is + // safe on the argument that if you have a non-const module, then you can + // access all instructions in the module as non-const. + // + // Returns an error if the passed-in instruction is not from this module, + // except that it is allowed to pass in a null pointer. + // + // TODO(b/78350259): Eliminate const laundering. The argument above is not + // reliable since at any time someone could add or discover a way for a + // non-const module to transitively contain a const HloInstruction. The + // reliable way to do this would be to create a const laundering map from a + // module, mapping each encountered HloInstruction to its non-const version + // and then look up each instruction in need of laundering in that map, but + // this is much more expensive and complicated. This returns a Status instead + // of doing a CHECK-failure in part to make it strongly apparent that this is + // something that can fail. + StatusOr LaunderConstInstructionFromModule( + const HloInstruction* hlo); + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 8a30cbf9cd..096ebb7946 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -116,7 +116,7 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // produces no HLO value in the graph. if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { - return InvalidArgument( + return InternalError( "Expected outfeed to have shape compatible with operand's shape %s, " "actual shape is %s:\n%s", ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), @@ -200,7 +200,7 @@ Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { transpose->operand(0)->shape(), transpose->dimensions())); } -Status ShapeVerifier::HandleParameter(HloInstruction*) { +Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { return tensorflow::Status::OK(); } @@ -410,7 +410,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { if (fp_type == PRIMITIVE_TYPE_INVALID) { fp_type = subshape.element_type(); } else if (fp_type != subshape.element_type()) { - return FailedPrecondition( + return InternalError( "Seen floating point types of different precisions in " "%s, but mixed precision is disallowed.", instruction->ToString().c_str()); @@ -490,7 +490,7 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } } if (!compatible) { - return InvalidArgument( + return InternalError( "Expected instruction to have shape compatible with %s, actual " "shape is %s:\n%s", ShapeUtil::HumanString(inferred_shape).c_str(), @@ -541,7 +541,7 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2) { if (instr1->channel_id() != instr2->channel_id()) { - return FailedPrecondition( + return InternalError( "Expected to have the same channel id, actual channel ids are: %s " "(%lld), %s (%lld)", instr1->ToString().c_str(), instr1->channel_id(), @@ -571,22 +571,22 @@ string ComputationsToString( Status VerifyHloStructure(HloModule* module) { for (const HloComputation* computation : module->computations()) { if (computation->parent() == nullptr) { - return FailedPrecondition("Computation %s has a null parent pointer", - computation->name().c_str()); + return InternalError("Computation %s has a null parent pointer", + computation->name().c_str()); } if (computation->parent() != module) { - return FailedPrecondition( + return InternalError( "Computation %s parent() does not point to parent module", computation->name().c_str()); } for (const HloInstruction* instruction : computation->instructions()) { if (instruction->parent() == nullptr) { - return FailedPrecondition("Instruction %s has a null parent pointer", - instruction->name().c_str()); + return InternalError("Instruction %s has a null parent pointer", + instruction->name().c_str()); } if (instruction->parent() != computation) { - return FailedPrecondition( + return InternalError( "Instruction %s parent() does not point to parent computation", instruction->name().c_str()); } @@ -602,7 +602,7 @@ Status VerifyHloStructure(HloModule* module) { for (int i = 0; i < instruction->operand_count(); ++i) { const HloInstruction* operand = instruction->operand(i); if (operand->parent() != instruction->parent()) { - return FailedPrecondition( + return InternalError( "Operand %d (%s) of instruction %s is in a different " "computation: %s vs %s", i, operand->name().c_str(), instruction->name().c_str(), @@ -619,7 +619,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { // The parent fusion instruction of the fusion computation must be 'fusion'. HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { - return FailedPrecondition( + return InternalError( "Instruction of fused computation does not match expected instruction " "%s.", fusion->ToString().c_str()); @@ -635,37 +635,37 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto* instruction : fused_computation->instructions()) { if (fused_root == instruction) { if (root_owned) { - return FailedPrecondition("Root appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Root appears more than once in %s.", + fusion->ToString().c_str()); } root_owned = true; } for (int i = 0; i < fused_parameters.size(); ++i) { if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { - return FailedPrecondition("Parameter appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Parameter appears more than once in %s.", + fusion->ToString().c_str()); } parameter_owned[i] = true; } } } if (!root_owned) { - return FailedPrecondition("Root not found in computation of %s.", - fusion->ToString().c_str()); + return InternalError("Root not found in computation of %s.", + fusion->ToString().c_str()); } // Make sure all the parameter_owned entries are set for (int i = 0; i < parameter_owned.size(); i++) { if (!parameter_owned[i]) { - return FailedPrecondition("Parameter %d not found in computation of %s.", - i, fusion->ToString().c_str()); + return InternalError("Parameter %d not found in computation of %s.", i, + fusion->ToString().c_str()); } } // Fused root must have no users. if (fused_root->user_count() != 0) { - return FailedPrecondition("Root of %s may not have users.", - fusion->ToString().c_str()); + return InternalError("Root of %s may not have users.", + fusion->ToString().c_str()); } // All uses of fused instructions must be in the fusion computation, and every @@ -674,13 +674,13 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { fusion->fused_instructions_computation()->instructions()) { if (instruction != fused_root) { if (instruction->user_count() == 0) { - return FailedPrecondition( - "Non-root instruction %s in %s must have users.", - instruction->ToString().c_str(), fusion->ToString().c_str()); + return InternalError("Non-root instruction %s in %s must have users.", + instruction->ToString().c_str(), + fusion->ToString().c_str()); } for (auto& user : instruction->users()) { if (fused_computation != user->parent()) { - return FailedPrecondition( + return InternalError( "Non-root instruction %s in %s may not have external users.", instruction->ToString().c_str(), fusion->ToString().c_str()); } @@ -695,34 +695,33 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); if (param_no < 0) { - return FailedPrecondition( - "Unexpected negative parameter number %lld in %s.", param_no, - fusion->ToString().c_str()); + return InternalError("Unexpected negative parameter number %lld in %s.", + param_no, fusion->ToString().c_str()); } if (param_no >= fused_parameters.size()) { - return FailedPrecondition( + return InternalError( "Unexpected parameter number %lld in %s: higher then number of " "parameters %lu.", param_no, fusion->ToString().c_str(), fused_parameters.size()); } if (parameter_numbers[param_no]) { - return FailedPrecondition( + return InternalError( "Did not expect parameter number %lld more than once in %s.", param_no, fusion->ToString().c_str()); } parameter_numbers[param_no] = true; if (!ShapeUtil::Compatible(fused_param->shape(), fusion->operand(param_no)->shape())) { - return FailedPrecondition( + return InternalError( "Shape mismatch between parameter number %lld and its operand in %s.", param_no, fusion->ToString().c_str()); } } - // Make sure all the parameter_numbers entries were seen + // Make sure all the parameter_numbers entries were seen. for (int i = 0; i < parameter_numbers.size(); i++) { if (!parameter_numbers[i]) { - return FailedPrecondition("Did not see parameter number %d in %s.", i, - fusion->ToString().c_str()); + return InternalError("Did not see parameter number %d in %s.", i, + fusion->ToString().c_str()); } } diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h index cccbce5fc8..0e1387c939 100644 --- a/tensorflow/compiler/xla/statusor.h +++ b/tensorflow/compiler/xla/statusor.h @@ -13,13 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// StatusOr is the union of a Status object and a T -// object. StatusOr models the concept of an object that is either a -// usable value, or an error Status explaining why such a value is -// not present. To this end, StatusOr does not allow its Status -// value to be Status::OK. Furthermore, the value of a StatusOr -// must not be null. This is enforced by a debug check in most cases, -// but even when it is not, clients must not set the value to null. +// StatusOr is the union of a Status object and a T object. StatusOr models +// the concept of an object that is either a value, or an error Status +// explaining why such a value is not present. To this end, StatusOr does not +// allow its Status value to be Status::OK. // // The primary use-case for StatusOr is as the return value of a // function which may fail. diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index f9d25945bc..7d76370e85 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -75,6 +75,14 @@ TEST(StatusOr, ElementType) { static_assert(std::is_same::element_type, char>(), ""); } +TEST(StatusOr, NullPointerStatusOr) { + // As a very special case, null-plain-pointer StatusOr used to be an + // error. Test that it no longer is. + StatusOr null_status(nullptr); + EXPECT_TRUE(null_status.ok()); + EXPECT_EQ(null_status.ValueOrDie(), nullptr); +} + TEST(StatusOr, TestNoDefaultConstructorInitialization) { // Explicitly initialize it with an error code. StatusOr statusor(tensorflow::errors::Cancelled("")); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 40dc0730ce..156a06c596 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -440,6 +440,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional metadata; attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; + optional backend_config; + attrs["backend_config"] = {/*required=*/false, AttrTy::kString, + &backend_config}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -1094,8 +1098,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, instruction->set_name(name); - // Add common attrs (sharding, control predecessors) to the instruction, if - // they were seen. + // Add shared attributes like metadata to the instruction, if they were seen. if (sharding) { instruction->set_sharding( HloSharding::FromProto(sharding.value()).ValueOrDie()); @@ -1112,6 +1115,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (metadata) { instruction->set_metadata(*metadata); } + if (backend_config) { + instruction->set_backend_config(std::move(*backend_config)); + } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index d38d8907a6..e100d8cda1 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -65,7 +65,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { R"(HloModule constant_pred_module ENTRY %constant_pred () -> pred[] { - ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} + ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar" } )" @@ -81,13 +81,14 @@ ENTRY %constant_s32 () -> s32[] { )" }, -// f32 constant, but the value is not a decimal +// f32 constant, but the value is not a decimal and there is a backend +// configuration { "ConstantF32", R"(HloModule ConstantF32_module ENTRY %ConstantF32.v4 () -> f32[] { - ROOT %constant = f32[] constant(42) + ROOT %constant = f32[] constant(42), backend_config="this is a configuration" } )" @@ -1013,6 +1014,19 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { // but the constant names will not be exactly the same. } +TEST_F(HloParserTest, ConfigurationField) { + const string original = R"(HloModule AModule +ENTRY %configuration_test() -> s32[] { + %constant = s32[] constant(42), backend_config="foo bar" +})"; + auto result = Parse(original); + TF_ASSERT_OK(result.status()); + EXPECT_EQ("foo bar", result.ValueOrDie() + ->entry_computation() + ->root_instruction() + ->backend_config()); +} + TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { const string original = R"(HloModule some_2_module @@ -1092,7 +1106,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} } )"; -- GitLab From bf228e1435da0032d2529de93661b742ee8a7048 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Fri, 4 May 2018 17:03:52 -0700 Subject: [PATCH 048/755] [tf.data] Adding `num_parallel_calls` to `map_and_batch`. PiperOrigin-RevId: 195495206 --- .../kernel_tests/batch_dataset_op_test.py | 44 +- .../contrib/data/python/ops/batching.py | 47 +- .../base_api/api_def_MapAndBatchDataset.pbtxt | 35 +- .../api_def_MapAndBatchDatasetV2.pbtxt | 54 ++ .../kernels/data/map_and_batch_dataset_op.cc | 773 +++++++++--------- tensorflow/core/ops/dataset_ops.cc | 13 + 6 files changed, 538 insertions(+), 428 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_MapAndBatchDatasetV2.pbtxt diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 6588fd04ac..2568b899d7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -427,7 +427,9 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def _testMapAndBatchDatasetHelper(self, num_parallel_batches=1): + def _testMapAndBatchDatasetHelper(self, + num_parallel_calls=None, + num_parallel_batches=None): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). @@ -446,6 +448,7 @@ class BatchDatasetTest(test.TestCase): batching.map_and_batch( map_func=_map_fn, batch_size=batch_size, + num_parallel_calls=num_parallel_calls, num_parallel_batches=num_parallel_batches)) .make_initializable_iterator()) init_op = iterator.initializer @@ -497,12 +500,18 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testMapAndBatchDataset(self): + def testMapAndBatch(self): return self._testMapAndBatchDatasetHelper() - def testMapAndBatchDatasetWithParallelBatching(self): + def testMapAndBatchWithParallelBatches(self): return self._testMapAndBatchDatasetHelper(num_parallel_batches=10) + def testMapAndBatchWithSequentialCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=1) + + def testMapAndBatchWithParallelCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=2) + def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False): iterator = ( dataset_ops.Dataset.range(10).apply( @@ -682,7 +691,7 @@ class UnbatchDatasetSerializationTest( class MapAndBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): - def testSerializationCore(self): + def testNumParallelBatches(self): range_size = 11 num_repeats = 2 batch_size = 5 @@ -709,6 +718,33 @@ class MapAndBatchDatasetSerializationTest( self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), num_outputs_drop_remainder) + def testNumParallelCalls(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_calls = 7 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_calls=num_parallel_calls, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + class PaddedBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 42ec2b0b01..b9393de4e9 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -466,14 +466,14 @@ def assert_element_shape(expected_shapes): class _MapAndBatchDataset(dataset_ops.MapDataset): """A `Dataset` that maps a function over a batch of elements.""" - def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches, + def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, drop_remainder): """See `Dataset.map()` for details.""" super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) self._batch_size_t = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - self._num_parallel_batches_t = ops.convert_to_tensor( - num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches") + self._num_parallel_calls_t = ops.convert_to_tensor( + num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") self._drop_remainder_t = ops.convert_to_tensor( drop_remainder, dtype=dtypes.bool, name="drop_remainder") @@ -483,12 +483,12 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def _as_variant_tensor(self): # pylint: disable=protected-access input_resource = self._input_dataset._as_variant_tensor() - return gen_dataset_ops.map_and_batch_dataset( + return gen_dataset_ops.map_and_batch_dataset_v2( input_resource, self._map_func.captured_inputs, f=self._map_func, batch_size=self._batch_size_t, - num_parallel_batches=self._num_parallel_batches_t, + num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), @@ -511,8 +511,9 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def map_and_batch(map_func, batch_size, - num_parallel_batches=1, - drop_remainder=False): + num_parallel_batches=None, + drop_remainder=False, + num_parallel_calls=None): """Fused implementation of `map` and `batch`. Maps `map_func` across `batch_size` consecutive elements of this dataset @@ -528,21 +529,37 @@ def map_and_batch(map_func, nested structure of tensors. batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of consecutive elements of this dataset to combine in a single batch. - num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the - number of batches to create in parallel. On one hand, higher values can - help mitigate the effect of stragglers. On the other hand, higher values - can increase contention if CPU is scarce. - drop_remainder: A `tf.bool` scalar `tf.Tensor`, representing whether the - last batch should be dropped in case its size is smaller than desired; - the default behavior is not to drop the smaller batch. + num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, + representing the number of batches to create in parallel. On one hand, + higher values can help mitigate the effect of stragglers. On the other + hand, higher values can increase contention if CPU is scarce. + drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing + whether the last batch should be dropped in case its size is smaller than + desired; the default behavior is not to drop the smaller batch. + num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, + representing the number of elements to process in parallel. If not + specified, `batch_size * num_parallel_batches` elements will be + processed in parallel. Returns: A `Dataset` transformation function, which can be passed to @{tf.data.Dataset.apply}. + + Raises: + ValueError: If both `num_parallel_batches` and `num_parallel_calls` are + specified. """ + if num_parallel_batches is None and num_parallel_calls is None: + num_parallel_calls = batch_size + elif num_parallel_batches is not None and num_parallel_calls is None: + num_parallel_calls = batch_size * num_parallel_batches + elif num_parallel_batches is not None and num_parallel_calls is not None: + raise ValueError("The `num_parallel_batches` and `num_parallel_calls` " + "arguments are mutually exclusive.") + def _apply_fn(dataset): return _MapAndBatchDataset(dataset, map_func, batch_size, - num_parallel_batches, drop_remainder) + num_parallel_calls, drop_remainder) return _apply_fn diff --git a/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt index bf544703de..e230c51edf 100644 --- a/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_MapAndBatchDataset.pbtxt @@ -1,5 +1,19 @@ op { graph_op_name: "MapAndBatchDataset" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: <